diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 6ac2e32a..f0dbc277 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform transcripts = sqlalchemy.Table( "transcript", @@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel): title: str +class TranscriptDuration(BaseModel): + duration: float + + +class TranscriptWaveform(BaseModel): + waveform: list[float] + + class TranscriptEvent(BaseModel): event: str data: dict @@ -118,22 +125,6 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - def unlink(self): self.data_path.unlink(missing_ok=True) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 8c78c48f..b0576b92 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -21,11 +21,13 @@ from pydantic import BaseModel from reflector.app import app from reflector.db.transcripts import ( Transcript, + TranscriptDuration, TranscriptFinalLongSummary, TranscriptFinalShortSummary, TranscriptFinalTitle, TranscriptText, TranscriptTopic, + TranscriptWaveform, transcripts_controller, ) from reflector.logger import logger @@ -45,6 +47,7 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) +from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import ( TitleSummaryWithId as TitleSummaryWithIdProcessorType, @@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) - async def on_duration(self, duration: float): + async def on_duration(self, data): async with self.transaction(): + duration = TranscriptDuration(duration=data) + transcript = await self.get_transcript() await transcripts_controller.update( transcript, { - "duration": duration, + "duration": duration.duration, }, ) + return await transcripts_controller.append_event( + transcript=transcript, event="DURATION", data=duration + ) + + async def on_waveform(self, data): + waveform = TranscriptWaveform(waveform=data) + + transcript = await self.get_transcript() + + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) class PipelineMainLive(PipelineMainBase): @@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase): BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + AudioWaveformProcessor( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), ] ), ] diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py new file mode 100644 index 00000000..acce904a --- /dev/null +++ b/server/reflector/processors/audio_waveform_processor.py @@ -0,0 +1,33 @@ +import json +from pathlib import Path + +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary +from reflector.utils.audio_waveform import get_audio_waveform + + +class AudioWaveformProcessor(Processor): + """ + Write the waveform for the final audio + """ + + INPUT_TYPE = TitleSummary + + def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs): + super().__init__(**kwargs) + if isinstance(audio_path, str): + audio_path = Path(audio_path) + if audio_path.suffix not in (".mp3", ".wav"): + raise ValueError("Only mp3 and wav files are supported") + self.audio_path = audio_path + self.waveform_path = waveform_path + + async def _push(self, _data): + self.waveform_path.parent.mkdir(parents=True, exist_ok=True) + self.logger.info("Waveform Processing Started") + waveform = get_audio_waveform(path=self.audio_path, segments_count=255) + + with open(self.waveform_path, "w") as fd: + json.dump(waveform, fd) + self.logger.info("Waveform Processing Finished") + await self.emit(waveform, name="waveform") diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 5de9ced3..77e4b0b7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -22,7 +22,6 @@ from reflector.db.transcripts import ( from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings from reflector.ws_manager import get_ws_manager -from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import RtcOffer, rtc_offer_base @@ -261,7 +260,7 @@ async def transcript_get_audio_waveform( if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - await run_in_threadpool(transcript.convert_audio_to_waveform) + # await run_in_threadpool(transcript.convert_audio_to_waveform) return transcript.audio_waveform diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index cf2ea304..65660a5e 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("FINAL_TITLE")] assert ev["data"]["title"] == "LLM TITLE" + assert "WAVEFORM" in eventnames + ev = events[eventnames.index("FINAL_TITLE")] + assert ev["data"]["title"] == "LLM TITLE" + # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] assert statuses.index("recording") < statuses.index("processing") diff --git a/www/app/lib/edgeConfig.ts b/www/app/lib/edgeConfig.ts index 5527121a..1140e555 100644 --- a/www/app/lib/edgeConfig.ts +++ b/www/app/lib/edgeConfig.ts @@ -3,9 +3,9 @@ import { isDevelopment } from "./utils"; const localConfig = { features: { - requireLogin: true, + requireLogin: false, privacy: true, - browse: true, + browse: false, }, api_url: "http://127.0.0.1:1250", websocket_url: "ws://127.0.0.1:1250",