server: add dummy diarization and fixes instanciation

This commit is contained in:
2023-11-01 11:55:46 +01:00
committed by Mathieu Virbel
parent d0057ae2c4
commit 4da890b95f
8 changed files with 57 additions and 18 deletions

View File

@@ -28,6 +28,7 @@ from reflector.db.transcripts import (
TranscriptTopic, TranscriptTopic,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner from reflector.pipelines.runner import PipelineRunner
from reflector.processors import ( from reflector.processors import (
AudioChunkerProcessor, AudioChunkerProcessor,
@@ -241,7 +242,7 @@ class PipelineMainLive(PipelineMainBase):
AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.get_instance().as_threaded(), AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
@@ -255,11 +256,18 @@ class PipelineMainLive(PipelineMainBase):
pipeline.options = self pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main live created",
transcript_id=self.transcript_id,
)
return pipeline return pipeline
async def on_ended(self): async def on_ended(self):
# when the pipeline ends, connect to the post pipeline # when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id) task_pipeline_main_post.delay(transcript_id=self.transcript_id)
@@ -274,7 +282,7 @@ class PipelineMainDiarization(PipelineMainBase):
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [ processors = [
AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic), AudioDiarizationAutoProcessor(callback=self.on_topic),
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -313,7 +321,10 @@ class PipelineMainDiarization(PipelineMainBase):
{"sub": transcript.user_id}, {"sub": transcript.user_id},
expires_delta=timedelta(minutes=15), expires_delta=timedelta(minutes=15),
) )
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id) path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=transcript.id,
)
url = f"{settings.BASE_URL}{path}?token={token}" url = f"{settings.BASE_URL}{path}?token={token}"
audio_diarization_input = AudioDiarizationInput( audio_diarization_input = AudioDiarizationInput(
audio_url=url, audio_url=url,
@@ -322,6 +333,10 @@ class PipelineMainDiarization(PipelineMainBase):
# as tempting to use pipeline.push, prefer to use the runner # as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job. # to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input) self.push(audio_diarization_input)
self.flush() self.flush()
@@ -330,5 +345,9 @@ class PipelineMainDiarization(PipelineMainBase):
@shared_task @shared_task
def task_pipeline_main_post(transcript_id: str): def task_pipeline_main_post(transcript_id: str):
logger.info(
"Starting main post pipeline",
transcript_id=transcript_id,
)
runner = PipelineMainDiarization(transcript_id=transcript_id) runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync() runner.start_sync()

View File

@@ -55,11 +55,8 @@ class PipelineRunner(BaseModel):
""" """
Start the pipeline synchronously (for non-asyncio apps) Start the pipeline synchronously (for non-asyncio apps)
""" """
loop = asyncio.get_event_loop() coro = self.run()
if not loop: asyncio.run(coro)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.run())
def push(self, data): def push(self, data):
""" """

View File

@@ -2,7 +2,7 @@ from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationBaseProcessor(Processor): class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary OUTPUT_TYPE = TitleSummary

View File

@@ -1,18 +1,17 @@
import importlib import importlib
from reflector.processors.base import Processor from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.settings import settings from reflector.settings import settings
class AudioDiarizationAutoProcessor(Processor): class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
_registry = {} _registry = {}
@classmethod @classmethod
def register(cls, name, kclass): def register(cls, name, kclass):
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod def __new__(cls, name: str | None = None, **kwargs):
def get_instance(cls, name: str | None = None, **kwargs):
if name is None: if name is None:
name = settings.DIARIZATION_BACKEND name = settings.DIARIZATION_BACKEND

View File

@@ -1,11 +1,11 @@
import httpx import httpx
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings from reflector.settings import settings
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor): class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
INPUT_TYPE = AudioDiarizationInput INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary OUTPUT_TYPE = TitleSummary

View File

@@ -11,8 +11,7 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
def register(cls, name, kclass): def register(cls, name, kclass):
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod def __new__(cls, name: str | None = None, **kwargs):
def get_instance(cls, name: str | None = None, **kwargs):
if name is None: if name is None:
name = settings.TRANSCRIPT_BACKEND name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry: if name not in cls._registry:

View File

@@ -60,12 +60,35 @@ async def dummy_transcript():
with patch( with patch(
"reflector.processors.audio_transcript_auto" "reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance" ".AudioTranscriptAutoProcessor.__new__"
) as mock_audio: ) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor() mock_audio.return_value = TestAudioTranscriptProcessor()
yield yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture @pytest.fixture
async def dummy_llm(): async def dummy_llm():
from reflector.llm.base import LLM from reflector.llm.base import LLM

View File

@@ -65,6 +65,7 @@ async def test_transcript_rtc_and_websocket(
dummy_llm, dummy_llm,
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,
@@ -204,6 +205,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_llm, dummy_llm,
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,