server: refactor with clearer pipeline instanciation and linked to model

This commit is contained in:
2023-10-26 19:00:56 +02:00
committed by Mathieu Virbel
parent 433c0500cc
commit 1c42473da0
8 changed files with 658 additions and 616 deletions

View File

@@ -0,0 +1,230 @@
"""
Main reflector pipeline for live streaming
==========================================
This is the default pipeline used in the API.
It is decoupled to:
- PipelineMainLive: have limited processing during live
- PipelineMainPost: do heavy lifting after the live
It is directly linked to our data model.
"""
from pathlib import Path
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
transcripts_controller,
)
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
BroadcastProcessor,
Pipeline,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.tasks.worker import celery
from reflector.ws_manager import WebsocketManager, get_ws_manager
def broadcast_to_socket(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
"""
async def wrapper(self, *args, **kwargs):
resp = await func(self, *args, **kwargs)
if resp is None:
return
await self.ws_manager.send_json(
room_id=self.ws_room_id,
message=resp.model_dump(mode="json"),
)
return wrapper
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
def prepare(self):
# prepare websocket
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
async def get_transcript(self) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
@broadcast_to_socket
async def on_transcript(self, data):
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
)
@broadcast_to_socket
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
text=data.transcript.text,
words=data.transcript.words,
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
# when the pipeline ends, connect to the post pipeline
async def on_ended():
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline.on_ended = self
return pipeline
class PipelineMainPost(PipelineMainBase):
"""
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
"""
@broadcast_to_socket
async def on_final_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
self.transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_socket
async def on_final_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_socket
async def on_final_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
# add diarization
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=self.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_final_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
return pipeline
@celery.task
def task_pipeline_main_post(transcript_id: str):
pass

View File

@@ -0,0 +1,117 @@
"""
Pipeline Runner
===============
Pipeline runner designed to be executed in a asyncio task.
It is meant to be subclassed, and implement a create() method
that expose/return a Pipeline instance.
During its lifecycle, it will emit the following status:
- started: the pipeline has been started
- push: the pipeline received at least one data
- flush: the pipeline is flushing
- ended: the pipeline has ended
- error: the pipeline has ended with an error
"""
import asyncio
from typing import Callable
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
on_status: Callable | None = None
on_ended: Callable | None = None
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event()
self._is_first_push = True
def create(self) -> Pipeline:
"""
Create the pipeline if not specified earlier.
Should be implemented in a subclass
"""
raise NotImplementedError()
def start(self):
"""
Start the pipeline as a coroutine task
"""
asyncio.get_event_loop().create_task(self.run())
async def push(self, data):
"""
Push data to the pipeline
"""
await self._add_cmd("PUSH", data)
async def flush(self):
"""
Flush the pipeline
"""
await self._add_cmd("FLUSH", None)
async def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
await self._q_cmd.put([cmd, data])
async def _set_status(self, status):
print("set_status", status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
async def run(self):
try:
# create the pipeline if not yet done
await self._set_status("init")
self._is_first_push = True
if not self.pipeline:
self.pipeline = await self.create()
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
logger.error("PipelineRunner error", error=e)
await self._set_status("error")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
async def cmd_push(self, data):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()