server: refactor with clearer pipeline instanciation and linked to model

This commit is contained in:
2023-10-26 19:00:56 +02:00
committed by Mathieu Virbel
parent 433c0500cc
commit 1c42473da0
8 changed files with 658 additions and 616 deletions

View File

@@ -1,32 +1,13 @@
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
# import models
import reflector.db.transcripts # noqa
engine = sqlalchemy.create_engine(
settings.DATABASE_URL, connect_args={"check_same_thread": False}

View File

@@ -0,0 +1,284 @@
import json
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import sqlalchemy
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
def generate_uuid4():
return str(uuid4())
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
timestamp: float
text: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
if existing_topic:
existing_topic.update_from(topic)
else:
self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(
self,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = True,
filter_recording: bool | None = True,
) -> list[Transcript]:
"""
Get all transcripts
If `user_id` is specified, only return transcripts that belong to the user.
Otherwise, return all anonymous transcripts.
Parameters:
- `order_by`: field to order by, e.g. "-created_at"
- `filter_empty`: filter out empty transcripts
- `filter_recording`: filter out transcripts that are currently recording
"""
query = transcripts.select().where(transcripts.c.user_id == user_id)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
"""
Get a transcript by id
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
"""
Add a new transcript
"""
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict):
"""
Update a transcript fields with key/values in values
"""
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with database.transaction():
yield
async def append_event(
self,
transcript: Transcript,
event: str,
data: Any,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
transcripts_controller = TranscriptController()

View File

@@ -0,0 +1,230 @@
"""
Main reflector pipeline for live streaming
==========================================
This is the default pipeline used in the API.
It is decoupled to:
- PipelineMainLive: have limited processing during live
- PipelineMainPost: do heavy lifting after the live
It is directly linked to our data model.
"""
from pathlib import Path
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
transcripts_controller,
)
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
BroadcastProcessor,
Pipeline,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.tasks.worker import celery
from reflector.ws_manager import WebsocketManager, get_ws_manager
def broadcast_to_socket(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
"""
async def wrapper(self, *args, **kwargs):
resp = await func(self, *args, **kwargs)
if resp is None:
return
await self.ws_manager.send_json(
room_id=self.ws_room_id,
message=resp.model_dump(mode="json"),
)
return wrapper
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
def prepare(self):
# prepare websocket
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
async def get_transcript(self) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
@broadcast_to_socket
async def on_transcript(self, data):
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
)
@broadcast_to_socket
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
text=data.transcript.text,
words=data.transcript.words,
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
# when the pipeline ends, connect to the post pipeline
async def on_ended():
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline.on_ended = self
return pipeline
class PipelineMainPost(PipelineMainBase):
"""
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
"""
@broadcast_to_socket
async def on_final_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
self.transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_socket
async def on_final_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_socket
async def on_final_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
# add diarization
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=self.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_final_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
return pipeline
@celery.task
def task_pipeline_main_post(transcript_id: str):
pass

View File

