mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
Split padding and transcription into separate workflow steps
- Split process_tracks into process_paddings + process_transcriptions - Create PaddingWorkflow and TranscriptionWorkflow as separate child workflows - Update dependency: mixdown_tracks now depends on process_paddings (not process_transcriptions) - Performance: mixdown starts ~295s earlier (after padding completes, not after transcription) Changes: - New: padding_workflow.py, transcription_workflow.py - Modified: daily_multitrack_pipeline.py (new tasks, updated dependencies) - Modified: models.py (new ProcessPaddingsResult, ProcessTranscriptionsResult, deleted dead ProcessTracksResult) - Modified: constants.py (new task names) - Modified: run_workers_cpu.py, run_workers_llm.py (workflow registration) - Deleted: track_processing.py Code quality fixes: - Removed redundant comments and verbose docstrings - Added language validation in process_transcriptions - Improved error logging with full context (transcript_id, track_index) - Fixed log accuracy bugs (use correct counts) - Updated worker pool documentation
This commit is contained in:
@@ -8,7 +8,8 @@ from enum import StrEnum
|
|||||||
class TaskName(StrEnum):
|
class TaskName(StrEnum):
|
||||||
GET_RECORDING = "get_recording"
|
GET_RECORDING = "get_recording"
|
||||||
GET_PARTICIPANTS = "get_participants"
|
GET_PARTICIPANTS = "get_participants"
|
||||||
PROCESS_TRACKS = "process_tracks"
|
PROCESS_PADDINGS = "process_paddings"
|
||||||
|
PROCESS_TRANSCRIPTIONS = "process_transcriptions"
|
||||||
MIXDOWN_TRACKS = "mixdown_tracks"
|
MIXDOWN_TRACKS = "mixdown_tracks"
|
||||||
GENERATE_WAVEFORM = "generate_waveform"
|
GENERATE_WAVEFORM = "generate_waveform"
|
||||||
DETECT_TOPICS = "detect_topics"
|
DETECT_TOPICS = "detect_topics"
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
CPU-heavy worker pool for audio processing tasks.
|
CPU-heavy worker pool for audio processing tasks.
|
||||||
Handles ONLY: mixdown_tracks
|
Handles: mixdown_tracks (serialized), padding workflows (parallel child workflows)
|
||||||
|
|
||||||
Configuration:
|
Configuration:
|
||||||
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
- slots=1: Mixdown serialized globally with max_runs=1
|
||||||
- Worker affinity: pool=cpu-heavy
|
- Worker affinity: pool=cpu-heavy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -11,6 +11,7 @@ from reflector.hatchet.client import HatchetClientManager
|
|||||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
daily_multitrack_pipeline,
|
daily_multitrack_pipeline,
|
||||||
)
|
)
|
||||||
|
from reflector.hatchet.workflows.padding_workflow import padding_workflow
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ def main():
|
|||||||
hatchet = HatchetClientManager.get_client()
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Starting Hatchet CPU worker pool (mixdown only)",
|
"Starting Hatchet CPU worker pool (mixdown + padding)",
|
||||||
worker_name="cpu-worker-pool",
|
worker_name="cpu-worker-pool",
|
||||||
slots=1,
|
slots=1,
|
||||||
labels={"pool": "cpu-heavy"},
|
labels={"pool": "cpu-heavy"},
|
||||||
@@ -31,11 +32,11 @@ def main():
|
|||||||
|
|
||||||
cpu_worker = hatchet.worker(
|
cpu_worker = hatchet.worker(
|
||||||
"cpu-worker-pool",
|
"cpu-worker-pool",
|
||||||
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
slots=1,
|
||||||
labels={
|
labels={
|
||||||
"pool": "cpu-heavy",
|
"pool": "cpu-heavy",
|
||||||
},
|
},
|
||||||
workflows=[daily_multitrack_pipeline],
|
workflows=[daily_multitrack_pipeline, padding_workflow],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|||||||
)
|
)
|
||||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
from reflector.hatchet.workflows.transcription_workflow import transcription_workflow
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def main():
|
|||||||
daily_multitrack_pipeline,
|
daily_multitrack_pipeline,
|
||||||
topic_chunk_workflow,
|
topic_chunk_workflow,
|
||||||
subject_workflow,
|
subject_workflow,
|
||||||
track_workflow,
|
transcription_workflow,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,10 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|||||||
PipelineInput,
|
PipelineInput,
|
||||||
daily_multitrack_pipeline,
|
daily_multitrack_pipeline,
|
||||||
)
|
)
|
||||||
|
from reflector.hatchet.workflows.padding_workflow import (
|
||||||
|
PaddingInput,
|
||||||
|
padding_workflow,
|
||||||
|
)
|
||||||
from reflector.hatchet.workflows.subject_processing import (
|
from reflector.hatchet.workflows.subject_processing import (
|
||||||
SubjectInput,
|
SubjectInput,
|
||||||
subject_workflow,
|
subject_workflow,
|
||||||
@@ -12,15 +16,20 @@ from reflector.hatchet.workflows.topic_chunk_processing import (
|
|||||||
TopicChunkInput,
|
TopicChunkInput,
|
||||||
topic_chunk_workflow,
|
topic_chunk_workflow,
|
||||||
)
|
)
|
||||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
from reflector.hatchet.workflows.transcription_workflow import (
|
||||||
|
TranscriptionInput,
|
||||||
|
transcription_workflow,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"daily_multitrack_pipeline",
|
"daily_multitrack_pipeline",
|
||||||
"subject_workflow",
|
"subject_workflow",
|
||||||
"topic_chunk_workflow",
|
"topic_chunk_workflow",
|
||||||
"track_workflow",
|
"padding_workflow",
|
||||||
|
"transcription_workflow",
|
||||||
"PipelineInput",
|
"PipelineInput",
|
||||||
"SubjectInput",
|
"SubjectInput",
|
||||||
"TopicChunkInput",
|
"TopicChunkInput",
|
||||||
"TrackInput",
|
"PaddingInput",
|
||||||
|
"TranscriptionInput",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -54,8 +54,9 @@ from reflector.hatchet.workflows.models import (
|
|||||||
PadTrackResult,
|
PadTrackResult,
|
||||||
ParticipantInfo,
|
ParticipantInfo,
|
||||||
ParticipantsResult,
|
ParticipantsResult,
|
||||||
|
ProcessPaddingsResult,
|
||||||
ProcessSubjectsResult,
|
ProcessSubjectsResult,
|
||||||
ProcessTracksResult,
|
ProcessTranscriptionsResult,
|
||||||
RecapResult,
|
RecapResult,
|
||||||
RecordingResult,
|
RecordingResult,
|
||||||
SubjectsResult,
|
SubjectsResult,
|
||||||
@@ -68,6 +69,7 @@ from reflector.hatchet.workflows.models import (
|
|||||||
WebhookResult,
|
WebhookResult,
|
||||||
ZulipResult,
|
ZulipResult,
|
||||||
)
|
)
|
||||||
|
from reflector.hatchet.workflows.padding_workflow import PaddingInput, padding_workflow
|
||||||
from reflector.hatchet.workflows.subject_processing import (
|
from reflector.hatchet.workflows.subject_processing import (
|
||||||
SubjectInput,
|
SubjectInput,
|
||||||
subject_workflow,
|
subject_workflow,
|
||||||
@@ -76,7 +78,10 @@ from reflector.hatchet.workflows.topic_chunk_processing import (
|
|||||||
TopicChunkInput,
|
TopicChunkInput,
|
||||||
topic_chunk_workflow,
|
topic_chunk_workflow,
|
||||||
)
|
)
|
||||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
from reflector.hatchet.workflows.transcription_workflow import (
|
||||||
|
TranscriptionInput,
|
||||||
|
transcription_workflow,
|
||||||
|
)
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines import topic_processing
|
from reflector.pipelines import topic_processing
|
||||||
from reflector.processors import AudioFileWriterProcessor
|
from reflector.processors import AudioFileWriterProcessor
|
||||||
@@ -404,39 +409,29 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=3,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.PROCESS_TRACKS)
|
@with_error_handling(TaskName.PROCESS_PADDINGS)
|
||||||
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
async def process_paddings(input: PipelineInput, ctx: Context) -> ProcessPaddingsResult:
|
||||||
"""Spawn child workflows for each track (dynamic fan-out)."""
|
"""Spawn child workflows for each track to apply padding (dynamic fan-out)."""
|
||||||
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
|
ctx.log(f"process_paddings: spawning {len(input.tracks)} padding workflows")
|
||||||
|
|
||||||
participants_result = ctx.task_output(get_participants)
|
|
||||||
source_language = participants_result.source_language
|
|
||||||
|
|
||||||
bulk_runs = [
|
bulk_runs = [
|
||||||
track_workflow.create_bulk_run_item(
|
padding_workflow.create_bulk_run_item(
|
||||||
input=TrackInput(
|
input=PaddingInput(
|
||||||
track_index=i,
|
track_index=i,
|
||||||
s3_key=track["s3_key"],
|
s3_key=track["s3_key"],
|
||||||
bucket_name=input.bucket_name,
|
bucket_name=input.bucket_name,
|
||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
language=source_language,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for i, track in enumerate(input.tracks)
|
for i, track in enumerate(input.tracks)
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await track_workflow.aio_run_many(bulk_runs)
|
results = await padding_workflow.aio_run_many(bulk_runs)
|
||||||
|
|
||||||
target_language = participants_result.target_language
|
|
||||||
|
|
||||||
track_words: list[list[Word]] = []
|
|
||||||
padded_tracks = []
|
padded_tracks = []
|
||||||
created_padded_files = set()
|
created_padded_files = set()
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
|
||||||
track_words.append(transcribe_result.words)
|
|
||||||
|
|
||||||
pad_result = PadTrackResult(**result[TaskName.PAD_TRACK])
|
pad_result = PadTrackResult(**result[TaskName.PAD_TRACK])
|
||||||
|
|
||||||
# Store S3 key info (not presigned URL) - consumer tasks presign on demand
|
# Store S3 key info (not presigned URL) - consumer tasks presign on demand
|
||||||
@@ -451,25 +446,75 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm"
|
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm"
|
||||||
created_padded_files.add(storage_path)
|
created_padded_files.add(storage_path)
|
||||||
|
|
||||||
all_words = [word for words in track_words for word in words]
|
ctx.log(f"process_paddings complete: {len(padded_tracks)} padded tracks")
|
||||||
all_words.sort(key=lambda w: w.start)
|
|
||||||
|
|
||||||
ctx.log(
|
return ProcessPaddingsResult(
|
||||||
f"process_tracks complete: {len(all_words)} words from {len(input.tracks)} tracks"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ProcessTracksResult(
|
|
||||||
all_words=all_words,
|
|
||||||
padded_tracks=padded_tracks,
|
padded_tracks=padded_tracks,
|
||||||
word_count=len(all_words),
|
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
target_language=target_language,
|
|
||||||
created_padded_files=list(created_padded_files),
|
created_padded_files=list(created_padded_files),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[process_tracks],
|
parents=[process_paddings],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.PROCESS_TRANSCRIPTIONS)
|
||||||
|
async def process_transcriptions(
|
||||||
|
input: PipelineInput, ctx: Context
|
||||||
|
) -> ProcessTranscriptionsResult:
|
||||||
|
"""Spawn child workflows for each padded track to transcribe (dynamic fan-out)."""
|
||||||
|
participants_result = ctx.task_output(get_participants)
|
||||||
|
paddings_result = ctx.task_output(process_paddings)
|
||||||
|
|
||||||
|
source_language = participants_result.source_language
|
||||||
|
if not source_language:
|
||||||
|
raise ValueError("source_language is required for transcription")
|
||||||
|
|
||||||
|
target_language = participants_result.target_language
|
||||||
|
padded_tracks = paddings_result.padded_tracks
|
||||||
|
|
||||||
|
ctx.log(
|
||||||
|
f"process_transcriptions: spawning {len(padded_tracks)} transcription workflows"
|
||||||
|
)
|
||||||
|
|
||||||
|
bulk_runs = [
|
||||||
|
transcription_workflow.create_bulk_run_item(
|
||||||
|
input=TranscriptionInput(
|
||||||
|
track_index=i,
|
||||||
|
padded_key=padded_track.key,
|
||||||
|
bucket_name=padded_track.bucket_name,
|
||||||
|
language=source_language,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i, padded_track in enumerate(padded_tracks)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await transcription_workflow.aio_run_many(bulk_runs)
|
||||||
|
|
||||||
|
track_words: list[list[Word]] = []
|
||||||
|
for result in results:
|
||||||
|
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
||||||
|
track_words.append(transcribe_result.words)
|
||||||
|
|
||||||
|
all_words = [word for words in track_words for word in words]
|
||||||
|
all_words.sort(key=lambda w: w.start)
|
||||||
|
|
||||||
|
ctx.log(
|
||||||
|
f"process_transcriptions complete: {len(all_words)} words from {len(padded_tracks)} tracks"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ProcessTranscriptionsResult(
|
||||||
|
all_words=all_words,
|
||||||
|
word_count=len(all_words),
|
||||||
|
num_tracks=len(input.tracks),
|
||||||
|
target_language=target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@daily_multitrack_pipeline.task(
|
||||||
|
parents=[process_paddings],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
retries=3,
|
retries=3,
|
||||||
desired_worker_labels={
|
desired_worker_labels={
|
||||||
@@ -489,12 +534,12 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
)
|
)
|
||||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
"""Mix all padded tracks into single audio file using PyAV."""
|
||||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||||
|
|
||||||
track_result = ctx.task_output(process_tracks)
|
paddings_result = ctx.task_output(process_paddings)
|
||||||
recording_result = ctx.task_output(get_recording)
|
recording_result = ctx.task_output(get_recording)
|
||||||
padded_tracks = track_result.padded_tracks
|
padded_tracks = paddings_result.padded_tracks
|
||||||
|
|
||||||
# Dynamic timeout: scales with track count and recording duration
|
# Dynamic timeout: scales with track count and recording duration
|
||||||
# Base 300s + 60s per track + 1s per 10s of recording
|
# Base 300s + 60s per track + 1s per 10s of recording
|
||||||
@@ -648,7 +693,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
|
|
||||||
|
|
||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[process_tracks],
|
parents=[process_transcriptions],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=3,
|
||||||
)
|
)
|
||||||
@@ -657,8 +702,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
"""Detect topics using parallel child workflows (one per chunk)."""
|
"""Detect topics using parallel child workflows (one per chunk)."""
|
||||||
ctx.log("detect_topics: analyzing transcript for topics")
|
ctx.log("detect_topics: analyzing transcript for topics")
|
||||||
|
|
||||||
track_result = ctx.task_output(process_tracks)
|
transcriptions_result = ctx.task_output(process_transcriptions)
|
||||||
words = track_result.all_words
|
words = transcriptions_result.all_words
|
||||||
|
|
||||||
if not words:
|
if not words:
|
||||||
ctx.log("detect_topics: no words, returning empty topics")
|
ctx.log("detect_topics: no words, returning empty topics")
|
||||||
@@ -1109,13 +1154,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
||||||
|
|
||||||
mixdown_result = ctx.task_output(mixdown_tracks)
|
mixdown_result = ctx.task_output(mixdown_tracks)
|
||||||
track_result = ctx.task_output(process_tracks)
|
transcriptions_result = ctx.task_output(process_transcriptions)
|
||||||
|
paddings_result = ctx.task_output(process_paddings)
|
||||||
|
|
||||||
duration = mixdown_result.duration
|
duration = mixdown_result.duration
|
||||||
all_words = track_result.all_words
|
all_words = transcriptions_result.all_words
|
||||||
|
|
||||||
# Cleanup temporary padded S3 files (deferred until finalize for semantic parity with Celery)
|
# Cleanup temporary padded S3 files (deferred until finalize for semantic parity with Celery)
|
||||||
created_padded_files = track_result.created_padded_files
|
created_padded_files = paddings_result.created_padded_files
|
||||||
if created_padded_files:
|
if created_padded_files:
|
||||||
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
|
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
|
||||||
storage = _spawn_storage()
|
storage = _spawn_storage()
|
||||||
|
|||||||
@@ -23,10 +23,8 @@ class ParticipantInfo(BaseModel):
|
|||||||
class PadTrackResult(BaseModel):
|
class PadTrackResult(BaseModel):
|
||||||
"""Result from pad_track task."""
|
"""Result from pad_track task."""
|
||||||
|
|
||||||
padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
padded_key: NonEmptyString
|
||||||
bucket_name: (
|
bucket_name: NonEmptyString | None
|
||||||
NonEmptyString | None
|
|
||||||
) # None means use default transcript storage bucket
|
|
||||||
size: int
|
size: int
|
||||||
track_index: int
|
track_index: int
|
||||||
|
|
||||||
@@ -59,18 +57,24 @@ class PaddedTrackInfo(BaseModel):
|
|||||||
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
||||||
|
|
||||||
key: NonEmptyString
|
key: NonEmptyString
|
||||||
bucket_name: NonEmptyString | None # None = use default storage bucket
|
bucket_name: NonEmptyString | None
|
||||||
|
|
||||||
|
|
||||||
class ProcessTracksResult(BaseModel):
|
class ProcessPaddingsResult(BaseModel):
|
||||||
"""Result from process_tracks task."""
|
"""Result from process_paddings task."""
|
||||||
|
|
||||||
|
padded_tracks: list[PaddedTrackInfo]
|
||||||
|
num_tracks: int
|
||||||
|
created_padded_files: list[NonEmptyString]
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessTranscriptionsResult(BaseModel):
|
||||||
|
"""Result from process_transcriptions task."""
|
||||||
|
|
||||||
all_words: list[Word]
|
all_words: list[Word]
|
||||||
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
|
|
||||||
word_count: int
|
word_count: int
|
||||||
num_tracks: int
|
num_tracks: int
|
||||||
target_language: NonEmptyString
|
target_language: NonEmptyString
|
||||||
created_padded_files: list[NonEmptyString]
|
|
||||||
|
|
||||||
|
|
||||||
class MixdownResult(BaseModel):
|
class MixdownResult(BaseModel):
|
||||||
|
|||||||
145
server/reflector/hatchet/workflows/padding_workflow.py
Normal file
145
server/reflector/hatchet/workflows/padding_workflow.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Hatchet child workflow: PaddingWorkflow
|
||||||
|
Handles individual audio track padding only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
from hatchet_sdk import Context
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
||||||
|
from reflector.hatchet.workflows.models import PadTrackResult
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||||
|
from reflector.utils.audio_padding import (
|
||||||
|
apply_audio_padding_to_file,
|
||||||
|
extract_stream_start_time_from_container,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaddingInput(BaseModel):
|
||||||
|
"""Input for individual track padding."""
|
||||||
|
|
||||||
|
track_index: int
|
||||||
|
s3_key: str
|
||||||
|
bucket_name: str
|
||||||
|
transcript_id: str
|
||||||
|
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
padding_workflow = hatchet.workflow(
|
||||||
|
name="PaddingWorkflow", input_validator=PaddingInput
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
||||||
|
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||||
|
"""Pad audio track with silence based on WebM container start_time."""
|
||||||
|
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] pad_track",
|
||||||
|
track_index=input.track_index,
|
||||||
|
s3_key=input.s3_key,
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||||
|
from reflector.settings import settings # noqa: PLC0415
|
||||||
|
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||||
|
|
||||||
|
storage = AwsStorage(
|
||||||
|
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||||
|
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||||
|
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
source_url = await storage.get_file_url(
|
||||||
|
input.s3_key,
|
||||||
|
operation="get_object",
|
||||||
|
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
bucket=input.bucket_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
with av.open(source_url) as in_container:
|
||||||
|
if in_container.duration:
|
||||||
|
try:
|
||||||
|
duration = timedelta(seconds=in_container.duration // 1_000_000)
|
||||||
|
ctx.log(
|
||||||
|
f"pad_track: track {input.track_index}, duration={duration}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
ctx.log(f"pad_track: track {input.track_index}, duration=ERROR")
|
||||||
|
|
||||||
|
start_time_seconds = extract_stream_start_time_from_container(
|
||||||
|
in_container, input.track_index, logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_time_seconds <= 0:
|
||||||
|
logger.info(
|
||||||
|
f"Track {input.track_index} requires no padding",
|
||||||
|
track_index=input.track_index,
|
||||||
|
)
|
||||||
|
return PadTrackResult(
|
||||||
|
padded_key=input.s3_key,
|
||||||
|
bucket_name=input.bucket_name,
|
||||||
|
size=0,
|
||||||
|
track_index=input.track_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
apply_audio_padding_to_file(
|
||||||
|
in_container,
|
||||||
|
temp_path,
|
||||||
|
start_time_seconds,
|
||||||
|
input.track_index,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_size = Path(temp_path).stat().st_size
|
||||||
|
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||||
|
|
||||||
|
with open(temp_path, "rb") as padded_file:
|
||||||
|
await storage.put_file(storage_path, padded_file)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Uploaded padded track to S3",
|
||||||
|
key=storage_path,
|
||||||
|
size=file_size,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
Path(temp_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}")
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] pad_track complete",
|
||||||
|
track_index=input.track_index,
|
||||||
|
padded_key=storage_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return PadTrackResult(
|
||||||
|
padded_key=storage_path,
|
||||||
|
bucket_name=None, # None = use default transcript storage bucket
|
||||||
|
size=file_size,
|
||||||
|
track_index=input.track_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] pad_track failed",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
track_index=input.track_index,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
"""
|
|
||||||
Hatchet child workflow: TrackProcessing
|
|
||||||
|
|
||||||
Handles individual audio track processing: padding and transcription.
|
|
||||||
Spawned dynamically by the main diarization pipeline for each track.
|
|
||||||
|
|
||||||
Architecture note: This is a separate workflow (not inline tasks in DailyMultitrackPipeline)
|
|
||||||
because Hatchet workflow DAGs are defined statically, but the number of tracks varies
|
|
||||||
at runtime. Child workflow spawning via `aio_run()` + `asyncio.gather()` is the
|
|
||||||
standard pattern for dynamic fan-out. See `process_tracks` in daily_multitrack_pipeline.py.
|
|
||||||
|
|
||||||
Note: This file uses deferred imports (inside tasks) intentionally.
|
|
||||||
Hatchet workers run in forked processes; fresh imports per task ensure
|
|
||||||
storage/DB connections are not shared across forks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from datetime import timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import av
|
|
||||||
from hatchet_sdk import Context
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
|
||||||
from reflector.hatchet.constants import TIMEOUT_AUDIO, TIMEOUT_HEAVY
|
|
||||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
|
||||||
from reflector.utils.audio_padding import (
|
|
||||||
apply_audio_padding_to_file,
|
|
||||||
extract_stream_start_time_from_container,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrackInput(BaseModel):
|
|
||||||
"""Input for individual track processing."""
|
|
||||||
|
|
||||||
track_index: int
|
|
||||||
s3_key: str
|
|
||||||
bucket_name: str
|
|
||||||
transcript_id: str
|
|
||||||
language: str = "en"
|
|
||||||
|
|
||||||
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
|
||||||
|
|
||||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
|
||||||
|
|
||||||
|
|
||||||
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
|
||||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|
||||||
"""Pad single audio track with silence for alignment.
|
|
||||||
|
|
||||||
Extracts stream.start_time from WebM container metadata and applies
|
|
||||||
silence padding using PyAV filter graph (adelay).
|
|
||||||
"""
|
|
||||||
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
|
||||||
logger.info(
|
|
||||||
"[Hatchet] pad_track",
|
|
||||||
track_index=input.track_index,
|
|
||||||
s3_key=input.s3_key,
|
|
||||||
transcript_id=input.transcript_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
|
||||||
from reflector.settings import settings # noqa: PLC0415
|
|
||||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
|
||||||
|
|
||||||
storage = AwsStorage(
|
|
||||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
|
||||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
|
||||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
|
||||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
|
||||||
)
|
|
||||||
|
|
||||||
source_url = await storage.get_file_url(
|
|
||||||
input.s3_key,
|
|
||||||
operation="get_object",
|
|
||||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
|
||||||
bucket=input.bucket_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
with av.open(source_url) as in_container:
|
|
||||||
if in_container.duration:
|
|
||||||
try:
|
|
||||||
duration = timedelta(seconds=in_container.duration // 1_000_000)
|
|
||||||
ctx.log(
|
|
||||||
f"pad_track: track {input.track_index}, duration={duration}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
ctx.log(f"pad_track: track {input.track_index}, duration=ERROR")
|
|
||||||
|
|
||||||
start_time_seconds = extract_stream_start_time_from_container(
|
|
||||||
in_container, input.track_index, logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
# If no padding needed, return original S3 key
|
|
||||||
if start_time_seconds <= 0:
|
|
||||||
logger.info(
|
|
||||||
f"Track {input.track_index} requires no padding",
|
|
||||||
track_index=input.track_index,
|
|
||||||
)
|
|
||||||
return PadTrackResult(
|
|
||||||
padded_key=input.s3_key,
|
|
||||||
bucket_name=input.bucket_name,
|
|
||||||
size=0,
|
|
||||||
track_index=input.track_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
|
||||||
temp_path = temp_file.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
apply_audio_padding_to_file(
|
|
||||||
in_container,
|
|
||||||
temp_path,
|
|
||||||
start_time_seconds,
|
|
||||||
input.track_index,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_size = Path(temp_path).stat().st_size
|
|
||||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"About to upload padded track",
|
|
||||||
key=storage_path,
|
|
||||||
size=file_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(temp_path, "rb") as padded_file:
|
|
||||||
await storage.put_file(storage_path, padded_file)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Uploaded padded track to S3",
|
|
||||||
key=storage_path,
|
|
||||||
size=file_size,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
Path(temp_path).unlink(missing_ok=True)
|
|
||||||
|
|
||||||
ctx.log(f"pad_track complete: track {input.track_index} -> {storage_path}")
|
|
||||||
logger.info(
|
|
||||||
"[Hatchet] pad_track complete",
|
|
||||||
track_index=input.track_index,
|
|
||||||
padded_key=storage_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return S3 key (not presigned URL) - consumer tasks presign on demand
|
|
||||||
# This avoids stale URLs when workflow is replayed
|
|
||||||
return PadTrackResult(
|
|
||||||
padded_key=storage_path,
|
|
||||||
bucket_name=None, # None = use default transcript storage bucket
|
|
||||||
size=file_size,
|
|
||||||
track_index=input.track_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@track_workflow.task(
|
|
||||||
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
|
||||||
)
|
|
||||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
|
||||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
|
||||||
ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}")
|
|
||||||
logger.info(
|
|
||||||
"[Hatchet] transcribe_track",
|
|
||||||
track_index=input.track_index,
|
|
||||||
language=input.language,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pad_result = ctx.task_output(pad_track)
|
|
||||||
padded_key = pad_result.padded_key
|
|
||||||
bucket_name = pad_result.bucket_name
|
|
||||||
|
|
||||||
if not padded_key:
|
|
||||||
raise ValueError("Missing padded_key from pad_track")
|
|
||||||
|
|
||||||
# Presign URL on demand (avoids stale URLs on workflow replay)
|
|
||||||
from reflector.settings import settings # noqa: PLC0415
|
|
||||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
|
||||||
|
|
||||||
storage = AwsStorage(
|
|
||||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
|
||||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
|
||||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
|
||||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_url = await storage.get_file_url(
|
|
||||||
padded_key,
|
|
||||||
operation="get_object",
|
|
||||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
|
||||||
bucket=bucket_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
|
||||||
transcribe_file_with_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript = await transcribe_file_with_processor(audio_url, input.language)
|
|
||||||
|
|
||||||
# Tag all words with speaker index
|
|
||||||
for word in transcript.words:
|
|
||||||
word.speaker = input.track_index
|
|
||||||
|
|
||||||
ctx.log(
|
|
||||||
f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"[Hatchet] transcribe_track complete",
|
|
||||||
track_index=input.track_index,
|
|
||||||
word_count=len(transcript.words),
|
|
||||||
)
|
|
||||||
|
|
||||||
return TranscribeTrackResult(
|
|
||||||
words=transcript.words,
|
|
||||||
track_index=input.track_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
|
||||||
raise
|
|
||||||
98
server/reflector/hatchet/workflows/transcription_workflow.py
Normal file
98
server/reflector/hatchet/workflows/transcription_workflow.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Hatchet child workflow: TranscriptionWorkflow
|
||||||
|
Handles individual audio track transcription only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from hatchet_sdk import Context
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.constants import TIMEOUT_HEAVY
|
||||||
|
from reflector.hatchet.workflows.models import TranscribeTrackResult
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionInput(BaseModel):
|
||||||
|
"""Input for individual track transcription."""
|
||||||
|
|
||||||
|
track_index: int
|
||||||
|
padded_key: str # S3 key from padding step
|
||||||
|
bucket_name: str | None # None = use default bucket
|
||||||
|
language: str = "en"
|
||||||
|
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
transcription_workflow = hatchet.workflow(
|
||||||
|
name="TranscriptionWorkflow", input_validator=TranscriptionInput
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@transcription_workflow.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
||||||
|
)
|
||||||
|
async def transcribe_track(
|
||||||
|
input: TranscriptionInput, ctx: Context
|
||||||
|
) -> TranscribeTrackResult:
|
||||||
|
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||||
|
ctx.log(f"transcribe_track: track {input.track_index}, language={input.language}")
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] transcribe_track",
|
||||||
|
track_index=input.track_index,
|
||||||
|
language=input.language,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reflector.settings import settings # noqa: PLC0415
|
||||||
|
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||||
|
|
||||||
|
storage = AwsStorage(
|
||||||
|
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||||
|
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||||
|
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_url = await storage.get_file_url(
|
||||||
|
input.padded_key,
|
||||||
|
operation="get_object",
|
||||||
|
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
bucket=input.bucket_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
||||||
|
transcribe_file_with_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcribe_file_with_processor(audio_url, input.language)
|
||||||
|
|
||||||
|
for word in transcript.words:
|
||||||
|
word.speaker = input.track_index
|
||||||
|
|
||||||
|
ctx.log(
|
||||||
|
f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] transcribe_track complete",
|
||||||
|
track_index=input.track_index,
|
||||||
|
word_count=len(transcript.words),
|
||||||
|
)
|
||||||
|
|
||||||
|
return TranscribeTrackResult(
|
||||||
|
words=transcript.words,
|
||||||
|
track_index=input.track_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] transcribe_track failed",
|
||||||
|
track_index=input.track_index,
|
||||||
|
padded_key=input.padded_key,
|
||||||
|
language=input.language,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
Reference in New Issue
Block a user