diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 124781fc..f2a8425e 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -75,6 +75,12 @@ class Transcript(BaseModel): else: self.topics.append(topic) + def events_dump(self, mode="json"): + return [event.model_dump(mode=mode) for event in self.events] + + def topics_dump(self, mode="json"): + return [topic.model_dump(mode=mode) for topic in self.topics] + class TranscriptController: async def get_all(self) -> list[Transcript]: @@ -192,7 +198,7 @@ async def transcript_delete(transcript_id: str): @router.get("/transcripts/{transcript_id}/audio") async def transcript_get_audio(transcript_id: str): - transcript = transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -202,7 +208,7 @@ async def transcript_get_audio(transcript_id: str): @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) async def transcript_get_topics(transcript_id: str): - transcript = transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") return transcript.topics @@ -250,7 +256,7 @@ ws_manager = WebsocketManager() @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket(transcript_id: str, websocket: WebSocket): - transcript = transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -282,7 +288,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # transcript from the database for each event. # print(f"Event: {event}", args, data) transcript_id = args - transcript = transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: return @@ -294,6 +300,12 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # FIXME don't do copy if event == PipelineEvent.TRANSCRIPT: resp = transcript.add_event(event=event, data=TranscriptText(text=data.text)) + await transcripts_controller.update( + transcript, + { + "events": transcript.events_dump(), + }, + ) elif event == PipelineEvent.TOPIC: topic = TranscriptTopic( @@ -308,8 +320,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data): await transcripts_controller.update( transcript, { - "events": transcript.events, - "topics": transcript.topics, + "events": transcript.events_dump(), + "topics": transcript.topics_dump(), }, ) @@ -319,14 +331,20 @@ async def handle_rtc_event(event: PipelineEvent, args, data): await transcripts_controller.update( transcript, { - "events": transcript.events, - "summary": transcript.summary, + "events": transcript.events_dump(), + "summary": final_summary.summary, }, ) elif event == PipelineEvent.STATUS: resp = transcript.add_event(event=event, data=data) - await transcripts_controller.update(transcript, {"status": transcript.status}) + await transcripts_controller.update( + transcript, + { + "events": transcript.events_dump(), + "status": data.value, + }, + ) else: logger.warning(f"Unknown event: {event}") @@ -340,7 +358,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data): async def transcript_record_webrtc( transcript_id: str, params: RtcOffer, request: Request ): - transcript = transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found")