try to move waveform to pipeline

This commit is contained in:
Sara
2023-11-15 20:30:00 +01:00
parent e98f1bf4bc
commit 1fc261a669
6 changed files with 72 additions and 23 deletions

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
@@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -118,22 +125,6 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)

View File

@@ -21,11 +21,13 @@ from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
@@ -45,6 +47,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
async def on_duration(self, duration: float):
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"duration": duration,
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
async def on_waveform(self, data):
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase):
@@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
),
]

View File

@@ -0,0 +1,33 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _push(self, _data):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")

View File

@@ -22,7 +22,6 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -261,7 +260,7 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
# await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform

View File

@@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")

View File

@@ -3,9 +3,9 @@ import { isDevelopment } from "./utils";
const localConfig = {
features: {
requireLogin: true,
requireLogin: false,
privacy: true,
browse: true,
browse: false,
},
api_url: "http://127.0.0.1:1250",
websocket_url: "ws://127.0.0.1:1250",