server: correctly save duration, when filewriter is finished

This commit is contained in:
2023-11-10 18:23:40 +01:00
committed by Mathieu Virbel
parent afa8010d29
commit e18a7c8d4e
6 changed files with 162 additions and 39 deletions

View File

@@ -0,0 +1,64 @@
"""fix duration
Revision ID: 4814901632bc
Revises: 38a927dcb099
Create Date: 2023-11-10 18:12:17.886522
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = "4814901632bc"
down_revision: Union[str, None] = "38a927dcb099"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# for all the transcripts, calculate the duration from the mp3
# and update the duration column
from pathlib import Path
from reflector.settings import settings
import av
bind = op.get_bind()
transcript = table(
"transcript", column("id", sa.String), column("duration", sa.Float)
)
# select only the one with duration = 0
results = bind.execute(
select([transcript.c.id, transcript.c.duration]).where(
transcript.c.duration == 0
)
)
data_dir = Path(settings.DATA_DIR)
for row in results:
audio_path = data_dir / row["id"] / "audio.mp3"
if not audio_path.exists():
continue
try:
print(f"Processing {audio_path}")
container = av.open(audio_path.as_posix())
print(container.duration)
duration = round(float(container.duration / av.time_base), 2)
print(f"Duration: {duration}")
bind.execute(
transcript.update()
.where(transcript.c.id == row["id"])
.values(duration=duration)
)
except Exception as e:
print(f"Failed to process {audio_path}: {e}")
def downgrade() -> None:
pass

View File

@@ -230,6 +230,16 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
async def on_duration(self, duration: float):
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"duration": duration,
},
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
@@ -243,7 +253,10 @@ class PipelineMainLive(PipelineMainBase):
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),

View File

@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
def __init__(self, path: Path | str, **kwargs):
super().__init__(**kwargs)
if isinstance(path, str):
path = Path(path)
if path.suffix not in (".mp3", ".wav"):
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
self.path = path
self.out_container = None
self.out_stream = None
self.last_packet = None
async def _push(self, data: av.AudioFrame):
if not self.out_container:
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
raise ValueError("Only mp3 and wav files are supported")
for packet in self.out_stream.encode(data):
self.out_container.mux(packet)
self.last_packet = packet
await self.emit(data)
async def _flush(self):
if self.out_container:
for packet in self.out_stream.encode():
self.out_container.mux(packet)
self.last_packet = packet
try:
if self.last_packet is not None:
duration = round(
float(
(self.last_packet.pts * self.last_packet.duration)
* self.last_packet.time_base
),
2,
)
except Exception:
self.logger.exception("Failed to get duration")
duration = 0
self.out_container.close()
self.out_container = None
self.out_stream = None
if duration > 0:
await self.emit(duration, name="duration")

View File

@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
data: Any
class Processor:
class Emitter:
def __init__(self, **kwargs):
self._callbacks = {}
# register callbacks from kwargs (on_*)
for key, value in kwargs.items():
if key.startswith("on_"):
self.on(value, name=key[3:])
def on(self, callback, name="default"):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
if name not in self._callbacks:
self._callbacks[name] = []
self._callbacks[name].append(callback)
def off(self, callback, name="default"):
"""
Unregister a callback to be called when data is emitted
"""
if name not in self._callbacks:
return
self._callbacks[name].remove(callback)
async def emit(self, data, name="default"):
if name not in self._callbacks:
return
for callback in self._callbacks[name]:
await callback(data)
class Processor(Emitter):
INPUT_TYPE: type = None
OUTPUT_TYPE: type = None
@@ -59,7 +94,8 @@ class Processor:
["processor"],
)
def __init__(self, callback=None, custom_logger=None):
def __init__(self, callback=None, custom_logger=None, **kwargs):
super().__init__(**kwargs)
self.name = name = self.__class__.__name__
self.m_processor = self.m_processor.labels(name)
self.m_processor_call = self.m_processor_call.labels(name)
@@ -70,9 +106,11 @@ class Processor:
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
self._processors = []
self._callbacks = []
# register callbacks
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
@@ -100,21 +138,6 @@ class Processor:
"""
self._processors.remove(processor)
def on(self, callback):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
self._callbacks.append(callback)
def off(self, callback):
"""
Unregister a callback to be called when data is emitted
"""
self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
@@ -123,13 +146,14 @@ class Processor:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data):
async def emit(self, data, name="default"):
if name == "default":
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
for callback in self._callbacks:
await callback(data)
await super().emit(data, name=name)
if name == "default":
for processor in self._processors:
await processor.push(data)
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def on(self, callback, name="default"):
self.processor.on(callback, name=name)
def off(self, callback):
self.processor.off(callback)
def off(self, callback, name="default"):
self.processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
for processor in self.processors:
processor.disconnect(processor)
def on(self, callback):
def on(self, callback, name="default"):
for processor in self.processors:
processor.on(callback)
processor.on(callback, name=name)
def off(self, callback):
def off(self, callback, name="default"):
for processor in self.processors:
processor.off(callback)
processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)

View File

@@ -51,7 +51,7 @@ class GetTranscript(BaseModel):
name: str
status: str
locked: bool
duration: int
duration: float
title: str | None
short_summary: str | None
long_summary: str | None

View File

@@ -191,6 +191,9 @@ async def test_transcript_rtc_and_websocket(
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200