server: start implementing new api

This commit is contained in:
Mathieu Virbel
2023-08-04 16:36:25 +02:00
parent e55cfce930
commit 20767fde3f
6 changed files with 424 additions and 5 deletions

34
server/poetry.lock generated
View File

@@ -802,6 +802,38 @@ typing-extensions = ">=4.5.0"
[package.extras] [package.extras]
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
[[package]]
name = "fastapi-pagination"
version = "0.12.6"
description = "FastAPI pagination"
optional = false
python-versions = ">=3.8,<4.0"
files = [
{file = "fastapi_pagination-0.12.6-py3-none-any.whl", hash = "sha256:d41719f2587be6c2551dd6f86554bb0cac7acd86349111644b2335dc38799e59"},
{file = "fastapi_pagination-0.12.6.tar.gz", hash = "sha256:1aae4115dd08cd20fe1436ddb30c7f40cb27c93ae6e7f574de3f8f4a43cadf80"},
]
[package.dependencies]
fastapi = ">=0.93.0"
pydantic = ">=1.9.1"
[package.extras]
all = ["SQLAlchemy (>=1.3.20)", "asyncpg (>=0.24.0)", "beanie (>=1.11.9,<2.0.0)", "bunnet (>=1.1.0,<2.0.0)", "databases (>=0.6.0)", "django (<5.0.0)", "mongoengine (>=0.23.1,<0.28.0)", "motor (>=2.5.1,<4.0.0)", "orm (>=0.3.1)", "ormar (>=0.11.2)", "piccolo (>=0.89,<0.119)", "pony (>=0.7.16,<0.8.0)", "scylla-driver (>=3.25.6,<4.0.0)", "sqlakeyset (>=2.0.1680321678,<3.0.0)", "sqlmodel (>=0.0.8,<0.0.9)", "tortoise-orm (>=0.16.18,<0.20.0)"]
asyncpg = ["SQLAlchemy (>=1.3.20)", "asyncpg (>=0.24.0)"]
beanie = ["beanie (>=1.11.9,<2.0.0)"]
bunnet = ["bunnet (>=1.1.0,<2.0.0)"]
databases = ["databases (>=0.6.0)"]
django = ["databases (>=0.6.0)", "django (<5.0.0)"]
mongoengine = ["mongoengine (>=0.23.1,<0.28.0)"]
motor = ["motor (>=2.5.1,<4.0.0)"]
orm = ["databases (>=0.6.0)", "orm (>=0.3.1)"]
ormar = ["ormar (>=0.11.2)"]
piccolo = ["piccolo (>=0.89,<0.119)"]
scylla-driver = ["scylla-driver (>=3.25.6,<4.0.0)"]
sqlalchemy = ["SQLAlchemy (>=1.3.20)", "sqlakeyset (>=2.0.1680321678,<3.0.0)"]
sqlmodel = ["sqlakeyset (>=2.0.1680321678,<3.0.0)", "sqlmodel (>=0.0.8,<0.0.9)"]
tortoise = ["tortoise-orm (>=0.16.18,<0.20.0)"]
[[package]] [[package]]
name = "faster-whisper" name = "faster-whisper"
version = "0.7.1" version = "0.7.1"
@@ -2595,4 +2627,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "1a98a080ce035b381521426c9d6f9f80e8656258beab6cdff95ea90cf6c77e85" content-hash = "a51d7d26b88683875685ede2298f0f02ab42b1f303657b47e0a5dee9be0dc9e6"

View File

