mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
647 lines
21 KiB
Python
647 lines
21 KiB
Python
import json
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import Annotated, Optional
|
|
from uuid import uuid4
|
|
|
|
import av
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
Request,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
)
|
|
from fastapi_pagination import Page, paginate
|
|
from pydantic import BaseModel, Field
|
|
from starlette.concurrency import run_in_threadpool
|
|
|
|
import reflector.auth as auth
|
|
from reflector.db import database, transcripts
|
|
from reflector.logger import logger
|
|
from reflector.settings import settings
|
|
from reflector.utils.audio_waveform import get_audio_waveform
|
|
|
|
from ._range_requests_response import range_requests_response
|
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
|
|
|
router = APIRouter()
|
|
|
|
# ==============================================================
|
|
# Models to move to a database, but required for the API to work
|
|
# ==============================================================
|
|
|
|
|
|
def generate_uuid4():
|
|
return str(uuid4())
|
|
|
|
|
|
def generate_transcript_name():
|
|
now = datetime.utcnow()
|
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
|
|
|
|
|
class AudioWaveform(BaseModel):
|
|
data: list[int]
|
|
|
|
|
|
class TranscriptText(BaseModel):
|
|
text: str
|
|
translation: str | None
|
|
|
|
|
|
class TranscriptTopic(BaseModel):
|
|
id: str = Field(default_factory=generate_uuid4)
|
|
title: str
|
|
summary: str
|
|
transcript: str
|
|
timestamp: float
|
|
|
|
|
|
class TranscriptFinalShortSummary(BaseModel):
|
|
short_summary: str
|
|
|
|
|
|
class TranscriptFinalLongSummary(BaseModel):
|
|
long_summary: str
|
|
|
|
|
|
class TranscriptFinalTitle(BaseModel):
|
|
title: str
|
|
|
|
|
|
class TranscriptEvent(BaseModel):
|
|
event: str
|
|
data: dict
|
|
|
|
|
|
class Transcript(BaseModel):
|
|
id: str = Field(default_factory=generate_uuid4)
|
|
user_id: str | None = None
|
|
name: str = Field(default_factory=generate_transcript_name)
|
|
status: str = "idle"
|
|
locked: bool = False
|
|
duration: float = 0
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
title: str | None = None
|
|
short_summary: str | None = None
|
|
long_summary: str | None = None
|
|
topics: list[TranscriptTopic] = []
|
|
events: list[TranscriptEvent] = []
|
|
source_language: str = "en"
|
|
target_language: str = "en"
|
|
|
|
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
|
ev = TranscriptEvent(event=event, data=data.model_dump())
|
|
self.events.append(ev)
|
|
return ev
|
|
|
|
def upsert_topic(self, topic: TranscriptTopic):
|
|
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
|
if existing_topic:
|
|
existing_topic.update_from(topic)
|
|
else:
|
|
self.topics.append(topic)
|
|
|
|
def events_dump(self, mode="json"):
|
|
return [event.model_dump(mode=mode) for event in self.events]
|
|
|
|
def topics_dump(self, mode="json"):
|
|
return [topic.model_dump(mode=mode) for topic in self.topics]
|
|
|
|
def convert_audio_to_mp3(self):
|
|
fn = self.audio_mp3_filename
|
|
if fn.exists():
|
|
return
|
|
|
|
logger.info(f"Converting audio to mp3: {self.audio_filename}")
|
|
inp = av.open(self.audio_filename.as_posix(), "r")
|
|
|
|
# create temporary file for mp3
|
|
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
|
|
out = av.open(tmp.name, "w")
|
|
stream = out.add_stream("mp3")
|
|
for frame in inp.decode(audio=0):
|
|
frame.pts = None
|
|
for packet in stream.encode(frame):
|
|
out.mux(packet)
|
|
for packet in stream.encode(None):
|
|
out.mux(packet)
|
|
out.close()
|
|
|
|
# move temporary file to final location
|
|
shutil.move(tmp.name, fn.as_posix())
|
|
|
|
def convert_audio_to_waveform(self, segments_count=256):
|
|
fn = self.audio_waveform_filename
|
|
if fn.exists():
|
|
return
|
|
waveform = get_audio_waveform(
|
|
path=self.audio_filename, segments_count=segments_count
|
|
)
|
|
try:
|
|
with open(fn, "w") as fd:
|
|
json.dump(waveform, fd)
|
|
except Exception:
|
|
# remove file if anything happen during the write
|
|
fn.unlink(missing_ok=True)
|
|
raise
|
|
return waveform
|
|
|
|
def unlink(self):
|
|
self.data_path.unlink(missing_ok=True)
|
|
|
|
@property
|
|
def data_path(self):
|
|
return Path(settings.DATA_DIR) / self.id
|
|
|
|
@property
|
|
def audio_filename(self):
|
|
return self.data_path / "audio.wav"
|
|
|
|
@property
|
|
def audio_mp3_filename(self):
|
|
return self.data_path / "audio.mp3"
|
|
|
|
@property
|
|
def audio_waveform_filename(self):
|
|
return self.data_path / "audio.json"
|
|
|
|
@property
|
|
def audio_waveform(self):
|
|
try:
|
|
with open(self.audio_waveform_filename) as fd:
|
|
data = json.load(fd)
|
|
except json.JSONDecodeError:
|
|
# unlink file if it's corrupted
|
|
self.audio_waveform_filename.unlink(missing_ok=True)
|
|
return None
|
|
|
|
return AudioWaveform(data=data)
|
|
|
|
|
|
class TranscriptController:
|
|
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
|
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
|
results = await database.fetch_all(query)
|
|
return results
|
|
|
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
|
if "user_id" in kwargs:
|
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
|
result = await database.fetch_one(query)
|
|
if not result:
|
|
return None
|
|
return Transcript(**result)
|
|
|
|
async def add(
|
|
self,
|
|
name: str,
|
|
source_language: str = "en",
|
|
target_language: str = "en",
|
|
user_id: str | None = None,
|
|
):
|
|
transcript = Transcript(
|
|
name=name,
|
|
source_language=source_language,
|
|
target_language=target_language,
|
|
user_id=user_id,
|
|
)
|
|
query = transcripts.insert().values(**transcript.model_dump())
|
|
await database.execute(query)
|
|
return transcript
|
|
|
|
async def update(self, transcript: Transcript, values: dict):
|
|
query = (
|
|
transcripts.update()
|
|
.where(transcripts.c.id == transcript.id)
|
|
.values(**values)
|
|
)
|
|
await database.execute(query)
|
|
for key, value in values.items():
|
|
setattr(transcript, key, value)
|
|
|
|
async def remove_by_id(
|
|
self, transcript_id: str, user_id: str | None = None
|
|
) -> None:
|
|
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
return
|
|
if user_id is not None and transcript.user_id != user_id:
|
|
return
|
|
transcript.unlink()
|
|
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
|
await database.execute(query)
|
|
|
|
|
|
transcripts_controller = TranscriptController()
|
|
|
|
|
|
# ==============================================================
|
|
# Transcripts list
|
|
# ==============================================================
|
|
|
|
|
|
class GetTranscript(BaseModel):
|
|
id: str
|
|
name: str
|
|
status: str
|
|
locked: bool
|
|
duration: int
|
|
title: str | None
|
|
short_summary: str | None
|
|
long_summary: str | None
|
|
created_at: datetime
|
|
source_language: str
|
|
target_language: str
|
|
|
|
|
|
class CreateTranscript(BaseModel):
|
|
name: str
|
|
source_language: str = Field("en")
|
|
target_language: str = Field("en")
|
|
|
|
|
|
class UpdateTranscript(BaseModel):
|
|
name: Optional[str] = Field(None)
|
|
locked: Optional[bool] = Field(None)
|
|
title: Optional[str] = Field(None)
|
|
short_summary: Optional[str] = Field(None)
|
|
long_summary: Optional[str] = Field(None)
|
|
|
|
|
|
class DeletionStatus(BaseModel):
|
|
status: str
|
|
|
|
|
|
@router.get("/transcripts", response_model=Page[GetTranscript])
|
|
async def transcripts_list(
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
if not user and not settings.PUBLIC_MODE:
|
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
|
|
user_id = user["sub"] if user else None
|
|
return paginate(await transcripts_controller.get_all(user_id=user_id))
|
|
|
|
|
|
@router.post("/transcripts", response_model=GetTranscript)
|
|
async def transcripts_create(
|
|
info: CreateTranscript,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
return await transcripts_controller.add(
|
|
info.name,
|
|
source_language=info.source_language,
|
|
target_language=info.target_language,
|
|
user_id=user_id,
|
|
)
|
|
|
|
|
|
# ==============================================================
|
|
# Single transcript
|
|
# ==============================================================
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
|
async def transcript_get(
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
return transcript
|
|
|
|
|
|
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
|
async def transcript_update(
|
|
transcript_id: str,
|
|
info: UpdateTranscript,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
values = {"events": []}
|
|
if info.name is not None:
|
|
values["name"] = info.name
|
|
if info.locked is not None:
|
|
values["locked"] = info.locked
|
|
if info.long_summary is not None:
|
|
values["long_summary"] = info.long_summary
|
|
for transcript_event in transcript.events:
|
|
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
|
transcript_event["long_summary"] = info.long_summary
|
|
break
|
|
values["events"].extend(transcript.events)
|
|
if info.short_summary is not None:
|
|
values["short_summary"] = info.short_summary
|
|
for transcript_event in transcript.events:
|
|
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
|
transcript_event["short_summary"] = info.short_summary
|
|
break
|
|
values["events"].extend(transcript.events)
|
|
if info.title is not None:
|
|
values["title"] = info.title
|
|
for transcript_event in transcript.events:
|
|
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
|
transcript_event["title"] = info.title
|
|
break
|
|
values["events"].extend(transcript.events)
|
|
await transcripts_controller.update(transcript, values)
|
|
return transcript
|
|
|
|
|
|
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
|
async def transcript_delete(
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
|
|
return DeletionStatus(status="ok")
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}/audio")
|
|
async def transcript_get_audio(
|
|
request: Request,
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
|
|
if not transcript.audio_filename.exists():
|
|
raise HTTPException(status_code=404, detail="Audio not found")
|
|
|
|
return range_requests_response(
|
|
request,
|
|
transcript.audio_filename,
|
|
content_type="audio/wav",
|
|
)
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
|
async def transcript_get_audio_mp3(
|
|
request: Request,
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
|
|
if not transcript.audio_filename.exists():
|
|
raise HTTPException(status_code=404, detail="Audio not found")
|
|
|
|
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
|
|
|
return range_requests_response(
|
|
request,
|
|
transcript.audio_mp3_filename,
|
|
content_type="audio/mp3",
|
|
)
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
|
async def transcript_get_audio_waveform(
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
) -> AudioWaveform:
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
|
|
if not transcript.audio_filename.exists():
|
|
raise HTTPException(status_code=404, detail="Audio not found")
|
|
|
|
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
|
|
|
return transcript.audio_waveform
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
|
async def transcript_get_topics(
|
|
transcript_id: str,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
return transcript.topics
|
|
|
|
|
|
@router.get("/transcripts/{transcript_id}/events")
|
|
async def transcript_get_websocket_events(transcript_id: str):
|
|
pass
|
|
|
|
|
|
# ==============================================================
|
|
# Websocket Manager
|
|
# ==============================================================
|
|
|
|
|
|
class WebsocketManager:
|
|
def __init__(self):
|
|
self.active_connections = {}
|
|
|
|
async def connect(self, transcript_id: str, websocket: WebSocket):
|
|
await websocket.accept()
|
|
if transcript_id not in self.active_connections:
|
|
self.active_connections[transcript_id] = []
|
|
self.active_connections[transcript_id].append(websocket)
|
|
|
|
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
|
if transcript_id not in self.active_connections:
|
|
return
|
|
self.active_connections[transcript_id].remove(websocket)
|
|
if not self.active_connections[transcript_id]:
|
|
del self.active_connections[transcript_id]
|
|
|
|
async def send_json(self, transcript_id: str, message):
|
|
if transcript_id not in self.active_connections:
|
|
return
|
|
for connection in self.active_connections[transcript_id][:]:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception:
|
|
self.active_connections[transcript_id].remove(connection)
|
|
|
|
|
|
ws_manager = WebsocketManager()
|
|
|
|
|
|
@router.websocket("/transcripts/{transcript_id}/events")
|
|
async def transcript_events_websocket(
|
|
transcript_id: str,
|
|
websocket: WebSocket,
|
|
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
# user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
|
|
await ws_manager.connect(transcript_id, websocket)
|
|
|
|
# on first connection, send all events
|
|
for event in transcript.events:
|
|
await websocket.send_json(event.model_dump(mode="json"))
|
|
|
|
# XXX if transcript is final (locked=True and status=ended)
|
|
# XXX send a final event to the client and close the connection
|
|
|
|
# endless loop to wait for new events
|
|
try:
|
|
while True:
|
|
await websocket.receive()
|
|
except (RuntimeError, WebSocketDisconnect):
|
|
ws_manager.disconnect(transcript_id, websocket)
|
|
|
|
|
|
# ==============================================================
|
|
# Web RTC
|
|
# ==============================================================
|
|
|
|
|
|
async def handle_rtc_event(event: PipelineEvent, args, data):
|
|
# OFC the current implementation is not good,
|
|
# but it's just a POC before persistence. It won't query the
|
|
# transcript from the database for each event.
|
|
# print(f"Event: {event}", args, data)
|
|
transcript_id = args
|
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
|
if not transcript:
|
|
return
|
|
|
|
# event send to websocket clients may not be the same as the event
|
|
# received from the pipeline. For example, the pipeline will send
|
|
# a TRANSCRIPT event with all words, but this is not what we want
|
|
# to send to the websocket client.
|
|
|
|
# FIXME don't do copy
|
|
if event == PipelineEvent.TRANSCRIPT:
|
|
resp = transcript.add_event(
|
|
event=event,
|
|
data=TranscriptText(text=data.text, translation=data.translation),
|
|
)
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
},
|
|
)
|
|
|
|
elif event == PipelineEvent.TOPIC:
|
|
topic = TranscriptTopic(
|
|
title=data.title,
|
|
summary=data.summary,
|
|
transcript=data.transcript.text,
|
|
timestamp=data.timestamp,
|
|
)
|
|
resp = transcript.add_event(event=event, data=topic)
|
|
transcript.upsert_topic(topic)
|
|
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
"topics": transcript.topics_dump(),
|
|
},
|
|
)
|
|
|
|
elif event == PipelineEvent.FINAL_TITLE:
|
|
final_title = TranscriptFinalTitle(title=data.title)
|
|
resp = transcript.add_event(event=event, data=final_title)
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
"title": final_title.title,
|
|
},
|
|
)
|
|
|
|
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
|
resp = transcript.add_event(event=event, data=final_long_summary)
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
"long_summary": final_long_summary.long_summary,
|
|
},
|
|
)
|
|
|
|
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
|
final_short_summary = TranscriptFinalShortSummary(
|
|
short_summary=data.short_summary
|
|
)
|
|
resp = transcript.add_event(event=event, data=final_short_summary)
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
"short_summary": final_short_summary.short_summary,
|
|
},
|
|
)
|
|
|
|
elif event == PipelineEvent.STATUS:
|
|
resp = transcript.add_event(event=event, data=data)
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"events": transcript.events_dump(),
|
|
"status": data.value,
|
|
},
|
|
)
|
|
|
|
else:
|
|
logger.warning(f"Unknown event: {event}")
|
|
return
|
|
|
|
# transmit to websocket clients
|
|
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
|
|
|
|
|
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
|
async def transcript_record_webrtc(
|
|
transcript_id: str,
|
|
params: RtcOffer,
|
|
request: Request,
|
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
):
|
|
user_id = user["sub"] if user else None
|
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
if not transcript:
|
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
|
|
if transcript.locked:
|
|
raise HTTPException(status_code=400, detail="Transcript is locked")
|
|
|
|
# FIXME do not allow multiple recording at the same time
|
|
return await rtc_offer_base(
|
|
params,
|
|
request,
|
|
event_callback=handle_rtc_event,
|
|
event_callback_args=transcript_id,
|
|
audio_filename=transcript.audio_filename,
|
|
source_language=transcript.source_language,
|
|
target_language=transcript.target_language,
|
|
)
|