mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
try to move waveform to pipeline
This commit is contained in:
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
|
|||||||
from reflector.db import database, metadata
|
from reflector.db import database, metadata
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
|
||||||
|
|
||||||
transcripts = sqlalchemy.Table(
|
transcripts = sqlalchemy.Table(
|
||||||
"transcript",
|
"transcript",
|
||||||
@@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel):
|
|||||||
title: str
|
title: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptDuration(BaseModel):
|
||||||
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptWaveform(BaseModel):
|
||||||
|
waveform: list[float]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptEvent(BaseModel):
|
class TranscriptEvent(BaseModel):
|
||||||
event: str
|
event: str
|
||||||
data: dict
|
data: dict
|
||||||
@@ -118,22 +125,6 @@ class Transcript(BaseModel):
|
|||||||
def topics_dump(self, mode="json"):
|
def topics_dump(self, mode="json"):
|
||||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||||
|
|
||||||
def convert_audio_to_waveform(self, segments_count=256):
|
|
||||||
fn = self.audio_waveform_filename
|
|
||||||
if fn.exists():
|
|
||||||
return
|
|
||||||
waveform = get_audio_waveform(
|
|
||||||
path=self.audio_mp3_filename, segments_count=segments_count
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with open(fn, "w") as fd:
|
|
||||||
json.dump(waveform, fd)
|
|
||||||
except Exception:
|
|
||||||
# remove file if anything happen during the write
|
|
||||||
fn.unlink(missing_ok=True)
|
|
||||||
raise
|
|
||||||
return waveform
|
|
||||||
|
|
||||||
def unlink(self):
|
def unlink(self):
|
||||||
self.data_path.unlink(missing_ok=True)
|
self.data_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ from pydantic import BaseModel
|
|||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
Transcript,
|
Transcript,
|
||||||
|
TranscriptDuration,
|
||||||
TranscriptFinalLongSummary,
|
TranscriptFinalLongSummary,
|
||||||
TranscriptFinalShortSummary,
|
TranscriptFinalShortSummary,
|
||||||
TranscriptFinalTitle,
|
TranscriptFinalTitle,
|
||||||
TranscriptText,
|
TranscriptText,
|
||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
|
TranscriptWaveform,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
@@ -45,6 +47,7 @@ from reflector.processors import (
|
|||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorProcessor,
|
TranscriptTranslatorProcessor,
|
||||||
)
|
)
|
||||||
|
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||||
from reflector.processors.types import AudioDiarizationInput
|
from reflector.processors.types import AudioDiarizationInput
|
||||||
from reflector.processors.types import (
|
from reflector.processors.types import (
|
||||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||||
@@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
data=final_short_summary,
|
data=final_short_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_duration(self, duration: float):
|
async def on_duration(self, data):
|
||||||
async with self.transaction():
|
async with self.transaction():
|
||||||
|
duration = TranscriptDuration(duration=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"duration": duration,
|
"duration": duration.duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript, event="DURATION", data=duration
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_waveform(self, data):
|
||||||
|
waveform = TranscriptWaveform(waveform=data)
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript, event="WAVEFORM", data=waveform
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainLive(PipelineMainBase):
|
class PipelineMainLive(PipelineMainBase):
|
||||||
@@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||||
|
AudioWaveformProcessor(
|
||||||
|
audio_path=transcript.audio_mp3_filename,
|
||||||
|
waveform_path=transcript.audio_waveform_filename,
|
||||||
|
on_waveform=self.on_waveform,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
33
server/reflector/processors/audio_waveform_processor.py
Normal file
33
server/reflector/processors/audio_waveform_processor.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import TitleSummary
|
||||||
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class AudioWaveformProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Write the waveform for the final audio
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if isinstance(audio_path, str):
|
||||||
|
audio_path = Path(audio_path)
|
||||||
|
if audio_path.suffix not in (".mp3", ".wav"):
|
||||||
|
raise ValueError("Only mp3 and wav files are supported")
|
||||||
|
self.audio_path = audio_path
|
||||||
|
self.waveform_path = waveform_path
|
||||||
|
|
||||||
|
async def _push(self, _data):
|
||||||
|
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.logger.info("Waveform Processing Started")
|
||||||
|
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
|
||||||
|
|
||||||
|
with open(self.waveform_path, "w") as fd:
|
||||||
|
json.dump(waveform, fd)
|
||||||
|
self.logger.info("Waveform Processing Finished")
|
||||||
|
await self.emit(waveform, name="waveform")
|
||||||
@@ -22,7 +22,6 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.ws_manager import get_ws_manager
|
from reflector.ws_manager import get_ws_manager
|
||||||
from starlette.concurrency import run_in_threadpool
|
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||||
@@ -261,7 +260,7 @@ async def transcript_get_audio_waveform(
|
|||||||
if not transcript.audio_mp3_filename.exists():
|
if not transcript.audio_mp3_filename.exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
# await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||||
|
|
||||||
return transcript.audio_waveform
|
return transcript.audio_waveform
|
||||||
|
|
||||||
|
|||||||
@@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
assert ev["data"]["title"] == "LLM TITLE"
|
assert ev["data"]["title"] == "LLM TITLE"
|
||||||
|
|
||||||
|
assert "WAVEFORM" in eventnames
|
||||||
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
|
assert ev["data"]["title"] == "LLM TITLE"
|
||||||
|
|
||||||
# check status order
|
# check status order
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
assert statuses.index("recording") < statuses.index("processing")
|
assert statuses.index("recording") < statuses.index("processing")
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ import { isDevelopment } from "./utils";
|
|||||||
|
|
||||||
const localConfig = {
|
const localConfig = {
|
||||||
features: {
|
features: {
|
||||||
requireLogin: true,
|
requireLogin: false,
|
||||||
privacy: true,
|
privacy: true,
|
||||||
browse: true,
|
browse: false,
|
||||||
},
|
},
|
||||||
api_url: "http://127.0.0.1:1250",
|
api_url: "http://127.0.0.1:1250",
|
||||||
websocket_url: "ws://127.0.0.1:1250",
|
websocket_url: "ws://127.0.0.1:1250",
|
||||||
|
|||||||
Reference in New Issue
Block a user