pull from main

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:38:35 +05:30
10 changed files with 158 additions and 670 deletions

View File

@@ -1,211 +0,0 @@
"""
Collection of data classes for streamlining and rigidly structuring
the input and output parameters of functions
"""
import datetime
from dataclasses import dataclass
from typing import List
from sortedcontainers import SortedDict
import av
@dataclass
class TitleSummaryInput:
"""
Data class for the input to generate title and summaries.
The outcome will be used to send query to the LLM for processing.
"""
input_text = str
transcribed_time = float
prompt = str
data = dict
def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text
self.transcribed_time = transcribed_time
self.prompt = f"""
### Human:
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{self.input_text}
### Assistant:
"""
self.data = {"prompt": self.prompt}
self.headers = {"Content-Type": "application/json"}
@dataclass
class IncrementalResult:
"""
Data class for the result of generating one title and summaries.
Defines how a single "topic" looks like.
"""
title = str
description = str
transcript = str
timestamp = str
def __init__(self, title, desc, transcript, timestamp):
self.title = title
self.description = desc
self.transcript = transcript
self.timestamp = timestamp
@dataclass
class TitleSummaryOutput:
"""
Data class for the result of all generated titles and summaries.
The result will be sent back to the client
"""
cmd = str
topics = List[IncrementalResult]
def __init__(self, inc_responses):
self.topics = inc_responses
self.cmd = "UPDATE_TOPICS"
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {"cmd": self.cmd, "topics": self.topics}
@dataclass
class ParseLLMResult:
"""
Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client.
"""
title = str
description = str
transcript = str
timestamp = str
def __init__(self, param: TitleSummaryInput, output: dict):
self.title = output["title"]
self.transcript = param.input_text
self.description = output.pop("summary")
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self) -> dict:
"""
Return the result dict after parsing the response from LLM
:return:
"""
return {
"title": self.title,
"description": self.description,
"transcript": self.transcript,
"timestamp": self.timestamp,
}
@dataclass
class TranscriptionInput:
"""
Data class to define the input to the transcription function
AudioFrames -> input
"""
frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames):
self.frames = frames
@dataclass
class TranscriptionOutput:
"""
Dataclass to define the result of the transcription function.
The result will be sent back to the client
"""
cmd = str
result_text = str
def __init__(self, result_text):
self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {"cmd": self.cmd, "text": self.result_text}
@dataclass
class FinalSummaryResult:
"""
Dataclass to define the result of the final summary function.
The result will be sent back to the client.
"""
cmd = str
final_summary = str
duration = str
def __init__(self, final_summary, time):
self.duration = str(datetime.timedelta(seconds=round(time)))
self.final_summary = final_summary
self.cmd = "DISPLAY_FINAL_SUMMARY"
def get_result(self) -> dict:
"""
Return the result dict for displaying the final summary
:return:
"""
return {
"cmd": self.cmd,
"duration": self.duration,
"summary": self.final_summary,
}
class BlackListedMessages:
"""
Class to hold the blacklisted messages. These messages should be filtered
out and not sent back to the client as part of the transcription.
"""
messages = [
" Thank you.",
" See you next time!",
" Thank you for watching!",
" Bye!",
" And that's what I'm talking about.",
]
@dataclass
class TranscriptionContext:
transcription_text: str
last_transcribed_time: float
incremental_responses: List[IncrementalResult]
sorted_transcripts: dict
data_channel: None # FIXME
logger: None
status: str
def __init__(self, logger):
self.transcription_text = ""
self.last_transcribed_time = 0.0
self.incremental_responses = []
self.data_channel = None
self.sorted_transcripts = SortedDict()
self.status = "idle"
self.logger = logger

View File

@@ -1,5 +1,6 @@
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -0,0 +1,39 @@
from reflector.processors.base import Processor
import av
from pathlib import Path
class AudioFileWriterProcessor(Processor):
"""
Write audio frames to a file.
"""
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
if isinstance(path, str):
path = Path(path)
self.path = path
self.out_container = None
self.out_stream = None
async def _push(self, data: av.AudioFrame):
if not self.out_container:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
self.out_stream = self.out_container.add_stream(
"pcm_s16le", rate=data.sample_rate
)
for packet in self.out_stream.encode(data):
self.out_container.mux(packet)
await self.emit(data)
async def _flush(self):
if self.out_container:
for packet in self.out_stream.encode(None):
self.out_container.mux(packet)
self.out_container.close()
self.out_container = None
self.out_stream = None

