mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: add dummy diarization and fixes instanciation
This commit is contained in:
@@ -28,6 +28,7 @@ from reflector.db.transcripts import (
|
|||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.runner import PipelineRunner
|
from reflector.pipelines.runner import PipelineRunner
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerProcessor,
|
||||||
@@ -241,7 +242,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.get_instance().as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
@@ -255,11 +256,18 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
pipeline.options = self
|
pipeline.options = self
|
||||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
|
pipeline.logger.info(
|
||||||
|
"Pipeline main live created",
|
||||||
|
transcript_id=self.transcript_id,
|
||||||
|
)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
async def on_ended(self):
|
async def on_ended(self):
|
||||||
# when the pipeline ends, connect to the post pipeline
|
# when the pipeline ends, connect to the post pipeline
|
||||||
|
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||||
|
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -274,7 +282,7 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
self.prepare()
|
||||||
processors = [
|
processors = [
|
||||||
AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
|
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
@@ -313,7 +321,10 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
{"sub": transcript.user_id},
|
{"sub": transcript.user_id},
|
||||||
expires_delta=timedelta(minutes=15),
|
expires_delta=timedelta(minutes=15),
|
||||||
)
|
)
|
||||||
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
|
path = app.url_path_for(
|
||||||
|
"transcript_get_audio_mp3",
|
||||||
|
transcript_id=transcript.id,
|
||||||
|
)
|
||||||
url = f"{settings.BASE_URL}{path}?token={token}"
|
url = f"{settings.BASE_URL}{path}?token={token}"
|
||||||
audio_diarization_input = AudioDiarizationInput(
|
audio_diarization_input = AudioDiarizationInput(
|
||||||
audio_url=url,
|
audio_url=url,
|
||||||
@@ -322,6 +333,10 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
|
|
||||||
# as tempting to use pipeline.push, prefer to use the runner
|
# as tempting to use pipeline.push, prefer to use the runner
|
||||||
# to let the start just do one job.
|
# to let the start just do one job.
|
||||||
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
|
pipeline.logger.info(
|
||||||
|
"Pipeline main post created", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
self.push(audio_diarization_input)
|
self.push(audio_diarization_input)
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
@@ -330,5 +345,9 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def task_pipeline_main_post(transcript_id: str):
|
def task_pipeline_main_post(transcript_id: str):
|
||||||
|
logger.info(
|
||||||
|
"Starting main post pipeline",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||||
runner.start_sync()
|
runner.start_sync()
|
||||||
|
|||||||
@@ -55,11 +55,8 @@ class PipelineRunner(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Start the pipeline synchronously (for non-asyncio apps)
|
Start the pipeline synchronously (for non-asyncio apps)
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
coro = self.run()
|
||||||
if not loop:
|
asyncio.run(coro)
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(self.run())
|
|
||||||
|
|
||||||
def push(self, data):
|
def push(self, data):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from reflector.processors.base import Processor
|
|||||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||||
|
|
||||||
|
|
||||||
class AudioDiarizationBaseProcessor(Processor):
|
class AudioDiarizationProcessor(Processor):
|
||||||
INPUT_TYPE = AudioDiarizationInput
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
OUTPUT_TYPE = TitleSummary
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
@@ -1,18 +1,17 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class AudioDiarizationAutoProcessor(Processor):
|
class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
|
||||||
_registry = {}
|
_registry = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name, kclass):
|
def register(cls, name, kclass):
|
||||||
cls._registry[name] = kclass
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
@classmethod
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
def get_instance(cls, name: str | None = None, **kwargs):
|
|
||||||
if name is None:
|
if name is None:
|
||||||
name = settings.DIARIZATION_BACKEND
|
name = settings.DIARIZATION_BACKEND
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import httpx
|
import httpx
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||||
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
|
|
||||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
|
class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
|
||||||
INPUT_TYPE = AudioDiarizationInput
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
OUTPUT_TYPE = TitleSummary
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
|||||||
def register(cls, name, kclass):
|
def register(cls, name, kclass):
|
||||||
cls._registry[name] = kclass
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
@classmethod
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
def get_instance(cls, name: str | None = None, **kwargs):
|
|
||||||
if name is None:
|
if name is None:
|
||||||
name = settings.TRANSCRIPT_BACKEND
|
name = settings.TRANSCRIPT_BACKEND
|
||||||
if name not in cls._registry:
|
if name not in cls._registry:
|
||||||
|
|||||||
@@ -60,12 +60,35 @@ async def dummy_transcript():
|
|||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.processors.audio_transcript_auto"
|
"reflector.processors.audio_transcript_auto"
|
||||||
".AudioTranscriptAutoProcessor.get_instance"
|
".AudioTranscriptAutoProcessor.__new__"
|
||||||
) as mock_audio:
|
) as mock_audio:
|
||||||
mock_audio.return_value = TestAudioTranscriptProcessor()
|
mock_audio.return_value = TestAudioTranscriptProcessor()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_diarization():
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
|
|
||||||
|
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
|
||||||
|
_time_idx = 0
|
||||||
|
|
||||||
|
async def _diarize(self, data):
|
||||||
|
i = self._time_idx
|
||||||
|
self._time_idx += 2
|
||||||
|
return [
|
||||||
|
{"start": i, "end": i + 1, "speaker": 0},
|
||||||
|
{"start": i + 1, "end": i + 2, "speaker": 1},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.audio_diarization_auto"
|
||||||
|
".AudioDiarizationAutoProcessor.__new__"
|
||||||
|
) as mock_audio:
|
||||||
|
mock_audio.return_value = TestAudioDiarizationProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_llm():
|
async def dummy_llm():
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
|
dummy_diarization,
|
||||||
ensure_casing,
|
ensure_casing,
|
||||||
appserver,
|
appserver,
|
||||||
sentence_tokenize,
|
sentence_tokenize,
|
||||||
@@ -204,6 +205,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
|
dummy_diarization,
|
||||||
ensure_casing,
|
ensure_casing,
|
||||||
appserver,
|
appserver,
|
||||||
sentence_tokenize,
|
sentence_tokenize,
|
||||||
|
|||||||
Reference in New Issue
Block a user