diff --git a/server/migrations/versions/4814901632bc_fix_duration.py b/server/migrations/versions/4814901632bc_fix_duration.py new file mode 100644 index 00000000..66628bb5 --- /dev/null +++ b/server/migrations/versions/4814901632bc_fix_duration.py @@ -0,0 +1,64 @@ +"""fix duration + +Revision ID: 4814901632bc +Revises: 38a927dcb099 +Create Date: 2023-11-10 18:12:17.886522 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = "4814901632bc" +down_revision: Union[str, None] = "38a927dcb099" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # for all the transcripts, calculate the duration from the mp3 + # and update the duration column + from pathlib import Path + from reflector.settings import settings + import av + + bind = op.get_bind() + transcript = table( + "transcript", column("id", sa.String), column("duration", sa.Float) + ) + + # select only the one with duration = 0 + results = bind.execute( + select([transcript.c.id, transcript.c.duration]).where( + transcript.c.duration == 0 + ) + ) + + data_dir = Path(settings.DATA_DIR) + for row in results: + audio_path = data_dir / row["id"] / "audio.mp3" + if not audio_path.exists(): + continue + + try: + print(f"Processing {audio_path}") + container = av.open(audio_path.as_posix()) + print(container.duration) + duration = round(float(container.duration / av.time_base), 2) + print(f"Duration: {duration}") + bind.execute( + transcript.update() + .where(transcript.c.id == row["id"]) + .values(duration=duration) + ) + except Exception as e: + print(f"Failed to process {audio_path}: {e}") + + +def downgrade() -> None: + pass diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index b2bc51ea..8c78c48f 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -230,6 +230,16 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) + async def on_duration(self, duration: float): + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "duration": duration, + }, + ) + class PipelineMainLive(PipelineMainBase): audio_filename: Path | None = None @@ -243,7 +253,10 @@ class PipelineMainLive(PipelineMainBase): transcript = await self.get_transcript() processors = [ - AudioFileWriterProcessor(path=transcript.audio_mp3_filename), + AudioFileWriterProcessor( + path=transcript.audio_mp3_filename, + on_duration=self.on_duration, + ), AudioChunkerProcessor(), AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(), diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py index d34dc3f0..36ee4263 100644 --- a/server/reflector/processors/audio_file_writer.py +++ b/server/reflector/processors/audio_file_writer.py @@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor): INPUT_TYPE = av.AudioFrame OUTPUT_TYPE = av.AudioFrame - def __init__(self, path: Path | str): - super().__init__() + def __init__(self, path: Path | str, **kwargs): + super().__init__(**kwargs) if isinstance(path, str): path = Path(path) if path.suffix not in (".mp3", ".wav"): @@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor): self.path = path self.out_container = None self.out_stream = None + self.last_packet = None async def _push(self, data: av.AudioFrame): if not self.out_container: @@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor): raise ValueError("Only mp3 and wav files are supported") for packet in self.out_stream.encode(data): self.out_container.mux(packet) + self.last_packet = packet await self.emit(data) async def _flush(self): if self.out_container: for packet in self.out_stream.encode(): self.out_container.mux(packet) + self.last_packet = packet + try: + if self.last_packet is not None: + duration = round( + float( + (self.last_packet.pts * self.last_packet.duration) + * self.last_packet.time_base + ), + 2, + ) + except Exception: + self.logger.exception("Failed to get duration") + duration = 0 + self.out_container.close() self.out_container = None self.out_stream = None + + if duration > 0: + await self.emit(duration, name="duration") diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 46bfb4a5..00f0223b 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -14,7 +14,42 @@ class PipelineEvent(BaseModel): data: Any -class Processor: +class Emitter: + def __init__(self, **kwargs): + self._callbacks = {} + + # register callbacks from kwargs (on_*) + for key, value in kwargs.items(): + if key.startswith("on_"): + self.on(value, name=key[3:]) + + def on(self, callback, name="default"): + """ + Register a callback to be called when data is emitted + """ + # ensure callback is asynchronous + if not asyncio.iscoroutinefunction(callback): + raise ValueError("Callback must be a coroutine function") + if name not in self._callbacks: + self._callbacks[name] = [] + self._callbacks[name].append(callback) + + def off(self, callback, name="default"): + """ + Unregister a callback to be called when data is emitted + """ + if name not in self._callbacks: + return + self._callbacks[name].remove(callback) + + async def emit(self, data, name="default"): + if name not in self._callbacks: + return + for callback in self._callbacks[name]: + await callback(data) + + +class Processor(Emitter): INPUT_TYPE: type = None OUTPUT_TYPE: type = None @@ -59,7 +94,8 @@ class Processor: ["processor"], ) - def __init__(self, callback=None, custom_logger=None): + def __init__(self, callback=None, custom_logger=None, **kwargs): + super().__init__(**kwargs) self.name = name = self.__class__.__name__ self.m_processor = self.m_processor.labels(name) self.m_processor_call = self.m_processor_call.labels(name) @@ -70,9 +106,11 @@ class Processor: self.m_processor_flush_success = self.m_processor_flush_success.labels(name) self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name) self._processors = [] - self._callbacks = [] + + # register callbacks if callback: self.on(callback) + self.uid = uuid4().hex self.flushed = False self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) @@ -100,21 +138,6 @@ class Processor: """ self._processors.remove(processor) - def on(self, callback): - """ - Register a callback to be called when data is emitted - """ - # ensure callback is asynchronous - if not asyncio.iscoroutinefunction(callback): - raise ValueError("Callback must be a coroutine function") - self._callbacks.append(callback) - - def off(self, callback): - """ - Unregister a callback to be called when data is emitted - """ - self._callbacks.remove(callback) - def get_pref(self, key: str, default: Any = None): """ Get a preference from the pipeline prefs @@ -123,15 +146,16 @@ class Processor: return self.pipeline.get_pref(key, default) return default - async def emit(self, data): - if self.pipeline: - await self.pipeline.emit( - PipelineEvent(processor=self.name, uid=self.uid, data=data) - ) - for callback in self._callbacks: - await callback(data) - for processor in self._processors: - await processor.push(data) + async def emit(self, data, name="default"): + if name == "default": + if self.pipeline: + await self.pipeline.emit( + PipelineEvent(processor=self.name, uid=self.uid, data=data) + ) + await super().emit(data, name=name) + if name == "default": + for processor in self._processors: + await processor.push(data) async def push(self, data): """ @@ -254,11 +278,11 @@ class ThreadedProcessor(Processor): def disconnect(self, processor: Processor): self.processor.disconnect(processor) - def on(self, callback): - self.processor.on(callback) + def on(self, callback, name="default"): + self.processor.on(callback, name=name) - def off(self, callback): - self.processor.off(callback) + def off(self, callback, name="default"): + self.processor.off(callback, name=name) def describe(self, level=0): super().describe(level) @@ -305,13 +329,13 @@ class BroadcastProcessor(Processor): for processor in self.processors: processor.disconnect(processor) - def on(self, callback): + def on(self, callback, name="default"): for processor in self.processors: - processor.on(callback) + processor.on(callback, name=name) - def off(self, callback): + def off(self, callback, name="default"): for processor in self.processors: - processor.off(callback) + processor.off(callback, name=name) def describe(self, level=0): super().describe(level) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index a202f3a1..5de9ced3 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -51,7 +51,7 @@ class GetTranscript(BaseModel): name: str status: str locked: bool - duration: int + duration: float title: str | None short_summary: str | None long_summary: str | None diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 413c8b24..cf2ea304 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -191,6 +191,9 @@ async def test_transcript_rtc_and_websocket( assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" + # check on the latest response that the audio duration is > 0 + assert resp.json()["duration"] > 0 + # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200