fix waveform in pipeline

This commit is contained in:
Sara
2023-11-17 13:38:32 +01:00
parent 8b8e92ceac
commit a846e38fbd
4 changed files with 25 additions and 24 deletions

View File

@@ -233,6 +233,7 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
@@ -248,14 +249,16 @@ class PipelineMainBase(PipelineRunner):
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
waveform = TranscriptWaveform(waveform=data)
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase):
@@ -283,7 +286,7 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor(
AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,

View File

@@ -22,7 +22,7 @@ class AudioWaveformProcessor(Processor):
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _push(self, _data):
async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
@@ -31,3 +31,6 @@ class AudioWaveformProcessor(Processor):
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
async def _push(_self, _data):
return

View File

@@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) >= 255

View File

@@ -183,8 +183,14 @@ async def test_transcript_rtc_and_websocket(
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
@@ -197,11 +203,12 @@ async def test_transcript_rtc_and_websocket(
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")