mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
fix waveform in pipeline
This commit is contained in:
@@ -233,6 +233,7 @@ class PipelineMainBase(PipelineRunner):
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
duration = TranscriptDuration(duration=data)
|
||||
@@ -248,14 +249,16 @@ class PipelineMainBase(PipelineRunner):
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
async with self.transaction():
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
@@ -283,7 +286,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
AudioWaveformProcessor(
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=transcript.audio_mp3_filename,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
|
||||
@@ -22,7 +22,7 @@ class AudioWaveformProcessor(Processor):
|
||||
self.audio_path = audio_path
|
||||
self.waveform_path = waveform_path
|
||||
|
||||
async def _push(self, _data):
|
||||
async def _flush(self):
|
||||
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info("Waveform Processing Started")
|
||||
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
|
||||
@@ -31,3 +31,6 @@ class AudioWaveformProcessor(Processor):
|
||||
json.dump(waveform, fd)
|
||||
self.logger.info("Waveform Processing Finished")
|
||||
await self.emit(waveform, name="waveform")
|
||||
|
||||
async def _push(_self, _data):
|
||||
return
|
||||
|
||||
@@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
|
||||
assert response.status_code == 206
|
||||
assert response.headers["content-type"] == content_type
|
||||
assert response.headers["content-range"].startswith("bytes 100-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_audio_download_waveform(fake_transcript):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
assert isinstance(response.json()["data"], list)
|
||||
assert len(response.json()["data"]) >= 255
|
||||
|
||||
@@ -183,8 +183,14 @@ async def test_transcript_rtc_and_websocket(
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
|
||||
assert "WAVEFORM" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
ev = events[eventnames.index("WAVEFORM")]
|
||||
assert isinstance(ev["data"]["waveform"], list)
|
||||
assert len(ev["data"]["waveform"]) >= 250
|
||||
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
|
||||
assert waveform_resp.status_code == 200
|
||||
assert waveform_resp.headers["content-type"] == "application/json"
|
||||
assert isinstance(waveform_resp.json()["data"], list)
|
||||
assert len(waveform_resp.json()["data"]) >= 250
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
@@ -197,11 +203,12 @@ async def test_transcript_rtc_and_websocket(
|
||||
|
||||
# check on the latest response that the audio duration is > 0
|
||||
assert resp.json()["duration"] > 0
|
||||
assert "DURATION" in eventnames
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mpeg"
|
||||
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert audio_resp.status_code == 200
|
||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
|
||||
Reference in New Issue
Block a user