mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: correctly save duration, when filewriter is finished
This commit is contained in:
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""fix duration
|
||||||
|
|
||||||
|
Revision ID: 4814901632bc
|
||||||
|
Revises: 38a927dcb099
|
||||||
|
Create Date: 2023-11-10 18:12:17.886522
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.sql import table, column
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "4814901632bc"
|
||||||
|
down_revision: Union[str, None] = "38a927dcb099"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# for all the transcripts, calculate the duration from the mp3
|
||||||
|
# and update the duration column
|
||||||
|
from pathlib import Path
|
||||||
|
from reflector.settings import settings
|
||||||
|
import av
|
||||||
|
|
||||||
|
bind = op.get_bind()
|
||||||
|
transcript = table(
|
||||||
|
"transcript", column("id", sa.String), column("duration", sa.Float)
|
||||||
|
)
|
||||||
|
|
||||||
|
# select only the one with duration = 0
|
||||||
|
results = bind.execute(
|
||||||
|
select([transcript.c.id, transcript.c.duration]).where(
|
||||||
|
transcript.c.duration == 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
data_dir = Path(settings.DATA_DIR)
|
||||||
|
for row in results:
|
||||||
|
audio_path = data_dir / row["id"] / "audio.mp3"
|
||||||
|
if not audio_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Processing {audio_path}")
|
||||||
|
container = av.open(audio_path.as_posix())
|
||||||
|
print(container.duration)
|
||||||
|
duration = round(float(container.duration / av.time_base), 2)
|
||||||
|
print(f"Duration: {duration}")
|
||||||
|
bind.execute(
|
||||||
|
transcript.update()
|
||||||
|
.where(transcript.c.id == row["id"])
|
||||||
|
.values(duration=duration)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to process {audio_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
pass
|
||||||
@@ -230,6 +230,16 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
data=final_short_summary,
|
data=final_short_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_duration(self, duration: float):
|
||||||
|
async with self.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"duration": duration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainLive(PipelineMainBase):
|
class PipelineMainLive(PipelineMainBase):
|
||||||
audio_filename: Path | None = None
|
audio_filename: Path | None = None
|
||||||
@@ -243,7 +253,10 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
AudioFileWriterProcessor(
|
||||||
|
path=transcript.audio_mp3_filename,
|
||||||
|
on_duration=self.on_duration,
|
||||||
|
),
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
|
|||||||
INPUT_TYPE = av.AudioFrame
|
INPUT_TYPE = av.AudioFrame
|
||||||
OUTPUT_TYPE = av.AudioFrame
|
OUTPUT_TYPE = av.AudioFrame
|
||||||
|
|
||||||
def __init__(self, path: Path | str):
|
def __init__(self, path: Path | str, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.suffix not in (".mp3", ".wav"):
|
if path.suffix not in (".mp3", ".wav"):
|
||||||
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
|
|||||||
self.path = path
|
self.path = path
|
||||||
self.out_container = None
|
self.out_container = None
|
||||||
self.out_stream = None
|
self.out_stream = None
|
||||||
|
self.last_packet = None
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
if not self.out_container:
|
if not self.out_container:
|
||||||
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
|
|||||||
raise ValueError("Only mp3 and wav files are supported")
|
raise ValueError("Only mp3 and wav files are supported")
|
||||||
for packet in self.out_stream.encode(data):
|
for packet in self.out_stream.encode(data):
|
||||||
self.out_container.mux(packet)
|
self.out_container.mux(packet)
|
||||||
|
self.last_packet = packet
|
||||||
await self.emit(data)
|
await self.emit(data)
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if self.out_container:
|
if self.out_container:
|
||||||
for packet in self.out_stream.encode():
|
for packet in self.out_stream.encode():
|
||||||
self.out_container.mux(packet)
|
self.out_container.mux(packet)
|
||||||
|
self.last_packet = packet
|
||||||
|
try:
|
||||||
|
if self.last_packet is not None:
|
||||||
|
duration = round(
|
||||||
|
float(
|
||||||
|
(self.last_packet.pts * self.last_packet.duration)
|
||||||
|
* self.last_packet.time_base
|
||||||
|
),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Failed to get duration")
|
||||||
|
duration = 0
|
||||||
|
|
||||||
self.out_container.close()
|
self.out_container.close()
|
||||||
self.out_container = None
|
self.out_container = None
|
||||||
self.out_stream = None
|
self.out_stream = None
|
||||||
|
|
||||||
|
if duration > 0:
|
||||||
|
await self.emit(duration, name="duration")
|
||||||
|
|||||||
@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
|
|||||||
data: Any
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Emitter:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._callbacks = {}
|
||||||
|
|
||||||
|
# register callbacks from kwargs (on_*)
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key.startswith("on_"):
|
||||||
|
self.on(value, name=key[3:])
|
||||||
|
|
||||||
|
def on(self, callback, name="default"):
|
||||||
|
"""
|
||||||
|
Register a callback to be called when data is emitted
|
||||||
|
"""
|
||||||
|
# ensure callback is asynchronous
|
||||||
|
if not asyncio.iscoroutinefunction(callback):
|
||||||
|
raise ValueError("Callback must be a coroutine function")
|
||||||
|
if name not in self._callbacks:
|
||||||
|
self._callbacks[name] = []
|
||||||
|
self._callbacks[name].append(callback)
|
||||||
|
|
||||||
|
def off(self, callback, name="default"):
|
||||||
|
"""
|
||||||
|
Unregister a callback to be called when data is emitted
|
||||||
|
"""
|
||||||
|
if name not in self._callbacks:
|
||||||
|
return
|
||||||
|
self._callbacks[name].remove(callback)
|
||||||
|
|
||||||
|
async def emit(self, data, name="default"):
|
||||||
|
if name not in self._callbacks:
|
||||||
|
return
|
||||||
|
for callback in self._callbacks[name]:
|
||||||
|
await callback(data)
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(Emitter):
|
||||||
INPUT_TYPE: type = None
|
INPUT_TYPE: type = None
|
||||||
OUTPUT_TYPE: type = None
|
OUTPUT_TYPE: type = None
|
||||||
|
|
||||||
@@ -59,7 +94,8 @@ class Processor:
|
|||||||
["processor"],
|
["processor"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, callback=None, custom_logger=None):
|
def __init__(self, callback=None, custom_logger=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.name = name = self.__class__.__name__
|
self.name = name = self.__class__.__name__
|
||||||
self.m_processor = self.m_processor.labels(name)
|
self.m_processor = self.m_processor.labels(name)
|
||||||
self.m_processor_call = self.m_processor_call.labels(name)
|
self.m_processor_call = self.m_processor_call.labels(name)
|
||||||
@@ -70,9 +106,11 @@ class Processor:
|
|||||||
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
|
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
|
||||||
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
|
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
|
||||||
self._processors = []
|
self._processors = []
|
||||||
self._callbacks = []
|
|
||||||
|
# register callbacks
|
||||||
if callback:
|
if callback:
|
||||||
self.on(callback)
|
self.on(callback)
|
||||||
|
|
||||||
self.uid = uuid4().hex
|
self.uid = uuid4().hex
|
||||||
self.flushed = False
|
self.flushed = False
|
||||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||||
@@ -100,21 +138,6 @@ class Processor:
|
|||||||
"""
|
"""
|
||||||
self._processors.remove(processor)
|
self._processors.remove(processor)
|
||||||
|
|
||||||
def on(self, callback):
|
|
||||||
"""
|
|
||||||
Register a callback to be called when data is emitted
|
|
||||||
"""
|
|
||||||
# ensure callback is asynchronous
|
|
||||||
if not asyncio.iscoroutinefunction(callback):
|
|
||||||
raise ValueError("Callback must be a coroutine function")
|
|
||||||
self._callbacks.append(callback)
|
|
||||||
|
|
||||||
def off(self, callback):
|
|
||||||
"""
|
|
||||||
Unregister a callback to be called when data is emitted
|
|
||||||
"""
|
|
||||||
self._callbacks.remove(callback)
|
|
||||||
|
|
||||||
def get_pref(self, key: str, default: Any = None):
|
def get_pref(self, key: str, default: Any = None):
|
||||||
"""
|
"""
|
||||||
Get a preference from the pipeline prefs
|
Get a preference from the pipeline prefs
|
||||||
@@ -123,15 +146,16 @@ class Processor:
|
|||||||
return self.pipeline.get_pref(key, default)
|
return self.pipeline.get_pref(key, default)
|
||||||
return default
|
return default
|
||||||
|
|
||||||
async def emit(self, data):
|
async def emit(self, data, name="default"):
|
||||||
if self.pipeline:
|
if name == "default":
|
||||||
await self.pipeline.emit(
|
if self.pipeline:
|
||||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
await self.pipeline.emit(
|
||||||
)
|
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||||
for callback in self._callbacks:
|
)
|
||||||
await callback(data)
|
await super().emit(data, name=name)
|
||||||
for processor in self._processors:
|
if name == "default":
|
||||||
await processor.push(data)
|
for processor in self._processors:
|
||||||
|
await processor.push(data)
|
||||||
|
|
||||||
async def push(self, data):
|
async def push(self, data):
|
||||||
"""
|
"""
|
||||||
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
|
|||||||
def disconnect(self, processor: Processor):
|
def disconnect(self, processor: Processor):
|
||||||
self.processor.disconnect(processor)
|
self.processor.disconnect(processor)
|
||||||
|
|
||||||
def on(self, callback):
|
def on(self, callback, name="default"):
|
||||||
self.processor.on(callback)
|
self.processor.on(callback, name=name)
|
||||||
|
|
||||||
def off(self, callback):
|
def off(self, callback, name="default"):
|
||||||
self.processor.off(callback)
|
self.processor.off(callback, name=name)
|
||||||
|
|
||||||
def describe(self, level=0):
|
def describe(self, level=0):
|
||||||
super().describe(level)
|
super().describe(level)
|
||||||
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
|
|||||||
for processor in self.processors:
|
for processor in self.processors:
|
||||||
processor.disconnect(processor)
|
processor.disconnect(processor)
|
||||||
|
|
||||||
def on(self, callback):
|
def on(self, callback, name="default"):
|
||||||
for processor in self.processors:
|
for processor in self.processors:
|
||||||
processor.on(callback)
|
processor.on(callback, name=name)
|
||||||
|
|
||||||
def off(self, callback):
|
def off(self, callback, name="default"):
|
||||||
for processor in self.processors:
|
for processor in self.processors:
|
||||||
processor.off(callback)
|
processor.off(callback, name=name)
|
||||||
|
|
||||||
def describe(self, level=0):
|
def describe(self, level=0):
|
||||||
super().describe(level)
|
super().describe(level)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class GetTranscript(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
status: str
|
status: str
|
||||||
locked: bool
|
locked: bool
|
||||||
duration: int
|
duration: float
|
||||||
title: str | None
|
title: str | None
|
||||||
short_summary: str | None
|
short_summary: str | None
|
||||||
long_summary: str | None
|
long_summary: str | None
|
||||||
|
|||||||
@@ -191,6 +191,9 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert events[-1]["event"] == "STATUS"
|
assert events[-1]["event"] == "STATUS"
|
||||||
assert events[-1]["data"]["value"] == "ended"
|
assert events[-1]["data"]["value"] == "ended"
|
||||||
|
|
||||||
|
# check on the latest response that the audio duration is > 0
|
||||||
|
assert resp.json()["duration"] > 0
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
# check that audio/mp3 is available
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user