View File

@@ -3,7 +3,6 @@ from reflector.processors.types import AudioFile
from time import monotonic_ns
from uuid import uuid4
import io
import wave
import av
@@ -28,12 +27,16 @@ class AudioMergeProcessor(Processor):
# create audio file
uu = uuid4().hex
fd = io.BytesIO()
with wave.open(fd, "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(sample_width)
wf.setframerate(sample_rate)
for frame in data:
wf.writeframes(frame.to_ndarray().tobytes())
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
for frame in data:
for packet in out_stream.encode(frame):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
out_container.close()
fd.seek(0)
# emit audio file
audiofile = AudioFile(
@@ -44,4 +47,5 @@ class AudioMergeProcessor(Processor):
sample_width=sample_width,
timestamp=data[0].pts * data[0].time_base,
)
await self.emit(audiofile)

View File

@@ -1,381 +0,0 @@
import argparse
import asyncio
import datetime
import json
import os
import wave
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import NoReturn, Union
import aiohttp_cors
import av
import requests
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel
from reflector.models import (
BlackListedMessages,
FinalSummaryResult,
ParseLLMResult,
TitleSummaryInput,
TitleSummaryOutput,
TranscriptionInput,
TranscriptionOutput,
TranscriptionContext,
)
from reflector.logger import logger
from reflector.utils.run_utils import run_in_executor
from reflector.settings import settings
# WebRTC components
pcs = set()
relay = MediaRelay()
executor = ThreadPoolExecutor()
# Transcription model
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
# LLM
LLM_URL = settings.LLM_URL
if not LLM_URL:
assert settings.LLM_BACKEND == "oobabooga"
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")
def parse_llm_output(
param: TitleSummaryInput, response: requests.Response
) -> Union[None, ParseLLMResult]:
"""
Function to parse the LLM response
:param param:
:param response:
:return:
"""
try:
output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output)
except Exception:
logger.exception("Exception while parsing LLM output")
return None
def get_title_and_summary(
ctx: TranscriptionContext, param: TitleSummaryInput
) -> Union[None, TitleSummaryOutput]:
"""
From the input provided (transcript), query the LLM to generate
topics and summaries
:param param:
:return:
"""
logger.info("Generating title and summary")
# TODO : Handle unexpected output formats from the model
try:
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
output = parse_llm_output(param, response)
if output:
result = output.get_result()
ctx.incremental_responses.append(result)
return TitleSummaryOutput(ctx.incremental_responses)
except Exception:
logger.exception("Exception while generating title and summary")
return None
def channel_send(channel, message: str) -> NoReturn:
"""
Send text messages via the data channel
:param channel:
:param message:
:return:
"""
if channel:
channel.send(message)
def channel_send_increment(
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
) -> NoReturn:
"""
Send the incremental topics and summaries via the data channel
:param channel:
:param param:
:return:
"""
if channel and param:
message = param.get_result()
channel.send(json.dumps(message))
def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
"""
Send the transcription result via the data channel
:param channel:
:return:
"""
if not ctx.data_channel:
return
try:
least_time = next(iter(ctx.sorted_transcripts))
message = ctx.sorted_transcripts[least_time].get_result()
if message:
del ctx.sorted_transcripts[least_time]
if message["text"] not in BlackListedMessages.messages:
ctx.data_channel.send(json.dumps(message))
# Due to exceptions if one of the earlier batches can't return
# a transcript, we don't want to be stuck waiting for the result
# With the threshold size of 3, we pop the first(lost) element
else:
if len(ctx.sorted_transcripts) >= 3:
del ctx.sorted_transcripts[least_time]
except Exception:
logger.exception("Exception while sending transcript")
def get_transcription(
ctx: TranscriptionContext, input_frames: TranscriptionInput
) -> Union[None, TranscriptionOutput]:
"""
From the collected audio frames create transcription by inferring from
the chosen transcription model
:param input_frames:
:return:
"""
ctx.logger.info("Transcribing..")
ctx.sorted_transcripts[input_frames.frames[0].time] = None
# TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error
# Passing ndarray (type casted with float) does not give any
# transcription. Refer issue,
# https://github.com/guillaumekln/faster-whisper/issues/369
audio_file = "test" + str(datetime.datetime.now())
wf = wave.open(audio_file, "wb")
wf.setnchannels(settings.AUDIO_CHANNELS)
wf.setframerate(settings.AUDIO_SAMPLING_RATE)
wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH)
for frame in input_frames.frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
result_text = ""
try:
segments, _ = model.transcribe(
audio_file,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.remove(audio_file)
segments = list(segments)
result_text = ""
duration = 0.0
for segment in segments:
result_text += segment.text
start_time = segment.start
end_time = segment.end
if not segment.start:
start_time = 0.0
if not segment.end:
end_time = 5.5
duration += end_time - start_time
ctx.last_transcribed_time += duration
ctx.transcription_text += result_text
except Exception:
logger.exception("Exception while transcribing")
result = TranscriptionOutput(result_text)
ctx.sorted_transcripts[input_frames.frames[0].time] = result
return result
def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
"""
Collate the incremental summaries generated so far and return as the final
summary
:return:
"""
final_summary = ""
# Collate inc summaries
for topic in ctx.incremental_responses:
final_summary += topic["description"]
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
with open(
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
) as file:
file.write(json.dumps(ctx.incremental_responses))
return response
class AudioStreamTrack(MediaStreamTrack):
"""
An audio stream track.
"""
kind = "audio"
def __init__(self, ctx: TranscriptionContext, track):
super().__init__()
self.ctx = ctx
self.track = track
self.audio_buffer = av.AudioFifo()
async def recv(self) -> av.audio.frame.AudioFrame:
ctx = self.ctx
frame = await self.track.recv()
self.audio_buffer.write(frame)
if local_frames := self.audio_buffer.read_many(
settings.AUDIO_BUFFER_SIZE, partial=False
):
whisper_result = run_in_executor(
get_transcription,
ctx,
TranscriptionInput(local_frames),
executor=executor,
)
whisper_result.add_done_callback(
lambda f: channel_send_transcript(ctx) if f.result() else None
)
if len(ctx.transcription_text) > 25:
llm_input_text = ctx.transcription_text
ctx.transcription_text = ""
param = TitleSummaryInput(
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
)
llm_result = run_in_executor(
get_title_and_summary, ctx, param, executor=executor
)
llm_result.add_done_callback(
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
if f.result()
else None
)
return frame
async def offer(request: requests.Request) -> web.Response:
"""
Establish the WebRTC connection with the client
:param request:
:return:
"""
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
# client identification
peername = request.transport.get_extra_info("peername")
if peername is not None:
clientid = f"{peername[0]}:{peername[1]}"
else:
clientid = uuid.uuid4()
# create a context for the whole rtc transaction
# add a customised logger to the context
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
# handle RTC peer connection
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("datachannel")
def on_datachannel(channel) -> NoReturn:
ctx.data_channel = channel
ctx.logger = ctx.logger.bind(channel=channel.label)
ctx.logger.info("Channel created by remote party")
@channel.on("message")
def on_message(message: str) -> NoReturn:
ctx.logger.info(f"Message: {message}")
if json.loads(message)["cmd"] == "STOP":
# Placeholder final summary
response = get_final_summary_response()
channel_send_increment(channel, response)
# To-do Add code to stop connection from server side here
# But have to handshake with client once
if isinstance(message, str) and message.startswith("ping"):
channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange() -> NoReturn:
ctx.logger.info(f"Connection state changed: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track) -> NoReturn:
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(application: web.Application) -> NoReturn:
"""
On shutdown, the coroutines that shutdown client connections are
executed
:param application:
:return:
"""
coroutines = [pc.close() for pc in pcs]
await asyncio.gather(*coroutines)
pcs.clear()
def create_app() -> web.Application:
"""
Create the web application
"""
app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown)
return app
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=1250, help="Server port (def: 1250)"
)
args = parser.parse_args()
app = create_app()
web.run_app(app, access_log=None, host=args.host, port=args.port)

