From 3ce279daa44311c80a290bae004abc0ee9524b3c Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Wed, 21 Jan 2026 16:53:06 -0500 Subject: [PATCH] Split padding and transcription into separate workflow steps - Split process_tracks into process_paddings + process_transcriptions - Create PaddingWorkflow and TranscriptionWorkflow as separate child workflows - Update dependency: mixdown_tracks now depends on process_paddings (not process_transcriptions) - Performance: mixdown starts ~295s earlier (after padding completes, not after transcription) Changes: - New: padding_workflow.py, transcription_workflow.py - Modified: daily_multitrack_pipeline.py (new tasks, updated dependencies) - Modified: models.py (new ProcessPaddingsResult, ProcessTranscriptionsResult, deleted dead ProcessTracksResult) - Modified: constants.py (new task names) - Modified: run_workers_cpu.py, run_workers_llm.py (workflow registration) - Deleted: track_processing.py Code quality fixes: - Removed redundant comments and verbose docstrings - Added language validation in process_transcriptions - Improved error logging with full context (transcript_id, track_index) - Fixed log accuracy bugs (use correct counts) - Updated worker pool documentation --- server/reflector/hatchet/constants.py | 3 +- server/reflector/hatchet/run_workers_cpu.py | 11 +- server/reflector/hatchet/run_workers_llm.py | 4 +- .../reflector/hatchet/workflows/__init__.py | 15 +- .../workflows/daily_multitrack_pipeline.py | 124 +++++++--- server/reflector/hatchet/workflows/models.py | 22 +- .../hatchet/workflows/padding_workflow.py | 145 +++++++++++ .../hatchet/workflows/track_processing.py | 229 ------------------ .../workflows/transcription_workflow.py | 98 ++++++++ 9 files changed, 363 insertions(+), 288 deletions(-) create mode 100644 server/reflector/hatchet/workflows/padding_workflow.py delete mode 100644 server/reflector/hatchet/workflows/track_processing.py create mode 100644 server/reflector/hatchet/workflows/transcription_workflow.py diff --git a/server/reflector/hatchet/constants.py b/server/reflector/hatchet/constants.py index fbe6d25b..0f47045e 100644 --- a/server/reflector/hatchet/constants.py +++ b/server/reflector/hatchet/constants.py @@ -8,7 +8,8 @@ from enum import StrEnum class TaskName(StrEnum): GET_RECORDING = "get_recording" GET_PARTICIPANTS = "get_participants" - PROCESS_TRACKS = "process_tracks" + PROCESS_PADDINGS = "process_paddings" + PROCESS_TRANSCRIPTIONS = "process_transcriptions" MIXDOWN_TRACKS = "mixdown_tracks" GENERATE_WAVEFORM = "generate_waveform" DETECT_TOPICS = "detect_topics" diff --git a/server/reflector/hatchet/run_workers_cpu.py b/server/reflector/hatchet/run_workers_cpu.py index 3fa1106d..39de3a1b 100644 --- a/server/reflector/hatchet/run_workers_cpu.py +++ b/server/reflector/hatchet/run_workers_cpu.py @@ -1,9 +1,9 @@ """ CPU-heavy worker pool for audio processing tasks. -Handles ONLY: mixdown_tracks +Handles: mixdown_tracks (serialized), padding workflows (parallel child workflows) Configuration: -- slots=1: Only mixdown (already serialized globally with max_runs=1) +- slots=1: Mixdown serialized globally with max_runs=1 - Worker affinity: pool=cpu-heavy """ @@ -11,6 +11,7 @@ from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.workflows.daily_multitrack_pipeline import ( daily_multitrack_pipeline, ) +from reflector.hatchet.workflows.padding_workflow import padding_workflow from reflector.logger import logger from reflector.settings import settings @@ -23,7 +24,7 @@ def main(): hatchet = HatchetClientManager.get_client() logger.info( - "Starting Hatchet CPU worker pool (mixdown only)", + "Starting Hatchet CPU worker pool (mixdown + padding)", worker_name="cpu-worker-pool", slots=1, labels={"pool": "cpu-heavy"}, @@ -31,11 +32,11 @@ def main(): cpu_worker = hatchet.worker( "cpu-worker-pool", - slots=1, # Only 1 mixdown at a time (already serialized globally) + slots=1, labels={ "pool": "cpu-heavy", }, - workflows=[daily_multitrack_pipeline], + workflows=[daily_multitrack_pipeline, padding_workflow], ) try: diff --git a/server/reflector/hatchet/run_workers_llm.py b/server/reflector/hatchet/run_workers_llm.py index 00c3a115..a3b44bb8 100644 --- a/server/reflector/hatchet/run_workers_llm.py +++ b/server/reflector/hatchet/run_workers_llm.py @@ -9,7 +9,7 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import ( ) from reflector.hatchet.workflows.subject_processing import subject_workflow from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow -from reflector.hatchet.workflows.track_processing import track_workflow +from reflector.hatchet.workflows.transcription_workflow import transcription_workflow from reflector.logger import logger from reflector.settings import settings @@ -42,7 +42,7 @@ def main(): daily_multitrack_pipeline, topic_chunk_workflow, subject_workflow, - track_workflow, + transcription_workflow, ], ) diff --git a/server/reflector/hatchet/workflows/__init__.py b/server/reflector/hatchet/workflows/__init__.py index ea242ad6..bf75f76d 100644 --- a/server/reflector/hatchet/workflows/__init__.py +++ b/server/reflector/hatchet/workflows/__init__.py @@ -4,6 +4,10 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import ( PipelineInput, daily_multitrack_pipeline, ) +from reflector.hatchet.workflows.padding_workflow import ( + PaddingInput, + padding_workflow, +) from reflector.hatchet.workflows.subject_processing import ( SubjectInput, subject_workflow, @@ -12,15 +16,20 @@ from reflector.hatchet.workflows.topic_chunk_processing import ( TopicChunkInput, topic_chunk_workflow, ) -from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow +from reflector.hatchet.workflows.transcription_workflow import ( + TranscriptionInput, + transcription_workflow, +) __all__ = [ "daily_multitrack_pipeline", "subject_workflow", "topic_chunk_workflow", - "track_workflow", + "padding_workflow", + "transcription_workflow", "PipelineInput", "SubjectInput", "TopicChunkInput", - "TrackInput", + "PaddingInput", + "TranscriptionInput", ] diff --git a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py index 0726cfd6..4c1d2869 100644 --- a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py +++ b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py @@ -54,8 +54,9 @@ from reflector.hatchet.workflows.models import ( PadTrackResult, ParticipantInfo, ParticipantsResult, + ProcessPaddingsResult, ProcessSubjectsResult, - ProcessTracksResult, + ProcessTranscriptionsResult, RecapResult, RecordingResult, SubjectsResult, @@ -68,6 +69,7 @@ from reflector.hatchet.workflows.models import ( WebhookResult, ZulipResult, ) +from reflector.hatchet.workflows.padding_workflow import PaddingInput, padding_workflow from reflector.hatchet.workflows.subject_processing import ( SubjectInput, subject_workflow, @@ -76,7 +78,10 @@ from reflector.hatchet.workflows.topic_chunk_processing import ( TopicChunkInput, topic_chunk_workflow, ) -from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow +from reflector.hatchet.workflows.transcription_workflow import ( + TranscriptionInput, + transcription_workflow, +) from reflector.logger import logger from reflector.pipelines import topic_processing from reflector.processors import AudioFileWriterProcessor @@ -404,39 +409,29 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) -@with_error_handling(TaskName.PROCESS_TRACKS) -async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult: - """Spawn child workflows for each track (dynamic fan-out).""" - ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows") - - participants_result = ctx.task_output(get_participants) - source_language = participants_result.source_language +@with_error_handling(TaskName.PROCESS_PADDINGS) +async def process_paddings(input: PipelineInput, ctx: Context) -> ProcessPaddingsResult: + """Spawn child workflows for each track to apply padding (dynamic fan-out).""" + ctx.log(f"process_paddings: spawning {len(input.tracks)} padding workflows") bulk_runs = [ - track_workflow.create_bulk_run_item( - input=TrackInput( + padding_workflow.create_bulk_run_item( + input=PaddingInput( track_index=i, s3_key=track["s3_key"], bucket_name=input.bucket_name, transcript_id=input.transcript_id, - language=source_language, ) ) for i, track in enumerate(input.tracks) ] - results = await track_workflow.aio_run_many(bulk_runs) + results = await padding_workflow.aio_run_many(bulk_runs) - target_language = participants_result.target_language - - track_words: list[list[Word]] = [] padded_tracks = [] created_padded_files = set() for result in results: - transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK]) - track_words.append(transcribe_result.words) - pad_result = PadTrackResult(**result[TaskName.PAD_TRACK]) # Store S3 key info (not presigned URL) - consumer tasks presign on demand @@ -451,25 +446,75 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm" created_padded_files.add(storage_path) - all_words = [word for words in track_words for word in words] - all_words.sort(key=lambda w: w.start) + ctx.log(f"process_paddings complete: {len(padded_tracks)} padded tracks") - ctx.log( - f"process_tracks complete: {len(all_words)} words from {len(input.tracks)} tracks" - ) - - return ProcessTracksResult( - all_words=all_words, + return ProcessPaddingsResult( padded_tracks=padded_tracks, - word_count=len(all_words), num_tracks=len(input.tracks), - target_language=target_language, created_padded_files=list(created_padded_files), ) @daily_multitrack_pipeline.task( - parents=[process_tracks], + parents=[process_paddings], + execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), + retries=3, +) +@with_error_handling(TaskName.PROCESS_TRANSCRIPTIONS) +async def process_transcriptions( + input: PipelineInput, ctx: Context +) -> ProcessTranscriptionsResult: + """Spawn child workflows for each padded track to transcribe (dynamic fan-out).""" + participants_result = ctx.task_output(get_participants) + paddings_result = ctx.task_output(process_paddings) + + source_language = participants_result.source_language + if not source_language: + raise ValueError("source_language is required for transcription") + + target_language = participants_result.target_language + padded_tracks = paddings_result.padded_tracks + + ctx.log( + f"process_transcriptions: spawning {len(padded_tracks)} transcription workflows" + ) + + bulk_runs = [ + transcription_workflow.create_bulk_run_item( + input=TranscriptionInput( + track_index=i, + padded_key=padded_track.key, + bucket_name=padded_track.bucket_name, + language=source_language, + ) + ) + for i, padded_track in enumerate(padded_tracks) + ] + + results = await transcription_workflow.aio_run_many(bulk_runs) + + track_words: list[list[Word]] = [] + for result in results: + transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK]) + track_words.append(transcribe_result.words) + + all_words = [word for words in track_words for word in words] + all_words.sort(key=lambda w: w.start) + + ctx.log( + f"process_transcriptions complete: {len(all_words)} words from {len(padded_tracks)} tracks" + ) + + return ProcessTranscriptionsResult( + all_words=all_words, + word_count=len(all_words), + num_tracks=len(input.tracks), + target_language=target_language, + ) + + +@daily_multitrack_pipeline.task( + parents=[process_paddings], execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3, desired_worker_labels={ @@ -489,12 +534,12 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes ) @with_error_handling(TaskName.MIXDOWN_TRACKS) async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: - """Mix all padded tracks into single audio file using PyAV (same as Celery).""" + """Mix all padded tracks into single audio file using PyAV.""" ctx.log("mixdown_tracks: mixing padded tracks into single audio file") - track_result = ctx.task_output(process_tracks) + paddings_result = ctx.task_output(process_paddings) recording_result = ctx.task_output(get_recording) - padded_tracks = track_result.padded_tracks + padded_tracks = paddings_result.padded_tracks # Dynamic timeout: scales with track count and recording duration # Base 300s + 60s per track + 1s per 10s of recording @@ -648,7 +693,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul @daily_multitrack_pipeline.task( - parents=[process_tracks], + parents=[process_transcriptions], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) @@ -657,8 +702,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: """Detect topics using parallel child workflows (one per chunk).""" ctx.log("detect_topics: analyzing transcript for topics") - track_result = ctx.task_output(process_tracks) - words = track_result.all_words + transcriptions_result = ctx.task_output(process_transcriptions) + words = transcriptions_result.all_words if not words: ctx.log("detect_topics: no words, returning empty topics") @@ -1109,13 +1154,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: ctx.log("finalize: saving transcript and setting status to 'ended'") mixdown_result = ctx.task_output(mixdown_tracks) - track_result = ctx.task_output(process_tracks) + transcriptions_result = ctx.task_output(process_transcriptions) + paddings_result = ctx.task_output(process_paddings) duration = mixdown_result.duration - all_words = track_result.all_words + all_words = transcriptions_result.all_words # Cleanup temporary padded S3 files (deferred until finalize for semantic parity with Celery) - created_padded_files = track_result.created_padded_files + created_padded_files = paddings_result.created_padded_files if created_padded_files: ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files") storage = _spawn_storage() diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index 1bad1f4a..b78f6e1a 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -23,10 +23,8 @@ class ParticipantInfo(BaseModel): class PadTrackResult(BaseModel): """Result from pad_track task.""" - padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay - bucket_name: ( - NonEmptyString | None - ) # None means use default transcript storage bucket + padded_key: NonEmptyString + bucket_name: NonEmptyString | None size: int track_index: int @@ -59,18 +57,24 @@ class PaddedTrackInfo(BaseModel): """Info for a padded track - S3 key + bucket for on-demand presigning.""" key: NonEmptyString - bucket_name: NonEmptyString | None # None = use default storage bucket + bucket_name: NonEmptyString | None -class ProcessTracksResult(BaseModel): - """Result from process_tracks task.""" +class ProcessPaddingsResult(BaseModel): + """Result from process_paddings task.""" + + padded_tracks: list[PaddedTrackInfo] + num_tracks: int + created_padded_files: list[NonEmptyString] + + +class ProcessTranscriptionsResult(BaseModel): + """Result from process_transcriptions task.""" all_words: list[Word] - padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs word_count: int num_tracks: int target_language: NonEmptyString - created_padded_files: list[NonEmptyString] class MixdownResult(BaseModel): diff --git a/server/reflector/hatchet/workflows/padding_workflow.py b/server/reflector/hatchet/workflows/padding_workflow.py new file mode 100644 index 00000000..2740d499 --- /dev/null +++ b/server/reflector/hatchet/workflows/padding_workflow.py @@ -0,0 +1,145 @@ +""" +Hatchet child workflow: PaddingWorkflow +Handles individual audio track padding only. +""" + +import tempfile +from datetime import timedelta +from pathlib import Path + +import av +from hatchet_sdk import Context +from pydantic import BaseModel + +from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.constants import TIMEOUT_AUDIO +from reflector.hatchet.workflows.models import PadTrackResult +from reflector.logger import logger +from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS +from reflector.utils.audio_padding import ( + apply_audio_padding_to_file, + extract_stream_start_time_from_container, +) + + +class PaddingInput(BaseModel): + """Input for individual track padding.""" + + track_index: int + s3_key: str + bucket_name: str + transcript_id: str + + +hatchet = HatchetClientManager.get_client() + +padding_workflow = hatchet.workflow( + name="PaddingWorkflow", input_validator=PaddingInput +) + + +@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) +async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult: + """Pad audio track with silence based on WebM container start_time.""" + ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}") + logger.info( + "[Hatchet] pad_track", + track_index=input.track_index, + s3_key=input.s3_key, + transcript_id=input.transcript_id, + ) + + try: + # Create fresh storage instance to avoid aioboto3 fork issues + from reflector.settings import settings # noqa: PLC0415 + from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415 + + storage = AwsStorage( + aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, + aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + source_url = await storage.get_file_url( + input.s3_key, + operation="get_object", + expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, + bucket=input.bucket_name, + ) + + with av.open(source_url) as in_container: + if in_container.duration: + try: + duration = timedelta(seconds=in_container.duration // 1_000_000) + ctx.log( + f"pad_track: track {input.track_index}, duration={duration}" + ) + except Exception: + ctx.log(f"pad_track: track {input.track_index}, duration=ERROR") + + start_time_seconds = extract_stream_start_time_from_container( + in_container, input.track_index, logger=logger + ) + + if start_time_seconds <= 0: + logger.info( + f"Track {input.track_index} requires no padding", + track_index=input.track_index, + ) + return PadTrackResult( + padded_key=input.s3_key, + bucket_name=input.bucket_name, + size=0, + track_index=input.track_index, + ) + + with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file: + temp_path = temp_file.name + + try: + apply_audio_padding_to_file( + in_container, + temp_path, + start_time_seconds, + input.track_index, + logger=logger, + ) + + file_size = Path(temp_path).stat().st_size + storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm" + + with open(temp_path, "rb") as padded_file: + await storage.put_file(storage_path, padded_file) + + logger.info( + f"Uploaded padded track to S3", + key=storage_path, + size=file_size, + ) + finally: + Path(temp_path).unlink(missing_ok=True) + + ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}") + logger.info( + "[Hatchet] pad_track complete", + track_index=input.track_index, + padded_key=storage_path, + ) + + return PadTrackResult( + padded_key=storage_path, + bucket_name=None, # None = use default transcript storage bucket + size=file_size, + track_index=input.track_index, + ) + + except Exception as e: + logger.error( + "[Hatchet] pad_track failed", + transcript_id=input.transcript_id, + track_index=input.track_index, + error=str(e), + exc_info=True, + ) + raise diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py deleted file mode 100644 index dd3aea3a..00000000 --- a/server/reflector/hatchet/workflows/track_processing.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Hatchet child workflow: TrackProcessing - -Handles individual audio track processing: padding and transcription. -Spawned dynamically by the main diarization pipeline for each track. - -Architecture note: This is a separate workflow (not inline tasks in DailyMultitrackPipeline) -because Hatchet workflow DAGs are defined statically, but the number of tracks varies -at runtime. Child workflow spawning via `aio_run()` + `asyncio.gather()` is the -standard pattern for dynamic fan-out. See `process_tracks` in daily_multitrack_pipeline.py. - -Note: This file uses deferred imports (inside tasks) intentionally. -Hatchet workers run in forked processes; fresh imports per task ensure -storage/DB connections are not shared across forks. -""" - -import tempfile -from datetime import timedelta -from pathlib import Path - -import av -from hatchet_sdk import Context -from pydantic import BaseModel - -from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.constants import TIMEOUT_AUDIO, TIMEOUT_HEAVY -from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult -from reflector.logger import logger -from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS -from reflector.utils.audio_padding import ( - apply_audio_padding_to_file, - extract_stream_start_time_from_container, -) - - -class TrackInput(BaseModel): - """Input for individual track processing.""" - - track_index: int - s3_key: str - bucket_name: str - transcript_id: str - language: str = "en" - - -hatchet = HatchetClientManager.get_client() - -track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput) - - -@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) -async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: - """Pad single audio track with silence for alignment. - - Extracts stream.start_time from WebM container metadata and applies - silence padding using PyAV filter graph (adelay). - """ - ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}") - logger.info( - "[Hatchet] pad_track", - track_index=input.track_index, - s3_key=input.s3_key, - transcript_id=input.transcript_id, - ) - - try: - # Create fresh storage instance to avoid aioboto3 fork issues - from reflector.settings import settings # noqa: PLC0415 - from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415 - - storage = AwsStorage( - aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, - aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION, - aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, - ) - - source_url = await storage.get_file_url( - input.s3_key, - operation="get_object", - expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, - bucket=input.bucket_name, - ) - - with av.open(source_url) as in_container: - if in_container.duration: - try: - duration = timedelta(seconds=in_container.duration // 1_000_000) - ctx.log( - f"pad_track: track {input.track_index}, duration={duration}" - ) - except Exception: - ctx.log(f"pad_track: track {input.track_index}, duration=ERROR") - - start_time_seconds = extract_stream_start_time_from_container( - in_container, input.track_index, logger=logger - ) - - # If no padding needed, return original S3 key - if start_time_seconds <= 0: - logger.info( - f"Track {input.track_index} requires no padding", - track_index=input.track_index, - ) - return PadTrackResult( - padded_key=input.s3_key, - bucket_name=input.bucket_name, - size=0, - track_index=input.track_index, - ) - - with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file: - temp_path = temp_file.name - - try: - apply_audio_padding_to_file( - in_container, - temp_path, - start_time_seconds, - input.track_index, - logger=logger, - ) - - file_size = Path(temp_path).stat().st_size - storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm" - - logger.info( - f"About to upload padded track", - key=storage_path, - size=file_size, - ) - - with open(temp_path, "rb") as padded_file: - await storage.put_file(storage_path, padded_file) - - logger.info( - f"Uploaded padded track to S3", - key=storage_path, - size=file_size, - ) - finally: - Path(temp_path).unlink(missing_ok=True) - - ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}") - logger.info( - "[Hatchet] pad_track complete", - track_index=input.track_index, - padded_key=storage_path, - ) - - # Return S3 key (not presigned URL) - consumer tasks presign on demand - # This avoids stale URLs when workflow is replayed - return PadTrackResult( - padded_key=storage_path, - bucket_name=None, # None = use default transcript storage bucket - size=file_size, - track_index=input.track_index, - ) - - except Exception as e: - logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True) - raise - - -@track_workflow.task( - parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3 -) -async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult: - """Transcribe audio track using GPU (Modal.com) or local Whisper.""" - ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}") - logger.info( - "[Hatchet] transcribe_track", - track_index=input.track_index, - language=input.language, - ) - - try: - pad_result = ctx.task_output(pad_track) - padded_key = pad_result.padded_key - bucket_name = pad_result.bucket_name - - if not padded_key: - raise ValueError("Missing padded_key from pad_track") - - # Presign URL on demand (avoids stale URLs on workflow replay) - from reflector.settings import settings # noqa: PLC0415 - from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415 - - storage = AwsStorage( - aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, - aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION, - aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, - ) - - audio_url = await storage.get_file_url( - padded_key, - operation="get_object", - expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, - bucket=bucket_name, - ) - - from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415 - transcribe_file_with_processor, - ) - - transcript = await transcribe_file_with_processor(audio_url, input.language) - - # Tag all words with speaker index - for word in transcript.words: - word.speaker = input.track_index - - ctx.log( - f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words" - ) - logger.info( - "[Hatchet] transcribe_track complete", - track_index=input.track_index, - word_count=len(transcript.words), - ) - - return TranscribeTrackResult( - words=transcript.words, - track_index=input.track_index, - ) - - except Exception as e: - logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True) - raise diff --git a/server/reflector/hatchet/workflows/transcription_workflow.py b/server/reflector/hatchet/workflows/transcription_workflow.py new file mode 100644 index 00000000..ed279e79 --- /dev/null +++ b/server/reflector/hatchet/workflows/transcription_workflow.py @@ -0,0 +1,98 @@ +""" +Hatchet child workflow: TranscriptionWorkflow +Handles individual audio track transcription only. +""" + +from datetime import timedelta + +from hatchet_sdk import Context +from pydantic import BaseModel + +from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.constants import TIMEOUT_HEAVY +from reflector.hatchet.workflows.models import TranscribeTrackResult +from reflector.logger import logger +from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS + + +class TranscriptionInput(BaseModel): + """Input for individual track transcription.""" + + track_index: int + padded_key: str # S3 key from padding step + bucket_name: str | None # None = use default bucket + language: str = "en" + + +hatchet = HatchetClientManager.get_client() + +transcription_workflow = hatchet.workflow( + name="TranscriptionWorkflow", input_validator=TranscriptionInput +) + + +@transcription_workflow.task( + execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3 +) +async def transcribe_track( + input: TranscriptionInput, ctx: Context +) -> TranscribeTrackResult: + """Transcribe audio track using GPU (Modal.com) or local Whisper.""" + ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}") + logger.info( + "[Hatchet] transcribe_track", + track_index=input.track_index, + language=input.language, + ) + + try: + from reflector.settings import settings # noqa: PLC0415 + from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415 + + storage = AwsStorage( + aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, + aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + audio_url = await storage.get_file_url( + input.padded_key, + operation="get_object", + expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, + bucket=input.bucket_name, + ) + + from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415 + transcribe_file_with_processor, + ) + + transcript = await transcribe_file_with_processor(audio_url, input.language) + + for word in transcript.words: + word.speaker = input.track_index + + ctx.log( + f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words" + ) + logger.info( + "[Hatchet] transcribe_track complete", + track_index=input.track_index, + word_count=len(transcript.words), + ) + + return TranscribeTrackResult( + words=transcript.words, + track_index=input.track_index, + ) + + except Exception as e: + logger.error( + "[Hatchet] transcribe_track failed", + track_index=input.track_index, + padded_key=input.padded_key, + language=input.language, + error=str(e), + exc_info=True, + ) + raise