@@ -0,0 +1,117 @@
"""
Pipeline Runner
===============
Pipeline runner designed to be executed in a asyncio task.
It is meant to be subclassed, and implement a create() method
that expose/return a Pipeline instance.
During its lifecycle, it will emit the following status:
- started: the pipeline has been started
- push: the pipeline received at least one data
- flush: the pipeline is flushing
- ended: the pipeline has ended
- error: the pipeline has ended with an error
"""
import asyncio
from typing import Callable
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
on_status: Callable | None = None
on_ended: Callable | None = None
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event()
self._is_first_push = True
def create(self) -> Pipeline:
"""
Create the pipeline if not specified earlier.
Should be implemented in a subclass
"""
raise NotImplementedError()
def start(self):
"""
Start the pipeline as a coroutine task
"""
asyncio.get_event_loop().create_task(self.run())
async def push(self, data):
"""
Push data to the pipeline
"""
await self._add_cmd("PUSH", data)
async def flush(self):
"""
Flush the pipeline
"""
await self._add_cmd("FLUSH", None)
async def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
await self._q_cmd.put([cmd, data])
async def _set_status(self, status):
print("set_status", status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
async def run(self):
try:
# create the pipeline if not yet done
await self._set_status("init")
self._is_first_push = True
if not self.pipeline:
self.pipeline = await self.create()
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
logger.error("PipelineRunner error", error=e)
await self._set_status("error")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
async def cmd_push(self, data):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()

View File

@@ -3,7 +3,13 @@ from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
from .base import ( # noqa: F401
BroadcastProcessor,
Pipeline,
PipelineEvent,
Processor,
ThreadedProcessor,
)
from .transcript_final_long_summary import ( # noqa: F401
TranscriptFinalLongSummaryProcessor,
)

View File

@@ -1,8 +1,6 @@
import asyncio
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
from typing import Callable
from json import loads
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -11,25 +9,7 @@ from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
from reflector.pipelines.runner import PipelineRunner
sessions = []
router = APIRouter()
@@ -85,121 +65,10 @@ class PipelineEvent(StrEnum):
FINAL_TITLE = "FINAL_TITLE"
class PipelineOptions(BaseModel):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
on_transcript: Callable | None = None
on_topic: Callable | None = None
on_final_title: Callable | None = None
on_final_short_summary: Callable | None = None
on_final_long_summary: Callable | None = None
class PipelineRunner(object):
"""
Pipeline runner designed to be executed in a asyncio task
"""
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
self.pipeline = pipeline
self.q_cmd = asyncio.Queue()
self.ev_done = asyncio.Event()
self.status = "idle"
self.status_callback = status_callback
async def update_status(self, status):
print("update_status", status)
self.status = status
if self.status_callback:
try:
await self.status_callback(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
async def add_cmd(self, cmd: str, data):
await self.q_cmd.put([cmd, data])
async def push(self, data):
await self.add_cmd("PUSH", data)
async def flush(self):
await self.add_cmd("FLUSH", None)
async def run(self):
try:
await self.update_status("running")
while not self.ev_done.is_set():
cmd, data = await self.q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
await self.update_status("error")
logger.error("PipelineRunner error", error=e)
async def cmd_push(self, data):
if self.status == "idle":
await self.update_status("recording")
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self.update_status("processing")
await self.pipeline.flush()
await self.update_status("ended")
self.ev_done.set()
def start(self):
print("start task")
asyncio.get_event_loop().create_task(self.run())
async def pipeline_live_create(options: PipelineOptions):
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if options.audio_filename is not None:
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=options.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=options.on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=options.on_final_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = options
pipeline.set_pref("audio:source_language", options.source_language)
pipeline.set_pref("audio:target_language", options.target_language)
return pipeline
async def rtc_offer_base(
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
pipeline_runner: PipelineRunner,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -209,132 +78,9 @@ async def rtc_offer_base(
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
args=event_callback_args,
data=transcript,
)
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# handle RTC peer connection
pc = RTCPeerConnection()
# create pipeline
options = PipelineOptions(
audio_filename=audio_filename,
source_language=source_language,
target_language=target_language,
on_transcript=on_transcript,
on_topic=on_topic,
on_final_short_summary=on_final_short_summary,
on_final_long_summary=on_final_long_summary,
on_final_title=on_final_title,
)
pipeline = await pipeline_live_create(options)
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
ctx.pipeline_runner = pipeline_runner
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True):
@@ -400,8 +146,3 @@ async def rtc_clean_sessions(_):
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
return await rtc_offer_base(params, request)

View File

@@ -1,8 +1,5 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional
from uuid import uuid4
import reflector.auth as auth
from fastapi import (
@@ -15,12 +12,13 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.db.transcripts import (
AudioWaveform,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
@@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
ws_manager = get_ws_manager()
# ==============================================================
# Models to move to a database, but required for the API to work
# ==============================================================
def generate_uuid4():
return str(uuid4())
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
timestamp: float
text: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
if existing_topic:
existing_topic.update_from(topic)
else:
self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(
self,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
) -> list[Transcript]:
query = transcripts.select().where(transcripts.c.user_id == user_id)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict):
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
transcripts_controller = TranscriptController()
# ==============================================================
# Transcripts list
# ==============================================================
@@ -537,114 +325,6 @@ async def transcript_events_websocket(
# ==============================================================
async def handle_rtc_event(event: PipelineEvent, args, data):
try:
return await handle_rtc_event_once(event, args, data)
except Exception:
logger.exception("Error handling RTC event")
async def handle_rtc_event_once(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
# print(f"Event: {event}", args, data)
transcript_id = args
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
# event send to websocket clients may not be the same as the event
# received from the pipeline. For example, the pipeline will send
# a TRANSCRIPT event with all words, but this is not what we want
# to send to the websocket client.
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(
event=event,
data=TranscriptText(text=data.text, translation=data.translation),
)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
},
)
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
text=data.transcript.text,
words=data.transcript.words,
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"topics": transcript.topics_dump(),
},
)
elif event == PipelineEvent.FINAL_TITLE:
final_title = TranscriptFinalTitle(title=data.title)
resp = transcript.add_event(event=event, data=final_title)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"title": final_title.title,
},
)
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
resp = transcript.add_event(event=event, data=final_long_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"long_summary": final_long_summary.long_summary,
},
)
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
resp = transcript.add_event(event=event, data=final_short_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"short_summary": final_short_summary.short_summary,
},
)
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"status": data.value,
},
)
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
room_id = f"ts:{transcript_id}"
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
@@ -660,13 +340,14 @@ async def transcript_record_webrtc(
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_mp3_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
pipeline_runner=pipeline_runner,
)

View File

@@ -14,6 +14,7 @@ import json
import redis.asyncio as redis
from fastapi import WebSocket
from reflector.settings import settings
ws_manager = None
@@ -114,7 +115,8 @@ def get_ws_manager() -> WebsocketManager:
RedisConnectionError: If there is an error connecting to the Redis server.
"""
global ws_manager
from reflector.settings import settings
if ws_manager:
return ws_manager
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,