diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py new file mode 100644 index 00000000..576b02bd --- /dev/null +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -0,0 +1,43 @@ +"""audio_location + +Revision ID: f819277e5169 +Revises: 4814901632bc +Create Date: 2023-11-16 10:29:09.351664 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f819277e5169" +down_revision: Union[str, None] = "4814901632bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "audio_location", sa.String(), server_default="local", nullable=False + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "share_mode", + sa.VARCHAR(), + server_default=sa.text("'private'"), + nullable=False, + ), + ) + # ### end Alembic commands ### diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c0e8984b..44a6d56b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings +from reflector.storage import Storage transcripts = sqlalchemy.Table( "transcript", @@ -27,20 +28,33 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column( + "audio_location", + sqlalchemy.String, + nullable=False, + server_default="local", + ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), ) -def generate_uuid4(): +def generate_uuid4() -> str: return str(uuid4()) -def generate_transcript_name(): +def generate_transcript_name() -> str: now = datetime.utcnow() return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +def get_storage() -> Storage: + return Storage.get_instance( + name=settings.TRANSCRIPT_STORAGE_BACKEND, + settings_prefix="TRANSCRIPT_STORAGE_", + ) + + class AudioWaveform(BaseModel): data: list[float] @@ -133,6 +147,10 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id + @property + def audio_wav_filename(self): + return self.data_path / "audio.wav" + @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -157,6 +175,40 @@ class Transcript(BaseModel): return AudioWaveform(data=data) + async def get_audio_url(self) -> str: + if self.audio_location == "local": + return self._generate_local_audio_link() + elif self.audio_location == "storage": + return await self._generate_storage_audio_link() + raise Exception(f"Unknown audio location {self.audio_location}") + + async def _generate_storage_audio_link(self) -> str: + return await get_storage().get_file_url(self.storage_audio_path) + + def _generate_local_audio_link(self) -> str: + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from datetime import timedelta + + from reflector.app import app + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=self.id, + ) + url = f"{settings.BASE_URL}{path}" + if self.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": self.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + return url + class TranscriptController: async def get_all( @@ -292,15 +344,18 @@ class TranscriptController: """ Move mp3 file to storage """ - from reflector.storage import Storage - storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) - await storage.put_file( + # store the audio on external storage + await get_storage().put_file( transcript.storage_audio_path, - self.audio_mp3_filename.read_bytes(), + transcript.audio_mp3_filename.read_bytes(), ) + # indicate on the transcript that the audio is now on storage await self.update(transcript, {"audio_location": "storage"}) + # unlink the local file + transcript.audio_mp3_filename.unlink(missing_ok=True) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index e2f305c4..83b57949 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -14,12 +14,9 @@ It is directly linked to our data model. import asyncio import functools from contextlib import asynccontextmanager -from datetime import timedelta -from pathlib import Path -from celery import shared_task +from celery import chord, group, shared_task from pydantic import BaseModel -from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptDuration, @@ -56,7 +53,7 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager -from structlog import Logger +from structlog import BoundLogger as Logger def asynctask(f): @@ -97,13 +94,17 @@ def get_transcript(func): Decorator to fetch the transcript from the database from the first argument """ - async def wrapper(self, **kwargs): + async def wrapper(**kwargs): transcript_id = kwargs.pop("transcript_id") transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) if not transcript: raise Exception("Transcript {transcript_id} not found") tlogger = logger.bind(transcript_id=transcript.id) - return await func(self, transcript=transcript, logger=tlogger, **kwargs) + try: + return await func(transcript=transcript, logger=tlogger, **kwargs) + except Exception as exc: + tlogger.error("Pipeline error", exc_info=exc) + raise return wrapper @@ -162,7 +163,7 @@ class PipelineMainBase(PipelineRunner): "flush": "processing", "error": "error", } - elif isinstance(self, PipelineMainDiarization): + elif isinstance(self, PipelineMainFinalSummaries): status_mapping = { "push": "processing", "flush": "processing", @@ -170,7 +171,8 @@ class PipelineMainBase(PipelineRunner): "ended": "ended", } else: - raise Exception(f"Runner {self.__class__} is missing status mapping") + # intermediate pipeline don't update status + return # mutate to model status status = status_mapping.get(status) @@ -308,9 +310,10 @@ class PipelineMainBase(PipelineRunner): class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + """ + Main pipeline for live streaming, attach to RTC connection + Any long post process should be done in the post pipeline + """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction @@ -320,7 +323,7 @@ class PipelineMainLive(PipelineMainBase): processors = [ AudioFileWriterProcessor( - path=transcript.audio_mp3_filename, + path=transcript.audio_wav_filename, on_duration=self.on_duration, ), AudioChunkerProcessor(), @@ -329,15 +332,11 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), - ] + # XXX move as a task + AudioWaveformProcessor.as_threaded( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, ), ] pipeline = Pipeline(*processors) @@ -374,28 +373,16 @@ class PipelineMainDiarization(PipelineMainBase): # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() + + # diarization works only if the file is uploaded to an external storage + if transcript.audio_location == "local": + pipeline.logger.info("Audio is local, skipping diarization") + return + topics = self.get_transcript_topics(transcript) - - # we need to create an url to be used for diarization - # we can't use the audio_mp3_filename because it's not accessible - # from the diarization processor - from reflector.views.transcripts import create_access_token - - path = app.url_path_for( - "transcript_get_audio_mp3", - transcript_id=transcript.id, - ) - url = f"{settings.BASE_URL}{path}" - if transcript.user_id: - # we pass token only if the user_id is set - # otherwise, the audio is public - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) - url += f"?token={token}" + audio_url = await transcript.get_audio_url() audio_diarization_input = AudioDiarizationInput( - audio_url=url, + audio_url=audio_url, topics=topics, ) @@ -409,14 +396,60 @@ class PipelineMainDiarization(PipelineMainBase): return pipeline -class PipelineMainSummaries(PipelineMainBase): +class PipelineMainFromTopics(PipelineMainBase): + """ + Pseudo class for generating a pipeline from topics + """ + + def get_processors(self) -> list: + raise NotImplementedError + + async def create(self) -> Pipeline: + self.prepare() + processors = self.get_processors() + pipeline = Pipeline(*processors) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info(f"{self.__class__.__name__} pipeline created") + + # push topics + topics = self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + +class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): + """ + Generate title from the topics + """ + + def get_processors(self) -> list: + return [ + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ) + ] + + +class PipelineMainFinalSummaries(PipelineMainFromTopics): """ Generate summaries from the topics """ - async def create(self) -> Pipeline: - self.prepare() - pipeline = Pipeline( + def get_processors(self) -> list: + return [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -427,32 +460,7 @@ class PipelineMainSummaries(PipelineMainBase): ), ] ), - ) - pipeline.options = self - - # get transcript - transcript = await self.get_transcript() - pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info("Summaries pipeline created") - - # push topics - topics = await self.get_transcript_topics(transcript) - for topic in topics: - self.push(topic) - - self.flush() - - return pipeline - - -@shared_task -def task_pipeline_main_post(transcript_id: str): - logger.info( - "Starting main post pipeline", - transcript_id=transcript_id, - ) - runner = PipelineMainDiarization(transcript_id=transcript_id) - runner.start_sync() + ] @get_transcript @@ -470,24 +478,26 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): import av - input_container = av.open(wav_filename) - output_container = av.open(mp3_filename, "w") - input_audio_stream = input_container.streams.audio[0] - output_audio_stream = output_container.add_stream("mp3") - output_audio_stream.codec_context.set_parameters( - input_audio_stream.codec_context.parameters - ) - for packet in input_container.demux(input_audio_stream): - for frame in packet.decode(): - output_container.mux(frame) - input_container.close() - output_container.close() + with av.open(wav_filename.as_posix()) as in_container: + in_stream = in_container.streams.audio[0] + with av.open(mp3_filename.as_posix(), "w") as out_container: + out_stream = out_container.add_stream("mp3") + for frame in in_container.decode(in_stream): + for packet in out_stream.encode(frame): + out_container.mux(packet) + + # Delete the wav file + transcript.audio_wav_filename.unlink(missing_ok=True) logger.info("Convert to mp3 done") @get_transcript async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + if not settings.TRANSCRIPT_STORAGE_BACKEND: + logger.info("No storage backend configured, skipping mp3 upload") + return + logger.info("Starting upload mp3") # If the audio mp3 is not available, just skip @@ -497,27 +507,32 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): return # Upload to external storage and delete the file - await transcripts_controller.move_to_storage(transcript) - await transcripts_controller.unlink_mp3(transcript) + await transcripts_controller.move_mp3_to_storage(transcript) logger.info("Upload mp3 done") @get_transcript -@asynctask async def pipeline_diarization(transcript: Transcript, logger: Logger): logger.info("Starting diarization") runner = PipelineMainDiarization(transcript_id=transcript.id) - await runner.start() + await runner.run() logger.info("Diarization done") @get_transcript -@asynctask +async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): + logger.info("Starting title and short summary") + runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) + await runner.run() + logger.info("Title and short summary done") + + +@get_transcript async def pipeline_summaries(transcript: Transcript, logger: Logger): logger.info("Starting summaries") - runner = PipelineMainSummaries(transcript_id=transcript.id) - await runner.start() + runner = PipelineMainFinalSummaries(transcript_id=transcript.id) + await runner.run() logger.info("Summaries done") @@ -528,29 +543,35 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): @shared_task @asynctask -async def task_pipeline_convert_to_mp3(transcript_id: str): - await pipeline_convert_to_mp3(transcript_id) +async def task_pipeline_convert_to_mp3(*, transcript_id: str): + await pipeline_convert_to_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_upload_mp3(transcript_id: str): - await pipeline_upload_mp3(transcript_id) +async def task_pipeline_upload_mp3(*, transcript_id: str): + await pipeline_upload_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_diarization(transcript_id: str): - await pipeline_diarization(transcript_id) +async def task_pipeline_diarization(*, transcript_id: str): + await pipeline_diarization(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_summaries(transcript_id: str): - await pipeline_summaries(transcript_id) +async def task_pipeline_title_and_short_summary(*, transcript_id: str): + await pipeline_title_and_short_summary(transcript_id=transcript_id) -def pipeline_post(transcript_id: str): +@shared_task +@asynctask +async def task_pipeline_final_summaries(*, transcript_id: str): + await pipeline_summaries(transcript_id=transcript_id) + + +def pipeline_post(*, transcript_id: str): """ Run the post pipeline """ @@ -559,6 +580,15 @@ def pipeline_post(transcript_id: str): | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) - chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) - chain = chain_mp3_and_diarize | chain_summary + chain_title_preview = task_pipeline_title_and_short_summary.si( + transcript_id=transcript_id + ) + chain_final_summaries = task_pipeline_final_summaries.si( + transcript_id=transcript_id + ) + + chain = chord( + group(chain_mp3_and_diarize, chain_title_preview), + chain_final_summaries, + ) chain.delay() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index a1e137a7..4105d51f 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -119,8 +119,7 @@ class PipelineRunner(BaseModel): self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() - if self.on_ended: - await self.on_ended() + raise async def cmd_push(self, data): if self._is_first_push: diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 65412310..2c68c4e5 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -54,7 +54,7 @@ class Settings(BaseSettings): TRANSCRIPT_MODAL_API_KEY: str | None = None # Audio transcription storage - TRANSCRIPT_STORAGE_BACKEND: str = "aws" + TRANSCRIPT_STORAGE_BACKEND: str | None = None # Storage configuration for AWS TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" @@ -62,9 +62,6 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None - # Transcript MP3 storage - TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" - # LLM # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga"