diff --git a/server/reflector/models.py b/server/reflector/models.py deleted file mode 100644 index d1aaaa1e..00000000 --- a/server/reflector/models.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -Collection of data classes for streamlining and rigidly structuring -the input and output parameters of functions -""" - -import datetime -from dataclasses import dataclass -from typing import List -from sortedcontainers import SortedDict - -import av - - -@dataclass -class TitleSummaryInput: - """ - Data class for the input to generate title and summaries. - The outcome will be used to send query to the LLM for processing. - """ - - input_text = str - transcribed_time = float - prompt = str - data = dict - - def __init__(self, transcribed_time, input_text=""): - self.input_text = input_text - self.transcribed_time = transcribed_time - self.prompt = f""" - ### Human: - Create a JSON object as response.The JSON object must have 2 fields: - i) title and ii) summary.For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. - - {self.input_text} - - ### Assistant: - """ - self.data = {"prompt": self.prompt} - self.headers = {"Content-Type": "application/json"} - - -@dataclass -class IncrementalResult: - """ - Data class for the result of generating one title and summaries. - Defines how a single "topic" looks like. - """ - - title = str - description = str - transcript = str - timestamp = str - - def __init__(self, title, desc, transcript, timestamp): - self.title = title - self.description = desc - self.transcript = transcript - self.timestamp = timestamp - - -@dataclass -class TitleSummaryOutput: - """ - Data class for the result of all generated titles and summaries. - The result will be sent back to the client - """ - - cmd = str - topics = List[IncrementalResult] - - def __init__(self, inc_responses): - self.topics = inc_responses - self.cmd = "UPDATE_TOPICS" - - def get_result(self) -> dict: - """ - Return the result dict for displaying the transcription - :return: - """ - return {"cmd": self.cmd, "topics": self.topics} - - -@dataclass -class ParseLLMResult: - """ - Data class to parse the result returned by the LLM while generating title - and summaries. The result will be sent back to the client. - """ - - title = str - description = str - transcript = str - timestamp = str - - def __init__(self, param: TitleSummaryInput, output: dict): - self.title = output["title"] - self.transcript = param.input_text - self.description = output.pop("summary") - self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) - - def get_result(self) -> dict: - """ - Return the result dict after parsing the response from LLM - :return: - """ - return { - "title": self.title, - "description": self.description, - "transcript": self.transcript, - "timestamp": self.timestamp, - } - - -@dataclass -class TranscriptionInput: - """ - Data class to define the input to the transcription function - AudioFrames -> input - """ - - frames = List[av.audio.frame.AudioFrame] - - def __init__(self, frames): - self.frames = frames - - -@dataclass -class TranscriptionOutput: - """ - Dataclass to define the result of the transcription function. - The result will be sent back to the client - """ - - cmd = str - result_text = str - - def __init__(self, result_text): - self.cmd = "SHOW_TRANSCRIPTION" - self.result_text = result_text - - def get_result(self) -> dict: - """ - Return the result dict for displaying the transcription - :return: - """ - return {"cmd": self.cmd, "text": self.result_text} - - -@dataclass -class FinalSummaryResult: - """ - Dataclass to define the result of the final summary function. - The result will be sent back to the client. - """ - - cmd = str - final_summary = str - duration = str - - def __init__(self, final_summary, time): - self.duration = str(datetime.timedelta(seconds=round(time))) - self.final_summary = final_summary - self.cmd = "DISPLAY_FINAL_SUMMARY" - - def get_result(self) -> dict: - """ - Return the result dict for displaying the final summary - :return: - """ - return { - "cmd": self.cmd, - "duration": self.duration, - "summary": self.final_summary, - } - - -class BlackListedMessages: - """ - Class to hold the blacklisted messages. These messages should be filtered - out and not sent back to the client as part of the transcription. - """ - - messages = [ - " Thank you.", - " See you next time!", - " Thank you for watching!", - " Bye!", - " And that's what I'm talking about.", - ] - - -@dataclass -class TranscriptionContext: - transcription_text: str - last_transcribed_time: float - incremental_responses: List[IncrementalResult] - sorted_transcripts: dict - data_channel: None # FIXME - logger: None - status: str - - def __init__(self, logger): - self.transcription_text = "" - self.last_transcribed_time = 0.0 - self.incremental_responses = [] - self.data_channel = None - self.sorted_transcripts = SortedDict() - self.status = "idle" - self.logger = logger diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index da890513..8a926f30 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,5 +1,6 @@ from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401 from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401 +from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py new file mode 100644 index 00000000..d67db65e --- /dev/null +++ b/server/reflector/processors/audio_file_writer.py @@ -0,0 +1,39 @@ +from reflector.processors.base import Processor +import av +from pathlib import Path + + +class AudioFileWriterProcessor(Processor): + """ + Write audio frames to a file. + """ + + INPUT_TYPE = av.AudioFrame + OUTPUT_TYPE = av.AudioFrame + + def __init__(self, path: Path | str): + super().__init__() + if isinstance(path, str): + path = Path(path) + self.path = path + self.out_container = None + self.out_stream = None + + async def _push(self, data: av.AudioFrame): + if not self.out_container: + self.path.parent.mkdir(parents=True, exist_ok=True) + self.out_container = av.open(self.path.as_posix(), "w", format="wav") + self.out_stream = self.out_container.add_stream( + "pcm_s16le", rate=data.sample_rate + ) + for packet in self.out_stream.encode(data): + self.out_container.mux(packet) + await self.emit(data) + + async def _flush(self): + if self.out_container: + for packet in self.out_stream.encode(None): + self.out_container.mux(packet) + self.out_container.close() + self.out_container = None + self.out_stream = None diff --git a/server/reflector/processors/audio_merge.py b/server/reflector/processors/audio_merge.py index 37734a53..34c1741e 100644 --- a/server/reflector/processors/audio_merge.py +++ b/server/reflector/processors/audio_merge.py @@ -3,7 +3,6 @@ from reflector.processors.types import AudioFile from time import monotonic_ns from uuid import uuid4 import io -import wave import av @@ -28,12 +27,16 @@ class AudioMergeProcessor(Processor): # create audio file uu = uuid4().hex fd = io.BytesIO() - with wave.open(fd, "wb") as wf: - wf.setnchannels(channels) - wf.setsampwidth(sample_width) - wf.setframerate(sample_rate) - for frame in data: - wf.writeframes(frame.to_ndarray().tobytes()) + + out_container = av.open(fd, "w", format="wav") + out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate) + for frame in data: + for packet in out_stream.encode(frame): + out_container.mux(packet) + for packet in out_stream.encode(None): + out_container.mux(packet) + out_container.close() + fd.seek(0) # emit audio file audiofile = AudioFile( @@ -44,4 +47,5 @@ class AudioMergeProcessor(Processor): sample_width=sample_width, timestamp=data[0].pts * data[0].time_base, ) + await self.emit(audiofile) diff --git a/server/reflector/server.py b/server/reflector/server.py deleted file mode 100644 index 3b09efe4..00000000 --- a/server/reflector/server.py +++ /dev/null @@ -1,381 +0,0 @@ -import argparse -import asyncio -import datetime -import json -import os -import wave -import uuid -from concurrent.futures import ThreadPoolExecutor -from typing import NoReturn, Union - -import aiohttp_cors -import av -import requests -from aiohttp import web -from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription -from aiortc.contrib.media import MediaRelay -from faster_whisper import WhisperModel - -from reflector.models import ( - BlackListedMessages, - FinalSummaryResult, - ParseLLMResult, - TitleSummaryInput, - TitleSummaryOutput, - TranscriptionInput, - TranscriptionOutput, - TranscriptionContext, -) -from reflector.logger import logger -from reflector.utils.run_utils import run_in_executor -from reflector.settings import settings - -# WebRTC components -pcs = set() -relay = MediaRelay() -executor = ThreadPoolExecutor() - -# Transcription model -model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12) - -# LLM -LLM_URL = settings.LLM_URL -if not LLM_URL: - assert settings.LLM_BACKEND == "oobabooga" - LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate" -logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}") - - -def parse_llm_output( - param: TitleSummaryInput, response: requests.Response -) -> Union[None, ParseLLMResult]: - """ - Function to parse the LLM response - :param param: - :param response: - :return: - """ - try: - output = json.loads(response.json()["results"][0]["text"]) - return ParseLLMResult(param, output) - except Exception: - logger.exception("Exception while parsing LLM output") - return None - - -def get_title_and_summary( - ctx: TranscriptionContext, param: TitleSummaryInput -) -> Union[None, TitleSummaryOutput]: - """ - From the input provided (transcript), query the LLM to generate - topics and summaries - :param param: - :return: - """ - logger.info("Generating title and summary") - - # TODO : Handle unexpected output formats from the model - try: - response = requests.post(LLM_URL, headers=param.headers, json=param.data) - output = parse_llm_output(param, response) - if output: - result = output.get_result() - ctx.incremental_responses.append(result) - return TitleSummaryOutput(ctx.incremental_responses) - except Exception: - logger.exception("Exception while generating title and summary") - return None - - -def channel_send(channel, message: str) -> NoReturn: - """ - Send text messages via the data channel - :param channel: - :param message: - :return: - """ - if channel: - channel.send(message) - - -def channel_send_increment( - channel, param: Union[FinalSummaryResult, TitleSummaryOutput] -) -> NoReturn: - """ - Send the incremental topics and summaries via the data channel - :param channel: - :param param: - :return: - """ - if channel and param: - message = param.get_result() - channel.send(json.dumps(message)) - - -def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn: - """ - Send the transcription result via the data channel - :param channel: - :return: - """ - if not ctx.data_channel: - return - try: - least_time = next(iter(ctx.sorted_transcripts)) - message = ctx.sorted_transcripts[least_time].get_result() - if message: - del ctx.sorted_transcripts[least_time] - if message["text"] not in BlackListedMessages.messages: - ctx.data_channel.send(json.dumps(message)) - # Due to exceptions if one of the earlier batches can't return - # a transcript, we don't want to be stuck waiting for the result - # With the threshold size of 3, we pop the first(lost) element - else: - if len(ctx.sorted_transcripts) >= 3: - del ctx.sorted_transcripts[least_time] - except Exception: - logger.exception("Exception while sending transcript") - - -def get_transcription( - ctx: TranscriptionContext, input_frames: TranscriptionInput -) -> Union[None, TranscriptionOutput]: - """ - From the collected audio frames create transcription by inferring from - the chosen transcription model - :param input_frames: - :return: - """ - ctx.logger.info("Transcribing..") - ctx.sorted_transcripts[input_frames.frames[0].time] = None - - # TODO: Find cleaner way, watch "no transcription" issue below - # Passing IO objects instead of temporary files throws an error - # Passing ndarray (type casted with float) does not give any - # transcription. Refer issue, - # https://github.com/guillaumekln/faster-whisper/issues/369 - audio_file = "test" + str(datetime.datetime.now()) - wf = wave.open(audio_file, "wb") - wf.setnchannels(settings.AUDIO_CHANNELS) - wf.setframerate(settings.AUDIO_SAMPLING_RATE) - wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH) - - for frame in input_frames.frames: - wf.writeframes(b"".join(frame.to_ndarray())) - wf.close() - - result_text = "" - - try: - segments, _ = model.transcribe( - audio_file, - language="en", - beam_size=5, - vad_filter=True, - vad_parameters={"min_silence_duration_ms": 500}, - ) - os.remove(audio_file) - segments = list(segments) - result_text = "" - duration = 0.0 - for segment in segments: - result_text += segment.text - start_time = segment.start - end_time = segment.end - if not segment.start: - start_time = 0.0 - if not segment.end: - end_time = 5.5 - duration += end_time - start_time - - ctx.last_transcribed_time += duration - ctx.transcription_text += result_text - - except Exception: - logger.exception("Exception while transcribing") - - result = TranscriptionOutput(result_text) - ctx.sorted_transcripts[input_frames.frames[0].time] = result - return result - - -def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult: - """ - Collate the incremental summaries generated so far and return as the final - summary - :return: - """ - final_summary = "" - - # Collate inc summaries - for topic in ctx.incremental_responses: - final_summary += topic["description"] - - response = FinalSummaryResult(final_summary, ctx.last_transcribed_time) - - with open( - "./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8" - ) as file: - file.write(json.dumps(ctx.incremental_responses)) - - return response - - -class AudioStreamTrack(MediaStreamTrack): - """ - An audio stream track. - """ - - kind = "audio" - - def __init__(self, ctx: TranscriptionContext, track): - super().__init__() - self.ctx = ctx - self.track = track - self.audio_buffer = av.AudioFifo() - - async def recv(self) -> av.audio.frame.AudioFrame: - ctx = self.ctx - frame = await self.track.recv() - self.audio_buffer.write(frame) - - if local_frames := self.audio_buffer.read_many( - settings.AUDIO_BUFFER_SIZE, partial=False - ): - whisper_result = run_in_executor( - get_transcription, - ctx, - TranscriptionInput(local_frames), - executor=executor, - ) - whisper_result.add_done_callback( - lambda f: channel_send_transcript(ctx) if f.result() else None - ) - - if len(ctx.transcription_text) > 25: - llm_input_text = ctx.transcription_text - ctx.transcription_text = "" - param = TitleSummaryInput( - input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time - ) - llm_result = run_in_executor( - get_title_and_summary, ctx, param, executor=executor - ) - llm_result.add_done_callback( - lambda f: channel_send_increment(ctx.data_channel, llm_result.result()) - if f.result() - else None - ) - return frame - - -async def offer(request: requests.Request) -> web.Response: - """ - Establish the WebRTC connection with the client - :param request: - :return: - """ - params = await request.json() - offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) - - # client identification - peername = request.transport.get_extra_info("peername") - if peername is not None: - clientid = f"{peername[0]}:{peername[1]}" - else: - clientid = uuid.uuid4() - - # create a context for the whole rtc transaction - # add a customised logger to the context - ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - - # handle RTC peer connection - pc = RTCPeerConnection() - pcs.add(pc) - - @pc.on("datachannel") - def on_datachannel(channel) -> NoReturn: - ctx.data_channel = channel - ctx.logger = ctx.logger.bind(channel=channel.label) - ctx.logger.info("Channel created by remote party") - - @channel.on("message") - def on_message(message: str) -> NoReturn: - ctx.logger.info(f"Message: {message}") - if json.loads(message)["cmd"] == "STOP": - # Placeholder final summary - response = get_final_summary_response() - channel_send_increment(channel, response) - # To-do Add code to stop connection from server side here - # But have to handshake with client once - - if isinstance(message, str) and message.startswith("ping"): - channel_send(channel, "pong" + message[4:]) - - @pc.on("connectionstatechange") - async def on_connectionstatechange() -> NoReturn: - ctx.logger.info(f"Connection state changed: {pc.connectionState}") - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - - @pc.on("track") - def on_track(track) -> NoReturn: - ctx.logger.info(f"Track {track.kind} received") - pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track))) - - await pc.setRemoteDescription(offer) - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), - ) - - -async def on_shutdown(application: web.Application) -> NoReturn: - """ - On shutdown, the coroutines that shutdown client connections are - executed - :param application: - :return: - """ - coroutines = [pc.close() for pc in pcs] - await asyncio.gather(*coroutines) - pcs.clear() - - -def create_app() -> web.Application: - """ - Create the web application - """ - app = web.Application() - cors = aiohttp_cors.setup( - app, - defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, expose_headers="*", allow_headers="*" - ) - }, - ) - - offer_resource = cors.add(app.router.add_resource("/offer")) - cors.add(offer_resource.add_route("POST", offer)) - app.on_shutdown.append(on_shutdown) - return app - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") - parser.add_argument( - "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=1250, help="Server port (def: 1250)" - ) - args = parser.parse_args() - app = create_app() - web.run_app(app, access_log=None, host=args.host, port=args.port) diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 81f817da..957acb0b 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -9,6 +9,9 @@ class Settings(BaseSettings): # Database DATABASE_URL: str = "sqlite:///./reflector.sqlite3" + # local data directory (audio for no) + DATA_DIR: str = "./data" + # Whisper WHISPER_MODEL_SIZE: str = "tiny" WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny" diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index aef00580..f28eb021 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -2,17 +2,18 @@ import asyncio from fastapi import Request, APIRouter from reflector.events import subscribers_shutdown from pydantic import BaseModel -from reflector.models import TranscriptionContext from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from json import loads, dumps from enum import StrEnum +from pathlib import Path import av from reflector.processors import ( Pipeline, AudioChunkerProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, + AudioFileWriterProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, TranscriptFinalSummaryProcessor, @@ -25,6 +26,15 @@ sessions = [] router = APIRouter() +class TranscriptionContext(object): + def __init__(self, logger): + self.logger = logger + self.pipeline = None + self.data_channel = None + self.status = "idle" + self.topics = [] + + class AudioStreamTrack(MediaStreamTrack): """ An audio stream track. @@ -64,7 +74,11 @@ class PipelineEvent(StrEnum): async def rtc_offer_base( - params: RtcOffer, request: Request, event_callback=None, event_callback_args=None + params: RtcOffer, + request: Request, + event_callback=None, + event_callback_args=None, + audio_filename: Path | None = None, ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -73,7 +87,6 @@ async def rtc_offer_base( peername = request.client clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - ctx.topics = [] async def update_status(status: str): changed = ctx.status != status @@ -151,14 +164,18 @@ async def rtc_offer_base( # create a context for the whole rtc transaction # add a customised logger to the context - ctx.pipeline = Pipeline( + processors = [] + if audio_filename is not None: + processors += [AudioFileWriterProcessor(path=audio_filename)] + processors += [ AudioChunkerProcessor(), AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), TranscriptLinerProcessor(), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), - ) + ] + ctx.pipeline = Pipeline(*processors) # FIXME: warmup is not working well yet # await ctx.pipeline.warmup() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f2a8425e..6f952938 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -5,14 +5,20 @@ from fastapi import ( WebSocket, WebSocketDisconnect, ) +from fastapi.responses import FileResponse +from starlette.concurrency import run_in_threadpool from pydantic import BaseModel, Field from uuid import uuid4 from datetime import datetime from fastapi_pagination import Page, paginate from reflector.logger import logger from reflector.db import database, transcripts +from reflector.settings import settings from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent from typing import Optional +from pathlib import Path +from tempfile import NamedTemporaryFile +import av router = APIRouter() @@ -81,6 +87,44 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] + def convert_audio_to_mp3(self): + fn = self.audio_mp3_filename + if fn.exists(): + return + + logger.info(f"Converting audio to mp3: {self.audio_filename}") + inp = av.open(self.audio_filename.as_posix(), "r") + + # create temporary file for mp3 + with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: + out = av.open(tmp.name, "w") + stream = out.add_stream("mp3") + for frame in inp.decode(audio=0): + frame.pts = None + for packet in stream.encode(frame): + out.mux(packet) + for packet in stream.encode(None): + out.mux(packet) + out.close() + + # move temporary file to final location + Path(tmp.name).rename(fn) + + def unlink(self): + self.data_path.unlink(missing_ok=True) + + @property + def data_path(self): + return Path(settings.DATA_DIR) / self.id + + @property + def audio_filename(self): + return self.data_path / "audio.wav" + + @property + def audio_mp3_filename(self): + return self.data_path / "audio.mp3" + class TranscriptController: async def get_all(self) -> list[Transcript]: @@ -112,6 +156,10 @@ class TranscriptController: setattr(transcript, key, value) async def remove_by_id(self, transcript_id: str) -> None: + transcript = await self.get_by_id(transcript_id) + if not transcript: + return + transcript.unlink() query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) @@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str): if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - # TODO: Implement audio generation - return HTTPException(status_code=500, detail="Not implemented") + if not transcript.audio_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + return FileResponse(transcript.audio_filename, media_type="audio/wav") + + +@router.get("/transcripts/{transcript_id}/audio/mp3") +async def transcript_get_audio_mp3(transcript_id: str): + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + if not transcript.audio_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + await run_in_threadpool(transcript.convert_audio_to_mp3) + + return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3") @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) @@ -371,4 +435,5 @@ async def transcript_record_webrtc( request, event_callback=handle_rtc_event, event_callback_args=transcript_id, + audio_filename=transcript.audio_filename, ) diff --git a/server/tests/test_basic_rtc.py b/server/tests/test_basic_rtc.py deleted file mode 100644 index 93f33648..00000000 --- a/server/tests/test_basic_rtc.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -from unittest.mock import patch - - -@pytest.mark.asyncio -async def test_basic_rtc_server(aiohttp_server, event_loop): - # goal is to start the server, and send rtc audio to it - # validate the events received - import argparse - import json - from pathlib import Path - from reflector.server import create_app - from reflector.stream_client import StreamClient - from reflector.models import TitleSummaryOutput - from aiortc.contrib.signaling import add_signaling_arguments, create_signaling - - # customize settings to have a mock LLM server - with patch("reflector.server.get_title_and_summary") as mock_llm: - # any response from mock_llm will be test topic - mock_llm.return_value = TitleSummaryOutput(["topic_test"]) - - # create the server - app = create_app() - server = await aiohttp_server(app) - url = f"http://{server.host}:{server.port}/offer" - - # create signaling - parser = argparse.ArgumentParser() - add_signaling_arguments(parser) - args = parser.parse_args(["-s", "tcp-socket"]) - signaling = create_signaling(args) - - # create the client - path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" - client = StreamClient(signaling, url=url, play_from=path.as_posix()) - await client.start() - - # we just want the first transcription - # and topic update messages - - marks = { - "SHOW_TRANSCRIPTION": False, - "UPDATE_TOPICS": False, - } - - async for rawmsg in client.get_reader(): - msg = json.loads(rawmsg) - cmd = msg["cmd"] - if cmd == "SHOW_TRANSCRIPTION": - assert "text" in msg - assert "want to share my incredible experience" in msg["text"] - elif cmd == "UPDATE_TOPICS": - assert "topics" in msg - assert "topic_test" in msg["topics"] - marks[cmd] = True - - # break if we have all the events we need - if all(marks.values()): - break - - # stop the server - await server.close() - await client.stop() diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 23c7813f..5555d195 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -71,11 +71,15 @@ async def dummy_llm(): @pytest.mark.asyncio -async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): +async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread # to be able to connect with aiortc + from reflector.settings import settings + + settings.DATA_DIR = Path(tmpdir) + # start server host = "127.0.0.1" port = 1255 @@ -189,3 +193,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): resp = await ac.get(f"/transcripts/{tid}") assert resp.status_code == 200 assert resp.json()["status"] == "ended" + + # check that audio is available + resp = await ac.get(f"/transcripts/{tid}/audio") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/wav" + + # check that audio/mp3 is available + resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/mp3"