diff --git a/server/reflector/app.py b/server/reflector/app.py index 31509ce9..f40af489 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,9 +1,22 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from reflector.views.rtc_offer import router as rtc_offer_router +from reflector.events import subscribers_startup, subscribers_shutdown +from contextlib import asynccontextmanager + + +# lifespan events +@asynccontextmanager +async def lifespan(app: FastAPI): + for func in subscribers_startup: + await func() + yield + for func in subscribers_shutdown: + await func() + # build app -app = FastAPI() +app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/server/reflector/events.py b/server/reflector/events.py new file mode 100644 index 00000000..221ab4e5 --- /dev/null +++ b/server/reflector/events.py @@ -0,0 +1,2 @@ +subscribers_startup = [] +subscribers_shutdown = [] diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 847db231..da890513 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,9 +1,9 @@ from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401 -from .types import AudioFile, Transcript, Word, TitleSummary # noqa: F401 +from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401 -from .transcript_summarizer import TranscriptSummarizerProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 +from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401 diff --git a/server/reflector/processors/transcript_final_summary.py b/server/reflector/processors/transcript_final_summary.py new file mode 100644 index 00000000..208548f5 --- /dev/null +++ b/server/reflector/processors/transcript_final_summary.py @@ -0,0 +1,30 @@ +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary, FinalSummary + + +class TranscriptFinalSummaryProcessor(Processor): + """ + Assemble all summary into a line-based json + """ + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = FinalSummary + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.chunks: list[TitleSummary] = [] + + async def _push(self, data: TitleSummary): + self.chunks.append(data) + + async def _flush(self): + if not self.chunks: + self.logger.warning("No summary to output") + return + + # FIXME improve final summary + result = "\n".join([chunk.summary for chunk in self.chunks]) + last_chunk = self.chunks[-1] + duration = last_chunk.timestamp + last_chunk.duration + + await self.emit(FinalSummary(summary=result, duration=duration)) diff --git a/server/reflector/processors/transcript_summarizer.py b/server/reflector/processors/transcript_summarizer.py deleted file mode 100644 index e4e55e9e..00000000 --- a/server/reflector/processors/transcript_summarizer.py +++ /dev/null @@ -1,30 +0,0 @@ -from reflector.processors.base import Processor -from reflector.processors.types import TitleSummary -from pathlib import Path -import json - - -class TranscriptSummarizerProcessor(Processor): - """ - Assemble all summary into a line-based json - """ - - INPUT_TYPE = TitleSummary - OUTPUT_TYPE = Path - - def __init__(self, filename: Path, **kwargs): - super().__init__(**kwargs) - self.filename = filename - self.chunkcount = 0 - - async def _push(self, data: TitleSummary): - with open(self.filename, "a", encoding="utf8") as fd: - fd.write(json.dumps(data)) - self.chunkcount += 1 - - async def _flush(self): - if self.chunkcount == 0: - self.logger.warning("No summary to write") - return - self.logger.info(f"Writing to {self.filename}") - await self.emit(self.filename) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index c4c840dd..d762c708 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -66,3 +66,9 @@ class TitleSummary: timestamp: float duration: float transcript: Transcript + + +@dataclass +class FinalSummary: + summary: str + duration: float diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 77007035..c4eaddd8 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,10 +1,10 @@ +import asyncio from fastapi import Request, APIRouter +from reflector.events import subscribers_shutdown from pydantic import BaseModel from reflector.models import ( TranscriptionContext, TranscriptionOutput, - TitleSummaryOutput, - IncrementalResult, ) from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack @@ -17,10 +17,15 @@ from reflector.processors import ( AudioTranscriptAutoProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, + TranscriptFinalSummaryProcessor, Transcript, TitleSummary, + FinalSummary, ) +sessions = [] +router = APIRouter() + class AudioStreamTrack(MediaStreamTrack): """ @@ -49,10 +54,6 @@ class RtcOffer(BaseModel): type: str -sessions = [] -router = APIRouter() - - @router.post("/offer") async def rtc_offer(params: RtcOffer, request: Request): # build an rtc session @@ -62,40 +63,38 @@ async def rtc_offer(params: RtcOffer, request: Request): peername = request.client clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) + ctx.topics = [] # build pipeline callback async def on_transcript(transcript: Transcript): ctx.logger.info("Transcript", transcript=transcript) - cmd = TranscriptionOutput(transcript.text) - # FIXME: send the result to the client async way - ctx.data_channel.send(dumps(cmd.get_result())) - - async def on_summary(summary: TitleSummary): - ctx.logger.info("Summary", summary=summary) - # XXX doesnt work as expected, IncrementalResult is not serializable - # and previous implementation assume output of oobagooda - # result = TitleSummaryOutput( - # [ - # IncrementalResult( - # title=summary.title, - # desc=summary.summary, - # transcript=summary.transcript.text, - # timestamp=summary.timestamp, - # ) - # ] - # ) result = { - "cmd": "UPDATE_TOPICS", - "topics": [ - { - "title": summary.title, - "timestamp": summary.timestamp, - "transcript": summary.transcript.text, - "desc": summary.summary, - } - ], + "cmd": "SHOW_TRANSCRIPTION", + "text": transcript.text, } + ctx.data_channel.send(dumps(result)) + async def on_topic(summary: TitleSummary): + # FIXME: make it incremental with the frontend, not send everything + ctx.logger.info("Summary", summary=summary) + ctx.topics.append( + { + "title": summary.title, + "timestamp": summary.timestamp, + "transcript": summary.transcript.text, + "desc": summary.summary, + } + ) + result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} + ctx.data_channel.send(dumps(result)) + + async def on_final_summary(summary: FinalSummary): + ctx.logger.info("FinalSummary", final_summary=summary) + result = { + "cmd": "DISPLAY_FINAL_SUMMARY", + "summary": summary.summary, + "duration": summary.duration, + } ctx.data_channel.send(dumps(result)) # create a context for the whole rtc transaction @@ -105,15 +104,19 @@ async def rtc_offer(params: RtcOffer, request: Request): AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), TranscriptLinerProcessor(), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary), - # FinalSummaryProcessor.as_threaded( - # filename=result_fn, callback=on_final_summary - # ), + TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), + TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), ) # handle RTC peer connection pc = RTCPeerConnection() + async def flush_pipeline_and_quit(): + ctx.logger.info("Flushing pipeline") + await ctx.pipeline.flush() + ctx.logger.debug("Closing peer connection") + await pc.close() + @pc.on("datachannel") def on_datachannel(channel): ctx.data_channel = channel @@ -124,8 +127,8 @@ async def rtc_offer(params: RtcOffer, request: Request): def on_message(message: str): ctx.logger.info(f"Message: {message}") if loads(message)["cmd"] == "STOP": - # FIXME: flush the pipeline - pass + ctx.logger.debug("STOP command received") + asyncio.get_event_loop().create_task(flush_pipeline_and_quit()) if isinstance(message, str) and message.startswith("ping"): channel.send("pong" + message[4:]) @@ -148,3 +151,12 @@ async def rtc_offer(params: RtcOffer, request: Request): sessions.append(pc) return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type) + + +@subscribers_shutdown.append +async def rtc_clean_sessions(): + logger.info("Closing all RTC sessions") + for pc in sessions: + logger.debug(f"Closing session {pc}") + await pc.close() + sessions.clear()