diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py new file mode 100644 index 00000000..061abec4 --- /dev/null +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -0,0 +1,35 @@ +"""audio_location + +Revision ID: f819277e5169 +Revises: 4814901632bc +Create Date: 2023-11-16 10:29:09.351664 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f819277e5169" +down_revision: Union[str, None] = "4814901632bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "audio_location", sa.String(), server_default="local", nullable=False + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("transcript", "audio_location") + # ### end Alembic commands ### diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c563f587..0fba82ef 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings +from reflector.storage import Storage transcripts = sqlalchemy.Table( "transcript", @@ -28,6 +29,12 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column( + "audio_location", + sqlalchemy.String, + nullable=False, + server_default="local", + ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column( @@ -39,15 +46,22 @@ transcripts = sqlalchemy.Table( ) -def generate_uuid4(): +def generate_uuid4() -> str: return str(uuid4()) -def generate_transcript_name(): +def generate_transcript_name() -> str: now = datetime.utcnow() return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +def get_storage() -> Storage: + return Storage.get_instance( + name=settings.TRANSCRIPT_STORAGE_BACKEND, + settings_prefix="TRANSCRIPT_STORAGE_", + ) + + class AudioWaveform(BaseModel): data: list[float] @@ -114,6 +128,7 @@ class Transcript(BaseModel): source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private" + audio_location: str = "local" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -140,6 +155,10 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id + @property + def audio_wav_filename(self): + return self.data_path / "audio.wav" + @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -148,6 +167,10 @@ class Transcript(BaseModel): def audio_waveform_filename(self): return self.data_path / "audio.json" + @property + def storage_audio_path(self): + return f"{self.id}/audio.mp3" + @property def audio_waveform(self): try: @@ -160,6 +183,40 @@ class Transcript(BaseModel): return AudioWaveform(data=data) + async def get_audio_url(self) -> str: + if self.audio_location == "local": + return self._generate_local_audio_link() + elif self.audio_location == "storage": + return await self._generate_storage_audio_link() + raise Exception(f"Unknown audio location {self.audio_location}") + + async def _generate_storage_audio_link(self) -> str: + return await get_storage().get_file_url(self.storage_audio_path) + + def _generate_local_audio_link(self) -> str: + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from datetime import timedelta + + from reflector.app import app + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=self.id, + ) + url = f"{settings.BASE_URL}{path}" + if self.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": self.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + return url + class TranscriptController: async def get_all( @@ -336,5 +393,22 @@ class TranscriptController: transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) + async def move_mp3_to_storage(self, transcript: Transcript): + """ + Move mp3 file to storage + """ + + # store the audio on external storage + await get_storage().put_file( + transcript.storage_audio_path, + transcript.audio_mp3_filename.read_bytes(), + ) + + # indicate on the transcript that the audio is now on storage + await self.update(transcript, {"audio_location": "storage"}) + + # unlink the local file + transcript.audio_mp3_filename.unlink(missing_ok=True) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 3a9d1868..b182f421 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,13 +12,11 @@ It is directly linked to our data model. """ import asyncio +import functools from contextlib import asynccontextmanager -from datetime import timedelta -from pathlib import Path -from celery import shared_task +from celery import chord, group, shared_task from pydantic import BaseModel -from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptDuration, @@ -55,6 +53,22 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager +from structlog import BoundLogger as Logger + + +def asynctask(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + coro = f(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(coro) + return asyncio.run(coro) + + return wrapper def broadcast_to_sockets(func): @@ -75,6 +89,26 @@ def broadcast_to_sockets(func): return wrapper +def get_transcript(func): + """ + Decorator to fetch the transcript from the database from the first argument + """ + + async def wrapper(**kwargs): + transcript_id = kwargs.pop("transcript_id") + transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) + if not transcript: + raise Exception("Transcript {transcript_id} not found") + tlogger = logger.bind(transcript_id=transcript.id) + try: + return await func(transcript=transcript, logger=tlogger, **kwargs) + except Exception as exc: + tlogger.error("Pipeline error", exc_info=exc) + raise + + return wrapper + + class StrValue(BaseModel): value: str @@ -99,6 +133,19 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: + return [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + @asynccontextmanager async def transaction(self): async with self._lock: @@ -116,7 +163,7 @@ class PipelineMainBase(PipelineRunner): "flush": "processing", "error": "error", } - elif isinstance(self, PipelineMainDiarization): + elif isinstance(self, PipelineMainFinalSummaries): status_mapping = { "push": "processing", "flush": "processing", @@ -124,7 +171,8 @@ class PipelineMainBase(PipelineRunner): "ended": "ended", } else: - raise Exception(f"Runner {self.__class__} is missing status mapping") + # intermediate pipeline don't update status + return # mutate to model status status = status_mapping.get(status) @@ -262,9 +310,10 @@ class PipelineMainBase(PipelineRunner): class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + """ + Main pipeline for live streaming, attach to RTC connection + Any long post process should be done in the post pipeline + """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction @@ -274,7 +323,7 @@ class PipelineMainLive(PipelineMainBase): processors = [ AudioFileWriterProcessor( - path=transcript.audio_mp3_filename, + path=transcript.audio_wav_filename, on_duration=self.on_duration, ), AudioChunkerProcessor(), @@ -283,26 +332,13 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), - ] - ), ] pipeline = Pipeline(*processors) pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main live created", - transcript_id=self.transcript_id, - ) + pipeline.logger.info("Pipeline main live created") return pipeline @@ -310,26 +346,106 @@ class PipelineMainLive(PipelineMainBase): # when the pipeline ends, connect to the post pipeline logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) - task_pipeline_main_post.delay(transcript_id=self.transcript_id) + pipeline_post(transcript_id=self.transcript_id) class PipelineMainDiarization(PipelineMainBase): """ - Diarization is a long time process, so we do it in a separate pipeline - When done, adjust the short and final summary + Diarize the audio and update topics """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [] - if settings.DIARIZATION_ENABLED: - processors += [ - AudioDiarizationAutoProcessor(callback=self.on_topic), - ] + pipeline = Pipeline( + AudioDiarizationAutoProcessor(callback=self.on_topic), + ) + pipeline.options = self - processors += [ + # now let's start the pipeline by pushing information to the + # first processor diarization processor + # XXX translation is lost when converting our data model to the processor model + transcript = await self.get_transcript() + + # diarization works only if the file is uploaded to an external storage + if transcript.audio_location == "local": + pipeline.logger.info("Audio is local, skipping diarization") + return + + topics = self.get_transcript_topics(transcript) + audio_url = await transcript.get_audio_url() + audio_diarization_input = AudioDiarizationInput( + audio_url=audio_url, + topics=topics, + ) + + # as tempting to use pipeline.push, prefer to use the runner + # to let the start just do one job. + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info("Diarization pipeline created") + self.push(audio_diarization_input) + self.flush() + + return pipeline + + +class PipelineMainFromTopics(PipelineMainBase): + """ + Pseudo class for generating a pipeline from topics + """ + + def get_processors(self) -> list: + raise NotImplementedError + + async def create(self) -> Pipeline: + self.prepare() + + # get transcript + self._transcript = transcript = await self.get_transcript() + + # create pipeline + processors = self.get_processors() + pipeline = Pipeline(*processors) + pipeline.options = self + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info(f"{self.__class__.__name__} pipeline created") + + # push topics + topics = self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + +class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): + """ + Generate title from the topics + """ + + def get_processors(self) -> list: + return [ + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ) + ] + + +class PipelineMainFinalSummaries(PipelineMainFromTopics): + """ + Generate summaries from the topics + """ + + def get_processors(self) -> list: + return [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -341,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase): ] ), ] - pipeline = Pipeline(*processors) - pipeline.options = self - # now let's start the pipeline by pushing information to the - # first processor diarization processor - # XXX translation is lost when converting our data model to the processor model - transcript = await self.get_transcript() - topics = [ - TitleSummaryWithIdProcessorType( - id=topic.id, - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=topic.duration, - transcript=TranscriptProcessorType(words=topic.words), - ) - for topic in transcript.topics + +class PipelineMainWaveform(PipelineMainFromTopics): + """ + Generate waveform + """ + + def get_processors(self) -> list: + return [ + AudioWaveformProcessor.as_threaded( + audio_path=self._transcript.audio_wav_filename, + waveform_path=self._transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), ] - # we need to create an url to be used for diarization - # we can't use the audio_mp3_filename because it's not accessible - # from the diarization processor - from reflector.views.transcripts import create_access_token - path = app.url_path_for( - "transcript_get_audio_mp3", - transcript_id=transcript.id, - ) - url = f"{settings.BASE_URL}{path}" - if transcript.user_id: - # we pass token only if the user_id is set - # otherwise, the audio is public - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) - url += f"?token={token}" - audio_diarization_input = AudioDiarizationInput( - audio_url=url, - topics=topics, - ) +@get_transcript +async def pipeline_waveform(transcript: Transcript, logger: Logger): + logger.info("Starting waveform") + runner = PipelineMainWaveform(transcript_id=transcript.id) + await runner.run() + logger.info("Waveform done") - # as tempting to use pipeline.push, prefer to use the runner - # to let the start just do one job. - pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main post created", transcript_id=self.transcript_id - ) - self.push(audio_diarization_input) - self.flush() - return pipeline +@get_transcript +async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting convert to mp3") + + # If the audio wav is not available, just skip + wav_filename = transcript.audio_wav_filename + if not wav_filename.exists(): + logger.warning("Wav file not found, may be already converted") + return + + # Convert to mp3 + mp3_filename = transcript.audio_mp3_filename + + import av + + with av.open(wav_filename.as_posix()) as in_container: + in_stream = in_container.streams.audio[0] + with av.open(mp3_filename.as_posix(), "w") as out_container: + out_stream = out_container.add_stream("mp3") + for frame in in_container.decode(in_stream): + for packet in out_stream.encode(frame): + out_container.mux(packet) + + # Delete the wav file + transcript.audio_wav_filename.unlink(missing_ok=True) + + logger.info("Convert to mp3 done") + + +@get_transcript +async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + if not settings.TRANSCRIPT_STORAGE_BACKEND: + logger.info("No storage backend configured, skipping mp3 upload") + return + + logger.info("Starting upload mp3") + + # If the audio mp3 is not available, just skip + mp3_filename = transcript.audio_mp3_filename + if not mp3_filename.exists(): + logger.warning("Mp3 file not found, may be already uploaded") + return + + # Upload to external storage and delete the file + await transcripts_controller.move_mp3_to_storage(transcript) + + logger.info("Upload mp3 done") + + +@get_transcript +async def pipeline_diarization(transcript: Transcript, logger: Logger): + logger.info("Starting diarization") + runner = PipelineMainDiarization(transcript_id=transcript.id) + await runner.run() + logger.info("Diarization done") + + +@get_transcript +async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): + logger.info("Starting title and short summary") + runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) + await runner.run() + logger.info("Title and short summary done") + + +@get_transcript +async def pipeline_summaries(transcript: Transcript, logger: Logger): + logger.info("Starting summaries") + runner = PipelineMainFinalSummaries(transcript_id=transcript.id) + await runner.run() + logger.info("Summaries done") + + +# =================================================================== +# Celery tasks that can be called from the API +# =================================================================== @shared_task -def task_pipeline_main_post(transcript_id: str): - logger.info( - "Starting main post pipeline", - transcript_id=transcript_id, +@asynctask +async def task_pipeline_waveform(*, transcript_id: str): + await pipeline_waveform(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_convert_to_mp3(*, transcript_id: str): + await pipeline_convert_to_mp3(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_upload_mp3(*, transcript_id: str): + await pipeline_upload_mp3(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_diarization(*, transcript_id: str): + await pipeline_diarization(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_title_and_short_summary(*, transcript_id: str): + await pipeline_title_and_short_summary(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_final_summaries(*, transcript_id: str): + await pipeline_summaries(transcript_id=transcript_id) + + +def pipeline_post(*, transcript_id: str): + """ + Run the post pipeline + """ + chain_mp3_and_diarize = ( + task_pipeline_waveform.si(transcript_id=transcript_id) + | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + | task_pipeline_upload_mp3.si(transcript_id=transcript_id) + | task_pipeline_diarization.si(transcript_id=transcript_id) ) - runner = PipelineMainDiarization(transcript_id=transcript_id) - runner.start_sync() + chain_title_preview = task_pipeline_title_and_short_summary.si( + transcript_id=transcript_id + ) + chain_final_summaries = task_pipeline_final_summaries.si( + transcript_id=transcript_id + ) + + chain = chord( + group(chain_mp3_and_diarize, chain_title_preview), + chain_final_summaries, + ) + chain.delay() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index a1e137a7..708a4265 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -106,6 +106,14 @@ class PipelineRunner(BaseModel): if not self.pipeline: self.pipeline = await self.create() + if not self.pipeline: + # no pipeline created in create, just finish it then. + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + return + # start the loop await self._set_status("started") while not self._ev_done.is_set(): @@ -119,8 +127,7 @@ class PipelineRunner(BaseModel): self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() - if self.on_ended: - await self.on_ended() + raise async def cmd_push(self, data): if self._is_first_push: diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 65412310..2c68c4e5 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -54,7 +54,7 @@ class Settings(BaseSettings): TRANSCRIPT_MODAL_API_KEY: str | None = None # Audio transcription storage - TRANSCRIPT_STORAGE_BACKEND: str = "aws" + TRANSCRIPT_STORAGE_BACKEND: str | None = None # Storage configuration for AWS TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" @@ -62,9 +62,6 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None - # Transcript MP3 storage - TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" - # LLM # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga" diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 5cdafdbf..a457ddf8 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -1,6 +1,7 @@ +import importlib + from pydantic import BaseModel from reflector.settings import settings -import importlib class FileResult(BaseModel): @@ -17,7 +18,7 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name, settings_prefix=""): + def get_instance(cls, name: str, settings_prefix: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) @@ -45,3 +46,9 @@ class Storage: async def _delete_file(self, filename: str): raise NotImplementedError + + async def get_file_url(self, filename: str) -> str: + return await self._get_file_url(filename) + + async def _get_file_url(self, filename: str) -> str: + raise NotImplementedError diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 09a9c383..d2313293 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -1,6 +1,6 @@ import aioboto3 -from reflector.storage.base import Storage, FileResult from reflector.logger import logger +from reflector.storage.base import FileResult, Storage class AwsStorage(Storage): @@ -44,16 +44,18 @@ class AwsStorage(Storage): Body=data, ) + async def _get_file_url(self, filename: str) -> FileResult: + bucket = self.aws_bucket_name + folder = self.aws_folder + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( "get_object", Params={"Bucket": bucket, "Key": s3filename}, ExpiresIn=3600, ) - return FileResult( - filename=filename, - url=presigned_url, - ) + return presigned_url async def _delete_file(self, filename: str): bucket = self.aws_bucket_name diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 88351880..44b55629 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta from typing import Annotated, Literal, Optional +import httpx import reflector.auth as auth from fastapi import ( APIRouter, Depends, HTTPException, Request, + Response, WebSocket, WebSocketDisconnect, status, @@ -245,6 +247,42 @@ async def transcript_get_audio_mp3( transcript_id, user_id=user_id ) + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=500, detail="Audio not found") @@ -269,8 +307,8 @@ async def transcript_get_audio_waveform( transcript_id, user_id=user_id ) - if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=500, detail="Audio not found") + if not transcript.audio_waveform_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") return transcript.audio_waveform diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aafca9fd..532ebff9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tempfile import NamedTemporaryFile import pytest @@ -7,7 +8,6 @@ import pytest @pytest.mark.asyncio async def setup_database(): from reflector.settings import settings - from tempfile import NamedTemporaryFile with NamedTemporaryFile() as f: settings.DATABASE_URL = f"sqlite:///{f.name}" @@ -103,6 +103,25 @@ async def dummy_llm(): yield +@pytest.fixture +async def dummy_storage(): + from reflector.storage.base import Storage + + class DummyStorage(Storage): + async def _put_file(self, *args, **kwargs): + pass + + async def _delete_file(self, *args, **kwargs): + pass + + async def _get_file_url(self, *args, **kwargs): + return "http://fake_server/audio.mp3" + + with patch("reflector.storage.base.Storage.get_instance") as mock_storage: + mock_storage.return_value = DummyStorage() + yield + + @pytest.fixture def nltk(): with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: @@ -133,4 +152,17 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - return {"broker_url": "memory://", "result_backend": "rpc"} + with NamedTemporaryFile() as f: + yield { + "broker_url": "memory://", + "result_backend": f"db+sqlite:///{f.name}", + } + + +@pytest.fixture(scope="session") +def fake_mp3_upload(): + with patch( + "reflector.db.transcripts.TranscriptController.move_mp3_to_storage" + ) as mock_move: + mock_move.return_value = True + yield diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index b33b1db5..8502a0d9 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, @@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize,