mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: refactor with clearer pipeline instanciation and linked to model
This commit is contained in:
@@ -1,32 +1,13 @@
|
|||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
database = databases.Database(settings.DATABASE_URL)
|
database = databases.Database(settings.DATABASE_URL)
|
||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
# import models
|
||||||
transcripts = sqlalchemy.Table(
|
import reflector.db.transcripts # noqa
|
||||||
"transcript",
|
|
||||||
metadata,
|
|
||||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
|
||||||
sqlalchemy.Column("name", sqlalchemy.String),
|
|
||||||
sqlalchemy.Column("status", sqlalchemy.String),
|
|
||||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
|
||||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
|
||||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
|
||||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
|
||||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
|
||||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
|
||||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
|
||||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
|
||||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
|
||||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
|
||||||
# with user attached, optional
|
|
||||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
|
||||||
)
|
|
||||||
|
|
||||||
engine = sqlalchemy.create_engine(
|
engine = sqlalchemy.create_engine(
|
||||||
settings.DATABASE_URL, connect_args={"check_same_thread": False}
|
settings.DATABASE_URL, connect_args={"check_same_thread": False}
|
||||||
|
|||||||
284
server/reflector/db/transcripts.py
Normal file
284
server/reflector/db/transcripts.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from reflector.db import database, metadata
|
||||||
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
|
||||||
|
transcripts = sqlalchemy.Table(
|
||||||
|
"transcript",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("status", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||||
|
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||||
|
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||||
|
# with user attached, optional
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid4():
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def generate_transcript_name():
|
||||||
|
now = datetime.utcnow()
|
||||||
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioWaveform(BaseModel):
|
||||||
|
data: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptText(BaseModel):
|
||||||
|
text: str
|
||||||
|
translation: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptSegmentTopic(BaseModel):
|
||||||
|
speaker: int
|
||||||
|
text: str
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptTopic(BaseModel):
|
||||||
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
|
title: str
|
||||||
|
summary: str
|
||||||
|
timestamp: float
|
||||||
|
text: str | None = None
|
||||||
|
words: list[ProcessorWord] = []
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptFinalShortSummary(BaseModel):
|
||||||
|
short_summary: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptFinalLongSummary(BaseModel):
|
||||||
|
long_summary: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptFinalTitle(BaseModel):
|
||||||
|
title: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptEvent(BaseModel):
|
||||||
|
event: str
|
||||||
|
data: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Transcript(BaseModel):
|
||||||
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
|
user_id: str | None = None
|
||||||
|
name: str = Field(default_factory=generate_transcript_name)
|
||||||
|
status: str = "idle"
|
||||||
|
locked: bool = False
|
||||||
|
duration: float = 0
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
title: str | None = None
|
||||||
|
short_summary: str | None = None
|
||||||
|
long_summary: str | None = None
|
||||||
|
topics: list[TranscriptTopic] = []
|
||||||
|
events: list[TranscriptEvent] = []
|
||||||
|
source_language: str = "en"
|
||||||
|
target_language: str = "en"
|
||||||
|
|
||||||
|
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||||
|
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||||
|
self.events.append(ev)
|
||||||
|
return ev
|
||||||
|
|
||||||
|
def upsert_topic(self, topic: TranscriptTopic):
|
||||||
|
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||||
|
if existing_topic:
|
||||||
|
existing_topic.update_from(topic)
|
||||||
|
else:
|
||||||
|
self.topics.append(topic)
|
||||||
|
|
||||||
|
def events_dump(self, mode="json"):
|
||||||
|
return [event.model_dump(mode=mode) for event in self.events]
|
||||||
|
|
||||||
|
def topics_dump(self, mode="json"):
|
||||||
|
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||||
|
|
||||||
|
def convert_audio_to_waveform(self, segments_count=256):
|
||||||
|
fn = self.audio_waveform_filename
|
||||||
|
if fn.exists():
|
||||||
|
return
|
||||||
|
waveform = get_audio_waveform(
|
||||||
|
path=self.audio_mp3_filename, segments_count=segments_count
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with open(fn, "w") as fd:
|
||||||
|
json.dump(waveform, fd)
|
||||||
|
except Exception:
|
||||||
|
# remove file if anything happen during the write
|
||||||
|
fn.unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def unlink(self):
|
||||||
|
self.data_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_path(self):
|
||||||
|
return Path(settings.DATA_DIR) / self.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_mp3_filename(self):
|
||||||
|
return self.data_path / "audio.mp3"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_waveform_filename(self):
|
||||||
|
return self.data_path / "audio.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_waveform(self):
|
||||||
|
try:
|
||||||
|
with open(self.audio_waveform_filename) as fd:
|
||||||
|
data = json.load(fd)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# unlink file if it's corrupted
|
||||||
|
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AudioWaveform(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptController:
|
||||||
|
async def get_all(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
order_by: str | None = None,
|
||||||
|
filter_empty: bool | None = True,
|
||||||
|
filter_recording: bool | None = True,
|
||||||
|
) -> list[Transcript]:
|
||||||
|
"""
|
||||||
|
Get all transcripts
|
||||||
|
|
||||||
|
If `user_id` is specified, only return transcripts that belong to the user.
|
||||||
|
Otherwise, return all anonymous transcripts.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- `order_by`: field to order by, e.g. "-created_at"
|
||||||
|
- `filter_empty`: filter out empty transcripts
|
||||||
|
- `filter_recording`: filter out transcripts that are currently recording
|
||||||
|
"""
|
||||||
|
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||||
|
|
||||||
|
if order_by is not None:
|
||||||
|
field = getattr(transcripts.c, order_by[1:])
|
||||||
|
if order_by.startswith("-"):
|
||||||
|
field = field.desc()
|
||||||
|
query = query.order_by(field)
|
||||||
|
|
||||||
|
if filter_empty:
|
||||||
|
query = query.filter(transcripts.c.status != "idle")
|
||||||
|
|
||||||
|
if filter_recording:
|
||||||
|
query = query.filter(transcripts.c.status != "recording")
|
||||||
|
|
||||||
|
results = await database.fetch_all(query)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||||
|
"""
|
||||||
|
Get a transcript by id
|
||||||
|
"""
|
||||||
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
|
if "user_id" in kwargs:
|
||||||
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
|
result = await database.fetch_one(query)
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
return Transcript(**result)
|
||||||
|
|
||||||
|
async def add(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
source_language: str = "en",
|
||||||
|
target_language: str = "en",
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new transcript
|
||||||
|
"""
|
||||||
|
transcript = Transcript(
|
||||||
|
name=name,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
query = transcripts.insert().values(**transcript.model_dump())
|
||||||
|
await database.execute(query)
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
async def update(self, transcript: Transcript, values: dict):
|
||||||
|
"""
|
||||||
|
Update a transcript fields with key/values in values
|
||||||
|
"""
|
||||||
|
query = (
|
||||||
|
transcripts.update()
|
||||||
|
.where(transcripts.c.id == transcript.id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
await database.execute(query)
|
||||||
|
for key, value in values.items():
|
||||||
|
setattr(transcript, key, value)
|
||||||
|
|
||||||
|
async def remove_by_id(
|
||||||
|
self,
|
||||||
|
transcript_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Remove a transcript by id
|
||||||
|
"""
|
||||||
|
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||||
|
if not transcript:
|
||||||
|
return
|
||||||
|
if user_id is not None and transcript.user_id != user_id:
|
||||||
|
return
|
||||||
|
transcript.unlink()
|
||||||
|
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||||
|
await database.execute(query)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def transaction(self):
|
||||||
|
"""
|
||||||
|
A context manager for database transaction
|
||||||
|
"""
|
||||||
|
async with database.transaction():
|
||||||
|
yield
|
||||||
|
|
||||||
|
async def append_event(
|
||||||
|
self,
|
||||||
|
transcript: Transcript,
|
||||||
|
event: str,
|
||||||
|
data: Any,
|
||||||
|
) -> TranscriptEvent:
|
||||||
|
"""
|
||||||
|
Append an event to a transcript
|
||||||
|
"""
|
||||||
|
resp = transcript.add_event(event=event, data=data)
|
||||||
|
await self.update(transcript, {"events": transcript.events_dump()})
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
transcripts_controller = TranscriptController()
|
||||||
230
server/reflector/pipelines/main_live_pipeline.py
Normal file
230
server/reflector/pipelines/main_live_pipeline.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
Main reflector pipeline for live streaming
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
This is the default pipeline used in the API.
|
||||||
|
|
||||||
|
It is decoupled to:
|
||||||
|
- PipelineMainLive: have limited processing during live
|
||||||
|
- PipelineMainPost: do heavy lifting after the live
|
||||||
|
|
||||||
|
It is directly linked to our data model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
Transcript,
|
||||||
|
TranscriptFinalLongSummary,
|
||||||
|
TranscriptFinalShortSummary,
|
||||||
|
TranscriptFinalTitle,
|
||||||
|
TranscriptText,
|
||||||
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.pipelines.runner import PipelineRunner
|
||||||
|
from reflector.processors import (
|
||||||
|
AudioChunkerProcessor,
|
||||||
|
AudioFileWriterProcessor,
|
||||||
|
AudioMergeProcessor,
|
||||||
|
AudioTranscriptAutoProcessor,
|
||||||
|
BroadcastProcessor,
|
||||||
|
Pipeline,
|
||||||
|
TranscriptFinalLongSummaryProcessor,
|
||||||
|
TranscriptFinalShortSummaryProcessor,
|
||||||
|
TranscriptFinalTitleProcessor,
|
||||||
|
TranscriptLinerProcessor,
|
||||||
|
TranscriptTopicDetectorProcessor,
|
||||||
|
TranscriptTranslatorProcessor,
|
||||||
|
)
|
||||||
|
from reflector.tasks.worker import celery
|
||||||
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_to_socket(func):
|
||||||
|
"""
|
||||||
|
Decorator to broadcast transcript event to websockets
|
||||||
|
concerning this transcript
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
resp = await func(self, *args, **kwargs)
|
||||||
|
if resp is None:
|
||||||
|
return
|
||||||
|
await self.ws_manager.send_json(
|
||||||
|
room_id=self.ws_room_id,
|
||||||
|
message=resp.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainBase(PipelineRunner):
|
||||||
|
transcript_id: str
|
||||||
|
ws_room_id: str | None = None
|
||||||
|
ws_manager: WebsocketManager | None = None
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
# prepare websocket
|
||||||
|
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||||
|
self.ws_manager = get_ws_manager()
|
||||||
|
|
||||||
|
async def get_transcript(self) -> Transcript:
|
||||||
|
# fetch the transcript
|
||||||
|
result = await transcripts_controller.get_by_id(
|
||||||
|
transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
raise Exception("Transcript not found")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainLive(PipelineMainBase):
|
||||||
|
audio_filename: Path | None = None
|
||||||
|
source_language: str = "en"
|
||||||
|
target_language: str = "en"
|
||||||
|
|
||||||
|
@broadcast_to_socket
|
||||||
|
async def on_transcript(self, data):
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="TRANSCRIPT",
|
||||||
|
data=TranscriptText(text=data.text, translation=data.translation),
|
||||||
|
)
|
||||||
|
|
||||||
|
@broadcast_to_socket
|
||||||
|
async def on_topic(self, data):
|
||||||
|
topic = TranscriptTopic(
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
text=data.transcript.text,
|
||||||
|
words=data.transcript.words,
|
||||||
|
)
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="TOPIC",
|
||||||
|
data=topic,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create(self) -> Pipeline:
|
||||||
|
# create a context for the whole rtc transaction
|
||||||
|
# add a customised logger to the context
|
||||||
|
self.prepare()
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
processors = [
|
||||||
|
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||||
|
AudioChunkerProcessor(),
|
||||||
|
AudioMergeProcessor(),
|
||||||
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
|
TranscriptLinerProcessor(),
|
||||||
|
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||||
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
|
]
|
||||||
|
pipeline = Pipeline(*processors)
|
||||||
|
pipeline.options = self
|
||||||
|
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||||
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
|
|
||||||
|
# when the pipeline ends, connect to the post pipeline
|
||||||
|
async def on_ended():
|
||||||
|
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||||
|
|
||||||
|
pipeline.on_ended = self
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainPost(PipelineMainBase):
|
||||||
|
"""
|
||||||
|
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@broadcast_to_socket
|
||||||
|
async def on_final_title(self, data):
|
||||||
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
if not transcript.title:
|
||||||
|
transcripts_controller.update(
|
||||||
|
self.transcript,
|
||||||
|
{
|
||||||
|
"title": final_title.title,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_TITLE",
|
||||||
|
data=final_title,
|
||||||
|
)
|
||||||
|
|
||||||
|
@broadcast_to_socket
|
||||||
|
async def on_final_long_summary(self, data):
|
||||||
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"long_summary": final_long_summary.long_summary,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_LONG_SUMMARY",
|
||||||
|
data=final_long_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
@broadcast_to_socket
|
||||||
|
async def on_final_short_summary(self, data):
|
||||||
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
|
short_summary=data.short_summary
|
||||||
|
)
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"short_summary": final_short_summary.short_summary,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_SHORT_SUMMARY",
|
||||||
|
data=final_short_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create(self) -> Pipeline:
|
||||||
|
# create a context for the whole rtc transaction
|
||||||
|
# add a customised logger to the context
|
||||||
|
self.prepare()
|
||||||
|
processors = [
|
||||||
|
# add diarization
|
||||||
|
BroadcastProcessor(
|
||||||
|
processors=[
|
||||||
|
TranscriptFinalTitleProcessor.as_threaded(
|
||||||
|
callback=self.on_final_title
|
||||||
|
),
|
||||||
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
|
callback=self.on_final_long_summary
|
||||||
|
),
|
||||||
|
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||||
|
callback=self.on_final_short_summary
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
pipeline = Pipeline(*processors)
|
||||||
|
pipeline.options = self
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task
|
||||||
|
def task_pipeline_main_post(transcript_id: str):
|
||||||
|
pass
|
||||||
117
server/reflector/pipelines/runner.py
Normal file
117
server/reflector/pipelines/runner.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
Pipeline Runner
|
||||||
|
===============
|
||||||
|
|
||||||
|
Pipeline runner designed to be executed in a asyncio task.
|
||||||
|
|
||||||
|
It is meant to be subclassed, and implement a create() method
|
||||||
|
that expose/return a Pipeline instance.
|
||||||
|
|
||||||
|
During its lifecycle, it will emit the following status:
|
||||||
|
- started: the pipeline has been started
|
||||||
|
- push: the pipeline received at least one data
|
||||||
|
- flush: the pipeline is flushing
|
||||||
|
- ended: the pipeline has ended
|
||||||
|
- error: the pipeline has ended with an error
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.processors import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(BaseModel):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
status: str = "idle"
|
||||||
|
on_status: Callable | None = None
|
||||||
|
on_ended: Callable | None = None
|
||||||
|
pipeline: Pipeline | None = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._q_cmd = asyncio.Queue()
|
||||||
|
self._ev_done = asyncio.Event()
|
||||||
|
self._is_first_push = True
|
||||||
|
|
||||||
|
def create(self) -> Pipeline:
|
||||||
|
"""
|
||||||
|
Create the pipeline if not specified earlier.
|
||||||
|
Should be implemented in a subclass
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""
|
||||||
|
Start the pipeline as a coroutine task
|
||||||
|
"""
|
||||||
|
asyncio.get_event_loop().create_task(self.run())
|
||||||
|
|
||||||
|
async def push(self, data):
|
||||||
|
"""
|
||||||
|
Push data to the pipeline
|
||||||
|
"""
|
||||||
|
await self._add_cmd("PUSH", data)
|
||||||
|
|
||||||
|
async def flush(self):
|
||||||
|
"""
|
||||||
|
Flush the pipeline
|
||||||
|
"""
|
||||||
|
await self._add_cmd("FLUSH", None)
|
||||||
|
|
||||||
|
async def _add_cmd(self, cmd: str, data):
|
||||||
|
"""
|
||||||
|
Enqueue a command to be executed in the runner.
|
||||||
|
Currently supported commands: PUSH, FLUSH
|
||||||
|
"""
|
||||||
|
await self._q_cmd.put([cmd, data])
|
||||||
|
|
||||||
|
async def _set_status(self, status):
|
||||||
|
print("set_status", status)
|
||||||
|
self.status = status
|
||||||
|
if self.on_status:
|
||||||
|
try:
|
||||||
|
await self.on_status(status)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("PipelineRunner status_callback error", error=e)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
try:
|
||||||
|
# create the pipeline if not yet done
|
||||||
|
await self._set_status("init")
|
||||||
|
self._is_first_push = True
|
||||||
|
if not self.pipeline:
|
||||||
|
self.pipeline = await self.create()
|
||||||
|
|
||||||
|
# start the loop
|
||||||
|
await self._set_status("started")
|
||||||
|
while not self._ev_done.is_set():
|
||||||
|
cmd, data = await self._q_cmd.get()
|
||||||
|
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||||
|
if func:
|
||||||
|
await func(data)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown command {cmd}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("PipelineRunner error", error=e)
|
||||||
|
await self._set_status("error")
|
||||||
|
self._ev_done.set()
|
||||||
|
if self.on_ended:
|
||||||
|
await self.on_ended()
|
||||||
|
|
||||||
|
async def cmd_push(self, data):
|
||||||
|
if self._is_first_push:
|
||||||
|
await self._set_status("push")
|
||||||
|
self._is_first_push = False
|
||||||
|
await self.pipeline.push(data)
|
||||||
|
|
||||||
|
async def cmd_flush(self, data):
|
||||||
|
await self._set_status("flush")
|
||||||
|
await self.pipeline.flush()
|
||||||
|
await self._set_status("ended")
|
||||||
|
self._ev_done.set()
|
||||||
|
if self.on_ended:
|
||||||
|
await self.on_ended()
|
||||||
@@ -3,7 +3,13 @@ from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
|||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||||
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
|
from .base import ( # noqa: F401
|
||||||
|
BroadcastProcessor,
|
||||||
|
Pipeline,
|
||||||
|
PipelineEvent,
|
||||||
|
Processor,
|
||||||
|
ThreadedProcessor,
|
||||||
|
)
|
||||||
from .transcript_final_long_summary import ( # noqa: F401
|
from .transcript_final_long_summary import ( # noqa: F401
|
||||||
TranscriptFinalLongSummaryProcessor,
|
TranscriptFinalLongSummaryProcessor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from json import dumps, loads
|
from json import loads
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
@@ -11,25 +9,7 @@ from prometheus_client import Gauge
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.events import subscribers_shutdown
|
from reflector.events import subscribers_shutdown
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.pipelines.runner import PipelineRunner
|
||||||
AudioChunkerProcessor,
|
|
||||||
AudioFileWriterProcessor,
|
|
||||||
AudioMergeProcessor,
|
|
||||||
AudioTranscriptAutoProcessor,
|
|
||||||
FinalLongSummary,
|
|
||||||
FinalShortSummary,
|
|
||||||
Pipeline,
|
|
||||||
TitleSummary,
|
|
||||||
Transcript,
|
|
||||||
TranscriptFinalLongSummaryProcessor,
|
|
||||||
TranscriptFinalShortSummaryProcessor,
|
|
||||||
TranscriptFinalTitleProcessor,
|
|
||||||
TranscriptLinerProcessor,
|
|
||||||
TranscriptTopicDetectorProcessor,
|
|
||||||
TranscriptTranslatorProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.base import BroadcastProcessor
|
|
||||||
from reflector.processors.types import FinalTitle
|
|
||||||
|
|
||||||
sessions = []
|
sessions = []
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -85,121 +65,10 @@ class PipelineEvent(StrEnum):
|
|||||||
FINAL_TITLE = "FINAL_TITLE"
|
FINAL_TITLE = "FINAL_TITLE"
|
||||||
|
|
||||||
|
|
||||||
class PipelineOptions(BaseModel):
|
|
||||||
audio_filename: Path | None = None
|
|
||||||
source_language: str = "en"
|
|
||||||
target_language: str = "en"
|
|
||||||
|
|
||||||
on_transcript: Callable | None = None
|
|
||||||
on_topic: Callable | None = None
|
|
||||||
on_final_title: Callable | None = None
|
|
||||||
on_final_short_summary: Callable | None = None
|
|
||||||
on_final_long_summary: Callable | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunner(object):
|
|
||||||
"""
|
|
||||||
Pipeline runner designed to be executed in a asyncio task
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
|
|
||||||
self.pipeline = pipeline
|
|
||||||
self.q_cmd = asyncio.Queue()
|
|
||||||
self.ev_done = asyncio.Event()
|
|
||||||
self.status = "idle"
|
|
||||||
self.status_callback = status_callback
|
|
||||||
|
|
||||||
async def update_status(self, status):
|
|
||||||
print("update_status", status)
|
|
||||||
self.status = status
|
|
||||||
if self.status_callback:
|
|
||||||
try:
|
|
||||||
await self.status_callback(status)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("PipelineRunner status_callback error", error=e)
|
|
||||||
|
|
||||||
async def add_cmd(self, cmd: str, data):
|
|
||||||
await self.q_cmd.put([cmd, data])
|
|
||||||
|
|
||||||
async def push(self, data):
|
|
||||||
await self.add_cmd("PUSH", data)
|
|
||||||
|
|
||||||
async def flush(self):
|
|
||||||
await self.add_cmd("FLUSH", None)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
try:
|
|
||||||
await self.update_status("running")
|
|
||||||
while not self.ev_done.is_set():
|
|
||||||
cmd, data = await self.q_cmd.get()
|
|
||||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
|
||||||
if func:
|
|
||||||
await func(data)
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unknown command {cmd}")
|
|
||||||
except Exception as e:
|
|
||||||
await self.update_status("error")
|
|
||||||
logger.error("PipelineRunner error", error=e)
|
|
||||||
|
|
||||||
async def cmd_push(self, data):
|
|
||||||
if self.status == "idle":
|
|
||||||
await self.update_status("recording")
|
|
||||||
await self.pipeline.push(data)
|
|
||||||
|
|
||||||
async def cmd_flush(self, data):
|
|
||||||
await self.update_status("processing")
|
|
||||||
await self.pipeline.flush()
|
|
||||||
await self.update_status("ended")
|
|
||||||
self.ev_done.set()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
print("start task")
|
|
||||||
asyncio.get_event_loop().create_task(self.run())
|
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_live_create(options: PipelineOptions):
|
|
||||||
# create a context for the whole rtc transaction
|
|
||||||
# add a customised logger to the context
|
|
||||||
processors = []
|
|
||||||
if options.audio_filename is not None:
|
|
||||||
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
|
|
||||||
processors += [
|
|
||||||
AudioChunkerProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
|
||||||
TranscriptLinerProcessor(),
|
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
|
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
|
|
||||||
BroadcastProcessor(
|
|
||||||
processors=[
|
|
||||||
TranscriptFinalTitleProcessor.as_threaded(
|
|
||||||
callback=options.on_final_title
|
|
||||||
),
|
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
|
||||||
callback=options.on_final_long_summary
|
|
||||||
),
|
|
||||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
|
||||||
callback=options.on_final_short_summary
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
pipeline = Pipeline(*processors)
|
|
||||||
pipeline.options = options
|
|
||||||
pipeline.set_pref("audio:source_language", options.source_language)
|
|
||||||
pipeline.set_pref("audio:target_language", options.target_language)
|
|
||||||
|
|
||||||
return pipeline
|
|
||||||
|
|
||||||
|
|
||||||
async def rtc_offer_base(
|
async def rtc_offer_base(
|
||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
event_callback=None,
|
pipeline_runner: PipelineRunner,
|
||||||
event_callback_args=None,
|
|
||||||
audio_filename: Path | None = None,
|
|
||||||
source_language: str = "en",
|
|
||||||
target_language: str = "en",
|
|
||||||
):
|
):
|
||||||
# build an rtc session
|
# build an rtc session
|
||||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||||
@@ -209,132 +78,9 @@ async def rtc_offer_base(
|
|||||||
clientid = f"{peername[0]}:{peername[1]}"
|
clientid = f"{peername[0]}:{peername[1]}"
|
||||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||||
|
|
||||||
async def update_status(status: str):
|
|
||||||
changed = ctx.status != status
|
|
||||||
if changed:
|
|
||||||
ctx.status = status
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.STATUS,
|
|
||||||
args=event_callback_args,
|
|
||||||
data=StrValue(value=status),
|
|
||||||
)
|
|
||||||
|
|
||||||
# build pipeline callback
|
|
||||||
async def on_transcript(transcript: Transcript):
|
|
||||||
ctx.logger.info("Transcript", transcript=transcript)
|
|
||||||
|
|
||||||
# send to RTC
|
|
||||||
if ctx.data_channel.readyState == "open":
|
|
||||||
result = {
|
|
||||||
"cmd": "SHOW_TRANSCRIPTION",
|
|
||||||
"text": transcript.text,
|
|
||||||
}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
# send to callback (eg. websocket)
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.TRANSCRIPT,
|
|
||||||
args=event_callback_args,
|
|
||||||
data=transcript,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_topic(topic: TitleSummary):
|
|
||||||
# FIXME: make it incremental with the frontend, not send everything
|
|
||||||
ctx.logger.info("Topic", topic=topic)
|
|
||||||
ctx.topics.append(
|
|
||||||
{
|
|
||||||
"title": topic.title,
|
|
||||||
"timestamp": topic.timestamp,
|
|
||||||
"transcript": topic.transcript.text,
|
|
||||||
"desc": topic.summary,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# send to RTC
|
|
||||||
if ctx.data_channel.readyState == "open":
|
|
||||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
# send to callback (eg. websocket)
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_final_short_summary(summary: FinalShortSummary):
|
|
||||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
|
||||||
|
|
||||||
# send to RTC
|
|
||||||
if ctx.data_channel.readyState == "open":
|
|
||||||
result = {
|
|
||||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
|
||||||
"summary": summary.short_summary,
|
|
||||||
"duration": summary.duration,
|
|
||||||
}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
# send to callback (eg. websocket)
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
|
||||||
args=event_callback_args,
|
|
||||||
data=summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_final_long_summary(summary: FinalLongSummary):
|
|
||||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
|
||||||
|
|
||||||
# send to RTC
|
|
||||||
if ctx.data_channel.readyState == "open":
|
|
||||||
result = {
|
|
||||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
|
||||||
"summary": summary.long_summary,
|
|
||||||
"duration": summary.duration,
|
|
||||||
}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
# send to callback (eg. websocket)
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
|
||||||
args=event_callback_args,
|
|
||||||
data=summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_final_title(title: FinalTitle):
|
|
||||||
ctx.logger.info("FinalTitle", final_title=title)
|
|
||||||
|
|
||||||
# send to RTC
|
|
||||||
if ctx.data_channel.readyState == "open":
|
|
||||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
# send to callback (eg. websocket)
|
|
||||||
if event_callback:
|
|
||||||
await event_callback(
|
|
||||||
event=PipelineEvent.FINAL_TITLE,
|
|
||||||
args=event_callback_args,
|
|
||||||
data=title,
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle RTC peer connection
|
# handle RTC peer connection
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
ctx.pipeline_runner = pipeline_runner
|
||||||
# create pipeline
|
|
||||||
options = PipelineOptions(
|
|
||||||
audio_filename=audio_filename,
|
|
||||||
source_language=source_language,
|
|
||||||
target_language=target_language,
|
|
||||||
on_transcript=on_transcript,
|
|
||||||
on_topic=on_topic,
|
|
||||||
on_final_short_summary=on_final_short_summary,
|
|
||||||
on_final_long_summary=on_final_long_summary,
|
|
||||||
on_final_title=on_final_title,
|
|
||||||
)
|
|
||||||
pipeline = await pipeline_live_create(options)
|
|
||||||
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
|
|
||||||
ctx.pipeline_runner.start()
|
ctx.pipeline_runner.start()
|
||||||
|
|
||||||
async def flush_pipeline_and_quit(close=True):
|
async def flush_pipeline_and_quit(close=True):
|
||||||
@@ -400,8 +146,3 @@ async def rtc_clean_sessions(_):
|
|||||||
logger.debug(f"Closing session {pc}")
|
logger.debug(f"Closing session {pc}")
|
||||||
await pc.close()
|
await pc.close()
|
||||||
sessions.clear()
|
sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/offer")
|
|
||||||
async def rtc_offer(params: RtcOffer, request: Request):
|
|
||||||
return await rtc_offer_base(params, request)
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import json
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -15,12 +12,13 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from reflector.db import database, transcripts
|
from reflector.db.transcripts import (
|
||||||
from reflector.logger import logger
|
AudioWaveform,
|
||||||
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
|
||||||
from reflector.ws_manager import get_ws_manager
|
from reflector.ws_manager import get_ws_manager
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
@@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
ws_manager = get_ws_manager()
|
ws_manager = get_ws_manager()
|
||||||
|
|
||||||
# ==============================================================
|
|
||||||
# Models to move to a database, but required for the API to work
|
|
||||||
# ==============================================================
|
|
||||||
|
|
||||||
|
|
||||||
def generate_uuid4():
|
|
||||||
return str(uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
def generate_transcript_name():
|
|
||||||
now = datetime.utcnow()
|
|
||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
|
|
||||||
|
|
||||||
class AudioWaveform(BaseModel):
|
|
||||||
data: list[float]
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptText(BaseModel):
|
|
||||||
text: str
|
|
||||||
translation: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptSegmentTopic(BaseModel):
|
|
||||||
speaker: int
|
|
||||||
text: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopic(BaseModel):
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
|
||||||
title: str
|
|
||||||
summary: str
|
|
||||||
timestamp: float
|
|
||||||
text: str | None = None
|
|
||||||
words: list[ProcessorWord] = []
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptFinalShortSummary(BaseModel):
|
|
||||||
short_summary: str
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptFinalLongSummary(BaseModel):
|
|
||||||
long_summary: str
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptFinalTitle(BaseModel):
|
|
||||||
title: str
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptEvent(BaseModel):
|
|
||||||
event: str
|
|
||||||
data: dict
|
|
||||||
|
|
||||||
|
|
||||||
class Transcript(BaseModel):
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
|
||||||
user_id: str | None = None
|
|
||||||
name: str = Field(default_factory=generate_transcript_name)
|
|
||||||
status: str = "idle"
|
|
||||||
locked: bool = False
|
|
||||||
duration: float = 0
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
title: str | None = None
|
|
||||||
short_summary: str | None = None
|
|
||||||
long_summary: str | None = None
|
|
||||||
topics: list[TranscriptTopic] = []
|
|
||||||
events: list[TranscriptEvent] = []
|
|
||||||
source_language: str = "en"
|
|
||||||
target_language: str = "en"
|
|
||||||
|
|
||||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
|
||||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
|
||||||
self.events.append(ev)
|
|
||||||
return ev
|
|
||||||
|
|
||||||
def upsert_topic(self, topic: TranscriptTopic):
|
|
||||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
|
||||||
if existing_topic:
|
|
||||||
existing_topic.update_from(topic)
|
|
||||||
else:
|
|
||||||
self.topics.append(topic)
|
|
||||||
|
|
||||||
def events_dump(self, mode="json"):
|
|
||||||
return [event.model_dump(mode=mode) for event in self.events]
|
|
||||||
|
|
||||||
def topics_dump(self, mode="json"):
|
|
||||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
|
||||||
|
|
||||||
def convert_audio_to_waveform(self, segments_count=256):
|
|
||||||
fn = self.audio_waveform_filename
|
|
||||||
if fn.exists():
|
|
||||||
return
|
|
||||||
waveform = get_audio_waveform(
|
|
||||||
path=self.audio_mp3_filename, segments_count=segments_count
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with open(fn, "w") as fd:
|
|
||||||
json.dump(waveform, fd)
|
|
||||||
except Exception:
|
|
||||||
# remove file if anything happen during the write
|
|
||||||
fn.unlink(missing_ok=True)
|
|
||||||
raise
|
|
||||||
return waveform
|
|
||||||
|
|
||||||
def unlink(self):
|
|
||||||
self.data_path.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data_path(self):
|
|
||||||
return Path(settings.DATA_DIR) / self.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_mp3_filename(self):
|
|
||||||
return self.data_path / "audio.mp3"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_waveform_filename(self):
|
|
||||||
return self.data_path / "audio.json"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_waveform(self):
|
|
||||||
try:
|
|
||||||
with open(self.audio_waveform_filename) as fd:
|
|
||||||
data = json.load(fd)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# unlink file if it's corrupted
|
|
||||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return AudioWaveform(data=data)
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptController:
|
|
||||||
async def get_all(
|
|
||||||
self,
|
|
||||||
user_id: str | None = None,
|
|
||||||
order_by: str | None = None,
|
|
||||||
filter_empty: bool | None = False,
|
|
||||||
filter_recording: bool | None = False,
|
|
||||||
) -> list[Transcript]:
|
|
||||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
|
||||||
|
|
||||||
if order_by is not None:
|
|
||||||
field = getattr(transcripts.c, order_by[1:])
|
|
||||||
if order_by.startswith("-"):
|
|
||||||
field = field.desc()
|
|
||||||
query = query.order_by(field)
|
|
||||||
|
|
||||||
if filter_empty:
|
|
||||||
query = query.filter(transcripts.c.status != "idle")
|
|
||||||
|
|
||||||
if filter_recording:
|
|
||||||
query = query.filter(transcripts.c.status != "recording")
|
|
||||||
|
|
||||||
results = await database.fetch_all(query)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
|
||||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
|
||||||
if "user_id" in kwargs:
|
|
||||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
|
||||||
result = await database.fetch_one(query)
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
return Transcript(**result)
|
|
||||||
|
|
||||||
async def add(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
source_language: str = "en",
|
|
||||||
target_language: str = "en",
|
|
||||||
user_id: str | None = None,
|
|
||||||
):
|
|
||||||
transcript = Transcript(
|
|
||||||
name=name,
|
|
||||||
source_language=source_language,
|
|
||||||
target_language=target_language,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
query = transcripts.insert().values(**transcript.model_dump())
|
|
||||||
await database.execute(query)
|
|
||||||
return transcript
|
|
||||||
|
|
||||||
async def update(self, transcript: Transcript, values: dict):
|
|
||||||
query = (
|
|
||||||
transcripts.update()
|
|
||||||
.where(transcripts.c.id == transcript.id)
|
|
||||||
.values(**values)
|
|
||||||
)
|
|
||||||
await database.execute(query)
|
|
||||||
for key, value in values.items():
|
|
||||||
setattr(transcript, key, value)
|
|
||||||
|
|
||||||
async def remove_by_id(
|
|
||||||
self, transcript_id: str, user_id: str | None = None
|
|
||||||
) -> None:
|
|
||||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
|
||||||
if not transcript:
|
|
||||||
return
|
|
||||||
if user_id is not None and transcript.user_id != user_id:
|
|
||||||
return
|
|
||||||
transcript.unlink()
|
|
||||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
|
||||||
await database.execute(query)
|
|
||||||
|
|
||||||
|
|
||||||
transcripts_controller = TranscriptController()
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Transcripts list
|
# Transcripts list
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
@@ -537,114 +325,6 @@ async def transcript_events_websocket(
|
|||||||
# ==============================================================
|
# ==============================================================
|
||||||
|
|
||||||
|
|
||||||
async def handle_rtc_event(event: PipelineEvent, args, data):
|
|
||||||
try:
|
|
||||||
return await handle_rtc_event_once(event, args, data)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error handling RTC event")
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_rtc_event_once(event: PipelineEvent, args, data):
|
|
||||||
# OFC the current implementation is not good,
|
|
||||||
# but it's just a POC before persistence. It won't query the
|
|
||||||
# transcript from the database for each event.
|
|
||||||
# print(f"Event: {event}", args, data)
|
|
||||||
transcript_id = args
|
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
|
||||||
if not transcript:
|
|
||||||
return
|
|
||||||
|
|
||||||
# event send to websocket clients may not be the same as the event
|
|
||||||
# received from the pipeline. For example, the pipeline will send
|
|
||||||
# a TRANSCRIPT event with all words, but this is not what we want
|
|
||||||
# to send to the websocket client.
|
|
||||||
|
|
||||||
# FIXME don't do copy
|
|
||||||
if event == PipelineEvent.TRANSCRIPT:
|
|
||||||
resp = transcript.add_event(
|
|
||||||
event=event,
|
|
||||||
data=TranscriptText(text=data.text, translation=data.translation),
|
|
||||||
)
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event == PipelineEvent.TOPIC:
|
|
||||||
topic = TranscriptTopic(
|
|
||||||
title=data.title,
|
|
||||||
summary=data.summary,
|
|
||||||
timestamp=data.timestamp,
|
|
||||||
text=data.transcript.text,
|
|
||||||
words=data.transcript.words,
|
|
||||||
)
|
|
||||||
resp = transcript.add_event(event=event, data=topic)
|
|
||||||
transcript.upsert_topic(topic)
|
|
||||||
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
"topics": transcript.topics_dump(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event == PipelineEvent.FINAL_TITLE:
|
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
|
||||||
resp = transcript.add_event(event=event, data=final_title)
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
"title": final_title.title,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
|
||||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
"long_summary": final_long_summary.long_summary,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
|
||||||
final_short_summary = TranscriptFinalShortSummary(
|
|
||||||
short_summary=data.short_summary
|
|
||||||
)
|
|
||||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
"short_summary": final_short_summary.short_summary,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event == PipelineEvent.STATUS:
|
|
||||||
resp = transcript.add_event(event=event, data=data)
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": transcript.events_dump(),
|
|
||||||
"status": data.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unknown event: {event}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# transmit to websocket clients
|
|
||||||
room_id = f"ts:{transcript_id}"
|
|
||||||
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||||
async def transcript_record_webrtc(
|
async def transcript_record_webrtc(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -660,13 +340,14 @@ async def transcript_record_webrtc(
|
|||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||||
|
|
||||||
|
# create a pipeline runner
|
||||||
|
from reflector.pipelines.main_live_pipeline import PipelineMainLive
|
||||||
|
|
||||||
|
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
|
||||||
|
|
||||||
# FIXME do not allow multiple recording at the same time
|
# FIXME do not allow multiple recording at the same time
|
||||||
return await rtc_offer_base(
|
return await rtc_offer_base(
|
||||||
params,
|
params,
|
||||||
request,
|
request,
|
||||||
event_callback=handle_rtc_event,
|
pipeline_runner=pipeline_runner,
|
||||||
event_callback_args=transcript_id,
|
|
||||||
audio_filename=transcript.audio_mp3_filename,
|
|
||||||
source_language=transcript.source_language,
|
|
||||||
target_language=transcript.target_language,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import json
|
|||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
ws_manager = None
|
ws_manager = None
|
||||||
|
|
||||||
@@ -114,7 +115,8 @@ def get_ws_manager() -> WebsocketManager:
|
|||||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||||
"""
|
"""
|
||||||
global ws_manager
|
global ws_manager
|
||||||
from reflector.settings import settings
|
if ws_manager:
|
||||||
|
return ws_manager
|
||||||
|
|
||||||
pubsub_client = RedisPubSubManager(
|
pubsub_client = RedisPubSubManager(
|
||||||
host=settings.REDIS_HOST,
|
host=settings.REDIS_HOST,
|
||||||
|
|||||||
Reference in New Issue
Block a user