From aecc3a0c3bc7b4c6daf9474f8ef95a21726c22fc Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 15 Nov 2023 21:24:21 +0100 Subject: [PATCH] server: first attempts to split post pipeline as single celery tasks --- server/reflector/db/transcripts.py | 19 ++ .../reflector/pipelines/main_live_pipeline.py | 239 +++++++++++++++--- server/reflector/storage/base.py | 11 +- server/reflector/storage/storage_aws.py | 20 +- 4 files changed, 241 insertions(+), 48 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index f0dbc277..c0e8984b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -106,6 +106,7 @@ class Transcript(BaseModel): events: list[TranscriptEvent] = [] source_language: str = "en" target_language: str = "en" + audio_location: str = "local" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -140,6 +141,10 @@ class Transcript(BaseModel): def audio_waveform_filename(self): return self.data_path / "audio.json" + @property + def storage_audio_path(self): + return f"{self.id}/audio.mp3" + @property def audio_waveform(self): try: @@ -283,5 +288,19 @@ class TranscriptController: transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) + async def move_mp3_to_storage(self, transcript: Transcript): + """ + Move mp3 file to storage + """ + from reflector.storage import Storage + + storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) + await storage.put_file( + transcript.storage_audio_path, + self.audio_mp3_filename.read_bytes(), + ) + + await self.update(transcript, {"audio_location": "storage"}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 3a9d1868..e2f305c4 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,6 +12,7 @@ It is directly linked to our data model. """ import asyncio +import functools from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path @@ -55,6 +56,22 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager +from structlog import Logger + + +def asynctask(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + coro = f(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(coro) + return asyncio.run(coro) + + return wrapper def broadcast_to_sockets(func): @@ -75,6 +92,22 @@ def broadcast_to_sockets(func): return wrapper +def get_transcript(func): + """ + Decorator to fetch the transcript from the database from the first argument + """ + + async def wrapper(self, **kwargs): + transcript_id = kwargs.pop("transcript_id") + transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) + if not transcript: + raise Exception("Transcript {transcript_id} not found") + tlogger = logger.bind(transcript_id=transcript.id) + return await func(self, transcript=transcript, logger=tlogger, **kwargs) + + return wrapper + + class StrValue(BaseModel): value: str @@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: + return [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + @asynccontextmanager async def transaction(self): async with self._lock: @@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase): pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main live created", - transcript_id=self.transcript_id, - ) + pipeline.logger.info("Pipeline main live created") return pipeline @@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase): # when the pipeline ends, connect to the post pipeline logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) - task_pipeline_main_post.delay(transcript_id=self.transcript_id) + pipeline_post(transcript_id=self.transcript_id) class PipelineMainDiarization(PipelineMainBase): """ - Diarization is a long time process, so we do it in a separate pipeline - When done, adjust the short and final summary + Diarize the audio and update topics """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [] - if settings.DIARIZATION_ENABLED: - processors += [ - AudioDiarizationAutoProcessor(callback=self.on_topic), - ] - - processors += [ - BroadcastProcessor( - processors=[ - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), - ] - ), - ] - pipeline = Pipeline(*processors) + pipeline = Pipeline( + AudioDiarizationAutoProcessor(callback=self.on_topic), + ) pipeline.options = self # now let's start the pipeline by pushing information to the # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() - topics = [ - TitleSummaryWithIdProcessorType( - id=topic.id, - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=topic.duration, - transcript=TranscriptProcessorType(words=topic.words), - ) - for topic in transcript.topics - ] + topics = self.get_transcript_topics(transcript) # we need to create an url to be used for diarization # we can't use the audio_mp3_filename because it's not accessible @@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase): # as tempting to use pipeline.push, prefer to use the runner # to let the start just do one job. pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main post created", transcript_id=self.transcript_id - ) + pipeline.logger.info("Diarization pipeline created") self.push(audio_diarization_input) self.flush() return pipeline +class PipelineMainSummaries(PipelineMainBase): + """ + Generate summaries from the topics + """ + + async def create(self) -> Pipeline: + self.prepare() + pipeline = Pipeline( + BroadcastProcessor( + processors=[ + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), + ) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info("Summaries pipeline created") + + # push topics + topics = await self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + @shared_task def task_pipeline_main_post(transcript_id: str): logger.info( @@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str): ) runner = PipelineMainDiarization(transcript_id=transcript_id) runner.start_sync() + + +@get_transcript +async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting convert to mp3") + + # If the audio wav is not available, just skip + wav_filename = transcript.audio_wav_filename + if not wav_filename.exists(): + logger.warning("Wav file not found, may be already converted") + return + + # Convert to mp3 + mp3_filename = transcript.audio_mp3_filename + + import av + + input_container = av.open(wav_filename) + output_container = av.open(mp3_filename, "w") + input_audio_stream = input_container.streams.audio[0] + output_audio_stream = output_container.add_stream("mp3") + output_audio_stream.codec_context.set_parameters( + input_audio_stream.codec_context.parameters + ) + for packet in input_container.demux(input_audio_stream): + for frame in packet.decode(): + output_container.mux(frame) + input_container.close() + output_container.close() + + logger.info("Convert to mp3 done") + + +@get_transcript +async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting upload mp3") + + # If the audio mp3 is not available, just skip + mp3_filename = transcript.audio_mp3_filename + if not mp3_filename.exists(): + logger.warning("Mp3 file not found, may be already uploaded") + return + + # Upload to external storage and delete the file + await transcripts_controller.move_to_storage(transcript) + await transcripts_controller.unlink_mp3(transcript) + + logger.info("Upload mp3 done") + + +@get_transcript +@asynctask +async def pipeline_diarization(transcript: Transcript, logger: Logger): + logger.info("Starting diarization") + runner = PipelineMainDiarization(transcript_id=transcript.id) + await runner.start() + logger.info("Diarization done") + + +@get_transcript +@asynctask +async def pipeline_summaries(transcript: Transcript, logger: Logger): + logger.info("Starting summaries") + runner = PipelineMainSummaries(transcript_id=transcript.id) + await runner.start() + logger.info("Summaries done") + + +# =================================================================== +# Celery tasks that can be called from the API +# =================================================================== + + +@shared_task +@asynctask +async def task_pipeline_convert_to_mp3(transcript_id: str): + await pipeline_convert_to_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_upload_mp3(transcript_id: str): + await pipeline_upload_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_diarization(transcript_id: str): + await pipeline_diarization(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_summaries(transcript_id: str): + await pipeline_summaries(transcript_id) + + +def pipeline_post(transcript_id: str): + """ + Run the post pipeline + """ + chain_mp3_and_diarize = ( + task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + | task_pipeline_upload_mp3.si(transcript_id=transcript_id) + | task_pipeline_diarization.si(transcript_id=transcript_id) + ) + chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) + chain = chain_mp3_and_diarize | chain_summary + chain.delay() diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 5cdafdbf..7c44ff4d 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -1,6 +1,7 @@ +import importlib + from pydantic import BaseModel from reflector.settings import settings -import importlib class FileResult(BaseModel): @@ -17,14 +18,14 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name, settings_prefix=""): + def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) # gather specific configuration for the processor # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` - config = {} + config = {"folder": folder} name_upper = name.upper() config_prefix = f"{settings_prefix}{name_upper}_" for key, value in settings: @@ -34,6 +35,10 @@ class Storage: return cls._registry[name](**config) + def __init__(self): + self.folder = "" + super().__init__() + async def put_file(self, filename: str, data: bytes) -> FileResult: return await self._put_file(filename, data) diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 09a9c383..5ab02903 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -1,6 +1,6 @@ import aioboto3 -from reflector.storage.base import Storage, FileResult from reflector.logger import logger +from reflector.storage.base import FileResult, Storage class AwsStorage(Storage): @@ -22,9 +22,14 @@ class AwsStorage(Storage): super().__init__() self.aws_bucket_name = aws_bucket_name - self.aws_folder = "" + folder = "" if "/" in aws_bucket_name: - self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) + self.aws_bucket_name, folder = aws_bucket_name.split("/", 1) + if folder: + if not self.folder: + self.folder = folder + else: + self.folder = f"{self.folder}/{folder}" self.session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -34,7 +39,7 @@ class AwsStorage(Storage): async def _put_file(self, filename: str, data: bytes) -> FileResult: bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: @@ -44,6 +49,11 @@ class AwsStorage(Storage): Body=data, ) + async def get_file_url(self, filename: str) -> FileResult: + bucket = self.aws_bucket_name + folder = self.folder + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( "get_object", Params={"Bucket": bucket, "Key": s3filename}, @@ -57,7 +67,7 @@ class AwsStorage(Storage): async def _delete_file(self, filename: str): bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: