mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
server: refactor with clearer pipeline instanciation and linked to model
This commit is contained in:
230
server/reflector/pipelines/main_live_pipeline.py
Normal file
230
server/reflector/pipelines/main_live_pipeline.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Main reflector pipeline for live streaming
|
||||
==========================================
|
||||
|
||||
This is the default pipeline used in the API.
|
||||
|
||||
It is decoupled to:
|
||||
- PipelineMainLive: have limited processing during live
|
||||
- PipelineMainPost: do heavy lifting after the live
|
||||
|
||||
It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
BroadcastProcessor,
|
||||
Pipeline,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.tasks.worker import celery
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
|
||||
|
||||
def broadcast_to_socket(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
concerning this transcript
|
||||
"""
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
resp = await func(self, *args, **kwargs)
|
||||
if resp is None:
|
||||
return
|
||||
await self.ws_manager.send_json(
|
||||
room_id=self.ws_room_id,
|
||||
message=resp.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
audio_filename: Path | None = None
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
@broadcast_to_socket
|
||||
async def on_transcript(self, data):
|
||||
async with transcripts_controller.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
|
||||
@broadcast_to_socket
|
||||
async def on_topic(self, data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
text=data.transcript.text,
|
||||
words=data.transcript.words,
|
||||
)
|
||||
async with transcripts_controller.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
)
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
async def on_ended():
|
||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||
|
||||
pipeline.on_ended = self
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainPost(PipelineMainBase):
|
||||
"""
|
||||
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
|
||||
"""
|
||||
|
||||
@broadcast_to_socket
|
||||
async def on_final_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with transcripts_controller.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if not transcript.title:
|
||||
transcripts_controller.update(
|
||||
self.transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
)
|
||||
|
||||
@broadcast_to_socket
|
||||
async def on_final_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with transcripts_controller.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_socket
|
||||
async def on_final_short_summary(self, data):
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with transcripts_controller.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = [
|
||||
# add diarization
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(
|
||||
callback=self.on_final_title
|
||||
),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@celery.task
|
||||
def task_pipeline_main_post(transcript_id: str):
|
||||
pass
|
||||
117
server/reflector/pipelines/runner.py
Normal file
117
server/reflector/pipelines/runner.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Pipeline Runner
|
||||
===============
|
||||
|
||||
Pipeline runner designed to be executed in a asyncio task.
|
||||
|
||||
It is meant to be subclassed, and implement a create() method
|
||||
that expose/return a Pipeline instance.
|
||||
|
||||
During its lifecycle, it will emit the following status:
|
||||
- started: the pipeline has been started
|
||||
- push: the pipeline received at least one data
|
||||
- flush: the pipeline is flushing
|
||||
- ended: the pipeline has ended
|
||||
- error: the pipeline has ended with an error
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
|
||||
class PipelineRunner(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
on_status: Callable | None = None
|
||||
on_ended: Callable | None = None
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._q_cmd = asyncio.Queue()
|
||||
self._ev_done = asyncio.Event()
|
||||
self._is_first_push = True
|
||||
|
||||
def create(self) -> Pipeline:
|
||||
"""
|
||||
Create the pipeline if not specified earlier.
|
||||
Should be implemented in a subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the pipeline as a coroutine task
|
||||
"""
|
||||
asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
async def push(self, data):
|
||||
"""
|
||||
Push data to the pipeline
|
||||
"""
|
||||
await self._add_cmd("PUSH", data)
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
Flush the pipeline
|
||||
"""
|
||||
await self._add_cmd("FLUSH", None)
|
||||
|
||||
async def _add_cmd(self, cmd: str, data):
|
||||
"""
|
||||
Enqueue a command to be executed in the runner.
|
||||
Currently supported commands: PUSH, FLUSH
|
||||
"""
|
||||
await self._q_cmd.put([cmd, data])
|
||||
|
||||
async def _set_status(self, status):
|
||||
print("set_status", status)
|
||||
self.status = status
|
||||
if self.on_status:
|
||||
try:
|
||||
await self.on_status(status)
|
||||
except Exception as e:
|
||||
logger.error("PipelineRunner status_callback error", error=e)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
# create the pipeline if not yet done
|
||||
await self._set_status("init")
|
||||
self._is_first_push = True
|
||||
if not self.pipeline:
|
||||
self.pipeline = await self.create()
|
||||
|
||||
# start the loop
|
||||
await self._set_status("started")
|
||||
while not self._ev_done.is_set():
|
||||
cmd, data = await self._q_cmd.get()
|
||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||
if func:
|
||||
await func(data)
|
||||
else:
|
||||
raise Exception(f"Unknown command {cmd}")
|
||||
except Exception as e:
|
||||
logger.error("PipelineRunner error", error=e)
|
||||
await self._set_status("error")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
|
||||
async def cmd_push(self, data):
|
||||
if self._is_first_push:
|
||||
await self._set_status("push")
|
||||
self._is_first_push = False
|
||||
await self.pipeline.push(data)
|
||||
|
||||
async def cmd_flush(self, data):
|
||||
await self._set_status("flush")
|
||||
await self.pipeline.flush()
|
||||
await self._set_status("ended")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
Reference in New Issue
Block a user