mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
Add multitrack pipeline
This commit is contained in:
292
server/reflector/pipelines/main_multitrack_pipeline.py
Normal file
292
server/reflector/pipelines/main_multitrack_pipeline.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import structlog
|
||||||
|
from celery import chain, shared_task
|
||||||
|
|
||||||
|
from reflector.asynctask import asynctask
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
TranscriptStatus,
|
||||||
|
TranscriptText,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.pipelines.main_file_pipeline import task_send_webhook_if_needed
|
||||||
|
from reflector.pipelines.main_live_pipeline import (
|
||||||
|
PipelineMainBase,
|
||||||
|
task_cleanup_consent,
|
||||||
|
task_pipeline_post_to_zulip,
|
||||||
|
)
|
||||||
|
from reflector.processors import (
|
||||||
|
TranscriptFinalSummaryProcessor,
|
||||||
|
TranscriptFinalTitleProcessor,
|
||||||
|
TranscriptTopicDetectorProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptInput
|
||||||
|
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||||
|
from reflector.processors.types import TitleSummary
|
||||||
|
from reflector.processors.types import (
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyPipeline:
|
||||||
|
def __init__(self, logger: structlog.BoundLogger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def get_pref(self, k, d=None):
|
||||||
|
return d
|
||||||
|
|
||||||
|
async def emit(self, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainMultitrack(PipelineMainBase):
|
||||||
|
"""Process multiple participant tracks for a transcript without mixing audio."""
|
||||||
|
|
||||||
|
def __init__(self, transcript_id: str):
|
||||||
|
super().__init__(transcript_id=transcript_id)
|
||||||
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
|
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||||
|
async with self.lock_transaction():
|
||||||
|
return await transcripts_controller.set_status(transcript_id, status)
|
||||||
|
|
||||||
|
async def _list_immediate_keys(
|
||||||
|
self, s3, bucket_name: str, prefix: str
|
||||||
|
) -> list[str]:
|
||||||
|
paginator = s3.get_paginator("list_objects_v2")
|
||||||
|
raw_prefix = prefix.rstrip("/")
|
||||||
|
prefixes = [raw_prefix, raw_prefix + "/"]
|
||||||
|
|
||||||
|
keys: set[str] = set()
|
||||||
|
for pref in prefixes:
|
||||||
|
for page in paginator.paginate(Bucket=bucket_name, Prefix=pref):
|
||||||
|
for obj in page.get("Contents", []):
|
||||||
|
key = obj["Key"]
|
||||||
|
if not key.startswith(pref):
|
||||||
|
continue
|
||||||
|
if pref.endswith("/"):
|
||||||
|
rel = key[len(pref) :]
|
||||||
|
if not rel or rel.endswith("/") or "/" in rel:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if key != pref:
|
||||||
|
continue
|
||||||
|
keys.add(key)
|
||||||
|
result = sorted(keys)
|
||||||
|
self.logger.info(
|
||||||
|
"S3 list immediate files",
|
||||||
|
prefixes=prefixes,
|
||||||
|
total_keys=len(result),
|
||||||
|
sample=result[:5],
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def process(self, bucket_name: str, prefix: str):
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=settings.RECORDING_STORAGE_AWS_REGION,
|
||||||
|
aws_access_key_id=settings.RECORDING_STORAGE_AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = await self._list_immediate_keys(s3, bucket_name, prefix)
|
||||||
|
if not keys:
|
||||||
|
raise Exception("No audio tracks found under prefix")
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
speaker_transcripts: list[TranscriptType] = []
|
||||||
|
for idx, key in enumerate(keys):
|
||||||
|
ext = ".mp4"
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj = s3.get_object(Bucket=bucket_name, Key=key)
|
||||||
|
data = obj["Body"].read()
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
"Skipping track - cannot read S3 object", key=key, error=str(e)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
storage_path = f"file_pipeline/{transcript.id}/tracks/track_{idx}{ext}"
|
||||||
|
try:
|
||||||
|
await storage.put_file(storage_path, data)
|
||||||
|
audio_url = await storage.get_file_url(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
"Skipping track - cannot upload to storage", key=key, error=str(e)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
t = await self.transcribe_file(audio_url, transcript.source_language)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
"Transcription via default backend failed, trying local whisper",
|
||||||
|
key=key,
|
||||||
|
url=audio_url,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
fallback = FileTranscriptAutoProcessor(name="whisper")
|
||||||
|
result = None
|
||||||
|
|
||||||
|
async def capture_result(r):
|
||||||
|
nonlocal result
|
||||||
|
result = r
|
||||||
|
|
||||||
|
fallback.on(capture_result)
|
||||||
|
await fallback.push(
|
||||||
|
FileTranscriptInput(
|
||||||
|
audio_url=audio_url, language=transcript.source_language
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await fallback.flush()
|
||||||
|
if not result:
|
||||||
|
raise Exception("No transcript captured in fallback")
|
||||||
|
t = result
|
||||||
|
except Exception as e2:
|
||||||
|
self.logger.warning(
|
||||||
|
"Skipping track - transcription failed after fallback",
|
||||||
|
key=key,
|
||||||
|
url=audio_url,
|
||||||
|
error=str(e2),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not t.words:
|
||||||
|
continue
|
||||||
|
for w in t.words:
|
||||||
|
w.speaker = idx
|
||||||
|
speaker_transcripts.append(t)
|
||||||
|
|
||||||
|
if not speaker_transcripts:
|
||||||
|
raise Exception("No valid track transcriptions")
|
||||||
|
|
||||||
|
merged_words = []
|
||||||
|
for t in speaker_transcripts:
|
||||||
|
merged_words.extend(t.words)
|
||||||
|
merged_words.sort(key=lambda w: w.start)
|
||||||
|
|
||||||
|
merged_transcript = TranscriptType(words=merged_words, translation=None)
|
||||||
|
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript,
|
||||||
|
event="TRANSCRIPT",
|
||||||
|
data=TranscriptText(
|
||||||
|
text=merged_transcript.text, translation=merged_transcript.translation
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
topics = await self.detect_topics(merged_transcript, transcript.target_language)
|
||||||
|
await asyncio.gather(
|
||||||
|
self.generate_title(topics),
|
||||||
|
self.generate_summaries(topics),
|
||||||
|
return_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.set_status(transcript.id, "ended")
|
||||||
|
|
||||||
|
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||||
|
processor = FileTranscriptAutoProcessor()
|
||||||
|
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||||
|
|
||||||
|
result: TranscriptType | None = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal result
|
||||||
|
result = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise ValueError("No transcript captured")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def detect_topics(
|
||||||
|
self, transcript: TranscriptType, target_language: str
|
||||||
|
) -> list[TitleSummary]:
|
||||||
|
chunk_size = 300
|
||||||
|
topics: list[TitleSummary] = []
|
||||||
|
|
||||||
|
async def on_topic(topic: TitleSummary):
|
||||||
|
topics.append(topic)
|
||||||
|
return await self.on_topic(topic)
|
||||||
|
|
||||||
|
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||||
|
topic_detector.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for i in range(0, len(transcript.words), chunk_size):
|
||||||
|
chunk_words = transcript.words[i : i + chunk_size]
|
||||||
|
if not chunk_words:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_transcript = TranscriptType(
|
||||||
|
words=chunk_words, translation=transcript.translation
|
||||||
|
)
|
||||||
|
await topic_detector.push(chunk_transcript)
|
||||||
|
|
||||||
|
await topic_detector.flush()
|
||||||
|
return topics
|
||||||
|
|
||||||
|
async def generate_title(self, topics: list[TitleSummary]):
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for title generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for summary generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
|
transcript=transcript,
|
||||||
|
callback=self.on_long_summary,
|
||||||
|
on_short_summary=self.on_short_summary,
|
||||||
|
)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_multitrack_process(
|
||||||
|
*, transcript_id: str, bucket_name: str, prefix: str
|
||||||
|
):
|
||||||
|
pipeline = PipelineMainMultitrack(transcript_id=transcript_id)
|
||||||
|
try:
|
||||||
|
await pipeline.set_status(transcript_id, "processing")
|
||||||
|
await pipeline.process(bucket_name, prefix)
|
||||||
|
except Exception:
|
||||||
|
await pipeline.set_status(transcript_id, "error")
|
||||||
|
raise
|
||||||
|
|
||||||
|
post_chain = chain(
|
||||||
|
task_cleanup_consent.si(transcript_id=transcript_id),
|
||||||
|
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
|
||||||
|
task_send_webhook_if_needed.si(transcript_id=transcript_id),
|
||||||
|
)
|
||||||
|
post_chain.delay()
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
@@ -17,6 +18,9 @@ from reflector.db.rooms import rooms_controller
|
|||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
|
from reflector.pipelines.main_multitrack_pipeline import (
|
||||||
|
task_pipeline_multitrack_process,
|
||||||
|
)
|
||||||
from reflector.redis_cache import get_redis_client
|
from reflector.redis_cache import get_redis_client
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.whereby import get_room_sessions
|
from reflector.whereby import get_room_sessions
|
||||||
@@ -147,6 +151,139 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def process_multitrack_recording(bucket_name: str, prefix: str):
|
||||||
|
logger.info(
|
||||||
|
"Processing multitrack recording",
|
||||||
|
bucket=bucket_name,
|
||||||
|
prefix=prefix,
|
||||||
|
room_name="daily",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
effective_room_name = "/daily"
|
||||||
|
dir_name = prefix.rstrip("/").split("/")[-1]
|
||||||
|
ts_match = re.search(r"(\d{14})$", dir_name)
|
||||||
|
if ts_match:
|
||||||
|
ts = ts_match.group(1)
|
||||||
|
recorded_at = datetime.strptime(ts, "%Y%m%d%H%M%S").replace(
|
||||||
|
tzinfo=timezone.utc
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
recorded_at = parse_datetime_with_timezone(dir_name)
|
||||||
|
except Exception:
|
||||||
|
recorded_at = datetime.now(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Could not parse recorded_at from prefix, using now()")
|
||||||
|
effective_room_name = "/daily"
|
||||||
|
recorded_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
meeting = await meetings_controller.get_by_room_name(effective_room_name)
|
||||||
|
if meeting:
|
||||||
|
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||||
|
else:
|
||||||
|
room = await rooms_controller.get_by_name(effective_room_name.lstrip("/"))
|
||||||
|
if not room:
|
||||||
|
raise Exception(f"Room not found: {effective_room_name}")
|
||||||
|
start_date = recorded_at
|
||||||
|
end_date = recorded_at
|
||||||
|
try:
|
||||||
|
dummy = await meetings_controller.create(
|
||||||
|
id=room.id + "-" + recorded_at.strftime("%Y%m%d%H%M%S"),
|
||||||
|
room_name=effective_room_name,
|
||||||
|
room_url=f"{effective_room_name}",
|
||||||
|
host_room_url=f"{effective_room_name}",
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
room=room,
|
||||||
|
)
|
||||||
|
meeting = dummy
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to create dummy meeting", error=str(e))
|
||||||
|
meeting = None
|
||||||
|
|
||||||
|
recording = await recordings_controller.get_by_object_key(bucket_name, prefix)
|
||||||
|
if not recording:
|
||||||
|
recording = await recordings_controller.create(
|
||||||
|
Recording(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_key=prefix,
|
||||||
|
recorded_at=recorded_at,
|
||||||
|
meeting_id=meeting.id if meeting else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_recording_id(recording.id)
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"topics": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
"",
|
||||||
|
source_kind=SourceKind.ROOM,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
user_id=room.user_id,
|
||||||
|
recording_id=recording.id,
|
||||||
|
share_mode="public",
|
||||||
|
meeting_id=meeting.id if meeting else None,
|
||||||
|
room_id=room.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=settings.RECORDING_STORAGE_AWS_REGION,
|
||||||
|
aws_access_key_id=settings.RECORDING_STORAGE_AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = s3.get_paginator("list_objects_v2")
|
||||||
|
raw_prefix = prefix.rstrip("/")
|
||||||
|
prefixes = [raw_prefix, raw_prefix + "/"]
|
||||||
|
|
||||||
|
all_keys_set: set[str] = set()
|
||||||
|
for pref in prefixes:
|
||||||
|
for page in paginator.paginate(Bucket=bucket_name, Prefix=pref):
|
||||||
|
contents = page.get("Contents", [])
|
||||||
|
for obj in contents:
|
||||||
|
key = obj["Key"]
|
||||||
|
if not key.startswith(pref):
|
||||||
|
continue
|
||||||
|
if pref.endswith("/"):
|
||||||
|
rel = key[len(pref) :]
|
||||||
|
if not rel or rel.endswith("/") or "/" in rel:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if key == pref:
|
||||||
|
all_keys_set.add(key)
|
||||||
|
continue
|
||||||
|
all_keys_set.add(key)
|
||||||
|
|
||||||
|
all_keys = sorted(all_keys_set)
|
||||||
|
logger.info(
|
||||||
|
"S3 list immediate files",
|
||||||
|
prefixes=prefixes,
|
||||||
|
total_keys=len(all_keys),
|
||||||
|
sample=all_keys[:5],
|
||||||
|
)
|
||||||
|
|
||||||
|
track_keys: list[str] = all_keys[:]
|
||||||
|
|
||||||
|
if not track_keys:
|
||||||
|
logger.info("No objects found under prefix", prefixes=prefixes)
|
||||||
|
raise Exception("No audio tracks found under prefix")
|
||||||
|
|
||||||
|
task_pipeline_multitrack_process.delay(
|
||||||
|
transcript_id=transcript.id, bucket_name=bucket_name, prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def process_meetings():
|
async def process_meetings():
|
||||||
|
|||||||
Reference in New Issue
Block a user