mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: fixes for tests
This commit is contained in:
@@ -75,6 +75,12 @@ class Transcript(BaseModel):
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
@@ -192,7 +198,7 @@ async def transcript_delete(transcript_id: str):
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio")
|
||||
async def transcript_get_audio(transcript_id: str):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -202,7 +208,7 @@ async def transcript_get_audio(transcript_id: str):
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
async def transcript_get_topics(transcript_id: str):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript.topics
|
||||
@@ -250,7 +256,7 @@ ws_manager = WebsocketManager()
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(transcript_id: str, websocket: WebSocket):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -282,7 +288,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
# transcript from the database for each event.
|
||||
# print(f"Event: {event}", args, data)
|
||||
transcript_id = args
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
@@ -294,6 +300,12 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
@@ -308,8 +320,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events,
|
||||
"topics": transcript.topics,
|
||||
"events": transcript.events_dump(),
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -319,14 +331,20 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events,
|
||||
"summary": transcript.summary,
|
||||
"events": transcript.events_dump(),
|
||||
"summary": final_summary.summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await transcripts_controller.update(transcript, {"status": transcript.status})
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"status": data.value,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
@@ -340,7 +358,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: str, params: RtcOffer, request: Request
|
||||
):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user