mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
Split padding and transcription into separate workflow steps
- Split process_tracks into process_paddings + process_transcriptions - Create PaddingWorkflow and TranscriptionWorkflow as separate child workflows - Update dependency: mixdown_tracks now depends on process_paddings (not process_transcriptions) - Performance: mixdown starts ~295s earlier (after padding completes, not after transcription) Changes: - New: padding_workflow.py, transcription_workflow.py - Modified: daily_multitrack_pipeline.py (new tasks, updated dependencies) - Modified: models.py (new ProcessPaddingsResult, ProcessTranscriptionsResult, deleted dead ProcessTracksResult) - Modified: constants.py (new task names) - Modified: run_workers_cpu.py, run_workers_llm.py (workflow registration) - Deleted: track_processing.py Code quality fixes: - Removed redundant comments and verbose docstrings - Added language validation in process_transcriptions - Improved error logging with full context (transcript_id, track_index) - Fixed log accuracy bugs (use correct counts) - Updated worker pool documentation
This commit is contained in:
@@ -8,7 +8,8 @@ from enum import StrEnum
|
||||
class TaskName(StrEnum):
|
||||
GET_RECORDING = "get_recording"
|
||||
GET_PARTICIPANTS = "get_participants"
|
||||
PROCESS_TRACKS = "process_tracks"
|
||||
PROCESS_PADDINGS = "process_paddings"
|
||||
PROCESS_TRANSCRIPTIONS = "process_transcriptions"
|
||||
MIXDOWN_TRACKS = "mixdown_tracks"
|
||||
GENERATE_WAVEFORM = "generate_waveform"
|
||||
DETECT_TOPICS = "detect_topics"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
CPU-heavy worker pool for audio processing tasks.
|
||||
Handles ONLY: mixdown_tracks
|
||||
Handles: mixdown_tracks (serialized), padding workflows (parallel child workflows)
|
||||
|
||||
Configuration:
|
||||
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
||||
- slots=1: Mixdown serialized globally with max_runs=1
|
||||
- Worker affinity: pool=cpu-heavy
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
daily_multitrack_pipeline,
|
||||
)
|
||||
from reflector.hatchet.workflows.padding_workflow import padding_workflow
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
@@ -23,7 +24,7 @@ def main():
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
logger.info(
|
||||
"Starting Hatchet CPU worker pool (mixdown only)",
|
||||
"Starting Hatchet CPU worker pool (mixdown + padding)",
|
||||
worker_name="cpu-worker-pool",
|
||||
slots=1,
|
||||
labels={"pool": "cpu-heavy"},
|
||||
@@ -31,11 +32,11 @@ def main():
|
||||
|
||||
cpu_worker = hatchet.worker(
|
||||
"cpu-worker-pool",
|
||||
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
||||
slots=1,
|
||||
labels={
|
||||
"pool": "cpu-heavy",
|
||||
},
|
||||
workflows=[daily_multitrack_pipeline],
|
||||
workflows=[daily_multitrack_pipeline, padding_workflow],
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,7 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
)
|
||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||
from reflector.hatchet.workflows.transcription_workflow import transcription_workflow
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
daily_multitrack_pipeline,
|
||||
topic_chunk_workflow,
|
||||
subject_workflow,
|
||||
track_workflow,
|
||||
transcription_workflow,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
PipelineInput,
|
||||
daily_multitrack_pipeline,
|
||||
)
|
||||
from reflector.hatchet.workflows.padding_workflow import (
|
||||
PaddingInput,
|
||||
padding_workflow,
|
||||
)
|
||||
from reflector.hatchet.workflows.subject_processing import (
|
||||
SubjectInput,
|
||||
subject_workflow,
|
||||
@@ -12,15 +16,20 @@ from reflector.hatchet.workflows.topic_chunk_processing import (
|
||||
TopicChunkInput,
|
||||
topic_chunk_workflow,
|
||||
)
|
||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||
from reflector.hatchet.workflows.transcription_workflow import (
|
||||
TranscriptionInput,
|
||||
transcription_workflow,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"daily_multitrack_pipeline",
|
||||
"subject_workflow",
|
||||
"topic_chunk_workflow",
|
||||
"track_workflow",
|
||||
"padding_workflow",
|
||||
"transcription_workflow",
|
||||
"PipelineInput",
|
||||
"SubjectInput",
|
||||
"TopicChunkInput",
|
||||
"TrackInput",
|
||||
"PaddingInput",
|
||||
"TranscriptionInput",
|
||||
]
|
||||
|
||||
@@ -54,8 +54,9 @@ from reflector.hatchet.workflows.models import (
|
||||
PadTrackResult,
|
||||
ParticipantInfo,
|
||||
ParticipantsResult,
|
||||
ProcessPaddingsResult,
|
||||
ProcessSubjectsResult,
|
||||
ProcessTracksResult,
|
||||
ProcessTranscriptionsResult,
|
||||
RecapResult,
|
||||
RecordingResult,
|
||||
SubjectsResult,
|
||||
@@ -68,6 +69,7 @@ from reflector.hatchet.workflows.models import (
|
||||
WebhookResult,
|
||||
ZulipResult,
|
||||
)
|
||||
from reflector.hatchet.workflows.padding_workflow import PaddingInput, padding_workflow
|
||||
from reflector.hatchet.workflows.subject_processing import (
|
||||
SubjectInput,
|
||||
subject_workflow,
|
||||
@@ -76,7 +78,10 @@ from reflector.hatchet.workflows.topic_chunk_processing import (
|
||||
TopicChunkInput,
|
||||
topic_chunk_workflow,
|
||||
)
|
||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||
from reflector.hatchet.workflows.transcription_workflow import (
|
||||
TranscriptionInput,
|
||||
transcription_workflow,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines import topic_processing
|
||||
from reflector.processors import AudioFileWriterProcessor
|
||||
@@ -404,39 +409,29 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
)
|
||||
@with_error_handling(TaskName.PROCESS_TRACKS)
|
||||
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
||||
"""Spawn child workflows for each track (dynamic fan-out)."""
|
||||
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
|
||||
|
||||
participants_result = ctx.task_output(get_participants)
|
||||
source_language = participants_result.source_language
|
||||
@with_error_handling(TaskName.PROCESS_PADDINGS)
|
||||
async def process_paddings(input: PipelineInput, ctx: Context) -> ProcessPaddingsResult:
|
||||
"""Spawn child workflows for each track to apply padding (dynamic fan-out)."""
|
||||
ctx.log(f"process_paddings: spawning {len(input.tracks)} padding workflows")
|
||||
|
||||
bulk_runs = [
|
||||
track_workflow.create_bulk_run_item(
|
||||
input=TrackInput(
|
||||
padding_workflow.create_bulk_run_item(
|
||||
input=PaddingInput(
|
||||
track_index=i,
|
||||
s3_key=track["s3_key"],
|
||||
bucket_name=input.bucket_name,
|
||||
transcript_id=input.transcript_id,
|
||||
language=source_language,
|
||||
)
|
||||
)
|
||||
for i, track in enumerate(input.tracks)
|
||||
]
|
||||
|
||||
results = await track_workflow.aio_run_many(bulk_runs)
|
||||
results = await padding_workflow.aio_run_many(bulk_runs)
|
||||
|
||||
target_language = participants_result.target_language
|
||||
|
||||
track_words: list[list[Word]] = []
|
||||
padded_tracks = []
|
||||
created_padded_files = set()
|
||||
|
||||
for result in results:
|
||||
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
||||
track_words.append(transcribe_result.words)
|
||||
|
||||
pad_result = PadTrackResult(**result[TaskName.PAD_TRACK])
|
||||
|
||||
# Store S3 key info (not presigned URL) - consumer tasks presign on demand
|
||||
@@ -451,25 +446,75 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm"
|
||||
created_padded_files.add(storage_path)
|
||||
|
||||
all_words = [word for words in track_words for word in words]
|
||||
all_words.sort(key=lambda w: w.start)
|
||||
ctx.log(f"process_paddings complete: {len(padded_tracks)} padded tracks")
|
||||
|
||||
ctx.log(
|
||||
f"process_tracks complete: {len(all_words)} words from {len(input.tracks)} tracks"
|
||||
)
|
||||
|
||||
return ProcessTracksResult(
|
||||
all_words=all_words,
|
||||
return ProcessPaddingsResult(
|
||||
padded_tracks=padded_tracks,
|
||||
word_count=len(all_words),
|
||||
num_tracks=len(input.tracks),
|
||||
target_language=target_language,
|
||||
created_padded_files=list(created_padded_files),
|
||||
)
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[process_tracks],
|
||||
parents=[process_paddings],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
)
|
||||
@with_error_handling(TaskName.PROCESS_TRANSCRIPTIONS)
|
||||
async def process_transcriptions(
|
||||
input: PipelineInput, ctx: Context
|
||||
) -> ProcessTranscriptionsResult:
|
||||
"""Spawn child workflows for each padded track to transcribe (dynamic fan-out)."""
|
||||
participants_result = ctx.task_output(get_participants)
|
||||
paddings_result = ctx.task_output(process_paddings)
|
||||
|
||||
source_language = participants_result.source_language
|
||||
if not source_language:
|
||||
raise ValueError("source_language is required for transcription")
|
||||
|
||||
target_language = participants_result.target_language
|
||||
padded_tracks = paddings_result.padded_tracks
|
||||
|
||||
ctx.log(
|
||||
f"process_transcriptions: spawning {len(padded_tracks)} transcription workflows"
|
||||
)
|
||||
|
||||
bulk_runs = [
|
||||
transcription_workflow.create_bulk_run_item(
|
||||
input=TranscriptionInput(
|
||||
track_index=i,
|
||||
padded_key=padded_track.key,
|
||||
bucket_name=padded_track.bucket_name,
|
||||
language=source_language,
|
||||
)
|
||||
)
|
||||
for i, padded_track in enumerate(padded_tracks)
|
||||
]
|
||||
|
||||
results = await transcription_workflow.aio_run_many(bulk_runs)
|
||||
|
||||
track_words: list[list[Word]] = []
|
||||
for result in results:
|
||||
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
||||
track_words.append(transcribe_result.words)
|
||||
|
||||
all_words = [word for words in track_words for word in words]
|
||||
all_words.sort(key=lambda w: w.start)
|
||||
|
||||
ctx.log(
|
||||
f"process_transcriptions complete: {len(all_words)} words from {len(padded_tracks)} tracks"
|
||||
)
|
||||
|
||||
return ProcessTranscriptionsResult(
|
||||
all_words=all_words,
|
||||
word_count=len(all_words),
|
||||
num_tracks=len(input.tracks),
|
||||
target_language=target_language,
|
||||
)
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[process_paddings],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||
retries=3,
|
||||
desired_worker_labels={
|
||||
@@ -489,12 +534,12 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
)
|
||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
||||
"""Mix all padded tracks into single audio file using PyAV."""
|
||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||
|
||||
track_result = ctx.task_output(process_tracks)
|
||||
paddings_result = ctx.task_output(process_paddings)
|
||||
recording_result = ctx.task_output(get_recording)
|
||||
padded_tracks = track_result.padded_tracks
|
||||
padded_tracks = paddings_result.padded_tracks
|
||||
|
||||
# Dynamic timeout: scales with track count and recording duration
|
||||
# Base 300s + 60s per track + 1s per 10s of recording
|
||||
@@ -648,7 +693,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[process_tracks],
|
||||
parents=[process_transcriptions],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
)
|
||||
@@ -657,8 +702,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
"""Detect topics using parallel child workflows (one per chunk)."""
|
||||
ctx.log("detect_topics: analyzing transcript for topics")
|
||||
|
||||
track_result = ctx.task_output(process_tracks)
|
||||
words = track_result.all_words
|
||||
transcriptions_result = ctx.task_output(process_transcriptions)
|
||||
words = transcriptions_result.all_words
|
||||
|
||||
if not words:
|
||||
ctx.log("detect_topics: no words, returning empty topics")
|
||||
@@ -1109,13 +1154,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
||||
|
||||
mixdown_result = ctx.task_output(mixdown_tracks)
|
||||
track_result = ctx.task_output(process_tracks)
|
||||
transcriptions_result = ctx.task_output(process_transcriptions)
|
||||
paddings_result = ctx.task_output(process_paddings)
|
||||
|
||||
duration = mixdown_result.duration
|
||||
all_words = track_result.all_words
|
||||
all_words = transcriptions_result.all_words
|
||||
|
||||
# Cleanup temporary padded S3 files (deferred until finalize for semantic parity with Celery)
|
||||
created_padded_files = track_result.created_padded_files
|
||||
created_padded_files = paddings_result.created_padded_files
|
||||
if created_padded_files:
|
||||
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
|
||||
storage = _spawn_storage()
|
||||
|
||||
@@ -23,10 +23,8 @@ class ParticipantInfo(BaseModel):
|
||||
class PadTrackResult(BaseModel):
|
||||
"""Result from pad_track task."""
|
||||
|
||||
padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
||||
bucket_name: (
|
||||
NonEmptyString | None
|
||||
) # None means use default transcript storage bucket
|
||||
padded_key: NonEmptyString
|
||||
bucket_name: NonEmptyString | None
|
||||
size: int
|
||||
track_index: int
|
||||
|
||||
@@ -59,18 +57,24 @@ class PaddedTrackInfo(BaseModel):
|
||||
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
||||
|
||||
key: NonEmptyString
|
||||
bucket_name: NonEmptyString | None # None = use default storage bucket
|
||||
bucket_name: NonEmptyString | None
|
||||
|
||||
|
||||
class ProcessTracksResult(BaseModel):
|
||||
"""Result from process_tracks task."""
|
||||
class ProcessPaddingsResult(BaseModel):
|
||||
"""Result from process_paddings task."""
|
||||
|
||||
padded_tracks: list[PaddedTrackInfo]
|
||||
num_tracks: int
|
||||
created_padded_files: list[NonEmptyString]
|
||||
|
||||
|
||||
class ProcessTranscriptionsResult(BaseModel):
|
||||
"""Result from process_transcriptions task."""
|
||||
|
||||
all_words: list[Word]
|
||||
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
|
||||
word_count: int
|
||||
num_tracks: int
|
||||
target_language: NonEmptyString
|
||||
created_padded_files: list[NonEmptyString]
|
||||
|
||||
|
||||
class MixdownResult(BaseModel):
|
||||
|
||||
145
server/reflector/hatchet/workflows/padding_workflow.py
Normal file
145
server/reflector/hatchet/workflows/padding_workflow.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Hatchet child workflow: PaddingWorkflow
|
||||
Handles individual audio track padding only.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
||||
from reflector.hatchet.workflows.models import PadTrackResult
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||
from reflector.utils.audio_padding import (
|
||||
apply_audio_padding_to_file,
|
||||
extract_stream_start_time_from_container,
|
||||
)
|
||||
|
||||
|
||||
class PaddingInput(BaseModel):
|
||||
"""Input for individual track padding."""
|
||||
|
||||
track_index: int
|
||||
s3_key: str
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
padding_workflow = hatchet.workflow(
|
||||
name="PaddingWorkflow", input_validator=PaddingInput
|
||||
)
|
||||
|
||||
|
||||
@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
||||
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad audio track with silence based on WebM container start_time."""
|
||||
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
||||
logger.info(
|
||||
"[Hatchet] pad_track",
|
||||
track_index=input.track_index,
|
||||
s3_key=input.s3_key,
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
source_url = await storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
bucket=input.bucket_name,
|
||||
)
|
||||
|
||||
with av.open(source_url) as in_container:
|
||||
if in_container.duration:
|
||||
try:
|
||||
duration = timedelta(seconds=in_container.duration // 1_000_000)
|
||||
ctx.log(
|
||||
f"pad_track: track {input.track_index}, duration={duration}"
|
||||
)
|
||||
except Exception:
|
||||
ctx.log(f"pad_track: track {input.track_index}, duration=ERROR")
|
||||
|
||||
start_time_seconds = extract_stream_start_time_from_container(
|
||||
in_container, input.track_index, logger=logger
|
||||
)
|
||||
|
||||
if start_time_seconds <= 0:
|
||||
logger.info(
|
||||
f"Track {input.track_index} requires no padding",
|
||||
track_index=input.track_index,
|
||||
)
|
||||
return PadTrackResult(
|
||||
padded_key=input.s3_key,
|
||||
bucket_name=input.bucket_name,
|
||||
size=0,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
apply_audio_padding_to_file(
|
||||
in_container,
|
||||
temp_path,
|
||||
start_time_seconds,
|
||||
input.track_index,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
file_size = Path(temp_path).stat().st_size
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||
|
||||
with open(temp_path, "rb") as padded_file:
|
||||
await storage.put_file(storage_path, padded_file)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded padded track to S3",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}")
|
||||
logger.info(
|
||||
"[Hatchet] pad_track complete",
|
||||
track_index=input.track_index,
|
||||
padded_key=storage_path,
|
||||
)
|
||||
|
||||
return PadTrackResult(
|
||||
padded_key=storage_path,
|
||||
bucket_name=None, # None = use default transcript storage bucket
|
||||
size=file_size,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[Hatchet] pad_track failed",
|
||||
transcript_id=input.transcript_id,
|
||||
track_index=input.track_index,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
Hatchet child workflow: TrackProcessing
|
||||
|
||||
Handles individual audio track processing: padding and transcription.
|
||||
Spawned dynamically by the main diarization pipeline for each track.
|
||||
|
||||
Architecture note: This is a separate workflow (not inline tasks in DailyMultitrackPipeline)
|
||||
because Hatchet workflow DAGs are defined statically, but the number of tracks varies
|
||||
at runtime. Child workflow spawning via `aio_run()` + `asyncio.gather()` is the
|
||||
standard pattern for dynamic fan-out. See `process_tracks` in daily_multitrack_pipeline.py.
|
||||
|
||||
Note: This file uses deferred imports (inside tasks) intentionally.
|
||||
Hatchet workers run in forked processes; fresh imports per task ensure
|
||||
storage/DB connections are not shared across forks.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO, TIMEOUT_HEAVY
|
||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||
from reflector.utils.audio_padding import (
|
||||
apply_audio_padding_to_file,
|
||||
extract_stream_start_time_from_container,
|
||||
)
|
||||
|
||||
|
||||
class TrackInput(BaseModel):
|
||||
"""Input for individual track processing."""
|
||||
|
||||
track_index: int
|
||||
s3_key: str
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
language: str = "en"
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||
|
||||
|
||||
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad single audio track with silence for alignment.
|
||||
|
||||
Extracts stream.start_time from WebM container metadata and applies
|
||||
silence padding using PyAV filter graph (adelay).
|
||||
"""
|
||||
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
||||
logger.info(
|
||||
"[Hatchet] pad_track",
|
||||
track_index=input.track_index,
|
||||
s3_key=input.s3_key,
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
source_url = await storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
bucket=input.bucket_name,
|
||||
)
|
||||
|
||||
with av.open(source_url) as in_container:
|
||||
if in_container.duration:
|
||||
try:
|
||||
duration = timedelta(seconds=in_container.duration // 1_000_000)
|
||||
ctx.log(
|
||||
f"pad_track: track {input.track_index}, duration={duration}"
|
||||
)
|
||||
except Exception:
|
||||
ctx.log(f"pad_track: track {input.track_index}, duration=ERROR")
|
||||
|
||||
start_time_seconds = extract_stream_start_time_from_container(
|
||||
in_container, input.track_index, logger=logger
|
||||
)
|
||||
|
||||
# If no padding needed, return original S3 key
|
||||
if start_time_seconds <= 0:
|
||||
logger.info(
|
||||
f"Track {input.track_index} requires no padding",
|
||||
track_index=input.track_index,
|
||||
)
|
||||
return PadTrackResult(
|
||||
padded_key=input.s3_key,
|
||||
bucket_name=input.bucket_name,
|
||||
size=0,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
apply_audio_padding_to_file(
|
||||
in_container,
|
||||
temp_path,
|
||||
start_time_seconds,
|
||||
input.track_index,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
file_size = Path(temp_path).stat().st_size
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||
|
||||
logger.info(
|
||||
f"About to upload padded track",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
with open(temp_path, "rb") as padded_file:
|
||||
await storage.put_file(storage_path, padded_file)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded padded track to S3",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}")
|
||||
logger.info(
|
||||
"[Hatchet] pad_track complete",
|
||||
track_index=input.track_index,
|
||||
padded_key=storage_path,
|
||||
)
|
||||
|
||||
# Return S3 key (not presigned URL) - consumer tasks presign on demand
|
||||
# This avoids stale URLs when workflow is replayed
|
||||
return PadTrackResult(
|
||||
padded_key=storage_path,
|
||||
bucket_name=None, # None = use default transcript storage bucket
|
||||
size=file_size,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@track_workflow.task(
|
||||
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
||||
)
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}")
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track",
|
||||
track_index=input.track_index,
|
||||
language=input.language,
|
||||
)
|
||||
|
||||
try:
|
||||
pad_result = ctx.task_output(pad_track)
|
||||
padded_key = pad_result.padded_key
|
||||
bucket_name = pad_result.bucket_name
|
||||
|
||||
if not padded_key:
|
||||
raise ValueError("Missing padded_key from pad_track")
|
||||
|
||||
# Presign URL on demand (avoids stale URLs on workflow replay)
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
audio_url = await storage.get_file_url(
|
||||
padded_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
bucket=bucket_name,
|
||||
)
|
||||
|
||||
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
||||
transcribe_file_with_processor,
|
||||
)
|
||||
|
||||
transcript = await transcribe_file_with_processor(audio_url, input.language)
|
||||
|
||||
# Tag all words with speaker index
|
||||
for word in transcript.words:
|
||||
word.speaker = input.track_index
|
||||
|
||||
ctx.log(
|
||||
f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words"
|
||||
)
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track complete",
|
||||
track_index=input.track_index,
|
||||
word_count=len(transcript.words),
|
||||
)
|
||||
|
||||
return TranscribeTrackResult(
|
||||
words=transcript.words,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||
raise
|
||||
98
server/reflector/hatchet/workflows/transcription_workflow.py
Normal file
98
server/reflector/hatchet/workflows/transcription_workflow.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Hatchet child workflow: TranscriptionWorkflow
|
||||
Handles individual audio track transcription only.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import TIMEOUT_HEAVY
|
||||
from reflector.hatchet.workflows.models import TranscribeTrackResult
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||
|
||||
|
||||
class TranscriptionInput(BaseModel):
|
||||
"""Input for individual track transcription."""
|
||||
|
||||
track_index: int
|
||||
padded_key: str # S3 key from padding step
|
||||
bucket_name: str | None # None = use default bucket
|
||||
language: str = "en"
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
transcription_workflow = hatchet.workflow(
|
||||
name="TranscriptionWorkflow", input_validator=TranscriptionInput
|
||||
)
|
||||
|
||||
|
||||
@transcription_workflow.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
||||
)
|
||||
async def transcribe_track(
|
||||
input: TranscriptionInput, ctx: Context
|
||||
) -> TranscribeTrackResult:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}")
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track",
|
||||
track_index=input.track_index,
|
||||
language=input.language,
|
||||
)
|
||||
|
||||
try:
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
audio_url = await storage.get_file_url(
|
||||
input.padded_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
bucket=input.bucket_name,
|
||||
)
|
||||
|
||||
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
||||
transcribe_file_with_processor,
|
||||
)
|
||||
|
||||
transcript = await transcribe_file_with_processor(audio_url, input.language)
|
||||
|
||||
for word in transcript.words:
|
||||
word.speaker = input.track_index
|
||||
|
||||
ctx.log(
|
||||
f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words"
|
||||
)
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track complete",
|
||||
track_index=input.track_index,
|
||||
word_count=len(transcript.words),
|
||||
)
|
||||
|
||||
return TranscribeTrackResult(
|
||||
words=transcript.words,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[Hatchet] transcribe_track failed",
|
||||
track_index=input.track_index,
|
||||
padded_key=input.padded_key,
|
||||
language=input.language,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
Reference in New Issue
Block a user