View File

@@ -9,6 +9,9 @@ class Settings(BaseSettings):
# Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
# local data directory (audio for no)
DATA_DIR: str = "./data"
# Whisper
WHISPER_MODEL_SIZE: str = "tiny"
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"

View File

@@ -2,17 +2,18 @@ import asyncio
from fastapi import Request, APIRouter
from reflector.events import subscribers_shutdown
from pydantic import BaseModel
from reflector.models import TranscriptionContext
from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps
from enum import StrEnum
from pathlib import Path
import av
from reflector.processors import (
Pipeline,
AudioChunkerProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor,
@@ -25,6 +26,15 @@ sessions = []
router = APIRouter()
class TranscriptionContext(object):
def __init__(self, logger):
self.logger = logger
self.pipeline = None
self.data_channel = None
self.status = "idle"
self.topics = []
class AudioStreamTrack(MediaStreamTrack):
"""
An audio stream track.
@@ -64,7 +74,11 @@ class PipelineEvent(StrEnum):
async def rtc_offer_base(
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -73,7 +87,6 @@ async def rtc_offer_base(
peername = request.client
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
ctx.topics = []
async def update_status(status: str):
changed = ctx.status != status
@@ -151,14 +164,18 @@ async def rtc_offer_base(
# create a context for the whole rtc transaction
# add a customised logger to the context
ctx.pipeline = Pipeline(
processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
)
]
ctx.pipeline = Pipeline(*processors)
# FIXME: warmup is not working well yet
# await ctx.pipeline.warmup()

View File

@@ -5,14 +5,20 @@ from fastapi import (
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from uuid import uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from reflector.logger import logger
from reflector.db import database, transcripts
from reflector.settings import settings
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
router = APIRouter()
@@ -81,6 +87,44 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_mp3(self):
fn = self.audio_mp3_filename
if fn.exists():
return
logger.info(f"Converting audio to mp3: {self.audio_filename}")
inp = av.open(self.audio_filename.as_posix(), "r")
# create temporary file for mp3
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
out = av.open(tmp.name, "w")
stream = out.add_stream("mp3")
for frame in inp.decode(audio=0):
frame.pts = None
for packet in stream.encode(frame):
out.mux(packet)
for packet in stream.encode(None):
out.mux(packet)
out.close()
# move temporary file to final location
Path(tmp.name).rename(fn)
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
class TranscriptController:
async def get_all(self) -> list[Transcript]:
@@ -112,6 +156,10 @@ class TranscriptController:
setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None:
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str):
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# TODO: Implement audio generation
return HTTPException(status_code=500, detail="Not implemented")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(transcript.audio_filename, media_type="audio/wav")
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_mp3)
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@@ -371,4 +435,5 @@ async def transcript_record_webrtc(
request,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_filename,
)

View File

@@ -1,63 +0,0 @@
import pytest
from unittest.mock import patch
@pytest.mark.asyncio
async def test_basic_rtc_server(aiohttp_server, event_loop):
# goal is to start the server, and send rtc audio to it
# validate the events received
import argparse
import json
from pathlib import Path
from reflector.server import create_app
from reflector.stream_client import StreamClient
from reflector.models import TitleSummaryOutput
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
# customize settings to have a mock LLM server
with patch("reflector.server.get_title_and_summary") as mock_llm:
# any response from mock_llm will be test topic
mock_llm.return_value = TitleSummaryOutput(["topic_test"])
# create the server
app = create_app()
server = await aiohttp_server(app)
url = f"http://{server.host}:{server.port}/offer"
# create signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
# create the client
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
# we just want the first transcription
# and topic update messages
marks = {
"SHOW_TRANSCRIPTION": False,
"UPDATE_TOPICS": False,
}
async for rawmsg in client.get_reader():
msg = json.loads(rawmsg)
cmd = msg["cmd"]
if cmd == "SHOW_TRANSCRIPTION":
assert "text" in msg
assert "want to share my incredible experience" in msg["text"]
elif cmd == "UPDATE_TOPICS":
assert "topics" in msg
assert "topic_test" in msg["topics"]
marks[cmd] = True
# break if we have all the events we need
if all(marks.values()):
break
# stop the server
await server.close()
await client.stop()

View File

@@ -71,11 +71,15 @@ async def dummy_llm():
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
from reflector.settings import settings
settings.DATA_DIR = Path(tmpdir)
# start server
host = "127.0.0.1"
port = 1255
@@ -189,3 +193,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"