From fce09455648d4f7217f6259f675ede7a8a92bfe9 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Tue, 16 Dec 2025 16:04:52 -0500 Subject: [PATCH] self-review (no-mistakes) --- TASKS.md | 115 -- server/reflector/hatchet/client.py | 46 +- .../hatchet/workflows/diarization_pipeline.py | 1691 ++++++++--------- server/reflector/hatchet/workflows/models.py | 123 ++ .../hatchet/workflows/track_processing.py | 43 +- .../reflector/services/transcript_process.py | 25 + server/reflector/worker/process.py | 5 + server/tests/conftest.py | 16 + server/tests/test_hatchet_client.py | 11 +- 9 files changed, 1034 insertions(+), 1041 deletions(-) delete mode 100644 TASKS.md create mode 100644 server/reflector/hatchet/workflows/models.py diff --git a/TASKS.md b/TASKS.md deleted file mode 100644 index 1b2d27c8..00000000 --- a/TASKS.md +++ /dev/null @@ -1,115 +0,0 @@ -# Durable Workflow Migration Tasks - -This document defines atomic, isolated work items for migrating the Daily.co multitrack diarization pipeline from Celery to durable workflow orchestration using **Hatchet**. - ---- - -## Provider Selection - -```bash -# .env -DURABLE_WORKFLOW_PROVIDER=none # Celery only (default) -DURABLE_WORKFLOW_PROVIDER=hatchet # Use Hatchet -DURABLE_WORKFLOW_SHADOW_MODE=true # Run both Hatchet + Celery (for comparison) -``` - ---- - -## Task Index - -| ID | Title | Status | -|----|-------|--------| -| INFRA-001 | Add container to docker-compose | Done | -| INFRA-002 | Create Python client wrapper | Done | -| INFRA-003 | Add environment configuration | Done | -| TASK-001 | Create workflow definition | Done | -| TASK-002 | get_recording task | Done | -| TASK-003 | get_participants task | Done | -| TASK-004 | pad_track task | Done | -| TASK-005 | mixdown_tracks task | Done | -| TASK-006 | generate_waveform task | Done | -| TASK-007 | transcribe_track task | Done | -| TASK-008 | merge_transcripts task | Done (in process_tracks) | -| TASK-009 | detect_topics task | Done | -| TASK-010 | generate_title task | Done | -| TASK-011 | generate_summary task | Done | -| TASK-012 | finalize task | Done | -| TASK-013 | cleanup_consent task | Done | -| TASK-014 | post_zulip task | Done | -| TASK-015 | send_webhook task | Done | -| EVENT-001 | Progress WebSocket events | Done | -| INTEG-001 | Pipeline trigger integration | Done | -| SHADOW-001 | Shadow mode toggle | Done | -| TEST-001 | Integration tests | Pending | -| TEST-002 | E2E workflow test | Pending | -| CUTOVER-001 | Production cutover | Pending | -| CLEANUP-001 | Remove Celery code | Pending | - ---- - -## File Structure - -``` -server/reflector/hatchet/ -├── client.py # SDK wrapper -├── progress.py # WebSocket progress emission -├── run_workers.py # Worker startup -└── workflows/ - ├── diarization_pipeline.py # Main workflow with all tasks - └── track_processing.py # Child workflow (pad + transcribe) -``` - ---- - -## Remaining Work - -### TEST-001: Integration Tests -- [ ] Test each task with mocked external services -- [ ] Test error handling and retries - -### TEST-002: E2E Workflow Test -- [ ] Complete workflow run with real Daily.co recording -- [ ] Verify output matches Celery pipeline -- [ ] Performance comparison - -### CUTOVER-001: Production Cutover -- [ ] Deploy with `DURABLE_WORKFLOW_PROVIDER=hatchet` -- [ ] Monitor for failures -- [ ] Compare results with shadow mode if needed - -### CLEANUP-001: Remove Celery Code -- [ ] Remove `main_multitrack_pipeline.py` -- [ ] Remove Celery task triggers -- [ ] Update documentation - ---- - -## Known Issues - -### Hatchet -- See `HATCHET_LLM_OBSERVATIONS.md` for debugging notes -- SDK v1.21+ API changes (breaking) -- JWT token Docker networking issues -- Worker appears hung without debug mode -- Workflow replay is version-locked (use --force to run latest code) - ---- - -## Quick Start - -### Hatchet -```bash -# Start infrastructure -docker compose up -d hatchet hatchet-worker - -# Workers auto-register on startup -``` - -### Trigger Workflow -```bash -# Set provider in .env -DURABLE_WORKFLOW_PROVIDER=hatchet - -# Process a Daily.co recording via webhook or API -# The pipeline trigger automatically uses the configured provider -``` diff --git a/server/reflector/hatchet/client.py b/server/reflector/hatchet/client.py index 2d48bb12..76088f17 100644 --- a/server/reflector/hatchet/client.py +++ b/server/reflector/hatchet/client.py @@ -1,37 +1,71 @@ -"""Hatchet Python client wrapper.""" +"""Hatchet Python client wrapper. -from hatchet_sdk import Hatchet +Uses singleton pattern because: +1. Hatchet client maintains persistent gRPC connections for workflow registration +2. Creating multiple clients would cause registration conflicts and resource leaks +3. The SDK is designed for a single client instance per process +4. Tests use `HatchetClientManager.reset()` to isolate state between tests +""" + +import logging + +from hatchet_sdk import ClientConfig, Hatchet from reflector.logger import logger from reflector.settings import settings class HatchetClientManager: - """Singleton manager for Hatchet client connections.""" + """Singleton manager for Hatchet client connections. + + Singleton pattern is used because Hatchet SDK maintains persistent gRPC + connections for workflow registration, and multiple clients would conflict. + + For testing, use the `reset()` method or the `reset_hatchet_client` fixture + to ensure test isolation. + """ _instance: Hatchet | None = None @classmethod def get_client(cls) -> Hatchet: - """Get or create the Hatchet client.""" + """Get or create the Hatchet client. + + Configures root logger so all logger.info() calls in workflows + appear in the Hatchet dashboard logs. + """ if cls._instance is None: if not settings.HATCHET_CLIENT_TOKEN: raise ValueError("HATCHET_CLIENT_TOKEN must be set") + # Pass root logger to Hatchet so workflow logs appear in dashboard + root_logger = logging.getLogger() cls._instance = Hatchet( debug=settings.HATCHET_DEBUG, + config=ClientConfig(logger=root_logger), ) return cls._instance @classmethod async def start_workflow( - cls, workflow_name: str, input_data: dict, key: str | None = None + cls, + workflow_name: str, + input_data: dict, + additional_metadata: dict | None = None, ) -> str: - """Start a workflow and return the workflow run ID.""" + """Start a workflow and return the workflow run ID. + + Args: + workflow_name: Name of the workflow to trigger. + input_data: Input data for the workflow run. + additional_metadata: Optional metadata for filtering in dashboard + (e.g., transcript_id, recording_id). + """ client = cls.get_client() result = await client.runs.aio_create( workflow_name, input_data, + additional_metadata=additional_metadata, ) # SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id return result.run.metadata.id diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 94c31242..bf15ae46 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -6,9 +6,12 @@ Orchestrates the full processing flow from recording metadata to final transcrip """ import asyncio +import functools import tempfile +from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path +from typing import Callable import av from hatchet_sdk import Context @@ -16,6 +19,20 @@ from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.progress import emit_progress_async +from reflector.hatchet.workflows.models import ( + ConsentResult, + FinalizeResult, + MixdownResult, + ParticipantsResult, + ProcessTracksResult, + RecordingResult, + SummaryResult, + TitleResult, + TopicsResult, + WaveformResult, + WebhookResult, + ZulipResult, +) from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow from reflector.logger import logger @@ -23,6 +40,7 @@ from reflector.logger import logger OPUS_STANDARD_SAMPLE_RATE = 48000 OPUS_DEFAULT_BIT_RATE = 64000 PRESIGNED_URL_EXPIRATION_SECONDS = 7200 +WAVEFORM_SEGMENTS = 255 class PipelineInput(BaseModel): @@ -49,8 +67,9 @@ diarization_pipeline = hatchet.workflow( # ============================================================================ -async def _get_fresh_db_connection(): - """Create fresh database connection for subprocess.""" +@asynccontextmanager +async def fresh_db_connection(): + """Context manager for database connections in Hatchet workers.""" import databases from reflector.db import _database_context @@ -60,22 +79,22 @@ async def _get_fresh_db_connection(): db = databases.Database(settings.DATABASE_URL) _database_context.set(db) await db.connect() - return db - - -async def _close_db_connection(db): - """Close database connection.""" - from reflector.db import _database_context - - await db.disconnect() - _database_context.set(None) - - -async def _set_error_status(transcript_id: str): - """Set transcript status to 'error' on workflow failure (matches Celery line 790).""" try: - db = await _get_fresh_db_connection() - try: + yield db + finally: + await db.disconnect() + _database_context.set(None) + + +async def set_workflow_error_status(transcript_id: str) -> bool: + """Set transcript status to 'error' on workflow failure. + + Returns: + True if status was set successfully, False if failed. + Failure is logged as CRITICAL since it means transcript may be stuck. + """ + try: + async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller await transcripts_controller.set_status(transcript_id, "error") @@ -83,14 +102,15 @@ async def _set_error_status(transcript_id: str): "[Hatchet] Set transcript status to error", transcript_id=transcript_id, ) - finally: - await _close_db_connection(db) + return True except Exception as e: - logger.error( - "[Hatchet] Failed to set error status", + logger.critical( + "[Hatchet] CRITICAL: Failed to set error status - transcript may be stuck in 'processing'", transcript_id=transcript_id, error=str(e), + exc_info=True, ) + return False def _get_storage(): @@ -106,13 +126,57 @@ def _get_storage(): ) +def _to_dict(output) -> dict: + """Convert task output to dict, handling both dict and Pydantic model returns. + + Hatchet SDK returns Pydantic models when tasks have typed return annotations, + but older code expects dicts. This helper normalizes the output. + """ + if isinstance(output, dict): + return output + return output.model_dump() + + +def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable: + """Decorator that handles task failures uniformly. + + Args: + step_name: Name of the step for logging and progress tracking. + set_error_status: Whether to set transcript status to 'error' on failure. + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(input: PipelineInput, ctx: Context): + try: + return await func(input, ctx) + except Exception as e: + logger.error( + f"[Hatchet] {step_name} failed", + transcript_id=input.transcript_id, + error=str(e), + exc_info=True, + ) + if set_error_status: + await set_workflow_error_status(input.transcript_id) + await emit_progress_async( + input.transcript_id, step_name, "failed", ctx.workflow_run_id + ) + raise + + return wrapper + + return decorator + + # ============================================================================ # Pipeline Tasks # ============================================================================ @diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3) -async def get_recording(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("get_recording") +async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: """Fetch recording metadata from Daily.co API.""" logger.info("[Hatchet] get_recording", recording_id=input.recording_id) @@ -120,9 +184,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id ) - # Set transcript status to "processing" at workflow start (matches Celery behavior) - db = await _get_fresh_db_connection() - try: + # Set transcript status to "processing" at workflow start + async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -132,290 +195,244 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict: "[Hatchet] Set transcript status to processing", transcript_id=input.transcript_id, ) - finally: - await _close_db_connection(db) - try: - from reflector.dailyco_api.client import DailyApiClient - from reflector.settings import settings - - if not input.recording_id: - # No recording_id in reprocess path - return minimal data - await emit_progress_async( - input.transcript_id, "get_recording", "completed", ctx.workflow_run_id - ) - return { - "id": None, - "mtg_session_id": None, - "room_name": input.room_name, - "duration": 0, - } - - if not settings.DAILY_API_KEY: - raise ValueError("DAILY_API_KEY not configured") - - async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: - recording = await client.get_recording(input.recording_id) - - logger.info( - "[Hatchet] get_recording complete", - recording_id=input.recording_id, - room_name=recording.room_name, - duration=recording.duration, - ) + from reflector.dailyco_api.client import DailyApiClient + from reflector.settings import settings + if not input.recording_id: + # No recording_id in reprocess path - return minimal data await emit_progress_async( input.transcript_id, "get_recording", "completed", ctx.workflow_run_id ) - - return { - "id": recording.id, - "mtg_session_id": recording.mtgSessionId, - "room_name": recording.room_name, - "duration": recording.duration, - } - - except Exception as e: - logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "get_recording", "failed", ctx.workflow_run_id + return RecordingResult( + id=None, + mtg_session_id=None, + room_name=input.room_name, + duration=0, ) - raise + + if not settings.DAILY_API_KEY: + raise ValueError("DAILY_API_KEY not configured") + + async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: + recording = await client.get_recording(input.recording_id) + + logger.info( + "[Hatchet] get_recording complete", + recording_id=input.recording_id, + room_name=recording.room_name, + duration=recording.duration, + ) + + await emit_progress_async( + input.transcript_id, "get_recording", "completed", ctx.workflow_run_id + ) + + return RecordingResult( + id=recording.id, + mtg_session_id=recording.mtgSessionId, + room_name=recording.room_name, + duration=recording.duration, + ) @diarization_pipeline.task( parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3 ) -async def get_participants(input: PipelineInput, ctx: Context) -> dict: - """Fetch participant list from Daily.co API and update transcript in database. - - Matches Celery's update_participants_from_daily() behavior. - """ +@with_error_handling("get_participants") +async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult: + """Fetch participant list from Daily.co API and update transcript in database.""" logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) await emit_progress_async( input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id ) - try: - recording_data = ctx.task_output(get_recording) - mtg_session_id = recording_data.get("mtg_session_id") + recording_data = _to_dict(ctx.task_output(get_recording)) + mtg_session_id = recording_data.get("mtg_session_id") - from reflector.dailyco_api.client import DailyApiClient - from reflector.settings import settings - from reflector.utils.daily import ( - filter_cam_audio_tracks, - parse_daily_recording_filename, + from reflector.dailyco_api.client import DailyApiClient + from reflector.settings import settings + from reflector.utils.daily import ( + filter_cam_audio_tracks, + parse_daily_recording_filename, + ) + + # Get transcript and reset events/topics/participants + async with fresh_db_connection(): + from reflector.db.transcripts import ( + TranscriptParticipant, + transcripts_controller, ) - # Get transcript and reset events/topics/participants (matches Celery line 599-607) - db = await _get_fresh_db_connection() - try: - from reflector.db.transcripts import ( - TranscriptParticipant, - transcripts_controller, - ) - - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript: - # Reset events/topics/participants (matches Celery line 599-607) - # Note: title NOT cleared - Celery preserves existing titles - await transcripts_controller.update( - transcript, - { - "events": [], - "topics": [], - "participants": [], - }, - ) - - if not mtg_session_id or not settings.DAILY_API_KEY: - await emit_progress_async( - input.transcript_id, - "get_participants", - "completed", - ctx.workflow_run_id, - ) - return { + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + # Reset events/topics/participants + # Note: title NOT cleared - preserves existing titles + await transcripts_controller.update( + transcript, + { + "events": [], + "topics": [], "participants": [], - "num_tracks": len(input.tracks), - "source_language": transcript.source_language - if transcript - else "en", - "target_language": transcript.target_language - if transcript - else "en", - } - - # Fetch participants from Daily API - async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: - participants = await client.get_meeting_participants(mtg_session_id) - - id_to_name = {} - id_to_user_id = {} - for p in participants.data: - if p.user_name: - id_to_name[p.participant_id] = p.user_name - if p.user_id: - id_to_user_id[p.participant_id] = p.user_id - - # Get track keys and filter for cam-audio tracks - track_keys = [t["s3_key"] for t in input.tracks] - cam_audio_keys = filter_cam_audio_tracks(track_keys) - - # Update participants in database (matches Celery lines 568-590) - participants_list = [] - for idx, key in enumerate(cam_audio_keys): - try: - parsed = parse_daily_recording_filename(key) - participant_id = parsed.participant_id - except ValueError as e: - logger.error( - "Failed to parse Daily recording filename", - error=str(e), - key=key, - ) - continue - - default_name = f"Speaker {idx}" - name = id_to_name.get(participant_id, default_name) - user_id = id_to_user_id.get(participant_id) - - participant = TranscriptParticipant( - id=participant_id, speaker=idx, name=name, user_id=user_id - ) - await transcripts_controller.upsert_participant(transcript, participant) - participants_list.append( - { - "participant_id": participant_id, - "user_name": name, - "speaker": idx, - } - ) - - logger.info( - "[Hatchet] get_participants complete", - participant_count=len(participants_list), + }, ) - finally: - await _close_db_connection(db) + if not mtg_session_id or not settings.DAILY_API_KEY: + await emit_progress_async( + input.transcript_id, + "get_participants", + "completed", + ctx.workflow_run_id, + ) + return ParticipantsResult( + participants=[], + num_tracks=len(input.tracks), + source_language=transcript.source_language if transcript else "en", + target_language=transcript.target_language if transcript else "en", + ) - await emit_progress_async( - input.transcript_id, "get_participants", "completed", ctx.workflow_run_id + # Fetch participants from Daily API + async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: + participants = await client.get_meeting_participants(mtg_session_id) + + id_to_name = {} + id_to_user_id = {} + for p in participants.data: + if p.user_name: + id_to_name[p.participant_id] = p.user_name + if p.user_id: + id_to_user_id[p.participant_id] = p.user_id + + # Get track keys and filter for cam-audio tracks + track_keys = [t["s3_key"] for t in input.tracks] + cam_audio_keys = filter_cam_audio_tracks(track_keys) + + # Update participants in database + participants_list = [] + for idx, key in enumerate(cam_audio_keys): + try: + parsed = parse_daily_recording_filename(key) + participant_id = parsed.participant_id + except ValueError as e: + logger.error( + "Failed to parse Daily recording filename", + error=str(e), + key=key, + ) + continue + + default_name = f"Speaker {idx}" + name = id_to_name.get(participant_id, default_name) + user_id = id_to_user_id.get(participant_id) + + participant = TranscriptParticipant( + id=participant_id, speaker=idx, name=name, user_id=user_id + ) + await transcripts_controller.upsert_participant(transcript, participant) + participants_list.append( + { + "participant_id": participant_id, + "user_name": name, + "speaker": idx, + } + ) + + logger.info( + "[Hatchet] get_participants complete", + participant_count=len(participants_list), ) - return { - "participants": participants_list, - "num_tracks": len(input.tracks), - "source_language": transcript.source_language if transcript else "en", - "target_language": transcript.target_language if transcript else "en", - } + await emit_progress_async( + input.transcript_id, "get_participants", "completed", ctx.workflow_run_id + ) - except Exception as e: - logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "get_participants", "failed", ctx.workflow_run_id - ) - raise + return ParticipantsResult( + participants=participants_list, + num_tracks=len(input.tracks), + source_language=transcript.source_language if transcript else "en", + target_language=transcript.target_language if transcript else "en", + ) @diarization_pipeline.task( parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3 ) -async def process_tracks(input: PipelineInput, ctx: Context) -> dict: - """Spawn child workflows for each track (dynamic fan-out). - - Processes pad_track and transcribe_track for each audio track in parallel. - """ +@with_error_handling("process_tracks") +async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult: + """Spawn child workflows for each track (dynamic fan-out).""" logger.info( "[Hatchet] process_tracks", num_tracks=len(input.tracks), transcript_id=input.transcript_id, ) - try: - # Get source_language from get_participants (matches Celery: uses transcript.source_language) - participants_data = ctx.task_output(get_participants) - source_language = participants_data.get("source_language", "en") + participants_data = _to_dict(ctx.task_output(get_participants)) + source_language = participants_data.get("source_language", "en") - # Spawn child workflows for each track with correct language - child_coroutines = [ - track_workflow.aio_run( - TrackInput( - track_index=i, - s3_key=track["s3_key"], - bucket_name=input.bucket_name, - transcript_id=input.transcript_id, - language=source_language, - ) + # Spawn child workflows for each track with correct language + child_coroutines = [ + track_workflow.aio_run( + TrackInput( + track_index=i, + s3_key=track["s3_key"], + bucket_name=input.bucket_name, + transcript_id=input.transcript_id, + language=source_language, ) - for i, track in enumerate(input.tracks) - ] - - # Wait for all child workflows to complete - results = await asyncio.gather(*child_coroutines) - - # Get target_language for later use in detect_topics - target_language = participants_data.get("target_language", "en") - - # Collect all track results - all_words = [] - padded_urls = [] - created_padded_files = set() - - for result in results: - transcribe_result = result.get("transcribe_track", {}) - all_words.extend(transcribe_result.get("words", [])) - - pad_result = result.get("pad_track", {}) - padded_urls.append(pad_result.get("padded_url")) - - # Track padded files for cleanup (matches Celery line 636-637) - track_index = pad_result.get("track_index") - if pad_result.get("size", 0) > 0 and track_index is not None: - # File was created (size > 0 means padding was applied) - storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" - created_padded_files.add(storage_path) - - # Sort words by start time - all_words.sort(key=lambda w: w.get("start", 0)) - - # NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes) - # Mixdown needs the padded files, so we can't delete them here - - logger.info( - "[Hatchet] process_tracks complete", - num_tracks=len(input.tracks), - total_words=len(all_words), ) + for i, track in enumerate(input.tracks) + ] - return { - "all_words": all_words, - "padded_urls": padded_urls, - "word_count": len(all_words), - "num_tracks": len(input.tracks), - "target_language": target_language, - "created_padded_files": list( - created_padded_files - ), # For cleanup after mixdown - } + # Wait for all child workflows to complete + results = await asyncio.gather(*child_coroutines) - except Exception as e: - logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id - ) - raise + # Get target_language for later use in detect_topics + target_language = participants_data.get("target_language", "en") + + # Collect results from each track (don't mutate lists while iterating) + track_words = [] + padded_urls = [] + created_padded_files = set() + + for result in results: + transcribe_result = result.get("transcribe_track", {}) + track_words.append(transcribe_result.get("words", [])) + + pad_result = result.get("pad_track", {}) + padded_urls.append(pad_result.get("padded_url")) + + # Track padded files for cleanup + track_index = pad_result.get("track_index") + if pad_result.get("size", 0) > 0 and track_index is not None: + storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" + created_padded_files.add(storage_path) + + # Merge all words and sort by start time + all_words = [word for words in track_words for word in words] + all_words.sort(key=lambda w: w.get("start", 0)) + + logger.info( + "[Hatchet] process_tracks complete", + num_tracks=len(input.tracks), + total_words=len(all_words), + ) + + return ProcessTracksResult( + all_words=all_words, + padded_urls=padded_urls, + word_count=len(all_words), + num_tracks=len(input.tracks), + target_language=target_language, + created_padded_files=list(created_padded_files), + ) @diarization_pipeline.task( parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3 ) -async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("mixdown_tracks") +async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: """Mix all padded tracks into single audio file using PyAV (same as Celery).""" logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) @@ -423,217 +440,204 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id ) + track_data = _to_dict(ctx.task_output(process_tracks)) + padded_urls = track_data.get("padded_urls", []) + + if not padded_urls: + raise ValueError("No padded tracks to mixdown") + + storage = _get_storage() + + # Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph + from fractions import Fraction + + from av.audio.resampler import AudioResampler + + from reflector.processors import AudioFileWriterProcessor + + valid_urls = [url for url in padded_urls if url] + if not valid_urls: + raise ValueError("No valid padded tracks to mixdown") + + # Determine target sample rate from first track + target_sample_rate = None + for url in valid_urls: + try: + container = av.open(url) + for frame in container.decode(audio=0): + target_sample_rate = frame.sample_rate + break + container.close() + if target_sample_rate: + break + except Exception: + continue + + if not target_sample_rate: + raise ValueError("No decodable audio frames in any track") + + # Build PyAV filter graph: N abuffer -> amix -> aformat -> sink + graph = av.filter.Graph() + inputs = [] + + for idx, url in enumerate(valid_urls): + args = ( + f"time_base=1/{target_sample_rate}:" + f"sample_rate={target_sample_rate}:" + f"sample_fmt=s32:" + f"channel_layout=stereo" + ) + in_ctx = graph.add("abuffer", args=args, name=f"in{idx}") + inputs.append(in_ctx) + + mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix") + fmt = graph.add( + "aformat", + args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}", + name="fmt", + ) + sink = graph.add("abuffersink", name="out") + + for idx, in_ctx in enumerate(inputs): + in_ctx.link_to(mixer, 0, idx) + mixer.link_to(fmt) + fmt.link_to(sink) + graph.configure() + + # Create temp output file + output_path = tempfile.mktemp(suffix=".mp3") + containers = [] + try: - track_data = ctx.task_output(process_tracks) - padded_urls = track_data.get("padded_urls", []) - - if not padded_urls: - raise ValueError("No padded tracks to mixdown") - - storage = _get_storage() - - # Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph - from fractions import Fraction - - from av.audio.resampler import AudioResampler - - from reflector.processors import AudioFileWriterProcessor - - valid_urls = [url for url in padded_urls if url] - if not valid_urls: - raise ValueError("No valid padded tracks to mixdown") - - # Determine target sample rate from first track - target_sample_rate = None + # Open all containers for url in valid_urls: try: - container = av.open(url) - for frame in container.decode(audio=0): - target_sample_rate = frame.sample_rate - break - container.close() - if target_sample_rate: - break - except Exception: - continue - - if not target_sample_rate: - raise ValueError("No decodable audio frames in any track") - - # Build PyAV filter graph: N abuffer -> amix -> aformat -> sink - graph = av.filter.Graph() - inputs = [] - - for idx, url in enumerate(valid_urls): - args = ( - f"time_base=1/{target_sample_rate}:" - f"sample_rate={target_sample_rate}:" - f"sample_fmt=s32:" - f"channel_layout=stereo" - ) - in_ctx = graph.add("abuffer", args=args, name=f"in{idx}") - inputs.append(in_ctx) - - mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix") - fmt = graph.add( - "aformat", - args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}", - name="fmt", - ) - sink = graph.add("abuffersink", name="out") - - for idx, in_ctx in enumerate(inputs): - in_ctx.link_to(mixer, 0, idx) - mixer.link_to(fmt) - fmt.link_to(sink) - graph.configure() - - # Create temp output file - output_path = tempfile.mktemp(suffix=".mp3") - containers = [] - - try: - # Open all containers - for url in valid_urls: - try: - c = av.open( - url, - options={ - "reconnect": "1", - "reconnect_streamed": "1", - "reconnect_delay_max": "5", - }, - ) - containers.append(c) - except Exception as e: - logger.warning( - "[Hatchet] mixdown: failed to open container", - url=url, - error=str(e), - ) - - if not containers: - raise ValueError("Could not open any track containers") - - # Create AudioFileWriterProcessor for MP3 output with duration capture - duration_ms = [0.0] # Mutable container for callback capture - - async def capture_duration(d): - duration_ms[0] = d - - writer = AudioFileWriterProcessor( - path=output_path, on_duration=capture_duration - ) - - decoders = [c.decode(audio=0) for c in containers] - active = [True] * len(decoders) - resamplers = [ - AudioResampler(format="s32", layout="stereo", rate=target_sample_rate) - for _ in decoders - ] - - while any(active): - for i, (dec, is_active) in enumerate(zip(decoders, active)): - if not is_active: - continue - try: - frame = next(dec) - except StopIteration: - active[i] = False - inputs[i].push(None) - continue - - if frame.sample_rate != target_sample_rate: - continue - out_frames = resamplers[i].resample(frame) or [] - for rf in out_frames: - rf.sample_rate = target_sample_rate - rf.time_base = Fraction(1, target_sample_rate) - inputs[i].push(rf) - - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - - # Flush remaining frames - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - - await writer.flush() - - # Duration is captured via callback in milliseconds (from AudioFileWriterProcessor) - - finally: - for c in containers: - try: - c.close() - except Exception: - pass - - # Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3) - file_size = Path(output_path).stat().st_size - storage_path = f"{input.transcript_id}/audio.mp3" - - with open(output_path, "rb") as mixed_file: - await storage.put_file(storage_path, mixed_file) - - Path(output_path).unlink(missing_ok=True) - - # Update transcript with audio_location (matches Celery line 661) - db = await _get_fresh_db_connection() - try: - from reflector.db.transcripts import transcripts_controller - - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript: - await transcripts_controller.update( - transcript, {"audio_location": "storage"} + c = av.open( + url, + options={ + "reconnect": "1", + "reconnect_streamed": "1", + "reconnect_delay_max": "5", + }, + ) + containers.append(c) + except Exception as e: + logger.warning( + "[Hatchet] mixdown: failed to open container", + url=url, + error=str(e), ) - finally: - await _close_db_connection(db) - logger.info( - "[Hatchet] mixdown_tracks uploaded", - key=storage_path, - size=file_size, + if not containers: + raise ValueError("Could not open any track containers") + + # Create AudioFileWriterProcessor for MP3 output with duration capture + duration_ms = [0.0] # Mutable container for callback capture + + async def capture_duration(d): + duration_ms[0] = d + + writer = AudioFileWriterProcessor( + path=output_path, on_duration=capture_duration ) - await emit_progress_async( - input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id - ) + decoders = [c.decode(audio=0) for c in containers] + active = [True] * len(decoders) + resamplers = [ + AudioResampler(format="s32", layout="stereo", rate=target_sample_rate) + for _ in decoders + ] - return { - "audio_key": storage_path, - "duration": duration_ms[ - 0 - ], # Duration in milliseconds from AudioFileWriterProcessor - "tracks_mixed": len(valid_urls), - } + while any(active): + for i, (dec, is_active) in enumerate(zip(decoders, active)): + if not is_active: + continue + try: + frame = next(dec) + except StopIteration: + active[i] = False + inputs[i].push(None) + continue - except Exception as e: - logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id - ) - raise + if frame.sample_rate != target_sample_rate: + continue + out_frames = resamplers[i].resample(frame) or [] + for rf in out_frames: + rf.sample_rate = target_sample_rate + rf.time_base = Fraction(1, target_sample_rate) + inputs[i].push(rf) + + while True: + try: + mixed = sink.pull() + except Exception: + break + mixed.sample_rate = target_sample_rate + mixed.time_base = Fraction(1, target_sample_rate) + await writer.push(mixed) + + # Flush remaining frames + while True: + try: + mixed = sink.pull() + except Exception: + break + mixed.sample_rate = target_sample_rate + mixed.time_base = Fraction(1, target_sample_rate) + await writer.push(mixed) + + await writer.flush() + + # Duration is captured via callback in milliseconds (from AudioFileWriterProcessor) + + finally: + for c in containers: + try: + c.close() + except Exception: + pass + + # Upload mixed file to storage + file_size = Path(output_path).stat().st_size + storage_path = f"{input.transcript_id}/audio.mp3" + + with open(output_path, "rb") as mixed_file: + await storage.put_file(storage_path, mixed_file) + + Path(output_path).unlink(missing_ok=True) + + # Update transcript with audio_location + async with fresh_db_connection(): + from reflector.db.transcripts import transcripts_controller + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + await transcripts_controller.update( + transcript, {"audio_location": "storage"} + ) + + logger.info( + "[Hatchet] mixdown_tracks uploaded", + key=storage_path, + size=file_size, + ) + + await emit_progress_async( + input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id + ) + + return MixdownResult( + audio_key=storage_path, + duration=duration_ms[0], + tracks_mixed=len(valid_urls), + ) @diarization_pipeline.task( parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3 ) -async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("generate_waveform") +async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult: """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery).""" logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id) @@ -641,96 +645,84 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id ) - try: - import httpx + import httpx - from reflector.db.transcripts import TranscriptWaveform, transcripts_controller - from reflector.utils.audio_waveform import get_audio_waveform - - # Cleanup temporary padded S3 files (matches Celery lines 710-725) - # Moved here from process_tracks because mixdown_tracks needs the padded files - track_data = ctx.task_output(process_tracks) - created_padded_files = track_data.get("created_padded_files", []) - if created_padded_files: - logger.info( - f"[Hatchet] Cleaning up {len(created_padded_files)} temporary S3 files" - ) - storage = _get_storage() - cleanup_tasks = [] - for storage_path in created_padded_files: - cleanup_tasks.append(storage.delete_file(storage_path)) - - cleanup_results = await asyncio.gather( - *cleanup_tasks, return_exceptions=True - ) - for storage_path, result in zip(created_padded_files, cleanup_results): - if isinstance(result, Exception): - logger.warning( - "[Hatchet] Failed to cleanup temporary padded track", - storage_path=storage_path, - error=str(result), - ) - - mixdown_data = ctx.task_output(mixdown_tracks) - audio_key = mixdown_data.get("audio_key") + from reflector.db.transcripts import TranscriptWaveform, transcripts_controller + from reflector.utils.audio_waveform import get_audio_waveform + # Cleanup temporary padded S3 files (deferred until after mixdown) + track_data = _to_dict(ctx.task_output(process_tracks)) + created_padded_files = track_data.get("created_padded_files", []) + if created_padded_files: + logger.info( + f"[Hatchet] Cleaning up {len(created_padded_files)} temporary S3 files" + ) storage = _get_storage() - audio_url = await storage.get_file_url( - audio_key, - operation="get_object", - expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, + cleanup_tasks = [] + for storage_path in created_padded_files: + cleanup_tasks.append(storage.delete_file(storage_path)) + + cleanup_results = await asyncio.gather(*cleanup_tasks, return_exceptions=True) + for storage_path, result in zip(created_padded_files, cleanup_results): + if isinstance(result, Exception): + logger.warning( + "[Hatchet] Failed to cleanup temporary padded track", + storage_path=storage_path, + error=str(result), + ) + + mixdown_data = _to_dict(ctx.task_output(mixdown_tracks)) + audio_key = mixdown_data.get("audio_key") + + storage = _get_storage() + audio_url = await storage.get_file_url( + audio_key, + operation="get_object", + expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, + ) + + # Download MP3 to temp file (AudioWaveformProcessor needs local file) + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: + temp_path = temp_file.name + + try: + async with httpx.AsyncClient() as client: + response = await client.get(audio_url, timeout=120) + response.raise_for_status() + with open(temp_path, "wb") as f: + f.write(response.content) + + # Generate waveform + waveform = get_audio_waveform( + path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS ) - # Download MP3 to temp file (AudioWaveformProcessor needs local file) - with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: - temp_path = temp_file.name + # Save waveform to database via event + async with fresh_db_connection(): + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + waveform_data = TranscriptWaveform(waveform=waveform) + await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform_data + ) - try: - async with httpx.AsyncClient() as client: - response = await client.get(audio_url, timeout=120) - response.raise_for_status() - with open(temp_path, "wb") as f: - f.write(response.content) + finally: + Path(temp_path).unlink(missing_ok=True) - # Generate waveform (matches Celery: get_audio_waveform with 255 segments) - waveform = get_audio_waveform(path=Path(temp_path), segments_count=255) + logger.info("[Hatchet] generate_waveform complete") - # Save waveform to database via event (matches Celery on_waveform callback) - db = await _get_fresh_db_connection() - try: - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript: - waveform_data = TranscriptWaveform(waveform=waveform) - await transcripts_controller.append_event( - transcript=transcript, event="WAVEFORM", data=waveform_data - ) - finally: - await _close_db_connection(db) + await emit_progress_async( + input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id + ) - finally: - Path(temp_path).unlink(missing_ok=True) - - logger.info("[Hatchet] generate_waveform complete") - - await emit_progress_async( - input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id - ) - - return {"waveform_generated": True} - - except Exception as e: - logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id - ) - raise + return WaveformResult(waveform_generated=True) @diarization_pipeline.task( parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3 ) -async def detect_topics(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("detect_topics") +async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: """Detect topics using LLM and save to database (matches Celery on_topic callback).""" logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) @@ -738,79 +730,66 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id ) - try: - track_data = ctx.task_output(process_tracks) - words = track_data.get("all_words", []) - target_language = track_data.get("target_language", "en") + track_data = _to_dict(ctx.task_output(process_tracks)) + words = track_data.get("all_words", []) + target_language = track_data.get("target_language", "en") - from reflector.db.transcripts import TranscriptTopic, transcripts_controller - from reflector.pipelines import topic_processing - from reflector.processors.types import ( - TitleSummaryWithId as TitleSummaryWithIdProcessorType, - ) - from reflector.processors.types import Transcript as TranscriptType - from reflector.processors.types import Word + from reflector.db.transcripts import TranscriptTopic, transcripts_controller + from reflector.pipelines import topic_processing + from reflector.processors.types import ( + TitleSummaryWithId as TitleSummaryWithIdProcessorType, + ) + from reflector.processors.types import Transcript as TranscriptType + from reflector.processors.types import Word - # Convert word dicts to Word objects - word_objects = [Word(**w) for w in words] - transcript_type = TranscriptType(words=word_objects) + # Convert word dicts to Word objects + word_objects = [Word(**w) for w in words] + transcript_type = TranscriptType(words=word_objects) - empty_pipeline = topic_processing.EmptyPipeline(logger=logger) + empty_pipeline = topic_processing.EmptyPipeline(logger=logger) - # Get DB connection for callbacks - db = await _get_fresh_db_connection() + async with fresh_db_connection(): + transcript = await transcripts_controller.get_by_id(input.transcript_id) - try: - transcript = await transcripts_controller.get_by_id(input.transcript_id) - - # Callback that upserts topics to DB (matches Celery on_topic) - async def on_topic_callback(data): - topic = TranscriptTopic( - title=data.title, - summary=data.summary, - timestamp=data.timestamp, - transcript=data.transcript.text, - words=data.transcript.words, - ) - if isinstance(data, TitleSummaryWithIdProcessorType): - topic.id = data.id - await transcripts_controller.upsert_topic(transcript, topic) - await transcripts_controller.append_event( - transcript=transcript, event="TOPIC", data=topic - ) - - topics = await topic_processing.detect_topics( - transcript_type, - target_language, - on_topic_callback=on_topic_callback, - empty_pipeline=empty_pipeline, + # Callback that upserts topics to DB + async def on_topic_callback(data): + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + transcript=data.transcript.text, + words=data.transcript.words, + ) + if isinstance(data, TitleSummaryWithIdProcessorType): + topic.id = data.id + await transcripts_controller.upsert_topic(transcript, topic) + await transcripts_controller.append_event( + transcript=transcript, event="TOPIC", data=topic ) - finally: - await _close_db_connection(db) - topics_list = [t.model_dump() for t in topics] - - logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list)) - - await emit_progress_async( - input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id + topics = await topic_processing.detect_topics( + transcript_type, + target_language, + on_topic_callback=on_topic_callback, + empty_pipeline=empty_pipeline, ) - return {"topics": topics_list} + topics_list = [t.model_dump() for t in topics] - except Exception as e: - logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id - ) - raise + logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list)) + + await emit_progress_async( + input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id + ) + + return TopicsResult(topics=topics_list) @diarization_pipeline.task( parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3 ) -async def generate_title(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("generate_title") +async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: """Generate meeting title using LLM and save to database (matches Celery on_title callback).""" logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) @@ -818,70 +797,59 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id ) - try: - topics_data = ctx.task_output(detect_topics) - topics = topics_data.get("topics", []) + topics_data = _to_dict(ctx.task_output(detect_topics)) + topics = topics_data.get("topics", []) - from reflector.db.transcripts import ( - TranscriptFinalTitle, - transcripts_controller, - ) - from reflector.pipelines import topic_processing - from reflector.processors.types import TitleSummary + from reflector.db.transcripts import ( + TranscriptFinalTitle, + transcripts_controller, + ) + from reflector.pipelines import topic_processing + from reflector.processors.types import TitleSummary - topic_objects = [TitleSummary(**t) for t in topics] + topic_objects = [TitleSummary(**t) for t in topics] - empty_pipeline = topic_processing.EmptyPipeline(logger=logger) - title_result = None + empty_pipeline = topic_processing.EmptyPipeline(logger=logger) + title_result = None - db = await _get_fresh_db_connection() - try: - transcript = await transcripts_controller.get_by_id(input.transcript_id) + async with fresh_db_connection(): + transcript = await transcripts_controller.get_by_id(input.transcript_id) - # Callback that updates title in DB (matches Celery on_title) - async def on_title_callback(data): - nonlocal title_result - title_result = data.title - final_title = TranscriptFinalTitle(title=data.title) - if not transcript.title: - await transcripts_controller.update( - transcript, - {"title": final_title.title}, - ) - await transcripts_controller.append_event( - transcript=transcript, event="FINAL_TITLE", data=final_title + # Callback that updates title in DB + async def on_title_callback(data): + nonlocal title_result + title_result = data.title + final_title = TranscriptFinalTitle(title=data.title) + if not transcript.title: + await transcripts_controller.update( + transcript, + {"title": final_title.title}, ) - - await topic_processing.generate_title( - topic_objects, - on_title_callback=on_title_callback, - empty_pipeline=empty_pipeline, - logger=logger, + await transcripts_controller.append_event( + transcript=transcript, event="FINAL_TITLE", data=final_title ) - finally: - await _close_db_connection(db) - logger.info("[Hatchet] generate_title complete", title=title_result) - - await emit_progress_async( - input.transcript_id, "generate_title", "completed", ctx.workflow_run_id + await topic_processing.generate_title( + topic_objects, + on_title_callback=on_title_callback, + empty_pipeline=empty_pipeline, + logger=logger, ) - return {"title": title_result} + logger.info("[Hatchet] generate_title complete", title=title_result) - except Exception as e: - logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_title", "failed", ctx.workflow_run_id - ) - raise + await emit_progress_async( + input.transcript_id, "generate_title", "completed", ctx.workflow_run_id + ) + + return TitleResult(title=title_result) @diarization_pipeline.task( parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3 ) -async def generate_summary(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("generate_summary") +async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: """Generate meeting summary using LLM and save to database (matches Celery callbacks).""" logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) @@ -889,88 +857,76 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id ) - try: - topics_data = ctx.task_output(detect_topics) - topics = topics_data.get("topics", []) + topics_data = _to_dict(ctx.task_output(detect_topics)) + topics = topics_data.get("topics", []) - from reflector.db.transcripts import ( - TranscriptFinalLongSummary, - TranscriptFinalShortSummary, - transcripts_controller, - ) - from reflector.pipelines import topic_processing - from reflector.processors.types import TitleSummary + from reflector.db.transcripts import ( + TranscriptFinalLongSummary, + TranscriptFinalShortSummary, + transcripts_controller, + ) + from reflector.pipelines import topic_processing + from reflector.processors.types import TitleSummary - topic_objects = [TitleSummary(**t) for t in topics] + topic_objects = [TitleSummary(**t) for t in topics] - empty_pipeline = topic_processing.EmptyPipeline(logger=logger) - summary_result = None - short_summary_result = None + empty_pipeline = topic_processing.EmptyPipeline(logger=logger) + summary_result = None + short_summary_result = None - db = await _get_fresh_db_connection() - try: - transcript = await transcripts_controller.get_by_id(input.transcript_id) + async with fresh_db_connection(): + transcript = await transcripts_controller.get_by_id(input.transcript_id) - # Callback that updates long_summary in DB (matches Celery on_long_summary) - async def on_long_summary_callback(data): - nonlocal summary_result - summary_result = data.long_summary - final_long_summary = TranscriptFinalLongSummary( - long_summary=data.long_summary - ) - await transcripts_controller.update( - transcript, - {"long_summary": final_long_summary.long_summary}, - ) - await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_LONG_SUMMARY", - data=final_long_summary, - ) - - # Callback that updates short_summary in DB (matches Celery on_short_summary) - async def on_short_summary_callback(data): - nonlocal short_summary_result - short_summary_result = data.short_summary - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - await transcripts_controller.update( - transcript, - {"short_summary": final_short_summary.short_summary}, - ) - await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_SHORT_SUMMARY", - data=final_short_summary, - ) - - await topic_processing.generate_summaries( - topic_objects, - transcript, # DB transcript for context - on_long_summary_callback=on_long_summary_callback, - on_short_summary_callback=on_short_summary_callback, - empty_pipeline=empty_pipeline, - logger=logger, + # Callback that updates long_summary in DB + async def on_long_summary_callback(data): + nonlocal summary_result + summary_result = data.long_summary + final_long_summary = TranscriptFinalLongSummary( + long_summary=data.long_summary + ) + await transcripts_controller.update( + transcript, + {"long_summary": final_long_summary.long_summary}, + ) + await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, ) - finally: - await _close_db_connection(db) - logger.info("[Hatchet] generate_summary complete") + # Callback that updates short_summary in DB + async def on_short_summary_callback(data): + nonlocal short_summary_result + short_summary_result = data.short_summary + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + await transcripts_controller.update( + transcript, + {"short_summary": final_short_summary.short_summary}, + ) + await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) - await emit_progress_async( - input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id + await topic_processing.generate_summaries( + topic_objects, + transcript, # DB transcript for context + on_long_summary_callback=on_long_summary_callback, + on_short_summary_callback=on_short_summary_callback, + empty_pipeline=empty_pipeline, + logger=logger, ) - return {"summary": summary_result, "short_summary": short_summary_result} + logger.info("[Hatchet] generate_summary complete") - except Exception as e: - logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id - ) - raise + await emit_progress_async( + input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id + ) + + return SummaryResult(summary=summary_result, short_summary=short_summary_result) @diarization_pipeline.task( @@ -978,7 +934,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict: execution_timeout=timedelta(seconds=60), retries=3, ) -async def finalize(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("finalize") +async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: """Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'. Matches Celery's on_transcript + set_status behavior. @@ -990,81 +947,64 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id ) - try: - mixdown_data = ctx.task_output(mixdown_tracks) - track_data = ctx.task_output(process_tracks) + mixdown_data = _to_dict(ctx.task_output(mixdown_tracks)) + track_data = _to_dict(ctx.task_output(process_tracks)) - duration = mixdown_data.get("duration", 0) - all_words = track_data.get("all_words", []) + duration = mixdown_data.get("duration", 0) + all_words = track_data.get("all_words", []) - db = await _get_fresh_db_connection() + async with fresh_db_connection(): + from reflector.db.transcripts import TranscriptText, transcripts_controller + from reflector.processors.types import Transcript as TranscriptType + from reflector.processors.types import Word - try: - from reflector.db.transcripts import TranscriptText, transcripts_controller - from reflector.processors.types import Transcript as TranscriptType - from reflector.processors.types import Word + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript is None: + raise ValueError(f"Transcript {input.transcript_id} not found in database") - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript is None: - raise ValueError( - f"Transcript {input.transcript_id} not found in database" - ) + # Convert words back to Word objects for storage + word_objects = [Word(**w) for w in all_words] - # Convert words back to Word objects for storage - word_objects = [Word(**w) for w in all_words] + # Create merged transcript for TRANSCRIPT event + merged_transcript = TranscriptType(words=word_objects, translation=None) - # Create merged transcript for TRANSCRIPT event (matches Celery line 734-736) - merged_transcript = TranscriptType(words=word_objects, translation=None) - - # Emit TRANSCRIPT event (matches Celery on_transcript callback) - await transcripts_controller.append_event( - transcript=transcript, - event="TRANSCRIPT", - data=TranscriptText( - text=merged_transcript.text, - translation=merged_transcript.translation, - ), - ) - - # Save duration and clear workflow_run_id (workflow completed successfully) - # Note: title/long_summary/short_summary already saved by their callbacks - await transcripts_controller.update( - transcript, - { - "duration": duration, - "workflow_run_id": None, # Clear on success - no need to resume - }, - ) - - # Set status to "ended" (matches Celery line 745) - await transcripts_controller.set_status(input.transcript_id, "ended") - - logger.info( - "[Hatchet] finalize complete", transcript_id=input.transcript_id - ) - - finally: - await _close_db_connection(db) - - await emit_progress_async( - input.transcript_id, "finalize", "completed", ctx.workflow_run_id + # Emit TRANSCRIPT event + await transcripts_controller.append_event( + transcript=transcript, + event="TRANSCRIPT", + data=TranscriptText( + text=merged_transcript.text, + translation=merged_transcript.translation, + ), ) - return {"status": "COMPLETED"} - - except Exception as e: - logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True) - await _set_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, "finalize", "failed", ctx.workflow_run_id + # Save duration and clear workflow_run_id (workflow completed successfully) + # Note: title/long_summary/short_summary already saved by their callbacks + await transcripts_controller.update( + transcript, + { + "duration": duration, + "workflow_run_id": None, # Clear on success - no need to resume + }, ) - raise + + # Set status to "ended" + await transcripts_controller.set_status(input.transcript_id, "ended") + + logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id) + + await emit_progress_async( + input.transcript_id, "finalize", "completed", ctx.workflow_run_id + ) + + return FinalizeResult(status="COMPLETED") @diarization_pipeline.task( parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3 ) -async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("cleanup_consent", set_error_status=False) +async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: """Check and handle consent requirements.""" logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id) @@ -1072,46 +1012,34 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id ) - try: - db = await _get_fresh_db_connection() + async with fresh_db_connection(): + from reflector.db.meetings import meetings_controller + from reflector.db.transcripts import transcripts_controller - try: - from reflector.db.meetings import meetings_controller - from reflector.db.transcripts import transcripts_controller + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript and transcript.meeting_id: + meeting = await meetings_controller.get_by_id(transcript.meeting_id) + if meeting: + # Check consent logic here + # For now just mark as checked + pass - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript and transcript.meeting_id: - meeting = await meetings_controller.get_by_id(transcript.meeting_id) - if meeting: - # Check consent logic here - # For now just mark as checked - pass - - logger.info( - "[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id - ) - - finally: - await _close_db_connection(db) - - await emit_progress_async( - input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id + logger.info( + "[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id ) - return {"consent_checked": True} + await emit_progress_async( + input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id + ) - except Exception as e: - logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True) - await emit_progress_async( - input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id - ) - raise + return ConsentResult(consent_checked=True) @diarization_pipeline.task( parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5 ) -async def post_zulip(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("post_zulip", set_error_status=False) +async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: """Post notification to Zulip.""" logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id) @@ -1119,53 +1047,39 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id ) - try: - from reflector.settings import settings - - if not settings.ZULIP_REALM: - logger.info("[Hatchet] post_zulip skipped (Zulip not configured)") - await emit_progress_async( - input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id - ) - return {"zulip_message_id": None, "skipped": True} - - from reflector.zulip import post_transcript_notification - - db = await _get_fresh_db_connection() - - try: - from reflector.db.transcripts import transcripts_controller - - transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript: - message_id = await post_transcript_notification(transcript) - logger.info( - "[Hatchet] post_zulip complete", zulip_message_id=message_id - ) - else: - message_id = None - - finally: - await _close_db_connection(db) + from reflector.settings import settings + if not settings.ZULIP_REALM: + logger.info("[Hatchet] post_zulip skipped (Zulip not configured)") await emit_progress_async( input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id ) + return ZulipResult(zulip_message_id=None, skipped=True) - return {"zulip_message_id": message_id} + from reflector.zulip import post_transcript_notification - except Exception as e: - logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True) - await emit_progress_async( - input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id - ) - raise + async with fresh_db_connection(): + from reflector.db.transcripts import transcripts_controller + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + message_id = await post_transcript_notification(transcript) + logger.info("[Hatchet] post_zulip complete", zulip_message_id=message_id) + else: + message_id = None + + await emit_progress_async( + input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id + ) + + return ZulipResult(zulip_message_id=message_id) @diarization_pipeline.task( parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30 ) -async def send_webhook(input: PipelineInput, ctx: Context) -> dict: +@with_error_handling("send_webhook", set_error_status=False) +async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: """Send completion webhook to external service.""" logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id) @@ -1173,64 +1087,51 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id ) - try: - if not input.room_id: - logger.info("[Hatchet] send_webhook skipped (no room_id)") - await emit_progress_async( - input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id - ) - return {"webhook_sent": False, "skipped": True} - - db = await _get_fresh_db_connection() - - try: - from reflector.db.rooms import rooms_controller - from reflector.db.transcripts import transcripts_controller - - room = await rooms_controller.get_by_id(input.room_id) - transcript = await transcripts_controller.get_by_id(input.transcript_id) - - if room and room.webhook_url and transcript: - import httpx - - webhook_payload = { - "event": "transcript.completed", - "transcript_id": input.transcript_id, - "title": transcript.title, - "duration": transcript.duration, - } - - async with httpx.AsyncClient() as client: - response = await client.post( - room.webhook_url, json=webhook_payload, timeout=30 - ) - response.raise_for_status() - - logger.info( - "[Hatchet] send_webhook complete", status_code=response.status_code - ) - - await emit_progress_async( - input.transcript_id, - "send_webhook", - "completed", - ctx.workflow_run_id, - ) - - return {"webhook_sent": True, "response_code": response.status_code} - - finally: - await _close_db_connection(db) - + if not input.room_id: + logger.info("[Hatchet] send_webhook skipped (no room_id)") await emit_progress_async( input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id ) + return WebhookResult(webhook_sent=False, skipped=True) - return {"webhook_sent": False, "skipped": True} + async with fresh_db_connection(): + from reflector.db.rooms import rooms_controller + from reflector.db.transcripts import transcripts_controller - except Exception as e: - logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True) - await emit_progress_async( - input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id - ) - raise + room = await rooms_controller.get_by_id(input.room_id) + transcript = await transcripts_controller.get_by_id(input.transcript_id) + + if room and room.webhook_url and transcript: + import httpx + + webhook_payload = { + "event": "transcript.completed", + "transcript_id": input.transcript_id, + "title": transcript.title, + "duration": transcript.duration, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + room.webhook_url, json=webhook_payload, timeout=30 + ) + response.raise_for_status() + + logger.info( + "[Hatchet] send_webhook complete", status_code=response.status_code + ) + + await emit_progress_async( + input.transcript_id, + "send_webhook", + "completed", + ctx.workflow_run_id, + ) + + return WebhookResult(webhook_sent=True, response_code=response.status_code) + + await emit_progress_async( + input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id + ) + + return WebhookResult(webhook_sent=False, skipped=True) diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py new file mode 100644 index 00000000..9011bc86 --- /dev/null +++ b/server/reflector/hatchet/workflows/models.py @@ -0,0 +1,123 @@ +""" +Pydantic models for Hatchet workflow task return types. + +Provides static typing for all task outputs, enabling type checking +and better IDE support. +""" + +from typing import Any + +from pydantic import BaseModel + +# ============================================================================ +# Track Processing Results (track_processing.py) +# ============================================================================ + + +class PadTrackResult(BaseModel): + """Result from pad_track task.""" + + padded_url: str + size: int + track_index: int + + +class TranscribeTrackResult(BaseModel): + """Result from transcribe_track task.""" + + words: list[dict[str, Any]] + track_index: int + + +# ============================================================================ +# Diarization Pipeline Results (diarization_pipeline.py) +# ============================================================================ + + +class RecordingResult(BaseModel): + """Result from get_recording task.""" + + id: str | None + mtg_session_id: str | None + room_name: str | None + duration: float + + +class ParticipantsResult(BaseModel): + """Result from get_participants task.""" + + participants: list[dict[str, Any]] + num_tracks: int + source_language: str + target_language: str + + +class ProcessTracksResult(BaseModel): + """Result from process_tracks task.""" + + all_words: list[dict[str, Any]] + padded_urls: list[str | None] + word_count: int + num_tracks: int + target_language: str + created_padded_files: list[str] + + +class MixdownResult(BaseModel): + """Result from mixdown_tracks task.""" + + audio_key: str + duration: float + tracks_mixed: int + + +class WaveformResult(BaseModel): + """Result from generate_waveform task.""" + + waveform_generated: bool + + +class TopicsResult(BaseModel): + """Result from detect_topics task.""" + + topics: list[dict[str, Any]] + + +class TitleResult(BaseModel): + """Result from generate_title task.""" + + title: str | None + + +class SummaryResult(BaseModel): + """Result from generate_summary task.""" + + summary: str | None + short_summary: str | None + + +class FinalizeResult(BaseModel): + """Result from finalize task.""" + + status: str + + +class ConsentResult(BaseModel): + """Result from cleanup_consent task.""" + + consent_checked: bool + + +class ZulipResult(BaseModel): + """Result from post_zulip task.""" + + zulip_message_id: int | None = None + skipped: bool = False + + +class WebhookResult(BaseModel): + """Result from send_webhook task.""" + + webhook_sent: bool + skipped: bool = False + response_code: int | None = None diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index 304d6b37..c5f5ac4f 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -18,8 +18,17 @@ from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.progress import emit_progress_async +from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult from reflector.logger import logger + +def _to_dict(output) -> dict: + """Convert task output to dict, handling both dict and Pydantic model returns.""" + if isinstance(output, dict): + return output + return output.model_dump() + + # Audio constants matching existing pipeline OPUS_STANDARD_SAMPLE_RATE = 48000 OPUS_DEFAULT_BIT_RATE = 64000 @@ -161,7 +170,7 @@ def _apply_audio_padding_to_file( @track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3) -async def pad_track(input: TrackInput, ctx: Context) -> dict: +async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: """Pad single audio track with silence for alignment. Extracts stream.start_time from WebM container metadata and applies @@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict: await emit_progress_async( input.transcript_id, "pad_track", "completed", ctx.workflow_run_id ) - return { - "padded_url": source_url, - "size": 0, - "track_index": input.track_index, - } + return PadTrackResult( + padded_url=source_url, + size=0, + track_index=input.track_index, + ) # Create temp file for padded output with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file: @@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict: input.transcript_id, "pad_track", "completed", ctx.workflow_run_id ) - return { - "padded_url": padded_url, - "size": file_size, - "track_index": input.track_index, - } + return PadTrackResult( + padded_url=padded_url, + size=file_size, + track_index=input.track_index, + ) except Exception as e: logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True) @@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict: @track_workflow.task( parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3 ) -async def transcribe_track(input: TrackInput, ctx: Context) -> dict: +async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult: """Transcribe audio track using GPU (Modal.com) or local Whisper.""" logger.info( "[Hatchet] transcribe_track", @@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict: ) try: - pad_result = ctx.task_output(pad_track) + pad_result = _to_dict(ctx.task_output(pad_track)) audio_url = pad_result.get("padded_url") if not audio_url: @@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict: input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id ) - return { - "words": words, - "track_index": input.track_index, - } + return TranscribeTrackResult( + words=words, + track_index=input.track_index, + ) except Exception as e: logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True) diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py index 06f2e6d6..62d0d30f 100644 --- a/server/reflector/services/transcript_process.py +++ b/server/reflector/services/transcript_process.py @@ -224,6 +224,26 @@ def dispatch_transcript_processing( transcript, {"workflow_run_id": None} ) + # Re-fetch transcript to check for concurrent dispatch (TOCTOU protection) + transcript = await transcripts_controller.get_by_id( + config.transcript_id + ) + if transcript and transcript.workflow_run_id: + # Another process started a workflow between validation and now + try: + status = await HatchetClientManager.get_workflow_run_status( + transcript.workflow_run_id + ) + if "RUNNING" in status or "QUEUED" in status: + logger.info( + "Concurrent workflow detected, skipping dispatch", + workflow_id=transcript.workflow_run_id, + ) + return transcript.workflow_run_id + except Exception: + # If we can't get status, proceed with new workflow + pass + workflow_id = await HatchetClientManager.start_workflow( workflow_name="DiarizationPipeline", input_data={ @@ -234,6 +254,11 @@ def dispatch_transcript_processing( "transcript_id": config.transcript_id, "room_id": config.room_id, }, + additional_metadata={ + "transcript_id": config.transcript_id, + "recording_id": config.recording_id, + "daily_recording_id": config.recording_id, + }, ) if transcript: diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index 4309b486..043684e5 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner( "transcript_id": transcript.id, "room_id": room.id, }, + additional_metadata={ + "transcript_id": transcript.id, + "recording_id": recording_id, + "daily_recording_id": recording_id, + }, ) logger.info( "Started Hatchet workflow", diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 2931a0c2..24d2103f 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -527,6 +527,22 @@ def fake_mp3_upload(): yield +@pytest.fixture(autouse=True) +def reset_hatchet_client(): + """Reset HatchetClientManager singleton before and after each test. + + This ensures test isolation - each test starts with a fresh client state. + The fixture is autouse=True so it applies to all tests automatically. + """ + from reflector.hatchet.client import HatchetClientManager + + # Reset before test + HatchetClientManager.reset() + yield + # Reset after test to clean up + HatchetClientManager.reset() + + @pytest.fixture async def fake_transcript_with_topics(tmpdir, client): import shutil diff --git a/server/tests/test_hatchet_client.py b/server/tests/test_hatchet_client.py index 8336440c..0e04e36a 100644 --- a/server/tests/test_hatchet_client.py +++ b/server/tests/test_hatchet_client.py @@ -2,6 +2,9 @@ Tests for HatchetClientManager error handling and validation. Only tests that catch real bugs - not mock verification tests. + +Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py) +automatically resets the singleton before and after each test. """ from unittest.mock import AsyncMock, MagicMock, patch @@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception(): """ from reflector.hatchet.client import HatchetClientManager - HatchetClientManager._instance = None - with patch("reflector.hatchet.client.settings") as mock_settings: mock_settings.HATCHET_CLIENT_TOKEN = "test-token" mock_settings.HATCHET_DEBUG = False @@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception(): # Should return False on error (workflow might be gone) assert can_replay is False - HatchetClientManager._instance = None - def test_hatchet_client_raises_without_token(): """Test that get_client raises ValueError without token. @@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token(): """ from reflector.hatchet.client import HatchetClientManager - HatchetClientManager._instance = None - with patch("reflector.hatchet.client.settings") as mock_settings: mock_settings.HATCHET_CLIENT_TOKEN = None with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"): HatchetClientManager.get_client() - - HatchetClientManager._instance = None