try to move waveform to pipeline

This commit is contained in:
Sara
2023-11-15 20:30:00 +01:00
parent e98f1bf4bc
commit 1fc261a669
6 changed files with 72 additions and 23 deletions

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table( transcripts = sqlalchemy.Table(
"transcript", "transcript",
@@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel):
title: str title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel): class TranscriptEvent(BaseModel):
event: str event: str
data: dict data: dict
@@ -118,22 +125,6 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"): def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics] return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self): def unlink(self):
self.data_path.unlink(missing_ok=True) self.data_path.unlink(missing_ok=True)

View File

@@ -21,11 +21,13 @@ from pydantic import BaseModel
from reflector.app import app from reflector.app import app
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptDuration,
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
TranscriptFinalShortSummary, TranscriptFinalShortSummary,
TranscriptFinalTitle, TranscriptFinalTitle,
TranscriptText, TranscriptText,
TranscriptTopic, TranscriptTopic,
TranscriptWaveform,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
@@ -45,6 +47,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor, TranscriptTranslatorProcessor,
) )
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import ( from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType, TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary, data=final_short_summary,
) )
async def on_duration(self, duration: float): async def on_duration(self, data):
async with self.transaction(): async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript() transcript = await self.get_transcript()
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"duration": duration, "duration": duration.duration,
}, },
) )
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
async def on_waveform(self, data):
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase): class PipelineMainLive(PipelineMainBase):
@@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
] ]
), ),
] ]

View File

@@ -0,0 +1,33 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _push(self, _data):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")

View File

@@ -22,7 +22,6 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import get_ws_manager from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
@@ -261,7 +260,7 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists(): if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found") raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform) # await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform return transcript.audio_waveform

View File

@@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")] ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE" assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
# check status order # check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing") assert statuses.index("recording") < statuses.index("processing")

View File

@@ -3,9 +3,9 @@ import { isDevelopment } from "./utils";
const localConfig = { const localConfig = {
features: { features: {
requireLogin: true, requireLogin: false,
privacy: true, privacy: true,
browse: true, browse: false,
}, },
api_url: "http://127.0.0.1:1250", api_url: "http://127.0.0.1:1250",
websocket_url: "ws://127.0.0.1:1250", websocket_url: "ws://127.0.0.1:1250",