mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 20:59:05 +00:00
server: first attempts to split post pipeline as single celery tasks
This commit is contained in:
@@ -12,6 +12,7 @@ It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
@@ -55,6 +56,22 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from structlog import Logger
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
coro = f(*args, **kwargs)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
@@ -75,6 +92,22 @@ def broadcast_to_sockets(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
async def wrapper(self, **kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
return await func(self, transcript=transcript, logger=tlogger, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
@@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner):
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||
return [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self._lock:
|
||||
@@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main live created",
|
||||
transcript_id=self.transcript_id,
|
||||
)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
"""
|
||||
Diarization is a long time process, so we do it in a separate pipeline
|
||||
When done, adjust the short and final summary
|
||||
Diarize the audio and update topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = []
|
||||
if settings.DIARIZATION_ENABLED:
|
||||
processors += [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
]
|
||||
|
||||
processors += [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
topics = [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
@@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main post created", transcript_id=self.transcript_id
|
||||
)
|
||||
pipeline.logger.info("Diarization pipeline created")
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainSummaries(PipelineMainBase):
|
||||
"""
|
||||
Generate summaries from the topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
# get transcript
|
||||
transcript = await self.get_transcript()
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Summaries pipeline created")
|
||||
|
||||
# push topics
|
||||
topics = await self.get_transcript_topics(transcript)
|
||||
for topic in topics:
|
||||
self.push(topic)
|
||||
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@shared_task
|
||||
def task_pipeline_main_post(transcript_id: str):
|
||||
logger.info(
|
||||
@@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str):
|
||||
)
|
||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||
runner.start_sync()
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
wav_filename = transcript.audio_wav_filename
|
||||
if not wav_filename.exists():
|
||||
logger.warning("Wav file not found, may be already converted")
|
||||
return
|
||||
|
||||
# Convert to mp3
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
|
||||
import av
|
||||
|
||||
input_container = av.open(wav_filename)
|
||||
output_container = av.open(mp3_filename, "w")
|
||||
input_audio_stream = input_container.streams.audio[0]
|
||||
output_audio_stream = output_container.add_stream("mp3")
|
||||
output_audio_stream.codec_context.set_parameters(
|
||||
input_audio_stream.codec_context.parameters
|
||||
)
|
||||
for packet in input_container.demux(input_audio_stream):
|
||||
for frame in packet.decode():
|
||||
output_container.mux(frame)
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting upload mp3")
|
||||
|
||||
# If the audio mp3 is not available, just skip
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
if not mp3_filename.exists():
|
||||
logger.warning("Mp3 file not found, may be already uploaded")
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_to_storage(transcript)
|
||||
await transcripts_controller.unlink_mp3(transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
@asynctask
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.start()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
@asynctask
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainSummaries(transcript_id=transcript.id)
|
||||
await runner.start()
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Celery tasks that can be called from the API
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_convert_to_mp3(transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_diarization(transcript_id: str):
|
||||
await pipeline_diarization(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_summaries(transcript_id: str):
|
||||
await pipeline_summaries(transcript_id)
|
||||
|
||||
|
||||
def pipeline_post(transcript_id: str):
|
||||
"""
|
||||
Run the post pipeline
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
|
||||
chain = chain_mp3_and_diarize | chain_summary
|
||||
chain.delay()
|
||||
|
||||
Reference in New Issue
Block a user