mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
server: first attempts to split post pipeline as single celery tasks
This commit is contained in:
@@ -106,6 +106,7 @@ class Transcript(BaseModel):
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
audio_location: str = "local"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
@@ -140,6 +141,10 @@ class Transcript(BaseModel):
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def storage_audio_path(self):
|
||||
return f"{self.id}/audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
@@ -283,5 +288,19 @@ class TranscriptController:
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
"""
|
||||
Move mp3 file to storage
|
||||
"""
|
||||
from reflector.storage import Storage
|
||||
|
||||
storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE)
|
||||
await storage.put_file(
|
||||
transcript.storage_audio_path,
|
||||
self.audio_mp3_filename.read_bytes(),
|
||||
)
|
||||
|
||||
await self.update(transcript, {"audio_location": "storage"})
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
@@ -12,6 +12,7 @@ It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
@@ -55,6 +56,22 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from structlog import Logger
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
coro = f(*args, **kwargs)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
@@ -75,6 +92,22 @@ def broadcast_to_sockets(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
async def wrapper(self, **kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
return await func(self, transcript=transcript, logger=tlogger, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
@@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner):
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||
return [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self._lock:
|
||||
@@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main live created",
|
||||
transcript_id=self.transcript_id,
|
||||
)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
"""
|
||||
Diarization is a long time process, so we do it in a separate pipeline
|
||||
When done, adjust the short and final summary
|
||||
Diarize the audio and update topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = []
|
||||
if settings.DIARIZATION_ENABLED:
|
||||
processors += [
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
]
|
||||
|
||||
processors += [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
topics = [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
@@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main post created", transcript_id=self.transcript_id
|
||||
)
|
||||
pipeline.logger.info("Diarization pipeline created")
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainSummaries(PipelineMainBase):
|
||||
"""
|
||||
Generate summaries from the topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
# get transcript
|
||||
transcript = await self.get_transcript()
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Summaries pipeline created")
|
||||
|
||||
# push topics
|
||||
topics = await self.get_transcript_topics(transcript)
|
||||
for topic in topics:
|
||||
self.push(topic)
|
||||
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@shared_task
|
||||
def task_pipeline_main_post(transcript_id: str):
|
||||
logger.info(
|
||||
@@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str):
|
||||
)
|
||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||
runner.start_sync()
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
wav_filename = transcript.audio_wav_filename
|
||||
if not wav_filename.exists():
|
||||
logger.warning("Wav file not found, may be already converted")
|
||||
return
|
||||
|
||||
# Convert to mp3
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
|
||||
import av
|
||||
|
||||
input_container = av.open(wav_filename)
|
||||
output_container = av.open(mp3_filename, "w")
|
||||
input_audio_stream = input_container.streams.audio[0]
|
||||
output_audio_stream = output_container.add_stream("mp3")
|
||||
output_audio_stream.codec_context.set_parameters(
|
||||
input_audio_stream.codec_context.parameters
|
||||
)
|
||||
for packet in input_container.demux(input_audio_stream):
|
||||
for frame in packet.decode():
|
||||
output_container.mux(frame)
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting upload mp3")
|
||||
|
||||
# If the audio mp3 is not available, just skip
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
if not mp3_filename.exists():
|
||||
logger.warning("Mp3 file not found, may be already uploaded")
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_to_storage(transcript)
|
||||
await transcripts_controller.unlink_mp3(transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
@asynctask
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.start()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
@asynctask
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainSummaries(transcript_id=transcript.id)
|
||||
await runner.start()
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Celery tasks that can be called from the API
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_convert_to_mp3(transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_diarization(transcript_id: str):
|
||||
await pipeline_diarization(transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_summaries(transcript_id: str):
|
||||
await pipeline_summaries(transcript_id)
|
||||
|
||||
|
||||
def pipeline_post(transcript_id: str):
|
||||
"""
|
||||
Run the post pipeline
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
|
||||
chain = chain_mp3_and_diarize | chain_summary
|
||||
chain.delay()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from pydantic import BaseModel
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class FileResult(BaseModel):
|
||||
@@ -17,14 +18,14 @@ class Storage:
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name, settings_prefix=""):
|
||||
def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.storage.storage_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
config = {"folder": folder}
|
||||
name_upper = name.upper()
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
@@ -34,6 +35,10 @@ class Storage:
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self):
|
||||
self.folder = ""
|
||||
super().__init__()
|
||||
|
||||
async def put_file(self, filename: str, data: bytes) -> FileResult:
|
||||
return await self._put_file(filename, data)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import aioboto3
|
||||
from reflector.storage.base import Storage, FileResult
|
||||
from reflector.logger import logger
|
||||
from reflector.storage.base import FileResult, Storage
|
||||
|
||||
|
||||
class AwsStorage(Storage):
|
||||
@@ -22,9 +22,14 @@ class AwsStorage(Storage):
|
||||
|
||||
super().__init__()
|
||||
self.aws_bucket_name = aws_bucket_name
|
||||
self.aws_folder = ""
|
||||
folder = ""
|
||||
if "/" in aws_bucket_name:
|
||||
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
|
||||
self.aws_bucket_name, folder = aws_bucket_name.split("/", 1)
|
||||
if folder:
|
||||
if not self.folder:
|
||||
self.folder = folder
|
||||
else:
|
||||
self.folder = f"{self.folder}/{folder}"
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
@@ -34,7 +39,7 @@ class AwsStorage(Storage):
|
||||
|
||||
async def _put_file(self, filename: str, data: bytes) -> FileResult:
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
folder = self.folder
|
||||
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
@@ -44,6 +49,11 @@ class AwsStorage(Storage):
|
||||
Body=data,
|
||||
)
|
||||
|
||||
async def get_file_url(self, filename: str) -> FileResult:
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.folder
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
presigned_url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": s3filename},
|
||||
@@ -57,7 +67,7 @@ class AwsStorage(Storage):
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
folder = self.folder
|
||||
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
|
||||
Reference in New Issue
Block a user