mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
api: implement first server API + tests
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioFile:
|
||||
class AudioFile(BaseModel):
|
||||
path: Path
|
||||
sample_rate: int
|
||||
channels: int
|
||||
@@ -14,15 +13,13 @@ class AudioFile:
|
||||
self.path.unlink()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Word:
|
||||
class Word(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transcript:
|
||||
class Transcript(BaseModel):
|
||||
text: str = ""
|
||||
words: list[Word] = None
|
||||
|
||||
@@ -59,8 +56,7 @@ class Transcript:
|
||||
return Transcript(text=self.text, words=words)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummary:
|
||||
class TitleSummary(BaseModel):
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
@@ -75,7 +71,6 @@ class TitleSummary:
|
||||
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSummary:
|
||||
class FinalSummary(BaseModel):
|
||||
summary: str
|
||||
duration: float
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import pyaudio
|
||||
import stamina
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaPlayer, MediaRelay
|
||||
@@ -24,7 +23,6 @@ class StreamClient:
|
||||
self.server_url = url
|
||||
self.play_from = play_from
|
||||
self.ping_pong = ping_pong
|
||||
self.paudio = pyaudio.PyAudio()
|
||||
|
||||
self.pc = RTCPeerConnection()
|
||||
|
||||
@@ -87,6 +85,7 @@ class StreamClient:
|
||||
self.logger.info(f"Track {track.kind} ended")
|
||||
|
||||
self.pc.addTrack(audio)
|
||||
self.track_audio = audio
|
||||
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
self.logger = self.logger.bind(channel=channel.label)
|
||||
@@ -142,3 +141,6 @@ class StreamClient:
|
||||
coro = self.run_offer(self.pc, self.signaling)
|
||||
task = asyncio.create_task(coro)
|
||||
await task
|
||||
|
||||
def is_ended(self):
|
||||
return self.track_audio is None or self.track_audio.readyState == "ended"
|
||||
|
||||
@@ -3,7 +3,8 @@ from pydantic import BaseModel, Field
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
from fastapi_pagination import Page, paginate
|
||||
from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent
|
||||
from reflector.logger import logger
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
@@ -46,6 +47,7 @@ class Transcript(BaseModel):
|
||||
|
||||
def add_event(self, event: str, data):
|
||||
self.events.append(TranscriptEvent(event=event, data=data))
|
||||
return {"event": event, "data": data}
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
@@ -239,14 +241,37 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
# OFC the current implementation is not good,
|
||||
# but it's just a POC before persistence. It won't query the
|
||||
# transcript from the database for each event.
|
||||
print(f"Event: {event}", args, data)
|
||||
# print(f"Event: {event}", args, data)
|
||||
transcript_id = args
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
transcript.add_event(event=event, data=data)
|
||||
if event == PipelineEvent.TOPIC:
|
||||
transcript.upsert_topic(TranscriptTopic(**data))
|
||||
|
||||
# event send to websocket clients may not be the same as the event
|
||||
# received from the pipeline. For example, the pipeline will send
|
||||
# a TRANSCRIPT event with all words, but this is not what we want
|
||||
# to send to the websocket client.
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(event=event, data={
|
||||
"text": data.text,
|
||||
})
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
transcript=data.transcript,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic.model_dump())
|
||||
transcript.upsert_topic(topic)
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp)
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
@@ -261,9 +286,9 @@ async def transcript_record_webrtc(
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# FIXME do not allow multiple recording at the same time
|
||||
return await rtc_offer(
|
||||
return await rtc_offer_base(
|
||||
params,
|
||||
request,
|
||||
event_callback=transcript.handle_event,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user