diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 4105d51f..708a4265 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -106,6 +106,14 @@ class PipelineRunner(BaseModel): if not self.pipeline: self.pipeline = await self.create() + if not self.pipeline: + # no pipeline created in create, just finish it then. + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + return + # start the loop await self._set_status("started") while not self._ev_done.is_set(): diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aaf42884..532ebff9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tempfile import NamedTemporaryFile import pytest @@ -7,7 +8,6 @@ import pytest @pytest.mark.asyncio async def setup_database(): from reflector.settings import settings - from tempfile import NamedTemporaryFile with NamedTemporaryFile() as f: settings.DATABASE_URL = f"sqlite:///{f.name}" @@ -103,6 +103,25 @@ async def dummy_llm(): yield +@pytest.fixture +async def dummy_storage(): + from reflector.storage.base import Storage + + class DummyStorage(Storage): + async def _put_file(self, *args, **kwargs): + pass + + async def _delete_file(self, *args, **kwargs): + pass + + async def _get_file_url(self, *args, **kwargs): + return "http://fake_server/audio.mp3" + + with patch("reflector.storage.base.Storage.get_instance") as mock_storage: + mock_storage.return_value = DummyStorage() + yield + + @pytest.fixture def nltk(): with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: @@ -133,10 +152,17 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - import tempfile - - with tempfile.NamedTemporaryFile() as fd: + with NamedTemporaryFile() as f: yield { "broker_url": "memory://", - "result_backend": "db+sqlite://" + fd.name, + "result_backend": f"db+sqlite:///{f.name}", } + + +@pytest.fixture(scope="session") +def fake_mp3_upload(): + with patch( + "reflector.db.transcripts.TranscriptController.move_mp3_to_storage" + ) as mock_move: + mock_move.return_value = True + yield diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index b33b1db5..8502a0d9 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, @@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize,