mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
server: implement data persistence with database
Using databases + sqlite/postgresql depending of what you want. Use DATABASE_URL to configure Closes #70
This commit is contained in:
@@ -6,10 +6,11 @@ from fastapi import (
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from fastapi_pagination import Page, paginate
|
||||
from reflector.logger import logger
|
||||
from reflector.db import database, transcripts
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
from typing import Optional
|
||||
|
||||
@@ -21,6 +22,10 @@ router = APIRouter()
|
||||
# ==============================================================
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
@@ -31,7 +36,7 @@ class TranscriptText(BaseModel):
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
transcript: str
|
||||
@@ -48,7 +53,7 @@ class TranscriptEvent(BaseModel):
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
@@ -72,19 +77,37 @@ class Transcript(BaseModel):
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
transcripts: list[Transcript] = []
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
query = transcripts.select()
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
def get_all(self) -> list[Transcript]:
|
||||
return self.transcripts
|
||||
async def get_by_id(self, transcript_id: str) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
def get_by_id(self, transcript_id: UUID) -> Transcript | None:
|
||||
return next((t for t in self.transcripts if t.id == transcript_id), None)
|
||||
async def add(self, name: str):
|
||||
transcript = Transcript(name=name)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
def add(self, transcript: Transcript):
|
||||
self.transcripts.append(transcript)
|
||||
async def update(self, transcript: Transcript, values: dict):
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
def remove(self, transcript: Transcript):
|
||||
self.transcripts.remove(transcript)
|
||||
async def remove_by_id(self, transcript_id: str) -> None:
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
@@ -96,7 +119,7 @@ transcripts_controller = TranscriptController()
|
||||
|
||||
|
||||
class GetTranscript(BaseModel):
|
||||
id: UUID
|
||||
id: str
|
||||
name: str
|
||||
status: str
|
||||
locked: bool
|
||||
@@ -123,15 +146,12 @@ class DeletionStatus(BaseModel):
|
||||
|
||||
@router.get("/transcripts", response_model=Page[GetTranscript])
|
||||
async def transcripts_list():
|
||||
return paginate(transcripts_controller.get_all())
|
||||
return paginate(await transcripts_controller.get_all())
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
async def transcripts_create(info: CreateTranscript):
|
||||
transcript = Transcript()
|
||||
transcript.name = info.name
|
||||
transcripts_controller.add(transcript)
|
||||
return transcript
|
||||
return await transcripts_controller.add(info.name)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -140,36 +160,38 @@ async def transcripts_create(info: CreateTranscript):
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_get(transcript_id: UUID):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_get(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_update(transcript_id: UUID, info: UpdateTranscript):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_update(transcript_id: str, info: UpdateTranscript):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {}
|
||||
if info.name is not None:
|
||||
transcript.name = info.name
|
||||
values["name"] = info.name
|
||||
if info.locked is not None:
|
||||
transcript.locked = info.locked
|
||||
values["locked"] = info.locked
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
|
||||
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
||||
async def transcript_delete(transcript_id: UUID):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_delete(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
transcripts_controller.remove(transcript)
|
||||
await transcripts_controller.remove_by_id(transcript.id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio")
|
||||
async def transcript_get_audio(transcript_id: UUID):
|
||||
async def transcript_get_audio(transcript_id: str):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -179,7 +201,7 @@ async def transcript_get_audio(transcript_id: UUID):
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
async def transcript_get_topics(transcript_id: UUID):
|
||||
async def transcript_get_topics(transcript_id: str):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -187,7 +209,7 @@ async def transcript_get_topics(transcript_id: UUID):
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/events")
|
||||
async def transcript_get_websocket_events(transcript_id: UUID):
|
||||
async def transcript_get_websocket_events(transcript_id: str):
|
||||
pass
|
||||
|
||||
|
||||
@@ -200,20 +222,20 @@ class WebsocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, transcript_id: UUID, websocket: WebSocket):
|
||||
async def connect(self, transcript_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if transcript_id not in self.active_connections:
|
||||
self.active_connections[transcript_id] = []
|
||||
self.active_connections[transcript_id].append(websocket)
|
||||
|
||||
def disconnect(self, transcript_id: UUID, websocket: WebSocket):
|
||||
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
self.active_connections[transcript_id].remove(websocket)
|
||||
if not self.active_connections[transcript_id]:
|
||||
del self.active_connections[transcript_id]
|
||||
|
||||
async def send_json(self, transcript_id: UUID, message):
|
||||
async def send_json(self, transcript_id: str, message):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
for connection in self.active_connections[transcript_id][:]:
|
||||
@@ -227,7 +249,7 @@ ws_manager = WebsocketManager()
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket):
|
||||
async def transcript_events_websocket(transcript_id: str, websocket: WebSocket):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -283,14 +305,28 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events,
|
||||
"topics": transcript.topics,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
transcript.summary = final_summary
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events,
|
||||
"summary": transcript.summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
transcript.status = data.value
|
||||
await transcripts_controller.update(transcript, {"status": transcript.status})
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
@@ -302,7 +338,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: UUID, params: RtcOffer, request: Request
|
||||
transcript_id: str, params: RtcOffer, request: Request
|
||||
):
|
||||
transcript = transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
|
||||
Reference in New Issue
Block a user