From 4da890b95fc2952e3ad8b679036e9a88632ea09a Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 1 Nov 2023 11:55:46 +0100 Subject: [PATCH] server: add dummy diarization and fixes instanciation --- .../reflector/pipelines/main_live_pipeline.py | 25 ++++++++++++++++--- server/reflector/pipelines/runner.py | 7 ++---- ...arization_base.py => audio_diarization.py} | 2 +- .../processors/audio_diarization_auto.py | 7 +++--- .../processors/audio_diarization_modal.py | 4 +-- .../processors/audio_transcript_auto.py | 3 +-- server/tests/conftest.py | 25 ++++++++++++++++++- server/tests/test_transcripts_rtc_ws.py | 2 ++ 8 files changed, 57 insertions(+), 18 deletions(-) rename server/reflector/processors/{audio_diarization_base.py => audio_diarization.py} (95%) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 87e2ff46..88e1bffd 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -28,6 +28,7 @@ from reflector.db.transcripts import ( TranscriptTopic, transcripts_controller, ) +from reflector.logger import logger from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, @@ -241,7 +242,7 @@ class PipelineMainLive(PipelineMainBase): AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.get_instance().as_threaded(), + AudioTranscriptAutoProcessor.as_threaded(), TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), @@ -255,11 +256,18 @@ class PipelineMainLive(PipelineMainBase): pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main live created", + transcript_id=self.transcript_id, + ) return pipeline async def on_ended(self): # when the pipeline ends, connect to the post pipeline + logger.info("Pipeline main live ended", transcript_id=self.transcript_id) + logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) task_pipeline_main_post.delay(transcript_id=self.transcript_id) @@ -274,7 +282,7 @@ class PipelineMainDiarization(PipelineMainBase): # add a customised logger to the context self.prepare() processors = [ - AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic), + AudioDiarizationAutoProcessor(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -313,7 +321,10 @@ class PipelineMainDiarization(PipelineMainBase): {"sub": transcript.user_id}, expires_delta=timedelta(minutes=15), ) - path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id) + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=transcript.id, + ) url = f"{settings.BASE_URL}{path}?token={token}" audio_diarization_input = AudioDiarizationInput( audio_url=url, @@ -322,6 +333,10 @@ class PipelineMainDiarization(PipelineMainBase): # as tempting to use pipeline.push, prefer to use the runner # to let the start just do one job. + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main post created", transcript_id=self.transcript_id + ) self.push(audio_diarization_input) self.flush() @@ -330,5 +345,9 @@ class PipelineMainDiarization(PipelineMainBase): @shared_task def task_pipeline_main_post(transcript_id: str): + logger.info( + "Starting main post pipeline", + transcript_id=transcript_id, + ) runner = PipelineMainDiarization(transcript_id=transcript_id) runner.start_sync() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 583cdcb6..a1e137a7 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -55,11 +55,8 @@ class PipelineRunner(BaseModel): """ Start the pipeline synchronously (for non-asyncio apps) """ - loop = asyncio.get_event_loop() - if not loop: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self.run()) + coro = self.run() + asyncio.run(coro) def push(self, data): """ diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization.py similarity index 95% rename from server/reflector/processors/audio_diarization_base.py rename to server/reflector/processors/audio_diarization.py index 2ad7e4bf..d69f4b80 100644 --- a/server/reflector/processors/audio_diarization_base.py +++ b/server/reflector/processors/audio_diarization.py @@ -2,7 +2,7 @@ from reflector.processors.base import Processor from reflector.processors.types import AudioDiarizationInput, TitleSummary -class AudioDiarizationBaseProcessor(Processor): +class AudioDiarizationProcessor(Processor): INPUT_TYPE = AudioDiarizationInput OUTPUT_TYPE = TitleSummary diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py index 1de19b45..0e7bfc5c 100644 --- a/server/reflector/processors/audio_diarization_auto.py +++ b/server/reflector/processors/audio_diarization_auto.py @@ -1,18 +1,17 @@ import importlib -from reflector.processors.base import Processor +from reflector.processors.audio_diarization import AudioDiarizationProcessor from reflector.settings import settings -class AudioDiarizationAutoProcessor(Processor): +class AudioDiarizationAutoProcessor(AudioDiarizationProcessor): _registry = {} @classmethod def register(cls, name, kclass): cls._registry[name] = kclass - @classmethod - def get_instance(cls, name: str | None = None, **kwargs): + def __new__(cls, name: str | None = None, **kwargs): if name is None: name = settings.DIARIZATION_BACKEND diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index b71dbcc9..52be7c5d 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -1,11 +1,11 @@ import httpx +from reflector.processors.audio_diarization import AudioDiarizationProcessor from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor -from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor from reflector.processors.types import AudioDiarizationInput, TitleSummary from reflector.settings import settings -class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor): +class AudioDiarizationModalProcessor(AudioDiarizationProcessor): INPUT_TYPE = AudioDiarizationInput OUTPUT_TYPE = TitleSummary diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index fc1f0b5e..ac79ced0 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -11,8 +11,7 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): def register(cls, name, kclass): cls._registry[name] = kclass - @classmethod - def get_instance(cls, name: str | None = None, **kwargs): + def __new__(cls, name: str | None = None, **kwargs): if name is None: name = settings.TRANSCRIPT_BACKEND if name not in cls._registry: diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d5f5f0b9..aafca9fd 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -60,12 +60,35 @@ async def dummy_transcript(): with patch( "reflector.processors.audio_transcript_auto" - ".AudioTranscriptAutoProcessor.get_instance" + ".AudioTranscriptAutoProcessor.__new__" ) as mock_audio: mock_audio.return_value = TestAudioTranscriptProcessor() yield +@pytest.fixture +async def dummy_diarization(): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + + class TestAudioDiarizationProcessor(AudioDiarizationProcessor): + _time_idx = 0 + + async def _diarize(self, data): + i = self._time_idx + self._time_idx += 2 + return [ + {"start": i, "end": i + 1, "speaker": 0}, + {"start": i + 1, "end": i + 2, "speaker": 1}, + ] + + with patch( + "reflector.processors.audio_diarization_auto" + ".AudioDiarizationAutoProcessor.__new__" + ) as mock_audio: + mock_audio.return_value = TestAudioDiarizationProcessor() + yield + + @pytest.fixture async def dummy_llm(): from reflector.llm.base import LLM diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 5a9a404b..8f8cac71 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -65,6 +65,7 @@ async def test_transcript_rtc_and_websocket( dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize, @@ -204,6 +205,7 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize,