From d8a842f099091ad1ad19b934a4ff1fadb3003a95 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 27 Oct 2023 20:00:07 +0200 Subject: [PATCH] server: full diarization processor implementation based on gokul app --- server/reflector/app.py | 3 + .../reflector/pipelines/main_live_pipeline.py | 30 +++++---- server/reflector/pipelines/runner.py | 6 +- server/reflector/processors/__init__.py | 2 +- .../reflector/processors/audio_diarization.py | 65 ------------------- .../processors/audio_diarization_auto.py | 34 ++++++++++ .../processors/audio_diarization_base.py | 28 ++++++++ .../processors/audio_diarization_modal.py | 36 ++++++++++ .../processors/audio_transcript_auto.py | 34 ++-------- server/reflector/processors/types.py | 2 +- server/reflector/settings.py | 10 +++ .../tools/start_post_main_live_pipeline.py | 14 ++++ server/reflector/views/transcripts.py | 29 ++++++++- server/reflector/worker/app.py | 1 + server/tests/test_transcripts_rtc_ws.py | 2 + 15 files changed, 186 insertions(+), 110 deletions(-) delete mode 100644 server/reflector/processors/audio_diarization.py create mode 100644 server/reflector/processors/audio_diarization_auto.py create mode 100644 server/reflector/processors/audio_diarization_base.py create mode 100644 server/reflector/processors/audio_diarization_modal.py create mode 100644 server/reflector/tools/start_post_main_live_pipeline.py diff --git a/server/reflector/app.py b/server/reflector/app.py index 758faf69..c2e3bf7e 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) +# prepare celery +from reflector.worker import app as celery_app # noqa + # simpler openapi id def use_route_names_as_operation_ids(app: FastAPI) -> None: diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 4159c889..87e2ff46 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -13,10 +13,12 @@ It is directly linked to our data model. import asyncio from contextlib import asynccontextmanager +from datetime import timedelta from pathlib import Path from celery import shared_task from pydantic import BaseModel +from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptFinalLongSummary, @@ -29,7 +31,7 @@ from reflector.db.transcripts import ( from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, - AudioDiarizationProcessor, + AudioDiarizationAutoProcessor, AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, @@ -45,6 +47,7 @@ from reflector.processors import ( from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import TitleSummary as TitleSummaryProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType +from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager @@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner): async with self.transaction(): transcript = await self.get_transcript() if not transcript.title: - transcripts_controller.update( + await transcripts_controller.update( transcript, { "title": final_title.title, @@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase): AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), + AudioTranscriptAutoProcessor.get_instance().as_threaded(), TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), ] ), ] @@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase): # add a customised logger to the context self.prepare() processors = [ - AudioDiarizationProcessor(), + AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase): for topic in transcript.topics ] + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from reflector.views.transcripts import create_access_token + + token = create_access_token( + {"sub": transcript.user_id}, + expires_delta=timedelta(minutes=15), + ) + path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id) + url = f"{settings.BASE_URL}{path}?token={token}" audio_diarization_input = AudioDiarizationInput( - audio_filename=transcript.audio_mp3_filename, + audio_url=url, topics=topics, ) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 0575cf96..583cdcb6 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -55,7 +55,11 @@ class PipelineRunner(BaseModel): """ Start the pipeline synchronously (for non-asyncio apps) """ - asyncio.run(self.run()) + loop = asyncio.get_event_loop() + if not loop: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.run()) def push(self, data): """ diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 01a3a174..1c88d6c5 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,5 +1,5 @@ from .audio_chunker import AudioChunkerProcessor # noqa: F401 -from .audio_diarization import AudioDiarizationProcessor # noqa: F401 +from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py deleted file mode 100644 index 8db8e8e5..00000000 --- a/server/reflector/processors/audio_diarization.py +++ /dev/null @@ -1,65 +0,0 @@ -from reflector.processors.base import Processor -from reflector.processors.types import AudioDiarizationInput, TitleSummary - - -class AudioDiarizationProcessor(Processor): - INPUT_TYPE = AudioDiarizationInput - OUTPUT_TYPE = TitleSummary - - async def _push(self, data: AudioDiarizationInput): - # Gather diarization data - diarization = [ - {"start": 0.0, "stop": 4.9, "speaker": 2}, - {"start": 5.6, "stop": 6.7, "speaker": 2}, - {"start": 7.3, "stop": 8.9, "speaker": 2}, - {"start": 7.3, "stop": 7.9, "speaker": 0}, - {"start": 9.4, "stop": 11.2, "speaker": 2}, - {"start": 9.7, "stop": 10.0, "speaker": 0}, - {"start": 10.0, "stop": 10.1, "speaker": 0}, - {"start": 11.7, "stop": 16.1, "speaker": 2}, - {"start": 11.8, "stop": 12.1, "speaker": 1}, - {"start": 16.4, "stop": 21.0, "speaker": 2}, - {"start": 21.1, "stop": 22.6, "speaker": 2}, - {"start": 24.7, "stop": 31.9, "speaker": 2}, - {"start": 32.0, "stop": 32.8, "speaker": 1}, - {"start": 33.4, "stop": 37.8, "speaker": 2}, - {"start": 37.9, "stop": 40.3, "speaker": 0}, - {"start": 39.2, "stop": 40.4, "speaker": 2}, - {"start": 40.7, "stop": 41.4, "speaker": 0}, - {"start": 41.6, "stop": 45.7, "speaker": 2}, - {"start": 46.4, "stop": 53.1, "speaker": 2}, - {"start": 53.6, "stop": 56.5, "speaker": 2}, - {"start": 54.9, "stop": 75.4, "speaker": 1}, - {"start": 57.3, "stop": 58.0, "speaker": 2}, - {"start": 65.7, "stop": 66.0, "speaker": 2}, - {"start": 75.8, "stop": 78.8, "speaker": 1}, - {"start": 79.0, "stop": 82.6, "speaker": 1}, - {"start": 83.2, "stop": 83.3, "speaker": 1}, - {"start": 84.5, "stop": 94.3, "speaker": 1}, - {"start": 95.1, "stop": 100.7, "speaker": 1}, - {"start": 100.7, "stop": 102.0, "speaker": 0}, - {"start": 100.7, "stop": 101.8, "speaker": 1}, - {"start": 102.0, "stop": 103.0, "speaker": 1}, - {"start": 103.0, "stop": 103.7, "speaker": 0}, - {"start": 103.7, "stop": 103.8, "speaker": 1}, - {"start": 103.8, "stop": 113.9, "speaker": 0}, - {"start": 114.7, "stop": 117.0, "speaker": 0}, - {"start": 117.0, "stop": 117.4, "speaker": 1}, - ] - - # now reapply speaker to topics (if any) - # topics is a list[BaseModel] with an attribute words - # words is a list[BaseModel] with text, start and speaker attribute - - print("IN DIARIZATION PROCESSOR", data) - - # mutate in place - for topic in data.topics: - for word in topic.transcript.words: - for d in diarization: - if d["start"] <= word.start <= d["stop"]: - word.speaker = d["speaker"] - - # emit them - for topic in data.topics: - await self.emit(topic) diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py new file mode 100644 index 00000000..1de19b45 --- /dev/null +++ b/server/reflector/processors/audio_diarization_auto.py @@ -0,0 +1,34 @@ +import importlib + +from reflector.processors.base import Processor +from reflector.settings import settings + + +class AudioDiarizationAutoProcessor(Processor): + _registry = {} + + @classmethod + def register(cls, name, kclass): + cls._registry[name] = kclass + + @classmethod + def get_instance(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.DIARIZATION_BACKEND + + if name not in cls._registry: + module_name = f"reflector.processors.audio_diarization_{name}" + importlib.import_module(module_name) + + # gather specific configuration for the processor + # search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` + config = {} + name_upper = name.upper() + settings_prefix = "DIARIZATION_" + config_prefix = f"{settings_prefix}{name_upper}_" + for key, value in settings: + if key.startswith(config_prefix): + config_name = key[len(settings_prefix) :].lower() + config[config_name] = value + + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization_base.py new file mode 100644 index 00000000..2ad7e4bf --- /dev/null +++ b/server/reflector/processors/audio_diarization_base.py @@ -0,0 +1,28 @@ +from reflector.processors.base import Processor +from reflector.processors.types import AudioDiarizationInput, TitleSummary + + +class AudioDiarizationBaseProcessor(Processor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: AudioDiarizationInput): + diarization = await self._diarize(data) + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + # mutate in place + for topic in data.topics: + for word in topic.transcript.words: + for d in diarization: + if d["start"] <= word.start <= d["end"]: + word.speaker = d["speaker"] + + # emit them + for topic in data.topics: + await self.emit(topic) + + async def _diarize(self, data: AudioDiarizationInput): + raise NotImplementedError diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py new file mode 100644 index 00000000..b71dbcc9 --- /dev/null +++ b/server/reflector/processors/audio_diarization_modal.py @@ -0,0 +1,36 @@ +import httpx +from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor +from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor +from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.settings import settings + + +class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.diarization_url = settings.DIARIZATION_URL + "/diarize" + self.headers = { + "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", + } + + async def _diarize(self, data: AudioDiarizationInput): + # Gather diarization data + params = { + "audio_file_url": data.audio_url, + "timestamp": 0, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.diarization_url, + headers=self.headers, + params=params, + timeout=None, + ) + response.raise_for_status() + return response.json()["text"] + + +AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index f223a52d..fc1f0b5e 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -1,8 +1,6 @@ import importlib from reflector.processors.audio_transcript import AudioTranscriptProcessor -from reflector.processors.base import Pipeline, Processor -from reflector.processors.types import AudioFile from reflector.settings import settings @@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): cls._registry[name] = kclass @classmethod - def get_instance(cls, name): + def get_instance(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.TRANSCRIPT_BACKEND if name not in cls._registry: module_name = f"reflector.processors.audio_transcript_{name}" importlib.import_module(module_name) @@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): config_name = key[len(settings_prefix) :].lower() config[config_name] = value - return cls._registry[name](**config) - - def __init__(self, **kwargs): - self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) - super().__init__(**kwargs) - - def set_pipeline(self, pipeline: Pipeline): - super().set_pipeline(pipeline) - self.processor.set_pipeline(pipeline) - - def connect(self, processor: Processor): - self.processor.connect(processor) - - def disconnect(self, processor: Processor): - self.processor.disconnect(processor) - - def on(self, callback): - self.processor.on(callback) - - def off(self, callback): - self.processor.off(callback) - - async def _push(self, data: AudioFile): - return await self.processor._push(data) - - async def _flush(self): - return await self.processor._flush() + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 3ec21491..b67f84b9 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel): class AudioDiarizationInput(BaseModel): - audio_filename: Path + audio_url: str topics: list[TitleSummary] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index d7cc2c33..021d509f 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -89,6 +89,10 @@ class Settings(BaseSettings): # LLM Modal configuration LLM_MODAL_API_KEY: str | None = None + # Diarization + DIARIZATION_BACKEND: str = "modal" + DIARIZATION_URL: str | None = None + # Sentry SENTRY_DSN: str | None = None @@ -121,5 +125,11 @@ class Settings(BaseSettings): REDIS_HOST: str = "localhost" REDIS_PORT: int = 6379 + # Secret key + SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21" + + # Current hosting/domain + BASE_URL: str = "http://localhost:1250" + settings = Settings() diff --git a/server/reflector/tools/start_post_main_live_pipeline.py b/server/reflector/tools/start_post_main_live_pipeline.py new file mode 100644 index 00000000..859f03a4 --- /dev/null +++ b/server/reflector/tools/start_post_main_live_pipeline.py @@ -0,0 +1,14 @@ +import argparse + +from reflector.app import celery_app # noqa +from reflector.pipelines.main_live_pipeline import task_pipeline_main_post + +parser = argparse.ArgumentParser() +parser.add_argument("transcript_id", type=str) +parser.add_argument("--delay", action="store_true") +args = parser.parse_args() + +if args.delay: + task_pipeline_main_post.delay(args.transcript_id) +else: + task_pipeline_main_post(args.transcript_id) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 31cbe28e..f83bc6de 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from typing import Annotated, Optional import reflector.auth as auth @@ -9,8 +9,10 @@ from fastapi import ( Request, WebSocket, WebSocketDisconnect, + status, ) from fastapi_pagination import Page, paginate +from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( AudioWaveform, @@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() +ALGORITHM = "HS256" +DOWNLOAD_EXPIRE_MINUTES = 60 + + +def create_access_token(data: dict, expires_delta: timedelta): + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + # ============================================================== # Transcripts list # ============================================================== @@ -198,8 +212,21 @@ async def transcript_get_audio_mp3( request: Request, transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, ): user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index 3714a64d..e1000364 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -4,6 +4,7 @@ from reflector.settings import settings app = Celery(__name__) app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.result_backend = settings.CELERY_RESULT_BACKEND +app.conf.broker_connection_retry_on_startup = True app.autodiscover_tasks( [ "reflector.pipelines.main_live_pipeline", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index e2bfee32..5a9a404b 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse @@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse