server: refactor with diarization, logic works

This commit is contained in:
2023-10-27 15:59:27 +02:00
committed by Mathieu Virbel
parent 1c42473da0
commit 07c4d080c2
17 changed files with 387 additions and 169 deletions

View File

@@ -11,8 +11,12 @@ It is decoupled to:
It is directly linked to our data model.
"""
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
from celery import shared_task
from pydantic import BaseModel
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
@@ -25,6 +29,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
@@ -37,11 +42,13 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.tasks.worker import celery
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.ws_manager import WebsocketManager, get_ws_manager
def broadcast_to_socket(func):
def broadcast_to_sockets(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
@@ -59,6 +66,10 @@ def broadcast_to_socket(func):
return wrapper
class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
@@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner):
def prepare(self):
# prepare websocket
self._lock = asyncio.Lock()
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
@@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def transaction(self):
async with self._lock:
async with transcripts_controller.transaction():
yield
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
@broadcast_to_sockets
async def on_status(self, status):
# if it's the first part, update the status of the transcript
# but do not set the ended status yet.
if isinstance(self, PipelineMainLive):
status_mapping = {
"started": "recording",
"push": "recording",
"flush": "processing",
"error": "error",
}
elif isinstance(self, PipelineMainDiarization):
status_mapping = {
"push": "processing",
"flush": "processing",
"error": "error",
"ended": "ended",
}
else:
raise Exception(f"Runner {self.__class__} is missing status mapping")
@broadcast_to_socket
# mutate to model status
status = status_mapping.get(status)
if not status:
return
# when the status of the pipeline changes, update the transcript
async with self.transaction():
transcript = await self.get_transcript()
if status == transcript.status:
return
resp = await transcripts_controller.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await transcripts_controller.update(
transcript,
{
"status": status,
},
)
return resp
@broadcast_to_sockets
async def on_transcript(self, data):
async with transcripts_controller.transaction():
async with self.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
@@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase):
data=TranscriptText(text=data.text, translation=data.translation),
)
@broadcast_to_socket
@broadcast_to_sockets
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
@@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase):
text=data.transcript.text,
words=data.transcript.words,
)
async with transcripts_controller.transaction():
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_sockets
async def on_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
@@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
# when the pipeline ends, connect to the post pipeline
async def on_ended():
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline.on_ended = self
return pipeline
async def on_ended(self):
# when the pipeline ends, connect to the post pipeline
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
class PipelineMainPost(PipelineMainBase):
class PipelineMainDiarization(PipelineMainBase):
"""
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
Diarization is a long time process, so we do it in a separate pipeline
When done, adjust the short and final summary
"""
@broadcast_to_socket
async def on_final_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
self.transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_socket
async def on_final_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_socket
async def on_final_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
# add diarization
AudioDiarizationProcessor(),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=self.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_final_long_summary
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_final_short_summary
callback=self.on_short_summary
),
]
),
@@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase):
pipeline = Pipeline(*processors)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
TitleSummaryProcessorType(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
audio_diarization_input = AudioDiarizationInput(
audio_filename=transcript.audio_mp3_filename,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
self.push(audio_diarization_input)
self.flush()
return pipeline
@celery.task
@shared_task
def task_pipeline_main_post(transcript_id: str):
pass
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()

View File

@@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status:
"""
import asyncio
from typing import Callable
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
@@ -27,8 +26,6 @@ class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
on_status: Callable | None = None
on_ended: Callable | None = None
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
@@ -36,6 +33,10 @@ class PipelineRunner(BaseModel):
self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event()
self._is_first_push = True
self._logger = logger.bind(
runner=id(self),
runner_cls=self.__class__.__name__,
)
def create(self) -> Pipeline:
"""
@@ -50,33 +51,51 @@ class PipelineRunner(BaseModel):
"""
asyncio.get_event_loop().create_task(self.run())
async def push(self, data):
def start_sync(self):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
asyncio.run(self.run())
def push(self, data):
"""
Push data to the pipeline
"""
await self._add_cmd("PUSH", data)
self._add_cmd("PUSH", data)
async def flush(self):
def flush(self):
"""
Flush the pipeline
"""
await self._add_cmd("FLUSH", None)
self._add_cmd("FLUSH", None)
async def _add_cmd(self, cmd: str, data):
async def on_status(self, status):
"""
Called when the status of the pipeline changes
"""
pass
async def on_ended(self):
"""
Called when the pipeline ends
"""
pass
def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
await self._q_cmd.put([cmd, data])
self._q_cmd.put_nowait([cmd, data])
async def _set_status(self, status):
print("set_status", status)
self._logger.debug("Runner status updated", status=status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
except Exception:
self._logger.exception("Runer error while setting status")
async def run(self):
try:
@@ -95,8 +114,8 @@ class PipelineRunner(BaseModel):
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
logger.error("PipelineRunner error", error=e)
except Exception:
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
if self.on_ended: