diff --git a/server/poetry.lock b/server/poetry.lock index 71206cba..b49d8df5 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -802,6 +802,38 @@ typing-extensions = ">=4.5.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "fastapi-pagination" +version = "0.12.6" +description = "FastAPI pagination" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "fastapi_pagination-0.12.6-py3-none-any.whl", hash = "sha256:d41719f2587be6c2551dd6f86554bb0cac7acd86349111644b2335dc38799e59"}, + {file = "fastapi_pagination-0.12.6.tar.gz", hash = "sha256:1aae4115dd08cd20fe1436ddb30c7f40cb27c93ae6e7f574de3f8f4a43cadf80"}, +] + +[package.dependencies] +fastapi = ">=0.93.0" +pydantic = ">=1.9.1" + +[package.extras] +all = ["SQLAlchemy (>=1.3.20)", "asyncpg (>=0.24.0)", "beanie (>=1.11.9,<2.0.0)", "bunnet (>=1.1.0,<2.0.0)", "databases (>=0.6.0)", "django (<5.0.0)", "mongoengine (>=0.23.1,<0.28.0)", "motor (>=2.5.1,<4.0.0)", "orm (>=0.3.1)", "ormar (>=0.11.2)", "piccolo (>=0.89,<0.119)", "pony (>=0.7.16,<0.8.0)", "scylla-driver (>=3.25.6,<4.0.0)", "sqlakeyset (>=2.0.1680321678,<3.0.0)", "sqlmodel (>=0.0.8,<0.0.9)", "tortoise-orm (>=0.16.18,<0.20.0)"] +asyncpg = ["SQLAlchemy (>=1.3.20)", "asyncpg (>=0.24.0)"] +beanie = ["beanie (>=1.11.9,<2.0.0)"] +bunnet = ["bunnet (>=1.1.0,<2.0.0)"] +databases = ["databases (>=0.6.0)"] +django = ["databases (>=0.6.0)", "django (<5.0.0)"] +mongoengine = ["mongoengine (>=0.23.1,<0.28.0)"] +motor = ["motor (>=2.5.1,<4.0.0)"] +orm = ["databases (>=0.6.0)", "orm (>=0.3.1)"] +ormar = ["ormar (>=0.11.2)"] +piccolo = ["piccolo (>=0.89,<0.119)"] +scylla-driver = ["scylla-driver (>=3.25.6,<4.0.0)"] +sqlalchemy = ["SQLAlchemy (>=1.3.20)", "sqlakeyset (>=2.0.1680321678,<3.0.0)"] +sqlmodel = ["sqlakeyset (>=2.0.1680321678,<3.0.0)", "sqlmodel (>=0.0.8,<0.0.9)"] +tortoise = ["tortoise-orm (>=0.16.18,<0.20.0)"] + [[package]] name = "faster-whisper" version = "0.7.1" @@ -2595,4 +2627,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "1a98a080ce035b381521426c9d6f9f80e8656258beab6cdff95ea90cf6c77e85" +content-hash = "a51d7d26b88683875685ede2298f0f02ab42b1f303657b47e0a5dee9be0dc9e6" diff --git a/server/pyproject.toml b/server/pyproject.toml index dc446796..b4eb307a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -22,6 +22,7 @@ uvicorn = {extras = ["standard"], version = "^0.23.1"} fastapi = "^0.100.1" sentry-sdk = {extras = ["fastapi"], version = "^1.29.2"} httpx = "^0.24.1" +fastapi-pagination = "^0.12.6" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 36e86d9a..f2988498 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,6 +1,8 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from fastapi_pagination import add_pagination from reflector.views.rtc_offer import router as rtc_offer_router +from reflector.views.transcripts import router as transcripts_router from reflector.events import subscribers_startup, subscribers_shutdown from reflector.logger import logger from reflector.settings import settings @@ -44,6 +46,8 @@ app.add_middleware( # register views app.include_router(rtc_offer_router) +app.include_router(transcripts_router, prefix="/v1") +add_pagination(app) if __name__ == "__main__": import uvicorn diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 11c98009..288153e6 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -6,6 +6,7 @@ from reflector.models import TranscriptionContext from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from json import loads, dumps +from enum import StrEnum import av from reflector.processors import ( Pipeline, @@ -51,8 +52,15 @@ class RtcOffer(BaseModel): type: str -@router.post("/offer") -async def rtc_offer(params: RtcOffer, request: Request): +class PipelineEvent(StrEnum): + TRANSCRIPT = "TRANSCRIPT" + TOPIC = "TOPIC" + FINAL_SUMMARY = "FINAL_SUMMARY" + + +async def rtc_offer_base( + params: RtcOffer, request: Request, event_callback=None, event_callback_args=None +): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -71,7 +79,16 @@ async def rtc_offer(params: RtcOffer, request: Request): } ctx.data_channel.send(dumps(result)) - async def on_topic(summary: TitleSummary): + if event_callback: + await event_callback( + event=PipelineEvent.TRANSCRIPT, + args=event_callback_args, + data=transcript, + ) + + async def on_topic( + summary: TitleSummary, event_callback=None, event_callback_args=None + ): # FIXME: make it incremental with the frontend, not send everything ctx.logger.info("Summary", summary=summary) ctx.topics.append( @@ -85,7 +102,14 @@ async def rtc_offer(params: RtcOffer, request: Request): result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} ctx.data_channel.send(dumps(result)) - async def on_final_summary(summary: FinalSummary): + if event_callback: + await event_callback( + event=PipelineEvent.TOPIC, args=event_callback_args, data=summary + ) + + async def on_final_summary( + summary: FinalSummary, event_callback=None, event_callback_args=None + ): ctx.logger.info("FinalSummary", final_summary=summary) result = { "cmd": "DISPLAY_FINAL_SUMMARY", @@ -94,6 +118,11 @@ async def rtc_offer(params: RtcOffer, request: Request): } ctx.data_channel.send(dumps(result)) + if event_callback: + await event_callback( + event=PipelineEvent.TOPIC, args=event_callback_args, data=summary + ) + # create a context for the whole rtc transaction # add a customised logger to the context ctx.pipeline = Pipeline( @@ -157,3 +186,8 @@ async def rtc_clean_sessions(): logger.debug(f"Closing session {pc}") await pc.close() sessions.clear() + + +@router.post("/offer") +async def rtc_offer(params: RtcOffer, request: Request): + return await rtc_offer_base(params, request) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py new file mode 100644 index 00000000..332a960b --- /dev/null +++ b/server/reflector/views/transcripts.py @@ -0,0 +1,269 @@ +from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect +from pydantic import BaseModel, Field +from uuid import UUID, uuid4 +from datetime import datetime +from fastapi_pagination import Page, paginate +from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent +import asyncio +from typing import Optional + + +router = APIRouter() + +# ============================================================== +# Models to move to a database, but required for the API to work +# ============================================================== + + +def generate_transcript_name(): + now = datetime.utcnow() + return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" + + +class TranscriptTopic(BaseModel): + id: UUID = Field(default_factory=uuid4) + title: str + summary: str + transcript: str + timestamp: float + + +class TranscriptEvent(BaseModel): + event: str + data: dict + + +class Transcript(BaseModel): + id: UUID = Field(default_factory=uuid4) + name: str = Field(default_factory=generate_transcript_name) + status: str = "idle" + locked: bool = False + duration: float = 0 + created_at: datetime = Field(default_factory=datetime.utcnow) + summary: str | None = None + topics: list[TranscriptTopic] = [] + events: list[TranscriptEvent] = [] + + def add_event(self, event: str, data): + self.events.append(TranscriptEvent(event=event, data=data)) + + def upsert_topic(self, topic: TranscriptTopic): + existing_topic = next((t for t in self.topics if t.id == topic.id), None) + if existing_topic: + existing_topic.update_from(topic) + else: + self.topics.append(topic) + + +class TranscriptController: + transcripts: list[Transcript] = [] + + def get_all(self) -> list[Transcript]: + return self.transcripts + + def get_by_id(self, transcript_id: UUID) -> Transcript | None: + return next((t for t in self.transcripts if t.id == transcript_id), None) + + def add(self, transcript: Transcript): + self.transcripts.append(transcript) + + def remove(self, transcript: Transcript): + self.transcripts.remove(transcript) + + +transcripts_controller = TranscriptController() + + +# ============================================================== +# Transcripts list +# ============================================================== + + +class GetTranscript(BaseModel): + id: UUID + name: str + status: str + locked: bool + duration: int + created_at: datetime + + +class CreateTranscript(BaseModel): + name: str + + +class UpdateTranscript(BaseModel): + name: Optional[str] = Field(None) + locked: Optional[bool] = Field(None) + + +class TranscriptEntryCreate(BaseModel): + name: str + + +class DeletionStatus(BaseModel): + status: str + + +@router.get("/transcripts", response_model=Page[GetTranscript]) +async def transcripts_list(): + return paginate(transcripts_controller.get_all()) + + +@router.post("/transcripts", response_model=GetTranscript) +async def transcripts_create(info: CreateTranscript): + transcript = Transcript() + transcript.name = info.name + transcripts_controller.add(transcript) + return transcript + + +# ============================================================== +# Single transcript +# ============================================================== + + +@router.get("/transcripts/{transcript_id}", response_model=GetTranscript) +async def transcript_get(transcript_id: UUID): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + return transcript + + +@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) +async def transcript_update(transcript_id: UUID, info: UpdateTranscript): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + if info.name is not None: + transcript.name = info.name + if info.locked is not None: + transcript.locked = info.locked + return transcript + + +@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) +async def transcript_delete(transcript_id: UUID): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + transcripts_controller.remove(transcript) + return DeletionStatus(status="ok") + + +@router.get("/transcripts/{transcript_id}/audio") +async def transcript_get_audio(transcript_id: UUID): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # TODO: Implement audio generation + return HTTPException(status_code=500, detail="Not implemented") + + +@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) +async def transcript_get_topics(transcript_id: UUID): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + return transcript.topics + + +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: UUID): + pass + + +# ============================================================== +# Websocket Manager +# ============================================================== + + +class WebsocketManager: + def __init__(self): + self.active_connections = {} + + async def connect(self, transcript_id: UUID, websocket: WebSocket): + await websocket.accept() + if transcript_id not in self.active_connections: + self.active_connections[transcript_id] = [] + self.active_connections[transcript_id].append(websocket) + + def disconnect(self, transcript_id: UUID, websocket: WebSocket): + if transcript_id not in self.active_connections: + return + self.active_connections[transcript_id].remove(websocket) + if not self.active_connections[transcript_id]: + del self.active_connections[transcript_id] + + async def send_json(self, transcript_id: UUID, message): + if transcript_id not in self.active_connections: + return + for connection in self.active_connections[transcript_id]: + await connection.send_json(message) + + +ws_manager = WebsocketManager() + + +@router.websocket("/transcripts/{transcript_id}/events") +async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + await ws_manager.connect(transcript_id, websocket) + + # on first connection, send all events + for event in transcript.events: + await websocket.send_json(event.model_dump()) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + try: + while True: + await asyncio.sleep(42) + except WebSocketDisconnect: + ws_manager.disconnect(transcript_id, websocket) + + +# ============================================================== +# Web RTC +# ============================================================== + + +async def handle_rtc_event(event: PipelineEvent, args, data): + # OFC the current implementation is not good, + # but it's just a POC before persistence. It won't query the + # transcript from the database for each event. + print(f"Event: {event}", args, data) + transcript_id = args + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + return + transcript.add_event(event=event, data=data) + if event == PipelineEvent.TOPIC: + transcript.upsert_topic(TranscriptTopic(**data)) + + +@router.post("/transcripts/{transcript_id}/record/webrtc") +async def transcript_record_webrtc( + transcript_id: UUID, params: RtcOffer, request: Request +): + transcript = transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + # FIXME do not allow multiple recording at the same time + return await rtc_offer( + params, + request, + event_callback=transcript.handle_event, + event_callback_args=transcript_id, + ) diff --git a/server/tests/test_transcripts.py b/server/tests/test_transcripts.py new file mode 100644 index 00000000..58ab8393 --- /dev/null +++ b/server/tests/test_transcripts.py @@ -0,0 +1,79 @@ +import pytest +from httpx import AsyncClient +from reflector.app import app + + +@pytest.mark.asyncio +async def test_transcript_create(): + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["name"] == "test" + assert response.json()["status"] == "idle" + assert response.json()["locked"] is False + assert response.json()["id"] is not None + assert response.json()["created_at"] is not None + + # ensure some fields are not returned + assert "topics" not in response.json() + assert "events" not in response.json() + + +@pytest.mark.asyncio +async def test_transcript_get_update_name(): + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["name"] == "test" + + tid = response.json()["id"] + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["name"] == "test" + + response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"}) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + +@pytest.mark.asyncio +async def test_transcripts_list(): + # XXX this test is a bit fragile, as it depends on the storage which + # is shared between tests + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "testxx1"}) + assert response.status_code == 200 + assert response.json()["name"] == "testxx1" + + response = await ac.post("/transcripts", json={"name": "testxx2"}) + assert response.status_code == 200 + assert response.json()["name"] == "testxx2" + + response = await ac.get("/transcripts") + assert response.status_code == 200 + assert len(response.json()["items"]) >= 2 + names = [t["name"] for t in response.json()["items"]] + assert "testxx1" in names + assert "testxx2" in names + +@pytest.mark.asyncio +async def test_transcript_delete(): + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "testdel1"}) + assert response.status_code == 200 + assert response.json()["name"] == "testdel1" + + tid = response.json()["id"] + response = await ac.delete(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 404 +