From 1c42473da029bb2f88b090a7323ac4f68b818275 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 26 Oct 2023 19:00:56 +0200 Subject: [PATCH] server: refactor with clearer pipeline instanciation and linked to model --- server/reflector/db/__init__.py | 23 +- server/reflector/db/transcripts.py | 284 +++++++++++++++ .../reflector/pipelines/main_live_pipeline.py | 230 ++++++++++++ server/reflector/pipelines/runner.py | 117 ++++++ server/reflector/processors/__init__.py | 8 +- server/reflector/views/rtc_offer.py | 267 +------------- server/reflector/views/transcripts.py | 341 +----------------- server/reflector/ws_manager.py | 4 +- 8 files changed, 658 insertions(+), 616 deletions(-) create mode 100644 server/reflector/db/transcripts.py create mode 100644 server/reflector/pipelines/main_live_pipeline.py create mode 100644 server/reflector/pipelines/runner.py diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index b68dfe20..9871c633 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,32 +1,13 @@ import databases import sqlalchemy - from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings database = databases.Database(settings.DATABASE_URL) metadata = sqlalchemy.MetaData() - -transcripts = sqlalchemy.Table( - "transcript", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String), - sqlalchemy.Column("status", sqlalchemy.String), - sqlalchemy.Column("locked", sqlalchemy.Boolean), - sqlalchemy.Column("duration", sqlalchemy.Integer), - sqlalchemy.Column("created_at", sqlalchemy.DateTime), - sqlalchemy.Column("title", sqlalchemy.String, nullable=True), - sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("topics", sqlalchemy.JSON), - sqlalchemy.Column("events", sqlalchemy.JSON), - sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), - sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), - # with user attached, optional - sqlalchemy.Column("user_id", sqlalchemy.String), -) +# import models +import reflector.db.transcripts # noqa engine = sqlalchemy.create_engine( settings.DATABASE_URL, connect_args={"check_same_thread": False} diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py new file mode 100644 index 00000000..2b9fc6b2 --- /dev/null +++ b/server/reflector/db/transcripts.py @@ -0,0 +1,284 @@ +import json +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import sqlalchemy +from pydantic import BaseModel, Field +from reflector.db import database, metadata +from reflector.processors.types import Word as ProcessorWord +from reflector.settings import settings +from reflector.utils.audio_waveform import get_audio_waveform + +transcripts = sqlalchemy.Table( + "transcript", + metadata, + sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("name", sqlalchemy.String), + sqlalchemy.Column("status", sqlalchemy.String), + sqlalchemy.Column("locked", sqlalchemy.Boolean), + sqlalchemy.Column("duration", sqlalchemy.Integer), + sqlalchemy.Column("created_at", sqlalchemy.DateTime), + sqlalchemy.Column("title", sqlalchemy.String, nullable=True), + sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("topics", sqlalchemy.JSON), + sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + # with user attached, optional + sqlalchemy.Column("user_id", sqlalchemy.String), +) + + +def generate_uuid4(): + return str(uuid4()) + + +def generate_transcript_name(): + now = datetime.utcnow() + return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" + + +class AudioWaveform(BaseModel): + data: list[float] + + +class TranscriptText(BaseModel): + text: str + translation: str | None + + +class TranscriptSegmentTopic(BaseModel): + speaker: int + text: str + timestamp: float + + +class TranscriptTopic(BaseModel): + id: str = Field(default_factory=generate_uuid4) + title: str + summary: str + timestamp: float + text: str | None = None + words: list[ProcessorWord] = [] + + +class TranscriptFinalShortSummary(BaseModel): + short_summary: str + + +class TranscriptFinalLongSummary(BaseModel): + long_summary: str + + +class TranscriptFinalTitle(BaseModel): + title: str + + +class TranscriptEvent(BaseModel): + event: str + data: dict + + +class Transcript(BaseModel): + id: str = Field(default_factory=generate_uuid4) + user_id: str | None = None + name: str = Field(default_factory=generate_transcript_name) + status: str = "idle" + locked: bool = False + duration: float = 0 + created_at: datetime = Field(default_factory=datetime.utcnow) + title: str | None = None + short_summary: str | None = None + long_summary: str | None = None + topics: list[TranscriptTopic] = [] + events: list[TranscriptEvent] = [] + source_language: str = "en" + target_language: str = "en" + + def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: + ev = TranscriptEvent(event=event, data=data.model_dump()) + self.events.append(ev) + return ev + + def upsert_topic(self, topic: TranscriptTopic): + existing_topic = next((t for t in self.topics if t.id == topic.id), None) + if existing_topic: + existing_topic.update_from(topic) + else: + self.topics.append(topic) + + def events_dump(self, mode="json"): + return [event.model_dump(mode=mode) for event in self.events] + + def topics_dump(self, mode="json"): + return [topic.model_dump(mode=mode) for topic in self.topics] + + def convert_audio_to_waveform(self, segments_count=256): + fn = self.audio_waveform_filename + if fn.exists(): + return + waveform = get_audio_waveform( + path=self.audio_mp3_filename, segments_count=segments_count + ) + try: + with open(fn, "w") as fd: + json.dump(waveform, fd) + except Exception: + # remove file if anything happen during the write + fn.unlink(missing_ok=True) + raise + return waveform + + def unlink(self): + self.data_path.unlink(missing_ok=True) + + @property + def data_path(self): + return Path(settings.DATA_DIR) / self.id + + @property + def audio_mp3_filename(self): + return self.data_path / "audio.mp3" + + @property + def audio_waveform_filename(self): + return self.data_path / "audio.json" + + @property + def audio_waveform(self): + try: + with open(self.audio_waveform_filename) as fd: + data = json.load(fd) + except json.JSONDecodeError: + # unlink file if it's corrupted + self.audio_waveform_filename.unlink(missing_ok=True) + return None + + return AudioWaveform(data=data) + + +class TranscriptController: + async def get_all( + self, + user_id: str | None = None, + order_by: str | None = None, + filter_empty: bool | None = True, + filter_recording: bool | None = True, + ) -> list[Transcript]: + """ + Get all transcripts + + If `user_id` is specified, only return transcripts that belong to the user. + Otherwise, return all anonymous transcripts. + + Parameters: + - `order_by`: field to order by, e.g. "-created_at" + - `filter_empty`: filter out empty transcripts + - `filter_recording`: filter out transcripts that are currently recording + """ + query = transcripts.select().where(transcripts.c.user_id == user_id) + + if order_by is not None: + field = getattr(transcripts.c, order_by[1:]) + if order_by.startswith("-"): + field = field.desc() + query = query.order_by(field) + + if filter_empty: + query = query.filter(transcripts.c.status != "idle") + + if filter_recording: + query = query.filter(transcripts.c.status != "recording") + + results = await database.fetch_all(query) + return results + + async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: + """ + Get a transcript by id + """ + query = transcripts.select().where(transcripts.c.id == transcript_id) + if "user_id" in kwargs: + query = query.where(transcripts.c.user_id == kwargs["user_id"]) + result = await database.fetch_one(query) + if not result: + return None + return Transcript(**result) + + async def add( + self, + name: str, + source_language: str = "en", + target_language: str = "en", + user_id: str | None = None, + ): + """ + Add a new transcript + """ + transcript = Transcript( + name=name, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) + query = transcripts.insert().values(**transcript.model_dump()) + await database.execute(query) + return transcript + + async def update(self, transcript: Transcript, values: dict): + """ + Update a transcript fields with key/values in values + """ + query = ( + transcripts.update() + .where(transcripts.c.id == transcript.id) + .values(**values) + ) + await database.execute(query) + for key, value in values.items(): + setattr(transcript, key, value) + + async def remove_by_id( + self, + transcript_id: str, + user_id: str | None = None, + ) -> None: + """ + Remove a transcript by id + """ + transcript = await self.get_by_id(transcript_id, user_id=user_id) + if not transcript: + return + if user_id is not None and transcript.user_id != user_id: + return + transcript.unlink() + query = transcripts.delete().where(transcripts.c.id == transcript_id) + await database.execute(query) + + @asynccontextmanager + async def transaction(self): + """ + A context manager for database transaction + """ + async with database.transaction(): + yield + + async def append_event( + self, + transcript: Transcript, + event: str, + data: Any, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + resp = transcript.add_event(event=event, data=data) + await self.update(transcript, {"events": transcript.events_dump()}) + return resp + + +transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py new file mode 100644 index 00000000..30f7ead3 --- /dev/null +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -0,0 +1,230 @@ +""" +Main reflector pipeline for live streaming +========================================== + +This is the default pipeline used in the API. + +It is decoupled to: +- PipelineMainLive: have limited processing during live +- PipelineMainPost: do heavy lifting after the live + +It is directly linked to our data model. +""" + +from pathlib import Path + +from reflector.db.transcripts import ( + Transcript, + TranscriptFinalLongSummary, + TranscriptFinalShortSummary, + TranscriptFinalTitle, + TranscriptText, + TranscriptTopic, + transcripts_controller, +) +from reflector.pipelines.runner import PipelineRunner +from reflector.processors import ( + AudioChunkerProcessor, + AudioFileWriterProcessor, + AudioMergeProcessor, + AudioTranscriptAutoProcessor, + BroadcastProcessor, + Pipeline, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, + TranscriptLinerProcessor, + TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, +) +from reflector.tasks.worker import celery +from reflector.ws_manager import WebsocketManager, get_ws_manager + + +def broadcast_to_socket(func): + """ + Decorator to broadcast transcript event to websockets + concerning this transcript + """ + + async def wrapper(self, *args, **kwargs): + resp = await func(self, *args, **kwargs) + if resp is None: + return + await self.ws_manager.send_json( + room_id=self.ws_room_id, + message=resp.model_dump(mode="json"), + ) + + return wrapper + + +class PipelineMainBase(PipelineRunner): + transcript_id: str + ws_room_id: str | None = None + ws_manager: WebsocketManager | None = None + + def prepare(self): + # prepare websocket + self.ws_room_id = f"ts:{self.transcript_id}" + self.ws_manager = get_ws_manager() + + async def get_transcript(self) -> Transcript: + # fetch the transcript + result = await transcripts_controller.get_by_id( + transcript_id=self.transcript_id + ) + if not result: + raise Exception("Transcript not found") + return result + + +class PipelineMainLive(PipelineMainBase): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + + @broadcast_to_socket + async def on_transcript(self, data): + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + return await transcripts_controller.append_event( + transcript=transcript, + event="TRANSCRIPT", + data=TranscriptText(text=data.text, translation=data.translation), + ) + + @broadcast_to_socket + async def on_topic(self, data): + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + text=data.transcript.text, + words=data.transcript.words, + ) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + return await transcripts_controller.append_event( + transcript=transcript, + event="TOPIC", + data=topic, + ) + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + transcript = await self.get_transcript() + + processors = [ + AudioFileWriterProcessor(path=transcript.audio_mp3_filename), + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + pipeline.set_pref("audio:source_language", transcript.source_language) + pipeline.set_pref("audio:target_language", transcript.target_language) + + # when the pipeline ends, connect to the post pipeline + async def on_ended(): + task_pipeline_main_post.delay(transcript_id=self.transcript_id) + + pipeline.on_ended = self + + return pipeline + + +class PipelineMainPost(PipelineMainBase): + """ + Implement the rest of the main pipeline, triggered after PipelineMainLive ended. + """ + + @broadcast_to_socket + async def on_final_title(self, data): + final_title = TranscriptFinalTitle(title=data.title) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + if not transcript.title: + transcripts_controller.update( + self.transcript, + { + "title": final_title.title, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_TITLE", + data=final_title, + ) + + @broadcast_to_socket + async def on_final_long_summary(self, data): + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "long_summary": final_long_summary.long_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + @broadcast_to_socket + async def on_final_short_summary(self, data): + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "short_summary": final_short_summary.short_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + processors = [ + # add diarization + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded( + callback=self.on_final_title + ), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_final_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_final_short_summary + ), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + + return pipeline + + +@celery.task +def task_pipeline_main_post(transcript_id: str): + pass diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py new file mode 100644 index 00000000..ce84fec4 --- /dev/null +++ b/server/reflector/pipelines/runner.py @@ -0,0 +1,117 @@ +""" +Pipeline Runner +=============== + +Pipeline runner designed to be executed in a asyncio task. + +It is meant to be subclassed, and implement a create() method +that expose/return a Pipeline instance. + +During its lifecycle, it will emit the following status: +- started: the pipeline has been started +- push: the pipeline received at least one data +- flush: the pipeline is flushing +- ended: the pipeline has ended +- error: the pipeline has ended with an error +""" + +import asyncio +from typing import Callable + +from pydantic import BaseModel, ConfigDict +from reflector.logger import logger +from reflector.processors import Pipeline + + +class PipelineRunner(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + status: str = "idle" + on_status: Callable | None = None + on_ended: Callable | None = None + pipeline: Pipeline | None = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._q_cmd = asyncio.Queue() + self._ev_done = asyncio.Event() + self._is_first_push = True + + def create(self) -> Pipeline: + """ + Create the pipeline if not specified earlier. + Should be implemented in a subclass + """ + raise NotImplementedError() + + def start(self): + """ + Start the pipeline as a coroutine task + """ + asyncio.get_event_loop().create_task(self.run()) + + async def push(self, data): + """ + Push data to the pipeline + """ + await self._add_cmd("PUSH", data) + + async def flush(self): + """ + Flush the pipeline + """ + await self._add_cmd("FLUSH", None) + + async def _add_cmd(self, cmd: str, data): + """ + Enqueue a command to be executed in the runner. + Currently supported commands: PUSH, FLUSH + """ + await self._q_cmd.put([cmd, data]) + + async def _set_status(self, status): + print("set_status", status) + self.status = status + if self.on_status: + try: + await self.on_status(status) + except Exception as e: + logger.error("PipelineRunner status_callback error", error=e) + + async def run(self): + try: + # create the pipeline if not yet done + await self._set_status("init") + self._is_first_push = True + if not self.pipeline: + self.pipeline = await self.create() + + # start the loop + await self._set_status("started") + while not self._ev_done.is_set(): + cmd, data = await self._q_cmd.get() + func = getattr(self, f"cmd_{cmd.lower()}") + if func: + await func(data) + else: + raise Exception(f"Unknown command {cmd}") + except Exception as e: + logger.error("PipelineRunner error", error=e) + await self._set_status("error") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + + async def cmd_push(self, data): + if self._is_first_push: + await self._set_status("push") + self._is_first_push = False + await self.pipeline.push(data) + + async def cmd_flush(self, data): + await self._set_status("flush") + await self.pipeline.flush() + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 96a3941d..960c6a35 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -3,7 +3,13 @@ from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 -from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401 +from .base import ( # noqa: F401 + BroadcastProcessor, + Pipeline, + PipelineEvent, + Processor, + ThreadedProcessor, +) from .transcript_final_long_summary import ( # noqa: F401 TranscriptFinalLongSummaryProcessor, ) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 48d804cc..5d10c181 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,8 +1,6 @@ import asyncio from enum import StrEnum -from json import dumps, loads -from pathlib import Path -from typing import Callable +from json import loads import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription @@ -11,25 +9,7 @@ from prometheus_client import Gauge from pydantic import BaseModel from reflector.events import subscribers_shutdown from reflector.logger import logger -from reflector.processors import ( - AudioChunkerProcessor, - AudioFileWriterProcessor, - AudioMergeProcessor, - AudioTranscriptAutoProcessor, - FinalLongSummary, - FinalShortSummary, - Pipeline, - TitleSummary, - Transcript, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, - TranscriptFinalTitleProcessor, - TranscriptLinerProcessor, - TranscriptTopicDetectorProcessor, - TranscriptTranslatorProcessor, -) -from reflector.processors.base import BroadcastProcessor -from reflector.processors.types import FinalTitle +from reflector.pipelines.runner import PipelineRunner sessions = [] router = APIRouter() @@ -85,121 +65,10 @@ class PipelineEvent(StrEnum): FINAL_TITLE = "FINAL_TITLE" -class PipelineOptions(BaseModel): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" - - on_transcript: Callable | None = None - on_topic: Callable | None = None - on_final_title: Callable | None = None - on_final_short_summary: Callable | None = None - on_final_long_summary: Callable | None = None - - -class PipelineRunner(object): - """ - Pipeline runner designed to be executed in a asyncio task - """ - - def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None): - self.pipeline = pipeline - self.q_cmd = asyncio.Queue() - self.ev_done = asyncio.Event() - self.status = "idle" - self.status_callback = status_callback - - async def update_status(self, status): - print("update_status", status) - self.status = status - if self.status_callback: - try: - await self.status_callback(status) - except Exception as e: - logger.error("PipelineRunner status_callback error", error=e) - - async def add_cmd(self, cmd: str, data): - await self.q_cmd.put([cmd, data]) - - async def push(self, data): - await self.add_cmd("PUSH", data) - - async def flush(self): - await self.add_cmd("FLUSH", None) - - async def run(self): - try: - await self.update_status("running") - while not self.ev_done.is_set(): - cmd, data = await self.q_cmd.get() - func = getattr(self, f"cmd_{cmd.lower()}") - if func: - await func(data) - else: - raise Exception(f"Unknown command {cmd}") - except Exception as e: - await self.update_status("error") - logger.error("PipelineRunner error", error=e) - - async def cmd_push(self, data): - if self.status == "idle": - await self.update_status("recording") - await self.pipeline.push(data) - - async def cmd_flush(self, data): - await self.update_status("processing") - await self.pipeline.flush() - await self.update_status("ended") - self.ev_done.set() - - def start(self): - print("start task") - asyncio.get_event_loop().create_task(self.run()) - - -async def pipeline_live_create(options: PipelineOptions): - # create a context for the whole rtc transaction - # add a customised logger to the context - processors = [] - if options.audio_filename is not None: - processors += [AudioFileWriterProcessor(path=options.audio_filename)] - processors += [ - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(), - TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded( - callback=options.on_final_title - ), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=options.on_final_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=options.on_final_short_summary - ), - ] - ), - ] - pipeline = Pipeline(*processors) - pipeline.options = options - pipeline.set_pref("audio:source_language", options.source_language) - pipeline.set_pref("audio:target_language", options.target_language) - - return pipeline - - async def rtc_offer_base( params: RtcOffer, request: Request, - event_callback=None, - event_callback_args=None, - audio_filename: Path | None = None, - source_language: str = "en", - target_language: str = "en", + pipeline_runner: PipelineRunner, ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -209,132 +78,9 @@ async def rtc_offer_base( clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - async def update_status(status: str): - changed = ctx.status != status - if changed: - ctx.status = status - if event_callback: - await event_callback( - event=PipelineEvent.STATUS, - args=event_callback_args, - data=StrValue(value=status), - ) - - # build pipeline callback - async def on_transcript(transcript: Transcript): - ctx.logger.info("Transcript", transcript=transcript) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": transcript.text, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TRANSCRIPT, - args=event_callback_args, - data=transcript, - ) - - async def on_topic(topic: TitleSummary): - # FIXME: make it incremental with the frontend, not send everything - ctx.logger.info("Topic", topic=topic) - ctx.topics.append( - { - "title": topic.title, - "timestamp": topic.timestamp, - "transcript": topic.transcript.text, - "desc": topic.summary, - } - ) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=topic - ) - - async def on_final_short_summary(summary: FinalShortSummary): - ctx.logger.info("FinalShortSummary", final_short_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_SHORT_SUMMARY", - "summary": summary.short_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_SHORT_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_long_summary(summary: FinalLongSummary): - ctx.logger.info("FinalLongSummary", final_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_LONG_SUMMARY", - "summary": summary.long_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_LONG_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_title(title: FinalTitle): - ctx.logger.info("FinalTitle", final_title=title) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_TITLE, - args=event_callback_args, - data=title, - ) - # handle RTC peer connection pc = RTCPeerConnection() - - # create pipeline - options = PipelineOptions( - audio_filename=audio_filename, - source_language=source_language, - target_language=target_language, - on_transcript=on_transcript, - on_topic=on_topic, - on_final_short_summary=on_final_short_summary, - on_final_long_summary=on_final_long_summary, - on_final_title=on_final_title, - ) - pipeline = await pipeline_live_create(options) - ctx.pipeline_runner = PipelineRunner(pipeline, update_status) + ctx.pipeline_runner = pipeline_runner ctx.pipeline_runner.start() async def flush_pipeline_and_quit(close=True): @@ -400,8 +146,3 @@ async def rtc_clean_sessions(_): logger.debug(f"Closing session {pc}") await pc.close() sessions.clear() - - -@router.post("/offer") -async def rtc_offer(params: RtcOffer, request: Request): - return await rtc_offer_base(params, request) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 9f02eb6d..e949d645 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,8 +1,5 @@ -import json from datetime import datetime -from pathlib import Path from typing import Annotated, Optional -from uuid import uuid4 import reflector.auth as auth from fastapi import ( @@ -15,12 +12,13 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field -from reflector.db import database, transcripts -from reflector.logger import logger +from reflector.db.transcripts import ( + AudioWaveform, + TranscriptTopic, + transcripts_controller, +) from reflector.processors.types import Transcript as ProcessorTranscript -from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool @@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() ws_manager = get_ws_manager() -# ============================================================== -# Models to move to a database, but required for the API to work -# ============================================================== - - -def generate_uuid4(): - return str(uuid4()) - - -def generate_transcript_name(): - now = datetime.utcnow() - return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" - - -class AudioWaveform(BaseModel): - data: list[float] - - -class TranscriptText(BaseModel): - text: str - translation: str | None - - -class TranscriptSegmentTopic(BaseModel): - speaker: int - text: str - timestamp: float - - -class TranscriptTopic(BaseModel): - id: str = Field(default_factory=generate_uuid4) - title: str - summary: str - timestamp: float - text: str | None = None - words: list[ProcessorWord] = [] - - -class TranscriptFinalShortSummary(BaseModel): - short_summary: str - - -class TranscriptFinalLongSummary(BaseModel): - long_summary: str - - -class TranscriptFinalTitle(BaseModel): - title: str - - -class TranscriptEvent(BaseModel): - event: str - data: dict - - -class Transcript(BaseModel): - id: str = Field(default_factory=generate_uuid4) - user_id: str | None = None - name: str = Field(default_factory=generate_transcript_name) - status: str = "idle" - locked: bool = False - duration: float = 0 - created_at: datetime = Field(default_factory=datetime.utcnow) - title: str | None = None - short_summary: str | None = None - long_summary: str | None = None - topics: list[TranscriptTopic] = [] - events: list[TranscriptEvent] = [] - source_language: str = "en" - target_language: str = "en" - - def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: - ev = TranscriptEvent(event=event, data=data.model_dump()) - self.events.append(ev) - return ev - - def upsert_topic(self, topic: TranscriptTopic): - existing_topic = next((t for t in self.topics if t.id == topic.id), None) - if existing_topic: - existing_topic.update_from(topic) - else: - self.topics.append(topic) - - def events_dump(self, mode="json"): - return [event.model_dump(mode=mode) for event in self.events] - - def topics_dump(self, mode="json"): - return [topic.model_dump(mode=mode) for topic in self.topics] - - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - - def unlink(self): - self.data_path.unlink(missing_ok=True) - - @property - def data_path(self): - return Path(settings.DATA_DIR) / self.id - - @property - def audio_mp3_filename(self): - return self.data_path / "audio.mp3" - - @property - def audio_waveform_filename(self): - return self.data_path / "audio.json" - - @property - def audio_waveform(self): - try: - with open(self.audio_waveform_filename) as fd: - data = json.load(fd) - except json.JSONDecodeError: - # unlink file if it's corrupted - self.audio_waveform_filename.unlink(missing_ok=True) - return None - - return AudioWaveform(data=data) - - -class TranscriptController: - async def get_all( - self, - user_id: str | None = None, - order_by: str | None = None, - filter_empty: bool | None = False, - filter_recording: bool | None = False, - ) -> list[Transcript]: - query = transcripts.select().where(transcripts.c.user_id == user_id) - - if order_by is not None: - field = getattr(transcripts.c, order_by[1:]) - if order_by.startswith("-"): - field = field.desc() - query = query.order_by(field) - - if filter_empty: - query = query.filter(transcripts.c.status != "idle") - - if filter_recording: - query = query.filter(transcripts.c.status != "recording") - - results = await database.fetch_all(query) - return results - - async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: - query = transcripts.select().where(transcripts.c.id == transcript_id) - if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) - result = await database.fetch_one(query) - if not result: - return None - return Transcript(**result) - - async def add( - self, - name: str, - source_language: str = "en", - target_language: str = "en", - user_id: str | None = None, - ): - transcript = Transcript( - name=name, - source_language=source_language, - target_language=target_language, - user_id=user_id, - ) - query = transcripts.insert().values(**transcript.model_dump()) - await database.execute(query) - return transcript - - async def update(self, transcript: Transcript, values: dict): - query = ( - transcripts.update() - .where(transcripts.c.id == transcript.id) - .values(**values) - ) - await database.execute(query) - for key, value in values.items(): - setattr(transcript, key, value) - - async def remove_by_id( - self, transcript_id: str, user_id: str | None = None - ) -> None: - transcript = await self.get_by_id(transcript_id, user_id=user_id) - if not transcript: - return - if user_id is not None and transcript.user_id != user_id: - return - transcript.unlink() - query = transcripts.delete().where(transcripts.c.id == transcript_id) - await database.execute(query) - - -transcripts_controller = TranscriptController() - - # ============================================================== # Transcripts list # ============================================================== @@ -537,114 +325,6 @@ async def transcript_events_websocket( # ============================================================== -async def handle_rtc_event(event: PipelineEvent, args, data): - try: - return await handle_rtc_event_once(event, args, data) - except Exception: - logger.exception("Error handling RTC event") - - -async def handle_rtc_event_once(event: PipelineEvent, args, data): - # OFC the current implementation is not good, - # but it's just a POC before persistence. It won't query the - # transcript from the database for each event. - # print(f"Event: {event}", args, data) - transcript_id = args - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - return - - # event send to websocket clients may not be the same as the event - # received from the pipeline. For example, the pipeline will send - # a TRANSCRIPT event with all words, but this is not what we want - # to send to the websocket client. - - # FIXME don't do copy - if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event( - event=event, - data=TranscriptText(text=data.text, translation=data.translation), - ) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - }, - ) - - elif event == PipelineEvent.TOPIC: - topic = TranscriptTopic( - title=data.title, - summary=data.summary, - timestamp=data.timestamp, - text=data.transcript.text, - words=data.transcript.words, - ) - resp = transcript.add_event(event=event, data=topic) - transcript.upsert_topic(topic) - - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "topics": transcript.topics_dump(), - }, - ) - - elif event == PipelineEvent.FINAL_TITLE: - final_title = TranscriptFinalTitle(title=data.title) - resp = transcript.add_event(event=event, data=final_title) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "title": final_title.title, - }, - ) - - elif event == PipelineEvent.FINAL_LONG_SUMMARY: - final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - resp = transcript.add_event(event=event, data=final_long_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "long_summary": final_long_summary.long_summary, - }, - ) - - elif event == PipelineEvent.FINAL_SHORT_SUMMARY: - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - resp = transcript.add_event(event=event, data=final_short_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "short_summary": final_short_summary.short_summary, - }, - ) - - elif event == PipelineEvent.STATUS: - resp = transcript.add_event(event=event, data=data) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "status": data.value, - }, - ) - - else: - logger.warning(f"Unknown event: {event}") - return - - # transmit to websocket clients - room_id = f"ts:{transcript_id}" - await ws_manager.send_json(room_id, resp.model_dump(mode="json")) - - @router.post("/transcripts/{transcript_id}/record/webrtc") async def transcript_record_webrtc( transcript_id: str, @@ -660,13 +340,14 @@ async def transcript_record_webrtc( if transcript.locked: raise HTTPException(status_code=400, detail="Transcript is locked") + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + # FIXME do not allow multiple recording at the same time return await rtc_offer_base( params, request, - event_callback=handle_rtc_event, - event_callback_args=transcript_id, - audio_filename=transcript.audio_mp3_filename, - source_language=transcript.source_language, - target_language=transcript.target_language, + pipeline_runner=pipeline_runner, ) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 1dfe9e3d..7650807b 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -14,6 +14,7 @@ import json import redis.asyncio as redis from fastapi import WebSocket +from reflector.settings import settings ws_manager = None @@ -114,7 +115,8 @@ def get_ws_manager() -> WebsocketManager: RedisConnectionError: If there is an error connecting to the Redis server. """ global ws_manager - from reflector.settings import settings + if ws_manager: + return ws_manager pubsub_client = RedisPubSubManager( host=settings.REDIS_HOST,