diff --git a/server/reflector/pipelines/main_multitrack_pipeline.py b/server/reflector/pipelines/main_multitrack_pipeline.py new file mode 100644 index 00000000..beb681d2 --- /dev/null +++ b/server/reflector/pipelines/main_multitrack_pipeline.py @@ -0,0 +1,292 @@ +import asyncio + +import boto3 +import structlog +from celery import chain, shared_task + +from reflector.asynctask import asynctask +from reflector.db.transcripts import ( + TranscriptStatus, + TranscriptText, + transcripts_controller, +) +from reflector.logger import logger +from reflector.pipelines.main_file_pipeline import task_send_webhook_if_needed +from reflector.pipelines.main_live_pipeline import ( + PipelineMainBase, + task_cleanup_consent, + task_pipeline_post_to_zulip, +) +from reflector.processors import ( + TranscriptFinalSummaryProcessor, + TranscriptFinalTitleProcessor, + TranscriptTopicDetectorProcessor, +) +from reflector.processors.file_transcript import FileTranscriptInput +from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor +from reflector.processors.types import TitleSummary +from reflector.processors.types import ( + Transcript as TranscriptType, +) +from reflector.settings import settings +from reflector.storage import get_transcripts_storage + + +class EmptyPipeline: + def __init__(self, logger: structlog.BoundLogger): + self.logger = logger + + def get_pref(self, k, d=None): + return d + + async def emit(self, event): + pass + + +class PipelineMainMultitrack(PipelineMainBase): + """Process multiple participant tracks for a transcript without mixing audio.""" + + def __init__(self, transcript_id: str): + super().__init__(transcript_id=transcript_id) + self.logger = logger.bind(transcript_id=self.transcript_id) + self.empty_pipeline = EmptyPipeline(logger=self.logger) + + async def set_status(self, transcript_id: str, status: TranscriptStatus): + async with self.lock_transaction(): + return await transcripts_controller.set_status(transcript_id, status) + + async def _list_immediate_keys( + self, s3, bucket_name: str, prefix: str + ) -> list[str]: + paginator = s3.get_paginator("list_objects_v2") + raw_prefix = prefix.rstrip("/") + prefixes = [raw_prefix, raw_prefix + "/"] + + keys: set[str] = set() + for pref in prefixes: + for page in paginator.paginate(Bucket=bucket_name, Prefix=pref): + for obj in page.get("Contents", []): + key = obj["Key"] + if not key.startswith(pref): + continue + if pref.endswith("/"): + rel = key[len(pref) :] + if not rel or rel.endswith("/") or "/" in rel: + continue + else: + if key != pref: + continue + keys.add(key) + result = sorted(keys) + self.logger.info( + "S3 list immediate files", + prefixes=prefixes, + total_keys=len(result), + sample=result[:5], + ) + return result + + async def process(self, bucket_name: str, prefix: str): + transcript = await self.get_transcript() + + s3 = boto3.client( + "s3", + region_name=settings.RECORDING_STORAGE_AWS_REGION, + aws_access_key_id=settings.RECORDING_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + keys = await self._list_immediate_keys(s3, bucket_name, prefix) + if not keys: + raise Exception("No audio tracks found under prefix") + + storage = get_transcripts_storage() + + speaker_transcripts: list[TranscriptType] = [] + for idx, key in enumerate(keys): + ext = ".mp4" + + try: + obj = s3.get_object(Bucket=bucket_name, Key=key) + data = obj["Body"].read() + except Exception as e: + self.logger.warning( + "Skipping track - cannot read S3 object", key=key, error=str(e) + ) + continue + + storage_path = f"file_pipeline/{transcript.id}/tracks/track_{idx}{ext}" + try: + await storage.put_file(storage_path, data) + audio_url = await storage.get_file_url(storage_path) + except Exception as e: + self.logger.warning( + "Skipping track - cannot upload to storage", key=key, error=str(e) + ) + continue + + try: + t = await self.transcribe_file(audio_url, transcript.source_language) + except Exception as e: + self.logger.warning( + "Transcription via default backend failed, trying local whisper", + key=key, + url=audio_url, + error=str(e), + ) + try: + fallback = FileTranscriptAutoProcessor(name="whisper") + result = None + + async def capture_result(r): + nonlocal result + result = r + + fallback.on(capture_result) + await fallback.push( + FileTranscriptInput( + audio_url=audio_url, language=transcript.source_language + ) + ) + await fallback.flush() + if not result: + raise Exception("No transcript captured in fallback") + t = result + except Exception as e2: + self.logger.warning( + "Skipping track - transcription failed after fallback", + key=key, + url=audio_url, + error=str(e2), + ) + continue + + if not t.words: + continue + for w in t.words: + w.speaker = idx + speaker_transcripts.append(t) + + if not speaker_transcripts: + raise Exception("No valid track transcriptions") + + merged_words = [] + for t in speaker_transcripts: + merged_words.extend(t.words) + merged_words.sort(key=lambda w: w.start) + + merged_transcript = TranscriptType(words=merged_words, translation=None) + + await transcripts_controller.append_event( + transcript, + event="TRANSCRIPT", + data=TranscriptText( + text=merged_transcript.text, translation=merged_transcript.translation + ), + ) + + topics = await self.detect_topics(merged_transcript, transcript.target_language) + await asyncio.gather( + self.generate_title(topics), + self.generate_summaries(topics), + return_exceptions=False, + ) + + await self.set_status(transcript.id, "ended") + + async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType: + processor = FileTranscriptAutoProcessor() + input_data = FileTranscriptInput(audio_url=audio_url, language=language) + + result: TranscriptType | None = None + + async def capture_result(transcript): + nonlocal result + result = transcript + + processor.on(capture_result) + await processor.push(input_data) + await processor.flush() + + if not result: + raise ValueError("No transcript captured") + + return result + + async def detect_topics( + self, transcript: TranscriptType, target_language: str + ) -> list[TitleSummary]: + chunk_size = 300 + topics: list[TitleSummary] = [] + + async def on_topic(topic: TitleSummary): + topics.append(topic) + return await self.on_topic(topic) + + topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic) + topic_detector.set_pipeline(self.empty_pipeline) + + for i in range(0, len(transcript.words), chunk_size): + chunk_words = transcript.words[i : i + chunk_size] + if not chunk_words: + continue + + chunk_transcript = TranscriptType( + words=chunk_words, translation=transcript.translation + ) + await topic_detector.push(chunk_transcript) + + await topic_detector.flush() + return topics + + async def generate_title(self, topics: list[TitleSummary]): + if not topics: + self.logger.warning("No topics for title generation") + return + + processor = TranscriptFinalTitleProcessor(callback=self.on_title) + processor.set_pipeline(self.empty_pipeline) + + for topic in topics: + await processor.push(topic) + + await processor.flush() + + async def generate_summaries(self, topics: list[TitleSummary]): + if not topics: + self.logger.warning("No topics for summary generation") + return + + transcript = await self.get_transcript() + processor = TranscriptFinalSummaryProcessor( + transcript=transcript, + callback=self.on_long_summary, + on_short_summary=self.on_short_summary, + ) + processor.set_pipeline(self.empty_pipeline) + + for topic in topics: + await processor.push(topic) + + await processor.flush() + + +@shared_task +@asynctask +async def task_pipeline_multitrack_process( + *, transcript_id: str, bucket_name: str, prefix: str +): + pipeline = PipelineMainMultitrack(transcript_id=transcript_id) + try: + await pipeline.set_status(transcript_id, "processing") + await pipeline.process(bucket_name, prefix) + except Exception: + await pipeline.set_status(transcript_id, "error") + raise + + post_chain = chain( + task_cleanup_consent.si(transcript_id=transcript_id), + task_pipeline_post_to_zulip.si(transcript_id=transcript_id), + task_send_webhook_if_needed.si(transcript_id=transcript_id), + ) + post_chain.delay() diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index f6be5b85..6ea0029d 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -1,5 +1,6 @@ import json import os +import re from datetime import datetime, timezone from urllib.parse import unquote @@ -17,6 +18,9 @@ from reflector.db.rooms import rooms_controller from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_live_pipeline import asynctask +from reflector.pipelines.main_multitrack_pipeline import ( + task_pipeline_multitrack_process, +) from reflector.redis_cache import get_redis_client from reflector.settings import settings from reflector.whereby import get_room_sessions @@ -147,6 +151,139 @@ async def process_recording(bucket_name: str, object_key: str): task_pipeline_file_process.delay(transcript_id=transcript.id) +@shared_task +@asynctask +async def process_multitrack_recording(bucket_name: str, prefix: str): + logger.info( + "Processing multitrack recording", + bucket=bucket_name, + prefix=prefix, + room_name="daily", + ) + + try: + effective_room_name = "/daily" + dir_name = prefix.rstrip("/").split("/")[-1] + ts_match = re.search(r"(\d{14})$", dir_name) + if ts_match: + ts = ts_match.group(1) + recorded_at = datetime.strptime(ts, "%Y%m%d%H%M%S").replace( + tzinfo=timezone.utc + ) + else: + try: + recorded_at = parse_datetime_with_timezone(dir_name) + except Exception: + recorded_at = datetime.now(timezone.utc) + except Exception: + logger.warning("Could not parse recorded_at from prefix, using now()") + effective_room_name = "/daily" + recorded_at = datetime.now(timezone.utc) + + meeting = await meetings_controller.get_by_room_name(effective_room_name) + if meeting: + room = await rooms_controller.get_by_id(meeting.room_id) + else: + room = await rooms_controller.get_by_name(effective_room_name.lstrip("/")) + if not room: + raise Exception(f"Room not found: {effective_room_name}") + start_date = recorded_at + end_date = recorded_at + try: + dummy = await meetings_controller.create( + id=room.id + "-" + recorded_at.strftime("%Y%m%d%H%M%S"), + room_name=effective_room_name, + room_url=f"{effective_room_name}", + host_room_url=f"{effective_room_name}", + start_date=start_date, + end_date=end_date, + room=room, + ) + meeting = dummy + except Exception as e: + logger.warning("Failed to create dummy meeting", error=str(e)) + meeting = None + + recording = await recordings_controller.get_by_object_key(bucket_name, prefix) + if not recording: + recording = await recordings_controller.create( + Recording( + bucket_name=bucket_name, + object_key=prefix, + recorded_at=recorded_at, + meeting_id=meeting.id if meeting else None, + ) + ) + + transcript = await transcripts_controller.get_by_recording_id(recording.id) + if transcript: + await transcripts_controller.update( + transcript, + { + "topics": [], + }, + ) + else: + transcript = await transcripts_controller.add( + "", + source_kind=SourceKind.ROOM, + source_language="en", + target_language="en", + user_id=room.user_id, + recording_id=recording.id, + share_mode="public", + meeting_id=meeting.id if meeting else None, + room_id=room.id, + ) + + s3 = boto3.client( + "s3", + region_name=settings.RECORDING_STORAGE_AWS_REGION, + aws_access_key_id=settings.RECORDING_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + paginator = s3.get_paginator("list_objects_v2") + raw_prefix = prefix.rstrip("/") + prefixes = [raw_prefix, raw_prefix + "/"] + + all_keys_set: set[str] = set() + for pref in prefixes: + for page in paginator.paginate(Bucket=bucket_name, Prefix=pref): + contents = page.get("Contents", []) + for obj in contents: + key = obj["Key"] + if not key.startswith(pref): + continue + if pref.endswith("/"): + rel = key[len(pref) :] + if not rel or rel.endswith("/") or "/" in rel: + continue + else: + if key == pref: + all_keys_set.add(key) + continue + all_keys_set.add(key) + + all_keys = sorted(all_keys_set) + logger.info( + "S3 list immediate files", + prefixes=prefixes, + total_keys=len(all_keys), + sample=all_keys[:5], + ) + + track_keys: list[str] = all_keys[:] + + if not track_keys: + logger.info("No objects found under prefix", prefixes=prefixes) + raise Exception("No audio tracks found under prefix") + + task_pipeline_multitrack_process.delay( + transcript_id=transcript.id, bucket_name=bucket_name, prefix=prefix + ) + + @shared_task @asynctask async def process_meetings():