From aecc3a0c3bc7b4c6daf9474f8ef95a21726c22fc Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 15 Nov 2023 21:24:21 +0100 Subject: [PATCH 1/7] server: first attempts to split post pipeline as single celery tasks --- server/reflector/db/transcripts.py | 19 ++ .../reflector/pipelines/main_live_pipeline.py | 239 +++++++++++++++--- server/reflector/storage/base.py | 11 +- server/reflector/storage/storage_aws.py | 20 +- 4 files changed, 241 insertions(+), 48 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index f0dbc277..c0e8984b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -106,6 +106,7 @@ class Transcript(BaseModel): events: list[TranscriptEvent] = [] source_language: str = "en" target_language: str = "en" + audio_location: str = "local" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -140,6 +141,10 @@ class Transcript(BaseModel): def audio_waveform_filename(self): return self.data_path / "audio.json" + @property + def storage_audio_path(self): + return f"{self.id}/audio.mp3" + @property def audio_waveform(self): try: @@ -283,5 +288,19 @@ class TranscriptController: transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) + async def move_mp3_to_storage(self, transcript: Transcript): + """ + Move mp3 file to storage + """ + from reflector.storage import Storage + + storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) + await storage.put_file( + transcript.storage_audio_path, + self.audio_mp3_filename.read_bytes(), + ) + + await self.update(transcript, {"audio_location": "storage"}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 3a9d1868..e2f305c4 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,6 +12,7 @@ It is directly linked to our data model. """ import asyncio +import functools from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path @@ -55,6 +56,22 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager +from structlog import Logger + + +def asynctask(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + coro = f(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(coro) + return asyncio.run(coro) + + return wrapper def broadcast_to_sockets(func): @@ -75,6 +92,22 @@ def broadcast_to_sockets(func): return wrapper +def get_transcript(func): + """ + Decorator to fetch the transcript from the database from the first argument + """ + + async def wrapper(self, **kwargs): + transcript_id = kwargs.pop("transcript_id") + transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) + if not transcript: + raise Exception("Transcript {transcript_id} not found") + tlogger = logger.bind(transcript_id=transcript.id) + return await func(self, transcript=transcript, logger=tlogger, **kwargs) + + return wrapper + + class StrValue(BaseModel): value: str @@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: + return [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + @asynccontextmanager async def transaction(self): async with self._lock: @@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase): pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main live created", - transcript_id=self.transcript_id, - ) + pipeline.logger.info("Pipeline main live created") return pipeline @@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase): # when the pipeline ends, connect to the post pipeline logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) - task_pipeline_main_post.delay(transcript_id=self.transcript_id) + pipeline_post(transcript_id=self.transcript_id) class PipelineMainDiarization(PipelineMainBase): """ - Diarization is a long time process, so we do it in a separate pipeline - When done, adjust the short and final summary + Diarize the audio and update topics """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [] - if settings.DIARIZATION_ENABLED: - processors += [ - AudioDiarizationAutoProcessor(callback=self.on_topic), - ] - - processors += [ - BroadcastProcessor( - processors=[ - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), - ] - ), - ] - pipeline = Pipeline(*processors) + pipeline = Pipeline( + AudioDiarizationAutoProcessor(callback=self.on_topic), + ) pipeline.options = self # now let's start the pipeline by pushing information to the # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() - topics = [ - TitleSummaryWithIdProcessorType( - id=topic.id, - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=topic.duration, - transcript=TranscriptProcessorType(words=topic.words), - ) - for topic in transcript.topics - ] + topics = self.get_transcript_topics(transcript) # we need to create an url to be used for diarization # we can't use the audio_mp3_filename because it's not accessible @@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase): # as tempting to use pipeline.push, prefer to use the runner # to let the start just do one job. pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main post created", transcript_id=self.transcript_id - ) + pipeline.logger.info("Diarization pipeline created") self.push(audio_diarization_input) self.flush() return pipeline +class PipelineMainSummaries(PipelineMainBase): + """ + Generate summaries from the topics + """ + + async def create(self) -> Pipeline: + self.prepare() + pipeline = Pipeline( + BroadcastProcessor( + processors=[ + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), + ) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info("Summaries pipeline created") + + # push topics + topics = await self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + @shared_task def task_pipeline_main_post(transcript_id: str): logger.info( @@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str): ) runner = PipelineMainDiarization(transcript_id=transcript_id) runner.start_sync() + + +@get_transcript +async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting convert to mp3") + + # If the audio wav is not available, just skip + wav_filename = transcript.audio_wav_filename + if not wav_filename.exists(): + logger.warning("Wav file not found, may be already converted") + return + + # Convert to mp3 + mp3_filename = transcript.audio_mp3_filename + + import av + + input_container = av.open(wav_filename) + output_container = av.open(mp3_filename, "w") + input_audio_stream = input_container.streams.audio[0] + output_audio_stream = output_container.add_stream("mp3") + output_audio_stream.codec_context.set_parameters( + input_audio_stream.codec_context.parameters + ) + for packet in input_container.demux(input_audio_stream): + for frame in packet.decode(): + output_container.mux(frame) + input_container.close() + output_container.close() + + logger.info("Convert to mp3 done") + + +@get_transcript +async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting upload mp3") + + # If the audio mp3 is not available, just skip + mp3_filename = transcript.audio_mp3_filename + if not mp3_filename.exists(): + logger.warning("Mp3 file not found, may be already uploaded") + return + + # Upload to external storage and delete the file + await transcripts_controller.move_to_storage(transcript) + await transcripts_controller.unlink_mp3(transcript) + + logger.info("Upload mp3 done") + + +@get_transcript +@asynctask +async def pipeline_diarization(transcript: Transcript, logger: Logger): + logger.info("Starting diarization") + runner = PipelineMainDiarization(transcript_id=transcript.id) + await runner.start() + logger.info("Diarization done") + + +@get_transcript +@asynctask +async def pipeline_summaries(transcript: Transcript, logger: Logger): + logger.info("Starting summaries") + runner = PipelineMainSummaries(transcript_id=transcript.id) + await runner.start() + logger.info("Summaries done") + + +# =================================================================== +# Celery tasks that can be called from the API +# =================================================================== + + +@shared_task +@asynctask +async def task_pipeline_convert_to_mp3(transcript_id: str): + await pipeline_convert_to_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_upload_mp3(transcript_id: str): + await pipeline_upload_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_diarization(transcript_id: str): + await pipeline_diarization(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_summaries(transcript_id: str): + await pipeline_summaries(transcript_id) + + +def pipeline_post(transcript_id: str): + """ + Run the post pipeline + """ + chain_mp3_and_diarize = ( + task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + | task_pipeline_upload_mp3.si(transcript_id=transcript_id) + | task_pipeline_diarization.si(transcript_id=transcript_id) + ) + chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) + chain = chain_mp3_and_diarize | chain_summary + chain.delay() diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 5cdafdbf..7c44ff4d 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -1,6 +1,7 @@ +import importlib + from pydantic import BaseModel from reflector.settings import settings -import importlib class FileResult(BaseModel): @@ -17,14 +18,14 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name, settings_prefix=""): + def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) # gather specific configuration for the processor # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` - config = {} + config = {"folder": folder} name_upper = name.upper() config_prefix = f"{settings_prefix}{name_upper}_" for key, value in settings: @@ -34,6 +35,10 @@ class Storage: return cls._registry[name](**config) + def __init__(self): + self.folder = "" + super().__init__() + async def put_file(self, filename: str, data: bytes) -> FileResult: return await self._put_file(filename, data) diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 09a9c383..5ab02903 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -1,6 +1,6 @@ import aioboto3 -from reflector.storage.base import Storage, FileResult from reflector.logger import logger +from reflector.storage.base import FileResult, Storage class AwsStorage(Storage): @@ -22,9 +22,14 @@ class AwsStorage(Storage): super().__init__() self.aws_bucket_name = aws_bucket_name - self.aws_folder = "" + folder = "" if "/" in aws_bucket_name: - self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) + self.aws_bucket_name, folder = aws_bucket_name.split("/", 1) + if folder: + if not self.folder: + self.folder = folder + else: + self.folder = f"{self.folder}/{folder}" self.session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -34,7 +39,7 @@ class AwsStorage(Storage): async def _put_file(self, filename: str, data: bytes) -> FileResult: bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: @@ -44,6 +49,11 @@ class AwsStorage(Storage): Body=data, ) + async def get_file_url(self, filename: str) -> FileResult: + bucket = self.aws_bucket_name + folder = self.folder + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( "get_object", Params={"Bucket": bucket, "Key": s3filename}, @@ -57,7 +67,7 @@ class AwsStorage(Storage): async def _delete_file(self, filename: str): bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: From 88f443e2c25cd8bf9158b1c83bb8c76a3813d843 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:32:18 +0100 Subject: [PATCH 2/7] server: revert change on storage folder --- server/reflector/storage/base.py | 14 ++++++++------ server/reflector/storage/storage_aws.py | 22 +++++++--------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 7c44ff4d..a457ddf8 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -18,14 +18,14 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""): + def get_instance(cls, name: str, settings_prefix: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) # gather specific configuration for the processor # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` - config = {"folder": folder} + config = {} name_upper = name.upper() config_prefix = f"{settings_prefix}{name_upper}_" for key, value in settings: @@ -35,10 +35,6 @@ class Storage: return cls._registry[name](**config) - def __init__(self): - self.folder = "" - super().__init__() - async def put_file(self, filename: str, data: bytes) -> FileResult: return await self._put_file(filename, data) @@ -50,3 +46,9 @@ class Storage: async def _delete_file(self, filename: str): raise NotImplementedError + + async def get_file_url(self, filename: str) -> str: + return await self._get_file_url(filename) + + async def _get_file_url(self, filename: str) -> str: + raise NotImplementedError diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 5ab02903..d2313293 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -22,14 +22,9 @@ class AwsStorage(Storage): super().__init__() self.aws_bucket_name = aws_bucket_name - folder = "" + self.aws_folder = "" if "/" in aws_bucket_name: - self.aws_bucket_name, folder = aws_bucket_name.split("/", 1) - if folder: - if not self.folder: - self.folder = folder - else: - self.folder = f"{self.folder}/{folder}" + self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) self.session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -39,7 +34,7 @@ class AwsStorage(Storage): async def _put_file(self, filename: str, data: bytes) -> FileResult: bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: @@ -49,9 +44,9 @@ class AwsStorage(Storage): Body=data, ) - async def get_file_url(self, filename: str) -> FileResult: + async def _get_file_url(self, filename: str) -> FileResult: bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( @@ -60,14 +55,11 @@ class AwsStorage(Storage): ExpiresIn=3600, ) - return FileResult( - filename=filename, - url=presigned_url, - ) + return presigned_url async def _delete_file(self, filename: str): bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: From 06b29d9bd4a1d1b7f6265756be57c5602888673a Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:34:33 +0100 Subject: [PATCH 3/7] server: add audio_location and move to external storage if possible --- .../versions/f819277e5169_audio_location.py | 43 ++++ server/reflector/db/transcripts.py | 67 ++++- .../reflector/pipelines/main_live_pipeline.py | 234 ++++++++++-------- server/reflector/pipelines/runner.py | 3 +- server/reflector/settings.py | 5 +- 5 files changed, 238 insertions(+), 114 deletions(-) create mode 100644 server/migrations/versions/f819277e5169_audio_location.py diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py new file mode 100644 index 00000000..576b02bd --- /dev/null +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -0,0 +1,43 @@ +"""audio_location + +Revision ID: f819277e5169 +Revises: 4814901632bc +Create Date: 2023-11-16 10:29:09.351664 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f819277e5169" +down_revision: Union[str, None] = "4814901632bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "audio_location", sa.String(), server_default="local", nullable=False + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "share_mode", + sa.VARCHAR(), + server_default=sa.text("'private'"), + nullable=False, + ), + ) + # ### end Alembic commands ### diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c0e8984b..44a6d56b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings +from reflector.storage import Storage transcripts = sqlalchemy.Table( "transcript", @@ -27,20 +28,33 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column( + "audio_location", + sqlalchemy.String, + nullable=False, + server_default="local", + ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), ) -def generate_uuid4(): +def generate_uuid4() -> str: return str(uuid4()) -def generate_transcript_name(): +def generate_transcript_name() -> str: now = datetime.utcnow() return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +def get_storage() -> Storage: + return Storage.get_instance( + name=settings.TRANSCRIPT_STORAGE_BACKEND, + settings_prefix="TRANSCRIPT_STORAGE_", + ) + + class AudioWaveform(BaseModel): data: list[float] @@ -133,6 +147,10 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id + @property + def audio_wav_filename(self): + return self.data_path / "audio.wav" + @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -157,6 +175,40 @@ class Transcript(BaseModel): return AudioWaveform(data=data) + async def get_audio_url(self) -> str: + if self.audio_location == "local": + return self._generate_local_audio_link() + elif self.audio_location == "storage": + return await self._generate_storage_audio_link() + raise Exception(f"Unknown audio location {self.audio_location}") + + async def _generate_storage_audio_link(self) -> str: + return await get_storage().get_file_url(self.storage_audio_path) + + def _generate_local_audio_link(self) -> str: + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from datetime import timedelta + + from reflector.app import app + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=self.id, + ) + url = f"{settings.BASE_URL}{path}" + if self.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": self.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + return url + class TranscriptController: async def get_all( @@ -292,15 +344,18 @@ class TranscriptController: """ Move mp3 file to storage """ - from reflector.storage import Storage - storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) - await storage.put_file( + # store the audio on external storage + await get_storage().put_file( transcript.storage_audio_path, - self.audio_mp3_filename.read_bytes(), + transcript.audio_mp3_filename.read_bytes(), ) + # indicate on the transcript that the audio is now on storage await self.update(transcript, {"audio_location": "storage"}) + # unlink the local file + transcript.audio_mp3_filename.unlink(missing_ok=True) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index e2f305c4..83b57949 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -14,12 +14,9 @@ It is directly linked to our data model. import asyncio import functools from contextlib import asynccontextmanager -from datetime import timedelta -from pathlib import Path -from celery import shared_task +from celery import chord, group, shared_task from pydantic import BaseModel -from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptDuration, @@ -56,7 +53,7 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager -from structlog import Logger +from structlog import BoundLogger as Logger def asynctask(f): @@ -97,13 +94,17 @@ def get_transcript(func): Decorator to fetch the transcript from the database from the first argument """ - async def wrapper(self, **kwargs): + async def wrapper(**kwargs): transcript_id = kwargs.pop("transcript_id") transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) if not transcript: raise Exception("Transcript {transcript_id} not found") tlogger = logger.bind(transcript_id=transcript.id) - return await func(self, transcript=transcript, logger=tlogger, **kwargs) + try: + return await func(transcript=transcript, logger=tlogger, **kwargs) + except Exception as exc: + tlogger.error("Pipeline error", exc_info=exc) + raise return wrapper @@ -162,7 +163,7 @@ class PipelineMainBase(PipelineRunner): "flush": "processing", "error": "error", } - elif isinstance(self, PipelineMainDiarization): + elif isinstance(self, PipelineMainFinalSummaries): status_mapping = { "push": "processing", "flush": "processing", @@ -170,7 +171,8 @@ class PipelineMainBase(PipelineRunner): "ended": "ended", } else: - raise Exception(f"Runner {self.__class__} is missing status mapping") + # intermediate pipeline don't update status + return # mutate to model status status = status_mapping.get(status) @@ -308,9 +310,10 @@ class PipelineMainBase(PipelineRunner): class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + """ + Main pipeline for live streaming, attach to RTC connection + Any long post process should be done in the post pipeline + """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction @@ -320,7 +323,7 @@ class PipelineMainLive(PipelineMainBase): processors = [ AudioFileWriterProcessor( - path=transcript.audio_mp3_filename, + path=transcript.audio_wav_filename, on_duration=self.on_duration, ), AudioChunkerProcessor(), @@ -329,15 +332,11 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), - ] + # XXX move as a task + AudioWaveformProcessor.as_threaded( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, ), ] pipeline = Pipeline(*processors) @@ -374,28 +373,16 @@ class PipelineMainDiarization(PipelineMainBase): # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() + + # diarization works only if the file is uploaded to an external storage + if transcript.audio_location == "local": + pipeline.logger.info("Audio is local, skipping diarization") + return + topics = self.get_transcript_topics(transcript) - - # we need to create an url to be used for diarization - # we can't use the audio_mp3_filename because it's not accessible - # from the diarization processor - from reflector.views.transcripts import create_access_token - - path = app.url_path_for( - "transcript_get_audio_mp3", - transcript_id=transcript.id, - ) - url = f"{settings.BASE_URL}{path}" - if transcript.user_id: - # we pass token only if the user_id is set - # otherwise, the audio is public - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) - url += f"?token={token}" + audio_url = await transcript.get_audio_url() audio_diarization_input = AudioDiarizationInput( - audio_url=url, + audio_url=audio_url, topics=topics, ) @@ -409,14 +396,60 @@ class PipelineMainDiarization(PipelineMainBase): return pipeline -class PipelineMainSummaries(PipelineMainBase): +class PipelineMainFromTopics(PipelineMainBase): + """ + Pseudo class for generating a pipeline from topics + """ + + def get_processors(self) -> list: + raise NotImplementedError + + async def create(self) -> Pipeline: + self.prepare() + processors = self.get_processors() + pipeline = Pipeline(*processors) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info(f"{self.__class__.__name__} pipeline created") + + # push topics + topics = self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + +class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): + """ + Generate title from the topics + """ + + def get_processors(self) -> list: + return [ + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ) + ] + + +class PipelineMainFinalSummaries(PipelineMainFromTopics): """ Generate summaries from the topics """ - async def create(self) -> Pipeline: - self.prepare() - pipeline = Pipeline( + def get_processors(self) -> list: + return [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -427,32 +460,7 @@ class PipelineMainSummaries(PipelineMainBase): ), ] ), - ) - pipeline.options = self - - # get transcript - transcript = await self.get_transcript() - pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info("Summaries pipeline created") - - # push topics - topics = await self.get_transcript_topics(transcript) - for topic in topics: - self.push(topic) - - self.flush() - - return pipeline - - -@shared_task -def task_pipeline_main_post(transcript_id: str): - logger.info( - "Starting main post pipeline", - transcript_id=transcript_id, - ) - runner = PipelineMainDiarization(transcript_id=transcript_id) - runner.start_sync() + ] @get_transcript @@ -470,24 +478,26 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): import av - input_container = av.open(wav_filename) - output_container = av.open(mp3_filename, "w") - input_audio_stream = input_container.streams.audio[0] - output_audio_stream = output_container.add_stream("mp3") - output_audio_stream.codec_context.set_parameters( - input_audio_stream.codec_context.parameters - ) - for packet in input_container.demux(input_audio_stream): - for frame in packet.decode(): - output_container.mux(frame) - input_container.close() - output_container.close() + with av.open(wav_filename.as_posix()) as in_container: + in_stream = in_container.streams.audio[0] + with av.open(mp3_filename.as_posix(), "w") as out_container: + out_stream = out_container.add_stream("mp3") + for frame in in_container.decode(in_stream): + for packet in out_stream.encode(frame): + out_container.mux(packet) + + # Delete the wav file + transcript.audio_wav_filename.unlink(missing_ok=True) logger.info("Convert to mp3 done") @get_transcript async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + if not settings.TRANSCRIPT_STORAGE_BACKEND: + logger.info("No storage backend configured, skipping mp3 upload") + return + logger.info("Starting upload mp3") # If the audio mp3 is not available, just skip @@ -497,27 +507,32 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): return # Upload to external storage and delete the file - await transcripts_controller.move_to_storage(transcript) - await transcripts_controller.unlink_mp3(transcript) + await transcripts_controller.move_mp3_to_storage(transcript) logger.info("Upload mp3 done") @get_transcript -@asynctask async def pipeline_diarization(transcript: Transcript, logger: Logger): logger.info("Starting diarization") runner = PipelineMainDiarization(transcript_id=transcript.id) - await runner.start() + await runner.run() logger.info("Diarization done") @get_transcript -@asynctask +async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): + logger.info("Starting title and short summary") + runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) + await runner.run() + logger.info("Title and short summary done") + + +@get_transcript async def pipeline_summaries(transcript: Transcript, logger: Logger): logger.info("Starting summaries") - runner = PipelineMainSummaries(transcript_id=transcript.id) - await runner.start() + runner = PipelineMainFinalSummaries(transcript_id=transcript.id) + await runner.run() logger.info("Summaries done") @@ -528,29 +543,35 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): @shared_task @asynctask -async def task_pipeline_convert_to_mp3(transcript_id: str): - await pipeline_convert_to_mp3(transcript_id) +async def task_pipeline_convert_to_mp3(*, transcript_id: str): + await pipeline_convert_to_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_upload_mp3(transcript_id: str): - await pipeline_upload_mp3(transcript_id) +async def task_pipeline_upload_mp3(*, transcript_id: str): + await pipeline_upload_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_diarization(transcript_id: str): - await pipeline_diarization(transcript_id) +async def task_pipeline_diarization(*, transcript_id: str): + await pipeline_diarization(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_summaries(transcript_id: str): - await pipeline_summaries(transcript_id) +async def task_pipeline_title_and_short_summary(*, transcript_id: str): + await pipeline_title_and_short_summary(transcript_id=transcript_id) -def pipeline_post(transcript_id: str): +@shared_task +@asynctask +async def task_pipeline_final_summaries(*, transcript_id: str): + await pipeline_summaries(transcript_id=transcript_id) + + +def pipeline_post(*, transcript_id: str): """ Run the post pipeline """ @@ -559,6 +580,15 @@ def pipeline_post(transcript_id: str): | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) - chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) - chain = chain_mp3_and_diarize | chain_summary + chain_title_preview = task_pipeline_title_and_short_summary.si( + transcript_id=transcript_id + ) + chain_final_summaries = task_pipeline_final_summaries.si( + transcript_id=transcript_id + ) + + chain = chord( + group(chain_mp3_and_diarize, chain_title_preview), + chain_final_summaries, + ) chain.delay() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index a1e137a7..4105d51f 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -119,8 +119,7 @@ class PipelineRunner(BaseModel): self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() - if self.on_ended: - await self.on_ended() + raise async def cmd_push(self, data): if self._is_first_push: diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 65412310..2c68c4e5 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -54,7 +54,7 @@ class Settings(BaseSettings): TRANSCRIPT_MODAL_API_KEY: str | None = None # Audio transcription storage - TRANSCRIPT_STORAGE_BACKEND: str = "aws" + TRANSCRIPT_STORAGE_BACKEND: str | None = None # Storage configuration for AWS TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" @@ -62,9 +62,6 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None - # Transcript MP3 storage - TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" - # LLM # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga" From 5ffa931822ff5b7c5205f942ca2332f52b9ef892 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:45:40 +0100 Subject: [PATCH 4/7] server: update backend tests results (rpc does not work with chords) --- .../migrations/versions/f819277e5169_audio_location.py | 10 +--------- server/tests/conftest.py | 8 +++++++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py index 576b02bd..061abec4 100644 --- a/server/migrations/versions/f819277e5169_audio_location.py +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -31,13 +31,5 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "transcript", - sa.Column( - "share_mode", - sa.VARCHAR(), - server_default=sa.text("'private'"), - nullable=False, - ), - ) + op.drop_column("transcript", "audio_location") # ### end Alembic commands ### diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aafca9fd..aaf42884 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -133,4 +133,10 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - return {"broker_url": "memory://", "result_backend": "rpc"} + import tempfile + + with tempfile.NamedTemporaryFile() as fd: + yield { + "broker_url": "memory://", + "result_backend": "db+sqlite://" + fd.name, + } From 99b973f36f5fbbe1c193dd00dc2233fc82eddb4d Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 17 Nov 2023 14:27:53 +0100 Subject: [PATCH 5/7] server: fix tests --- server/reflector/pipelines/runner.py | 8 ++++++ server/tests/conftest.py | 36 +++++++++++++++++++++---- server/tests/test_transcripts_rtc_ws.py | 4 +++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 4105d51f..708a4265 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -106,6 +106,14 @@ class PipelineRunner(BaseModel): if not self.pipeline: self.pipeline = await self.create() + if not self.pipeline: + # no pipeline created in create, just finish it then. + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + return + # start the loop await self._set_status("started") while not self._ev_done.is_set(): diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aaf42884..532ebff9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tempfile import NamedTemporaryFile import pytest @@ -7,7 +8,6 @@ import pytest @pytest.mark.asyncio async def setup_database(): from reflector.settings import settings - from tempfile import NamedTemporaryFile with NamedTemporaryFile() as f: settings.DATABASE_URL = f"sqlite:///{f.name}" @@ -103,6 +103,25 @@ async def dummy_llm(): yield +@pytest.fixture +async def dummy_storage(): + from reflector.storage.base import Storage + + class DummyStorage(Storage): + async def _put_file(self, *args, **kwargs): + pass + + async def _delete_file(self, *args, **kwargs): + pass + + async def _get_file_url(self, *args, **kwargs): + return "http://fake_server/audio.mp3" + + with patch("reflector.storage.base.Storage.get_instance") as mock_storage: + mock_storage.return_value = DummyStorage() + yield + + @pytest.fixture def nltk(): with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: @@ -133,10 +152,17 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - import tempfile - - with tempfile.NamedTemporaryFile() as fd: + with NamedTemporaryFile() as f: yield { "broker_url": "memory://", - "result_backend": "db+sqlite://" + fd.name, + "result_backend": f"db+sqlite:///{f.name}", } + + +@pytest.fixture(scope="session") +def fake_mp3_upload(): + with patch( + "reflector.db.transcripts.TranscriptController.move_mp3_to_storage" + ) as mock_move: + mock_move.return_value = True + yield diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index b33b1db5..8502a0d9 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, @@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, From 794d08c3a88d55a2ac6ea5faecd117697d838612 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 21 Nov 2023 14:46:16 +0100 Subject: [PATCH 6/7] server: redirect to storage url --- server/reflector/views/transcripts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6909b8ae..7496b26c 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -233,6 +233,12 @@ async def transcript_get_audio_mp3( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") + if transcript.audio_location == "storage": + url = transcript.get_audio_url() + from fastapi.responses import RedirectResponse + + return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND) + if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") From 0e5c0f66d91fa61f7bfee283a89f96f13dc57fef Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 21 Nov 2023 15:40:49 +0100 Subject: [PATCH 7/7] server: move waveform out of the live pipeline --- .../reflector/pipelines/main_live_pipeline.py | 46 +++++++++++++++---- server/reflector/views/transcripts.py | 22 +++++++-- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 83b57949..b182f421 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -332,12 +332,6 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - # XXX move as a task - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), ] pipeline = Pipeline(*processors) pipeline.options = self @@ -406,12 +400,14 @@ class PipelineMainFromTopics(PipelineMainBase): async def create(self) -> Pipeline: self.prepare() + + # get transcript + self._transcript = transcript = await self.get_transcript() + + # create pipeline processors = self.get_processors() pipeline = Pipeline(*processors) pipeline.options = self - - # get transcript - transcript = await self.get_transcript() pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.info(f"{self.__class__.__name__} pipeline created") @@ -463,6 +459,29 @@ class PipelineMainFinalSummaries(PipelineMainFromTopics): ] +class PipelineMainWaveform(PipelineMainFromTopics): + """ + Generate waveform + """ + + def get_processors(self) -> list: + return [ + AudioWaveformProcessor.as_threaded( + audio_path=self._transcript.audio_wav_filename, + waveform_path=self._transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), + ] + + +@get_transcript +async def pipeline_waveform(transcript: Transcript, logger: Logger): + logger.info("Starting waveform") + runner = PipelineMainWaveform(transcript_id=transcript.id) + await runner.run() + logger.info("Waveform done") + + @get_transcript async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): logger.info("Starting convert to mp3") @@ -541,6 +560,12 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): # =================================================================== +@shared_task +@asynctask +async def task_pipeline_waveform(*, transcript_id: str): + await pipeline_waveform(transcript_id=transcript_id) + + @shared_task @asynctask async def task_pipeline_convert_to_mp3(*, transcript_id: str): @@ -576,7 +601,8 @@ def pipeline_post(*, transcript_id: str): Run the post pipeline """ chain_mp3_and_diarize = ( - task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + task_pipeline_waveform.si(transcript_id=transcript_id) + | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 7496b26c..125aa311 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta from typing import Annotated, Optional +import httpx import reflector.auth as auth from fastapi import ( APIRouter, Depends, HTTPException, Request, + Response, WebSocket, WebSocketDisconnect, status, @@ -234,10 +236,22 @@ async def transcript_get_audio_mp3( raise HTTPException(status_code=404, detail="Transcript not found") if transcript.audio_location == "storage": - url = transcript.get_audio_url() - from fastapi.responses import RedirectResponse + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} - return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND) + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") @@ -263,7 +277,7 @@ async def transcript_get_audio_waveform( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - if not transcript.audio_mp3_filename.exists(): + if not transcript.audio_waveform_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") return transcript.audio_waveform