api: implement first server API + tests

This commit is contained in:
Mathieu Virbel
2023-08-04 18:16:37 +02:00
parent 20767fde3f
commit 96f52c631a
8 changed files with 327 additions and 23 deletions

View File

@@ -1,9 +1,8 @@
from dataclasses import dataclass
from pydantic import BaseModel
from pathlib import Path
@dataclass
class AudioFile:
class AudioFile(BaseModel):
path: Path
sample_rate: int
channels: int
@@ -14,15 +13,13 @@ class AudioFile:
self.path.unlink()
@dataclass
class Word:
class Word(BaseModel):
text: str
start: float
end: float
@dataclass
class Transcript:
class Transcript(BaseModel):
text: str = ""
words: list[Word] = None
@@ -59,8 +56,7 @@ class Transcript:
return Transcript(text=self.text, words=words)
@dataclass
class TitleSummary:
class TitleSummary(BaseModel):
title: str
summary: str
timestamp: float
@@ -75,7 +71,6 @@ class TitleSummary:
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
@dataclass
class FinalSummary:
class FinalSummary(BaseModel):
summary: str
duration: float

View File

@@ -3,7 +3,6 @@ import time
import uuid
import httpx
import pyaudio
import stamina
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaPlayer, MediaRelay
@@ -24,7 +23,6 @@ class StreamClient:
self.server_url = url
self.play_from = play_from
self.ping_pong = ping_pong
self.paudio = pyaudio.PyAudio()
self.pc = RTCPeerConnection()
@@ -87,6 +85,7 @@ class StreamClient:
self.logger.info(f"Track {track.kind} ended")
self.pc.addTrack(audio)
self.track_audio = audio
channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label)
@@ -142,3 +141,6 @@ class StreamClient:
coro = self.run_offer(self.pc, self.signaling)
task = asyncio.create_task(coro)
await task
def is_ended(self):
return self.track_audio is None or self.track_audio.readyState == "ended"

View File

@@ -3,7 +3,8 @@ from pydantic import BaseModel, Field
from uuid import UUID, uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent
from reflector.logger import logger
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
import asyncio
from typing import Optional
@@ -46,6 +47,7 @@ class Transcript(BaseModel):
def add_event(self, event: str, data):
self.events.append(TranscriptEvent(event=event, data=data))
return {"event": event, "data": data}
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -239,14 +241,37 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
print(f"Event: {event}", args, data)
# print(f"Event: {event}", args, data)
transcript_id = args
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
transcript.add_event(event=event, data=data)
if event == PipelineEvent.TOPIC:
transcript.upsert_topic(TranscriptTopic(**data))
# event send to websocket clients may not be the same as the event
# received from the pipeline. For example, the pipeline will send
# a TRANSCRIPT event with all words, but this is not what we want
# to send to the websocket client.
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={
"text": data.text,
})
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic.model_dump())
transcript.upsert_topic(topic)
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp)
@router.post("/transcripts/{transcript_id}/record/webrtc")
@@ -261,9 +286,9 @@ async def transcript_record_webrtc(
raise HTTPException(status_code=400, detail="Transcript is locked")
# FIXME do not allow multiple recording at the same time
return await rtc_offer(
return await rtc_offer_base(
params,
request,
event_callback=transcript.handle_event,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
)