server: fixes for tests

This commit is contained in:
2023-08-15 17:09:36 +02:00
committed by Mathieu Virbel
parent 857505124f
commit 044f40eb32

View File

@@ -75,6 +75,12 @@ class Transcript(BaseModel):
else: else:
self.topics.append(topic) self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
class TranscriptController: class TranscriptController:
async def get_all(self) -> list[Transcript]: async def get_all(self) -> list[Transcript]:
@@ -192,7 +198,7 @@ async def transcript_delete(transcript_id: str):
@router.get("/transcripts/{transcript_id}/audio") @router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(transcript_id: str): async def transcript_get_audio(transcript_id: str):
transcript = transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -202,7 +208,7 @@ async def transcript_get_audio(transcript_id: str):
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
async def transcript_get_topics(transcript_id: str): async def transcript_get_topics(transcript_id: str):
transcript = transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics return transcript.topics
@@ -250,7 +256,7 @@ ws_manager = WebsocketManager()
@router.websocket("/transcripts/{transcript_id}/events") @router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(transcript_id: str, websocket: WebSocket): async def transcript_events_websocket(transcript_id: str, websocket: WebSocket):
transcript = transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -282,7 +288,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# transcript from the database for each event. # transcript from the database for each event.
# print(f"Event: {event}", args, data) # print(f"Event: {event}", args, data)
transcript_id = args transcript_id = args
transcript = transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
return return
@@ -294,6 +300,12 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy # FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT: if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text)) resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
},
)
elif event == PipelineEvent.TOPIC: elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic( topic = TranscriptTopic(
@@ -308,8 +320,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"events": transcript.events, "events": transcript.events_dump(),
"topics": transcript.topics, "topics": transcript.topics_dump(),
}, },
) )
@@ -319,14 +331,20 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"events": transcript.events, "events": transcript.events_dump(),
"summary": transcript.summary, "summary": final_summary.summary,
}, },
) )
elif event == PipelineEvent.STATUS: elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data) resp = transcript.add_event(event=event, data=data)
await transcripts_controller.update(transcript, {"status": transcript.status}) await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"status": data.value,
},
)
else: else:
logger.warning(f"Unknown event: {event}") logger.warning(f"Unknown event: {event}")
@@ -340,7 +358,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
async def transcript_record_webrtc( async def transcript_record_webrtc(
transcript_id: str, params: RtcOffer, request: Request transcript_id: str, params: RtcOffer, request: Request
): ):
transcript = transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")