@@ -22,6 +22,7 @@ uvicorn = {extras = ["standard"], version = "^0.23.1"}
fastapi = "^0.100.1" fastapi = "^0.100.1"
sentry-sdk = {extras = ["fastapi"], version = "^1.29.2"} sentry-sdk = {extras = ["fastapi"], version = "^1.29.2"}
httpx = "^0.24.1" httpx = "^0.24.1"
fastapi-pagination = "^0.12.6"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@@ -1,6 +1,8 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.events import subscribers_startup, subscribers_shutdown from reflector.events import subscribers_startup, subscribers_shutdown
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
@@ -44,6 +46,8 @@ app.add_middleware(
# register views # register views
app.include_router(rtc_offer_router) app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1")
add_pagination(app)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@@ -6,6 +6,7 @@ from reflector.models import TranscriptionContext
from reflector.logger import logger from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps from json import loads, dumps
from enum import StrEnum
import av import av
from reflector.processors import ( from reflector.processors import (
Pipeline, Pipeline,
@@ -51,8 +52,15 @@ class RtcOffer(BaseModel):
type: str type: str
@router.post("/offer") class PipelineEvent(StrEnum):
async def rtc_offer(params: RtcOffer, request: Request): TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY"
async def rtc_offer_base(
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None
):
# build an rtc session # build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type) offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -71,7 +79,16 @@ async def rtc_offer(params: RtcOffer, request: Request):
} }
ctx.data_channel.send(dumps(result)) ctx.data_channel.send(dumps(result))
async def on_topic(summary: TitleSummary): if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
args=event_callback_args,
data=transcript,
)
async def on_topic(
summary: TitleSummary, event_callback=None, event_callback_args=None
):
# FIXME: make it incremental with the frontend, not send everything # FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary) ctx.logger.info("Summary", summary=summary)
ctx.topics.append( ctx.topics.append(
@@ -85,7 +102,14 @@ async def rtc_offer(params: RtcOffer, request: Request):
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result)) ctx.data_channel.send(dumps(result))
async def on_final_summary(summary: FinalSummary): if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
)
async def on_final_summary(
summary: FinalSummary, event_callback=None, event_callback_args=None
):
ctx.logger.info("FinalSummary", final_summary=summary) ctx.logger.info("FinalSummary", final_summary=summary)
result = { result = {
"cmd": "DISPLAY_FINAL_SUMMARY", "cmd": "DISPLAY_FINAL_SUMMARY",
@@ -94,6 +118,11 @@ async def rtc_offer(params: RtcOffer, request: Request):
} }
ctx.data_channel.send(dumps(result)) ctx.data_channel.send(dumps(result))
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
)
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
ctx.pipeline = Pipeline( ctx.pipeline = Pipeline(
@@ -157,3 +186,8 @@ async def rtc_clean_sessions():
logger.debug(f"Closing session {pc}") logger.debug(f"Closing session {pc}")
await pc.close() await pc.close()
sessions.clear() sessions.clear()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
return await rtc_offer_base(params, request)

View File

@@ -0,0 +1,269 @@
from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect
from pydantic import BaseModel, Field
from uuid import UUID, uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent
import asyncio
from typing import Optional
router = APIRouter()
# ==============================================================
# Models to move to a database, but required for the API to work
# ==============================================================
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class TranscriptTopic(BaseModel):
id: UUID = Field(default_factory=uuid4)
title: str
summary: str
transcript: str
timestamp: float
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: UUID = Field(default_factory=uuid4)
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
def add_event(self, event: str, data):
self.events.append(TranscriptEvent(event=event, data=data))
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
if existing_topic:
existing_topic.update_from(topic)
else:
self.topics.append(topic)
class TranscriptController:
transcripts: list[Transcript] = []
def get_all(self) -> list[Transcript]:
return self.transcripts
def get_by_id(self, transcript_id: UUID) -> Transcript | None:
return next((t for t in self.transcripts if t.id == transcript_id), None)
def add(self, transcript: Transcript):
self.transcripts.append(transcript)
def remove(self, transcript: Transcript):
self.transcripts.remove(transcript)
transcripts_controller = TranscriptController()
# ==============================================================
# Transcripts list
# ==============================================================
class GetTranscript(BaseModel):
id: UUID
name: str
status: str
locked: bool
duration: int
created_at: datetime
class CreateTranscript(BaseModel):
name: str
class UpdateTranscript(BaseModel):
name: Optional[str] = Field(None)
locked: Optional[bool] = Field(None)
class TranscriptEntryCreate(BaseModel):
name: str
class DeletionStatus(BaseModel):
status: str
@router.get("/transcripts", response_model=Page[GetTranscript])
async def transcripts_list():
return paginate(transcripts_controller.get_all())
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(info: CreateTranscript):
transcript = Transcript()
transcript.name = info.name
transcripts_controller.add(transcript)
return transcript
# ==============================================================
# Single transcript
# ==============================================================
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(transcript_id: UUID):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_update(transcript_id: UUID, info: UpdateTranscript):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if info.name is not None:
transcript.name = info.name
if info.locked is not None:
transcript.locked = info.locked
return transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(transcript_id: UUID):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
transcripts_controller.remove(transcript)
return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(transcript_id: UUID):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# TODO: Implement audio generation
return HTTPException(status_code=500, detail="Not implemented")
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
async def transcript_get_topics(transcript_id: UUID):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: UUID):
pass
# ==============================================================
# Websocket Manager
# ==============================================================
class WebsocketManager:
def __init__(self):
self.active_connections = {}
async def connect(self, transcript_id: UUID, websocket: WebSocket):
await websocket.accept()
if transcript_id not in self.active_connections:
self.active_connections[transcript_id] = []
self.active_connections[transcript_id].append(websocket)
def disconnect(self, transcript_id: UUID, websocket: WebSocket):
if transcript_id not in self.active_connections:
return
self.active_connections[transcript_id].remove(websocket)
if not self.active_connections[transcript_id]:
del self.active_connections[transcript_id]
async def send_json(self, transcript_id: UUID, message):
if transcript_id not in self.active_connections:
return
for connection in self.active_connections[transcript_id]:
await connection.send_json(message)
ws_manager = WebsocketManager()
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await ws_manager.connect(transcript_id, websocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump())
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
try:
while True:
await asyncio.sleep(42)
except WebSocketDisconnect:
ws_manager.disconnect(transcript_id, websocket)
# ==============================================================
# Web RTC
# ==============================================================
async def handle_rtc_event(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
print(f"Event: {event}", args, data)
transcript_id = args
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
transcript.add_event(event=event, data=data)
if event == PipelineEvent.TOPIC:
transcript.upsert_topic(TranscriptTopic(**data))
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: UUID, params: RtcOffer, request: Request
):
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# FIXME do not allow multiple recording at the same time
return await rtc_offer(
params,
request,
event_callback=transcript.handle_event,
event_callback_args=transcript_id,
)

View File

@@ -0,0 +1,79 @@
import pytest
from httpx import AsyncClient
from reflector.app import app
@pytest.mark.asyncio
async def test_transcript_create():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["status"] == "idle"
assert response.json()["locked"] is False
assert response.json()["id"] is not None
assert response.json()["created_at"] is not None
# ensure some fields are not returned
assert "topics" not in response.json()
assert "events" not in response.json()
@pytest.mark.asyncio
async def test_transcript_get_update_name():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test2"
@pytest.mark.asyncio
async def test_transcripts_list():
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
assert response.json()["name"] == "testxx1"
response = await ac.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200
assert response.json()["name"] == "testxx2"
response = await ac.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert "testxx1" in names
assert "testxx2" in names
@pytest.mark.asyncio
async def test_transcript_delete():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
tid = response.json()["id"]
response = await ac.delete(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 404