diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index da890513..8a926f30 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,5 +1,6 @@ from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401 from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401 +from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py new file mode 100644 index 00000000..c597f81d --- /dev/null +++ b/server/reflector/processors/audio_file_writer.py @@ -0,0 +1,35 @@ +from reflector.processors.base import Processor +import av +import wave +from pathlib import Path + + +class AudioFileWriterProcessor(Processor): + """ + Write audio frames to a file. + """ + + INPUT_TYPE = av.AudioFrame + OUTPUT_TYPE = av.AudioFrame + + def __init__(self, path: Path | str): + super().__init__() + if isinstance(path, str): + path = Path(path) + self.path = path + self.fd = None + + async def _push(self, data: av.AudioFrame): + if not self.fd: + self.path.parent.mkdir(parents=True, exist_ok=True) + self.fd = wave.open(self.path.as_posix(), "wb") + self.fd.setnchannels(len(data.layout.channels)) + self.fd.setsampwidth(data.format.bytes) + self.fd.setframerate(data.sample_rate) + self.fd.writeframes(data.to_ndarray().tobytes()) + await self.emit(data) + + async def _flush(self): + if self.fd: + self.fd.close() + self.fd = None diff --git a/server/reflector/settings.py b/server/reflector/settings.py index e776875b..0787b466 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -9,6 +9,9 @@ class Settings(BaseSettings): # Database DATABASE_URL: str = "sqlite:///./reflector.sqlite3" + # local data directory (audio for no) + DATA_DIR: str = "./data" + # Whisper WHISPER_MODEL_SIZE: str = "tiny" WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny" diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index aef00580..c0944a82 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -7,12 +7,14 @@ from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from json import loads, dumps from enum import StrEnum +from pathlib import Path import av from reflector.processors import ( Pipeline, AudioChunkerProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, + AudioFileWriterProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, TranscriptFinalSummaryProcessor, @@ -64,7 +66,11 @@ class PipelineEvent(StrEnum): async def rtc_offer_base( - params: RtcOffer, request: Request, event_callback=None, event_callback_args=None + params: RtcOffer, + request: Request, + event_callback=None, + event_callback_args=None, + audio_filename: Path | None = None, ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -151,14 +157,18 @@ async def rtc_offer_base( # create a context for the whole rtc transaction # add a customised logger to the context - ctx.pipeline = Pipeline( + processors = [] + if audio_filename is not None: + processors += [AudioFileWriterProcessor(path=audio_filename)] + processors += [ AudioChunkerProcessor(), AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), TranscriptLinerProcessor(), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), - ) + ] + ctx.pipeline = Pipeline(*processors) # FIXME: warmup is not working well yet # await ctx.pipeline.warmup() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f2a8425e..6f952938 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -5,14 +5,20 @@ from fastapi import ( WebSocket, WebSocketDisconnect, ) +from fastapi.responses import FileResponse +from starlette.concurrency import run_in_threadpool from pydantic import BaseModel, Field from uuid import uuid4 from datetime import datetime from fastapi_pagination import Page, paginate from reflector.logger import logger from reflector.db import database, transcripts +from reflector.settings import settings from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent from typing import Optional +from pathlib import Path +from tempfile import NamedTemporaryFile +import av router = APIRouter() @@ -81,6 +87,44 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] + def convert_audio_to_mp3(self): + fn = self.audio_mp3_filename + if fn.exists(): + return + + logger.info(f"Converting audio to mp3: {self.audio_filename}") + inp = av.open(self.audio_filename.as_posix(), "r") + + # create temporary file for mp3 + with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: + out = av.open(tmp.name, "w") + stream = out.add_stream("mp3") + for frame in inp.decode(audio=0): + frame.pts = None + for packet in stream.encode(frame): + out.mux(packet) + for packet in stream.encode(None): + out.mux(packet) + out.close() + + # move temporary file to final location + Path(tmp.name).rename(fn) + + def unlink(self): + self.data_path.unlink(missing_ok=True) + + @property + def data_path(self): + return Path(settings.DATA_DIR) / self.id + + @property + def audio_filename(self): + return self.data_path / "audio.wav" + + @property + def audio_mp3_filename(self): + return self.data_path / "audio.mp3" + class TranscriptController: async def get_all(self) -> list[Transcript]: @@ -112,6 +156,10 @@ class TranscriptController: setattr(transcript, key, value) async def remove_by_id(self, transcript_id: str) -> None: + transcript = await self.get_by_id(transcript_id) + if not transcript: + return + transcript.unlink() query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) @@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str): if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - # TODO: Implement audio generation - return HTTPException(status_code=500, detail="Not implemented") + if not transcript.audio_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + return FileResponse(transcript.audio_filename, media_type="audio/wav") + + +@router.get("/transcripts/{transcript_id}/audio/mp3") +async def transcript_get_audio_mp3(transcript_id: str): + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + if not transcript.audio_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + await run_in_threadpool(transcript.convert_audio_to_mp3) + + return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3") @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) @@ -371,4 +435,5 @@ async def transcript_record_webrtc( request, event_callback=handle_rtc_event, event_callback_args=transcript_id, + audio_filename=transcript.audio_filename, ) diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 70ee209b..f38728c2 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -70,11 +70,15 @@ async def dummy_llm(): @pytest.mark.asyncio -async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): +async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread # to be able to connect with aiortc + from reflector.settings import settings + + settings.DATA_DIR = Path(tmpdir) + # start server host = "127.0.0.1" port = 1255 @@ -188,3 +192,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): resp = await ac.get(f"/transcripts/{tid}") assert resp.status_code == 200 assert resp.json()["status"] == "ended" + + # check that audio is available + resp = await ac.get(f"/transcripts/{tid}/audio") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/wav" + + # check that audio/mp3 is available + resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/mp3"