mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
server: refactor with clearer pipeline instanciation and linked to model
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
@@ -15,12 +12,13 @@ from fastapi import (
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.db.transcripts import (
|
||||
AudioWaveform,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
@@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
router = APIRouter()
|
||||
ws_manager = get_ws_manager()
|
||||
|
||||
# ==============================================================
|
||||
# Models to move to a database, but required for the API to work
|
||||
# ==============================================================
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptSegmentTopic(BaseModel):
|
||||
speaker: int
|
||||
text: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
text: str | None = None
|
||||
words: list[ProcessorWord] = []
|
||||
|
||||
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
if existing_topic:
|
||||
existing_topic.update_from(topic)
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
filter_recording: bool | None = False,
|
||||
) -> list[Transcript]:
|
||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict):
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> None:
|
||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Transcripts list
|
||||
# ==============================================================
|
||||
@@ -537,114 +325,6 @@ async def transcript_events_websocket(
|
||||
# ==============================================================
|
||||
|
||||
|
||||
async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
try:
|
||||
return await handle_rtc_event_once(event, args, data)
|
||||
except Exception:
|
||||
logger.exception("Error handling RTC event")
|
||||
|
||||
|
||||
async def handle_rtc_event_once(event: PipelineEvent, args, data):
|
||||
# OFC the current implementation is not good,
|
||||
# but it's just a POC before persistence. It won't query the
|
||||
# transcript from the database for each event.
|
||||
# print(f"Event: {event}", args, data)
|
||||
transcript_id = args
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
# event send to websocket clients may not be the same as the event
|
||||
# received from the pipeline. For example, the pipeline will send
|
||||
# a TRANSCRIPT event with all words, but this is not what we want
|
||||
# to send to the websocket client.
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(
|
||||
event=event,
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
text=data.transcript.text,
|
||||
words=data.transcript.words,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_TITLE:
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
resp = transcript.add_event(event=event, data=final_title)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"status": data.value,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
room_id = f"ts:{transcript_id}"
|
||||
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: str,
|
||||
@@ -660,13 +340,14 @@ async def transcript_record_webrtc(
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# create a pipeline runner
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainLive
|
||||
|
||||
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
|
||||
|
||||
# FIXME do not allow multiple recording at the same time
|
||||
return await rtc_offer_base(
|
||||
params,
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_mp3_filename,
|
||||
source_language=transcript.source_language,
|
||||
target_language=transcript.target_language,
|
||||
pipeline_runner=pipeline_runner,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user