server: first attempts to split post pipeline as single celery tasks

This commit is contained in:
2023-11-15 21:24:21 +01:00
committed by Mathieu Virbel
parent 55a3a59d52
commit aecc3a0c3b
4 changed files with 241 additions and 48 deletions

View File

@@ -106,6 +106,7 @@ class Transcript(BaseModel):
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
source_language: str = "en" source_language: str = "en"
target_language: str = "en" target_language: str = "en"
audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump()) ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -140,6 +141,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self): def audio_waveform_filename(self):
return self.data_path / "audio.json" return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property @property
def audio_waveform(self): def audio_waveform(self):
try: try:
@@ -283,5 +288,19 @@ class TranscriptController:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()}) await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
from reflector.storage import Storage
storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE)
await storage.put_file(
transcript.storage_audio_path,
self.audio_mp3_filename.read_bytes(),
)
await self.update(transcript, {"audio_location": "storage"})
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -12,6 +12,7 @@ It is directly linked to our data model.
""" """
import asyncio import asyncio
import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
@@ -55,6 +56,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func): def broadcast_to_sockets(func):
@@ -75,6 +92,22 @@ def broadcast_to_sockets(func):
return wrapper return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(self, **kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
return await func(self, transcript=transcript, logger=tlogger, **kwargs)
return wrapper
class StrValue(BaseModel): class StrValue(BaseModel):
value: str value: str
@@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self._lock: async with self._lock:
@@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info( pipeline.logger.info("Pipeline main live created")
"Pipeline main live created",
transcript_id=self.transcript_id,
)
return pipeline return pipeline
@@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline # when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id) pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase): class PipelineMainDiarization(PipelineMainBase):
""" """
Diarization is a long time process, so we do it in a separate pipeline Diarize the audio and update topics
When done, adjust the short and final summary
""" """
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [] pipeline = Pipeline(
if settings.DIARIZATION_ENABLED: AudioDiarizationAutoProcessor(callback=self.on_topic),
processors += [ )
AudioDiarizationAutoProcessor(callback=self.on_topic),
]
processors += [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
# now let's start the pipeline by pushing information to the # now let's start the pipeline by pushing information to the
# first processor diarization processor # first processor diarization processor
# XXX translation is lost when converting our data model to the processor model # XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript() transcript = await self.get_transcript()
topics = [ topics = self.get_transcript_topics(transcript)
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
# we need to create an url to be used for diarization # we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible # we can't use the audio_mp3_filename because it's not accessible
@@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase):
# as tempting to use pipeline.push, prefer to use the runner # as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job. # to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info( pipeline.logger.info("Diarization pipeline created")
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input) self.push(audio_diarization_input)
self.flush() self.flush()
return pipeline return pipeline
class PipelineMainSummaries(PipelineMainBase):
"""
Generate summaries from the topics
"""
async def create(self) -> Pipeline:
self.prepare()
pipeline = Pipeline(
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
)
pipeline.options = self
# get transcript
transcript = await self.get_transcript()
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Summaries pipeline created")
# push topics
topics = await self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
@shared_task @shared_task
def task_pipeline_main_post(transcript_id: str): def task_pipeline_main_post(transcript_id: str):
logger.info( logger.info(
@@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str):
) )
runner = PipelineMainDiarization(transcript_id=transcript_id) runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync() runner.start_sync()
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
input_container = av.open(wav_filename)
output_container = av.open(mp3_filename, "w")
input_audio_stream = input_container.streams.audio[0]
output_audio_stream = output_container.add_stream("mp3")
output_audio_stream.codec_context.set_parameters(
input_audio_stream.codec_context.parameters
)
for packet in input_container.demux(input_audio_stream):
for frame in packet.decode():
output_container.mux(frame)
input_container.close()
output_container.close()
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_to_storage(transcript)
await transcripts_controller.unlink_mp3(transcript)
logger.info("Upload mp3 done")
@get_transcript
@asynctask
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.start()
logger.info("Diarization done")
@get_transcript
@asynctask
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainSummaries(transcript_id=transcript.id)
await runner.start()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(transcript_id: str):
await pipeline_convert_to_mp3(transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(transcript_id: str):
await pipeline_upload_mp3(transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(transcript_id: str):
await pipeline_diarization(transcript_id)
@shared_task
@asynctask
async def task_pipeline_summaries(transcript_id: str):
await pipeline_summaries(transcript_id)
def pipeline_post(transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
chain = chain_mp3_and_diarize | chain_summary
chain.delay()

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel from pydantic import BaseModel
from reflector.settings import settings from reflector.settings import settings
import importlib
class FileResult(BaseModel): class FileResult(BaseModel):
@@ -17,14 +18,14 @@ class Storage:
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod @classmethod
def get_instance(cls, name, settings_prefix=""): def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""):
if name not in cls._registry: if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}" module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name) importlib.import_module(module_name)
# gather specific configuration for the processor # gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {} config = {"folder": folder}
name_upper = name.upper() name_upper = name.upper()
config_prefix = f"{settings_prefix}{name_upper}_" config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings: for key, value in settings:
@@ -34,6 +35,10 @@ class Storage:
return cls._registry[name](**config) return cls._registry[name](**config)
def __init__(self):
self.folder = ""
super().__init__()
async def put_file(self, filename: str, data: bytes) -> FileResult: async def put_file(self, filename: str, data: bytes) -> FileResult:
return await self._put_file(filename, data) return await self._put_file(filename, data)

View File

@@ -1,6 +1,6 @@
import aioboto3 import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage): class AwsStorage(Storage):
@@ -22,9 +22,14 @@ class AwsStorage(Storage):
super().__init__() super().__init__()
self.aws_bucket_name = aws_bucket_name self.aws_bucket_name = aws_bucket_name
self.aws_folder = "" folder = ""
if "/" in aws_bucket_name: if "/" in aws_bucket_name:
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) self.aws_bucket_name, folder = aws_bucket_name.split("/", 1)
if folder:
if not self.folder:
self.folder = folder
else:
self.folder = f"{self.folder}/{folder}"
self.session = aioboto3.Session( self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
@@ -34,7 +39,7 @@ class AwsStorage(Storage):
async def _put_file(self, filename: str, data: bytes) -> FileResult: async def _put_file(self, filename: str, data: bytes) -> FileResult:
bucket = self.aws_bucket_name bucket = self.aws_bucket_name
folder = self.aws_folder folder = self.folder
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: async with self.session.client("s3") as client:
@@ -44,6 +49,11 @@ class AwsStorage(Storage):
Body=data, Body=data,
) )
async def get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url( presigned_url = await client.generate_presigned_url(
"get_object", "get_object",
Params={"Bucket": bucket, "Key": s3filename}, Params={"Bucket": bucket, "Key": s3filename},
@@ -57,7 +67,7 @@ class AwsStorage(Storage):
async def _delete_file(self, filename: str): async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name bucket = self.aws_bucket_name
folder = self.aws_folder folder = self.folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: async with self.session.client("s3") as client: