server: first attempts to split post pipeline as single celery tasks

This commit is contained in:
2023-11-15 21:24:21 +01:00
committed by Mathieu Virbel
parent 55a3a59d52
commit aecc3a0c3b
4 changed files with 241 additions and 48 deletions

View File

@@ -106,6 +106,7 @@ class Transcript(BaseModel):
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -140,6 +141,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property
def audio_waveform(self):
try:
@@ -283,5 +288,19 @@ class TranscriptController:
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
from reflector.storage import Storage
storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE)
await storage.put_file(
transcript.storage_audio_path,
self.audio_mp3_filename.read_bytes(),
)
await self.update(transcript, {"audio_location": "storage"})
transcripts_controller = TranscriptController()

View File

@@ -12,6 +12,7 @@ It is directly linked to our data model.
"""
import asyncio
import functools
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
@@ -55,6 +56,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func):
@@ -75,6 +92,22 @@ def broadcast_to_sockets(func):
return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(self, **kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
return await func(self, transcript=transcript, logger=tlogger, **kwargs)
return wrapper
class StrValue(BaseModel):
value: str
@@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager
async def transaction(self):
async with self._lock:
@@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main live created",
transcript_id=self.transcript_id,
)
pipeline.logger.info("Pipeline main live created")
return pipeline
@@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
Diarization is a long time process, so we do it in a separate pipeline
When done, adjust the short and final summary
Diarize the audio and update topics
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = []
if settings.DIARIZATION_ENABLED:
processors += [
pipeline = Pipeline(
AudioDiarizationAutoProcessor(callback=self.on_topic),
]
processors += [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
topics = self.get_transcript_topics(transcript)
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
@@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase):
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
class PipelineMainSummaries(PipelineMainBase):
"""
Generate summaries from the topics
"""
async def create(self) -> Pipeline:
self.prepare()
pipeline = Pipeline(
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
)
pipeline.options = self
# get transcript
transcript = await self.get_transcript()
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Summaries pipeline created")
# push topics
topics = await self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
@shared_task
def task_pipeline_main_post(transcript_id: str):
logger.info(
@@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str):
)
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
input_container = av.open(wav_filename)
output_container = av.open(mp3_filename, "w")
input_audio_stream = input_container.streams.audio[0]
output_audio_stream = output_container.add_stream("mp3")
output_audio_stream.codec_context.set_parameters(
input_audio_stream.codec_context.parameters
)
for packet in input_container.demux(input_audio_stream):
for frame in packet.decode():
output_container.mux(frame)
input_container.close()
output_container.close()
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_to_storage(transcript)
await transcripts_controller.unlink_mp3(transcript)
logger.info("Upload mp3 done")
@get_transcript
@asynctask
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.start()
logger.info("Diarization done")
@get_transcript
@asynctask
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainSummaries(transcript_id=transcript.id)
await runner.start()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(transcript_id: str):
await pipeline_convert_to_mp3(transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(transcript_id: str):
await pipeline_upload_mp3(transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(transcript_id: str):
await pipeline_diarization(transcript_id)
@shared_task
@asynctask
async def task_pipeline_summaries(transcript_id: str):
await pipeline_summaries(transcript_id)
def pipeline_post(transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
chain = chain_mp3_and_diarize | chain_summary
chain.delay()

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel
from reflector.settings import settings
import importlib
class FileResult(BaseModel):
@@ -17,14 +18,14 @@ class Storage:
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name, settings_prefix=""):
def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
config = {"folder": folder}
name_upper = name.upper()
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
@@ -34,6 +35,10 @@ class Storage:
return cls._registry[name](**config)
def __init__(self):
self.folder = ""
super().__init__()
async def put_file(self, filename: str, data: bytes) -> FileResult:
return await self._put_file(filename, data)

View File

@@ -1,6 +1,6 @@
import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage):
@@ -22,9 +22,14 @@ class AwsStorage(Storage):
super().__init__()
self.aws_bucket_name = aws_bucket_name
self.aws_folder = ""
folder = ""
if "/" in aws_bucket_name:
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
self.aws_bucket_name, folder = aws_bucket_name.split("/", 1)
if folder:
if not self.folder:
self.folder = folder
else:
self.folder = f"{self.folder}/{folder}"
self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
@@ -34,7 +39,7 @@ class AwsStorage(Storage):
async def _put_file(self, filename: str, data: bytes) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
folder = self.folder
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
@@ -44,6 +49,11 @@ class AwsStorage(Storage):
Body=data,
)
async def get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
@@ -57,7 +67,7 @@ class AwsStorage(Storage):
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name
folder = self.aws_folder
folder = self.folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: