diff --git a/server/poetry.lock b/server/poetry.lock index 35d98382..8783625b 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2676,6 +2676,20 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-celery" +version = "0.0.0" +description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest" +optional = false +python-versions = "*" +files = [ + {file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"}, + {file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"}, +] + +[package.dependencies] +celery = ">=4.4.0" + [[package]] name = "pytest-cov" version = "4.1.0" @@ -4064,4 +4078,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63" +content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d" diff --git a/server/pyproject.toml b/server/pyproject.toml index ed231a4f..c8614006 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -49,6 +49,7 @@ pytest-asyncio = "^0.21.1" pytest = "^7.4.0" httpx-ws = "^0.4.1" pytest-httpx = "^0.23.1" +pytest-celery = "^0.0.0" [tool.poetry.group.aws.dependencies] diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 2b9fc6b2..61a2c380 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -62,6 +62,7 @@ class TranscriptTopic(BaseModel): title: str summary: str timestamp: float + duration: float | None = 0 text: str | None = None words: list[ProcessorWord] = [] @@ -264,7 +265,7 @@ class TranscriptController: """ A context manager for database transaction """ - async with database.transaction(): + async with database.transaction(isolation="serializable"): yield async def append_event( @@ -280,5 +281,16 @@ class TranscriptController: await self.update(transcript, {"events": transcript.events_dump()}) return resp + async def upsert_topic( + self, + transcript: Transcript, + topic: TranscriptTopic, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + transcript.upsert_topic(topic) + await self.update(transcript, {"topics": transcript.topics_dump()}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 30f7ead3..4159c889 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -11,8 +11,12 @@ It is decoupled to: It is directly linked to our data model. """ +import asyncio +from contextlib import asynccontextmanager from pathlib import Path +from celery import shared_task +from pydantic import BaseModel from reflector.db.transcripts import ( Transcript, TranscriptFinalLongSummary, @@ -25,6 +29,7 @@ from reflector.db.transcripts import ( from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, + AudioDiarizationProcessor, AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, @@ -37,11 +42,13 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) -from reflector.tasks.worker import celery +from reflector.processors.types import AudioDiarizationInput +from reflector.processors.types import TitleSummary as TitleSummaryProcessorType +from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.ws_manager import WebsocketManager, get_ws_manager -def broadcast_to_socket(func): +def broadcast_to_sockets(func): """ Decorator to broadcast transcript event to websockets concerning this transcript @@ -59,6 +66,10 @@ def broadcast_to_socket(func): return wrapper +class StrValue(BaseModel): + value: str + + class PipelineMainBase(PipelineRunner): transcript_id: str ws_room_id: str | None = None @@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner): def prepare(self): # prepare websocket + self._lock = asyncio.Lock() self.ws_room_id = f"ts:{self.transcript_id}" self.ws_manager = get_ws_manager() @@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + @asynccontextmanager + async def transaction(self): + async with self._lock: + async with transcripts_controller.transaction(): + yield -class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + @broadcast_to_sockets + async def on_status(self, status): + # if it's the first part, update the status of the transcript + # but do not set the ended status yet. + if isinstance(self, PipelineMainLive): + status_mapping = { + "started": "recording", + "push": "recording", + "flush": "processing", + "error": "error", + } + elif isinstance(self, PipelineMainDiarization): + status_mapping = { + "push": "processing", + "flush": "processing", + "error": "error", + "ended": "ended", + } + else: + raise Exception(f"Runner {self.__class__} is missing status mapping") - @broadcast_to_socket + # mutate to model status + status = status_mapping.get(status) + if not status: + return + + # when the status of the pipeline changes, update the transcript + async with self.transaction(): + transcript = await self.get_transcript() + if status == transcript.status: + return + resp = await transcripts_controller.append_event( + transcript=transcript, + event="STATUS", + data=StrValue(value=status), + ) + await transcripts_controller.update( + transcript, + { + "status": status, + }, + ) + return resp + + @broadcast_to_sockets async def on_transcript(self, data): - async with transcripts_controller.transaction(): + async with self.transaction(): transcript = await self.get_transcript() return await transcripts_controller.append_event( transcript=transcript, @@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase): data=TranscriptText(text=data.text, translation=data.translation), ) - @broadcast_to_socket + @broadcast_to_sockets async def on_topic(self, data): topic = TranscriptTopic( title=data.title, @@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase): text=data.transcript.text, words=data.transcript.words, ) - async with transcripts_controller.transaction(): + async with self.transaction(): transcript = await self.get_transcript() + await transcripts_controller.upsert_topic(transcript, topic) return await transcripts_controller.append_event( transcript=transcript, event="TOPIC", data=topic, ) + @broadcast_to_sockets + async def on_title(self, data): + final_title = TranscriptFinalTitle(title=data.title) + async with self.transaction(): + transcript = await self.get_transcript() + if not transcript.title: + transcripts_controller.update( + transcript, + { + "title": final_title.title, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_TITLE", + data=final_title, + ) + + @broadcast_to_sockets + async def on_long_summary(self, data): + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "long_summary": final_long_summary.long_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + @broadcast_to_sockets + async def on_short_summary(self, data): + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "short_summary": final_short_summary.short_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + +class PipelineMainLive(PipelineMainBase): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context @@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), ] pipeline = Pipeline(*processors) pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) - # when the pipeline ends, connect to the post pipeline - async def on_ended(): - task_pipeline_main_post.delay(transcript_id=self.transcript_id) - - pipeline.on_ended = self - return pipeline + async def on_ended(self): + # when the pipeline ends, connect to the post pipeline + task_pipeline_main_post.delay(transcript_id=self.transcript_id) -class PipelineMainPost(PipelineMainBase): + +class PipelineMainDiarization(PipelineMainBase): """ - Implement the rest of the main pipeline, triggered after PipelineMainLive ended. + Diarization is a long time process, so we do it in a separate pipeline + When done, adjust the short and final summary """ - @broadcast_to_socket - async def on_final_title(self, data): - final_title = TranscriptFinalTitle(title=data.title) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - if not transcript.title: - transcripts_controller.update( - self.transcript, - { - "title": final_title.title, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_TITLE", - data=final_title, - ) - - @broadcast_to_socket - async def on_final_long_summary(self, data): - final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - await transcripts_controller.update( - transcript, - { - "long_summary": final_long_summary.long_summary, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_LONG_SUMMARY", - data=final_long_summary, - ) - - @broadcast_to_socket - async def on_final_short_summary(self, data): - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - await transcripts_controller.update( - transcript, - { - "short_summary": final_short_summary.short_summary, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_SHORT_SUMMARY", - data=final_short_summary, - ) - async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() processors = [ - # add diarization + AudioDiarizationProcessor(), BroadcastProcessor( processors=[ - TranscriptFinalTitleProcessor.as_threaded( - callback=self.on_final_title - ), TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_final_long_summary + callback=self.on_long_summary ), TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_final_short_summary + callback=self.on_short_summary ), ] ), @@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase): pipeline = Pipeline(*processors) pipeline.options = self + # now let's start the pipeline by pushing information to the + # first processor diarization processor + # XXX translation is lost when converting our data model to the processor model + transcript = await self.get_transcript() + topics = [ + TitleSummaryProcessorType( + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + + audio_diarization_input = AudioDiarizationInput( + audio_filename=transcript.audio_mp3_filename, + topics=topics, + ) + + # as tempting to use pipeline.push, prefer to use the runner + # to let the start just do one job. + self.push(audio_diarization_input) + self.flush() + return pipeline -@celery.task +@shared_task def task_pipeline_main_post(transcript_id: str): - pass + runner = PipelineMainDiarization(transcript_id=transcript_id) + runner.start_sync() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index ce84fec4..0575cf96 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status: """ import asyncio -from typing import Callable from pydantic import BaseModel, ConfigDict from reflector.logger import logger @@ -27,8 +26,6 @@ class PipelineRunner(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) status: str = "idle" - on_status: Callable | None = None - on_ended: Callable | None = None pipeline: Pipeline | None = None def __init__(self, **kwargs): @@ -36,6 +33,10 @@ class PipelineRunner(BaseModel): self._q_cmd = asyncio.Queue() self._ev_done = asyncio.Event() self._is_first_push = True + self._logger = logger.bind( + runner=id(self), + runner_cls=self.__class__.__name__, + ) def create(self) -> Pipeline: """ @@ -50,33 +51,51 @@ class PipelineRunner(BaseModel): """ asyncio.get_event_loop().create_task(self.run()) - async def push(self, data): + def start_sync(self): + """ + Start the pipeline synchronously (for non-asyncio apps) + """ + asyncio.run(self.run()) + + def push(self, data): """ Push data to the pipeline """ - await self._add_cmd("PUSH", data) + self._add_cmd("PUSH", data) - async def flush(self): + def flush(self): """ Flush the pipeline """ - await self._add_cmd("FLUSH", None) + self._add_cmd("FLUSH", None) - async def _add_cmd(self, cmd: str, data): + async def on_status(self, status): + """ + Called when the status of the pipeline changes + """ + pass + + async def on_ended(self): + """ + Called when the pipeline ends + """ + pass + + def _add_cmd(self, cmd: str, data): """ Enqueue a command to be executed in the runner. Currently supported commands: PUSH, FLUSH """ - await self._q_cmd.put([cmd, data]) + self._q_cmd.put_nowait([cmd, data]) async def _set_status(self, status): - print("set_status", status) + self._logger.debug("Runner status updated", status=status) self.status = status if self.on_status: try: await self.on_status(status) - except Exception as e: - logger.error("PipelineRunner status_callback error", error=e) + except Exception: + self._logger.exception("Runer error while setting status") async def run(self): try: @@ -95,8 +114,8 @@ class PipelineRunner(BaseModel): await func(data) else: raise Exception(f"Unknown command {cmd}") - except Exception as e: - logger.error("PipelineRunner error", error=e) + except Exception: + self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() if self.on_ended: diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 960c6a35..01a3a174 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,4 +1,5 @@ from .audio_chunker import AudioChunkerProcessor # noqa: F401 +from .audio_diarization import AudioDiarizationProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py new file mode 100644 index 00000000..8db8e8e5 --- /dev/null +++ b/server/reflector/processors/audio_diarization.py @@ -0,0 +1,65 @@ +from reflector.processors.base import Processor +from reflector.processors.types import AudioDiarizationInput, TitleSummary + + +class AudioDiarizationProcessor(Processor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: AudioDiarizationInput): + # Gather diarization data + diarization = [ + {"start": 0.0, "stop": 4.9, "speaker": 2}, + {"start": 5.6, "stop": 6.7, "speaker": 2}, + {"start": 7.3, "stop": 8.9, "speaker": 2}, + {"start": 7.3, "stop": 7.9, "speaker": 0}, + {"start": 9.4, "stop": 11.2, "speaker": 2}, + {"start": 9.7, "stop": 10.0, "speaker": 0}, + {"start": 10.0, "stop": 10.1, "speaker": 0}, + {"start": 11.7, "stop": 16.1, "speaker": 2}, + {"start": 11.8, "stop": 12.1, "speaker": 1}, + {"start": 16.4, "stop": 21.0, "speaker": 2}, + {"start": 21.1, "stop": 22.6, "speaker": 2}, + {"start": 24.7, "stop": 31.9, "speaker": 2}, + {"start": 32.0, "stop": 32.8, "speaker": 1}, + {"start": 33.4, "stop": 37.8, "speaker": 2}, + {"start": 37.9, "stop": 40.3, "speaker": 0}, + {"start": 39.2, "stop": 40.4, "speaker": 2}, + {"start": 40.7, "stop": 41.4, "speaker": 0}, + {"start": 41.6, "stop": 45.7, "speaker": 2}, + {"start": 46.4, "stop": 53.1, "speaker": 2}, + {"start": 53.6, "stop": 56.5, "speaker": 2}, + {"start": 54.9, "stop": 75.4, "speaker": 1}, + {"start": 57.3, "stop": 58.0, "speaker": 2}, + {"start": 65.7, "stop": 66.0, "speaker": 2}, + {"start": 75.8, "stop": 78.8, "speaker": 1}, + {"start": 79.0, "stop": 82.6, "speaker": 1}, + {"start": 83.2, "stop": 83.3, "speaker": 1}, + {"start": 84.5, "stop": 94.3, "speaker": 1}, + {"start": 95.1, "stop": 100.7, "speaker": 1}, + {"start": 100.7, "stop": 102.0, "speaker": 0}, + {"start": 100.7, "stop": 101.8, "speaker": 1}, + {"start": 102.0, "stop": 103.0, "speaker": 1}, + {"start": 103.0, "stop": 103.7, "speaker": 0}, + {"start": 103.7, "stop": 103.8, "speaker": 1}, + {"start": 103.8, "stop": 113.9, "speaker": 0}, + {"start": 114.7, "stop": 117.0, "speaker": 0}, + {"start": 117.0, "stop": 117.4, "speaker": 1}, + ] + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + print("IN DIARIZATION PROCESSOR", data) + + # mutate in place + for topic in data.topics: + for word in topic.transcript.words: + for d in diarization: + if d["start"] <= word.start <= d["stop"]: + word.speaker = d["speaker"] + + # emit them + for topic in data.topics: + await self.emit(topic) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index d2c32d17..3ec21491 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -382,3 +382,8 @@ class TranslationLanguages(BaseModel): def is_supported(self, lang_id: str) -> bool: return lang_id in self.supported_languages + + +class AudioDiarizationInput(BaseModel): + audio_filename: Path + topics: list[TitleSummary] diff --git a/server/reflector/tasks/boot.py b/server/reflector/tasks/boot.py deleted file mode 100644 index 88cc2d6f..00000000 --- a/server/reflector/tasks/boot.py +++ /dev/null @@ -1,2 +0,0 @@ -import reflector.tasks.post_transcript # noqa -import reflector.tasks.worker # noqa diff --git a/server/reflector/tasks/worker.py b/server/reflector/tasks/worker.py deleted file mode 100644 index 4379a1b7..00000000 --- a/server/reflector/tasks/worker.py +++ /dev/null @@ -1,6 +0,0 @@ -from celery import Celery -from reflector.settings import settings - -celery = Celery(__name__) -celery.conf.broker_url = settings.CELERY_BROKER_URL -celery.conf.result_backend = settings.CELERY_RESULT_BACKEND diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 5d10c181..386ada9c 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,5 +1,4 @@ import asyncio -from enum import StrEnum from json import loads import av @@ -41,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack): ctx = self.ctx frame = await self.track.recv() try: - await ctx.pipeline_runner.push(frame) + ctx.pipeline_runner.push(frame) except Exception as e: ctx.logger.error("Pipeline error", error=e) return frame @@ -52,19 +51,6 @@ class RtcOffer(BaseModel): type: str -class StrValue(BaseModel): - value: str - - -class PipelineEvent(StrEnum): - TRANSCRIPT = "TRANSCRIPT" - TOPIC = "TOPIC" - FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY" - STATUS = "STATUS" - FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY" - FINAL_TITLE = "FINAL_TITLE" - - async def rtc_offer_base( params: RtcOffer, request: Request, @@ -90,7 +76,7 @@ async def rtc_offer_base( # - when we receive the close event, we do nothing. # 2. or the client close the connection # and there is nothing to do because it is already closed - await ctx.pipeline_runner.flush() + ctx.pipeline_runner.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index e949d645..31cbe28e 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -23,10 +23,9 @@ from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response -from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base +from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() -ws_manager = get_ws_manager() # ============================================================== # Transcripts list @@ -166,32 +165,17 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {"events": []} + values = {} if info.name is not None: values["name"] = info.name if info.locked is not None: values["locked"] = info.locked if info.long_summary is not None: values["long_summary"] = info.long_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY: - transcript_event["long_summary"] = info.long_summary - break - values["events"].extend(transcript.events) if info.short_summary is not None: values["short_summary"] = info.short_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY: - transcript_event["short_summary"] = info.short_summary - break - values["events"].extend(transcript.events) if info.title is not None: values["title"] = info.title - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_TITLE: - transcript_event["title"] = info.title - break - values["events"].extend(transcript.events) await transcripts_controller.update(transcript, values) return transcript @@ -295,6 +279,7 @@ async def transcript_events_websocket( # connect to websocket manager # use ts:transcript_id as room id room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() await ws_manager.add_user_to_room(room_id, websocket) try: @@ -303,9 +288,7 @@ async def transcript_events_websocket( # for now, do not send TRANSCRIPT or STATUS options - theses are live event # not necessary to be sent to the client; but keep the rest name = event.event - if name == PipelineEvent.TRANSCRIPT: - continue - if name == PipelineEvent.STATUS: + if name in ("TRANSCRIPT", "STATUS"): continue await websocket.send_json(event.model_dump(mode="json")) diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py new file mode 100644 index 00000000..3714a64d --- /dev/null +++ b/server/reflector/worker/app.py @@ -0,0 +1,11 @@ +from celery import Celery +from reflector.settings import settings + +app = Celery(__name__) +app.conf.broker_url = settings.CELERY_BROKER_URL +app.conf.result_backend = settings.CELERY_RESULT_BACKEND +app.autodiscover_tasks( + [ + "reflector.pipelines.main_live_pipeline", + ] +) diff --git a/server/reflector/tasks/post_transcript.py b/server/reflector/worker/post_transcript.py similarity index 100% rename from server/reflector/tasks/post_transcript.py rename to server/reflector/worker/post_transcript.py diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 7650807b..a84e3361 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -11,13 +11,12 @@ broadcast messages to all connected websockets. import asyncio import json +import threading import redis.asyncio as redis from fastapi import WebSocket from reflector.settings import settings -ws_manager = None - class RedisPubSubManager: def __init__(self, host="localhost", port=6379): @@ -114,13 +113,14 @@ def get_ws_manager() -> WebsocketManager: ImportError: If the 'reflector.settings' module cannot be imported. RedisConnectionError: If there is an error connecting to the Redis server. """ - global ws_manager - if ws_manager: - return ws_manager + local = threading.local() + if hasattr(local, "ws_manager"): + return local.ws_manager pubsub_client = RedisPubSubManager( host=settings.REDIS_HOST, port=settings.REDIS_PORT, ) ws_manager = WebsocketManager(pubsub_client=pubsub_client) + local.ws_manager = ws_manager return ws_manager diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 76b56abf..d5f5f0b9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -45,17 +45,16 @@ async def dummy_transcript(): from reflector.processors.types import AudioFile, Transcript, Word class TestAudioTranscriptProcessor(AudioTranscriptProcessor): - async def _transcript(self, data: AudioFile): - source_language = self.get_pref("audio:source_language", "en") - print("transcripting", source_language) - print("pipeline", self.pipeline) - print("prefs", self.pipeline.prefs) + _time_idx = 0 + async def _transcript(self, data: AudioFile): + i = self._time_idx + self._time_idx += 2 return Transcript( text="Hello world.", words=[ - Word(start=0.0, end=1.0, text="Hello"), - Word(start=1.0, end=2.0, text=" world."), + Word(start=i, end=i + 1, text="Hello", speaker=0), + Word(start=i + 1, end=i + 2, text=" world.", speaker=0), ], ) @@ -98,7 +97,17 @@ def ensure_casing(): @pytest.fixture def sentence_tokenize(): with patch( - "reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize" + "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize" ) as mock_sent_tokenize: mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"] yield + + +@pytest.fixture(scope="session") +def celery_enable_logging(): + return True + + +@pytest.fixture(scope="session") +def celery_config(): + return {"broker_url": "memory://", "result_backend": "rpc"} diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 50e74231..e2bfee32 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -32,7 +32,7 @@ class ThreadedUvicorn: @pytest.fixture -async def appserver(tmpdir): +async def appserver(tmpdir, celery_session_app, celery_session_worker): from reflector.settings import settings from reflector.app import app @@ -52,6 +52,13 @@ async def appserver(tmpdir): settings.DATA_DIR = DATA_DIR +@pytest.fixture(scope="session") +def celery_includes(): + return ["reflector.pipelines.main_live_pipeline"] + + +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket( tmpdir, @@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket( # XXX aiortc is long to close the connection # instead of waiting a long time, we just send a STOP client.channel.send(json.dumps({"cmd": "STOP"})) - - # wait the processing to finish - await asyncio.sleep(2) - await client.stop() # wait the processing to finish - await asyncio.sleep(2) + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") # stop websocket task websocket_task.cancel() @@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world.") + assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" - # check that transcript status in model is updated - resp = await ac.get(f"/transcripts/{tid}") - assert resp.status_code == 200 - assert resp.json()["status"] == "ended" - # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200 assert resp.headers["Content-Type"] == "audio/mpeg" +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket_and_fr( tmpdir, @@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr( await client.stop() # wait the processing to finish + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] == "ended": + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") + await asyncio.sleep(2) # stop websocket task @@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world.") + assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS"