diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 83b57949..b182f421 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -332,12 +332,6 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - # XXX move as a task - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), ] pipeline = Pipeline(*processors) pipeline.options = self @@ -406,12 +400,14 @@ class PipelineMainFromTopics(PipelineMainBase): async def create(self) -> Pipeline: self.prepare() + + # get transcript + self._transcript = transcript = await self.get_transcript() + + # create pipeline processors = self.get_processors() pipeline = Pipeline(*processors) pipeline.options = self - - # get transcript - transcript = await self.get_transcript() pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.info(f"{self.__class__.__name__} pipeline created") @@ -463,6 +459,29 @@ class PipelineMainFinalSummaries(PipelineMainFromTopics): ] +class PipelineMainWaveform(PipelineMainFromTopics): + """ + Generate waveform + """ + + def get_processors(self) -> list: + return [ + AudioWaveformProcessor.as_threaded( + audio_path=self._transcript.audio_wav_filename, + waveform_path=self._transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), + ] + + +@get_transcript +async def pipeline_waveform(transcript: Transcript, logger: Logger): + logger.info("Starting waveform") + runner = PipelineMainWaveform(transcript_id=transcript.id) + await runner.run() + logger.info("Waveform done") + + @get_transcript async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): logger.info("Starting convert to mp3") @@ -541,6 +560,12 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): # =================================================================== +@shared_task +@asynctask +async def task_pipeline_waveform(*, transcript_id: str): + await pipeline_waveform(transcript_id=transcript_id) + + @shared_task @asynctask async def task_pipeline_convert_to_mp3(*, transcript_id: str): @@ -576,7 +601,8 @@ def pipeline_post(*, transcript_id: str): Run the post pipeline """ chain_mp3_and_diarize = ( - task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + task_pipeline_waveform.si(transcript_id=transcript_id) + | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 7496b26c..125aa311 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta from typing import Annotated, Optional +import httpx import reflector.auth as auth from fastapi import ( APIRouter, Depends, HTTPException, Request, + Response, WebSocket, WebSocketDisconnect, status, @@ -234,10 +236,22 @@ async def transcript_get_audio_mp3( raise HTTPException(status_code=404, detail="Transcript not found") if transcript.audio_location == "storage": - url = transcript.get_audio_url() - from fastapi.responses import RedirectResponse + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} - return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND) + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") @@ -263,7 +277,7 @@ async def transcript_get_audio_waveform( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - if not transcript.audio_mp3_filename.exists(): + if not transcript.audio_waveform_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") return transcript.audio_waveform