self-review (no-mistakes)

This commit is contained in:
Igor Loskutov
2025-12-16 16:04:52 -05:00
parent e81e0cb5c3
commit fce0945564
9 changed files with 1034 additions and 1041 deletions

115
TASKS.md
View File

@@ -1,115 +0,0 @@
# Durable Workflow Migration Tasks
This document defines atomic, isolated work items for migrating the Daily.co multitrack diarization pipeline from Celery to durable workflow orchestration using **Hatchet**.
---
## Provider Selection
```bash
# .env
DURABLE_WORKFLOW_PROVIDER=none # Celery only (default)
DURABLE_WORKFLOW_PROVIDER=hatchet # Use Hatchet
DURABLE_WORKFLOW_SHADOW_MODE=true # Run both Hatchet + Celery (for comparison)
```
---
## Task Index
| ID | Title | Status |
|----|-------|--------|
| INFRA-001 | Add container to docker-compose | Done |
| INFRA-002 | Create Python client wrapper | Done |
| INFRA-003 | Add environment configuration | Done |
| TASK-001 | Create workflow definition | Done |
| TASK-002 | get_recording task | Done |
| TASK-003 | get_participants task | Done |
| TASK-004 | pad_track task | Done |
| TASK-005 | mixdown_tracks task | Done |
| TASK-006 | generate_waveform task | Done |
| TASK-007 | transcribe_track task | Done |
| TASK-008 | merge_transcripts task | Done (in process_tracks) |
| TASK-009 | detect_topics task | Done |
| TASK-010 | generate_title task | Done |
| TASK-011 | generate_summary task | Done |
| TASK-012 | finalize task | Done |
| TASK-013 | cleanup_consent task | Done |
| TASK-014 | post_zulip task | Done |
| TASK-015 | send_webhook task | Done |
| EVENT-001 | Progress WebSocket events | Done |
| INTEG-001 | Pipeline trigger integration | Done |
| SHADOW-001 | Shadow mode toggle | Done |
| TEST-001 | Integration tests | Pending |
| TEST-002 | E2E workflow test | Pending |
| CUTOVER-001 | Production cutover | Pending |
| CLEANUP-001 | Remove Celery code | Pending |
---
## File Structure
```
server/reflector/hatchet/
├── client.py # SDK wrapper
├── progress.py # WebSocket progress emission
├── run_workers.py # Worker startup
└── workflows/
├── diarization_pipeline.py # Main workflow with all tasks
└── track_processing.py # Child workflow (pad + transcribe)
```
---
## Remaining Work
### TEST-001: Integration Tests
- [ ] Test each task with mocked external services
- [ ] Test error handling and retries
### TEST-002: E2E Workflow Test
- [ ] Complete workflow run with real Daily.co recording
- [ ] Verify output matches Celery pipeline
- [ ] Performance comparison
### CUTOVER-001: Production Cutover
- [ ] Deploy with `DURABLE_WORKFLOW_PROVIDER=hatchet`
- [ ] Monitor for failures
- [ ] Compare results with shadow mode if needed
### CLEANUP-001: Remove Celery Code
- [ ] Remove `main_multitrack_pipeline.py`
- [ ] Remove Celery task triggers
- [ ] Update documentation
---
## Known Issues
### Hatchet
- See `HATCHET_LLM_OBSERVATIONS.md` for debugging notes
- SDK v1.21+ API changes (breaking)
- JWT token Docker networking issues
- Worker appears hung without debug mode
- Workflow replay is version-locked (use --force to run latest code)
---
## Quick Start
### Hatchet
```bash
# Start infrastructure
docker compose up -d hatchet hatchet-worker
# Workers auto-register on startup
```
### Trigger Workflow
```bash
# Set provider in .env
DURABLE_WORKFLOW_PROVIDER=hatchet
# Process a Daily.co recording via webhook or API
# The pipeline trigger automatically uses the configured provider
```

View File

@@ -1,37 +1,71 @@
"""Hatchet Python client wrapper.""" """Hatchet Python client wrapper.
from hatchet_sdk import Hatchet Uses singleton pattern because:
1. Hatchet client maintains persistent gRPC connections for workflow registration
2. Creating multiple clients would cause registration conflicts and resource leaks
3. The SDK is designed for a single client instance per process
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
"""
import logging
from hatchet_sdk import ClientConfig, Hatchet
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
class HatchetClientManager: class HatchetClientManager:
"""Singleton manager for Hatchet client connections.""" """Singleton manager for Hatchet client connections.
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
connections for workflow registration, and multiple clients would conflict.
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
to ensure test isolation.
"""
_instance: Hatchet | None = None _instance: Hatchet | None = None
@classmethod @classmethod
def get_client(cls) -> Hatchet: def get_client(cls) -> Hatchet:
"""Get or create the Hatchet client.""" """Get or create the Hatchet client.
Configures root logger so all logger.info() calls in workflows
appear in the Hatchet dashboard logs.
"""
if cls._instance is None: if cls._instance is None:
if not settings.HATCHET_CLIENT_TOKEN: if not settings.HATCHET_CLIENT_TOKEN:
raise ValueError("HATCHET_CLIENT_TOKEN must be set") raise ValueError("HATCHET_CLIENT_TOKEN must be set")
# Pass root logger to Hatchet so workflow logs appear in dashboard
root_logger = logging.getLogger()
cls._instance = Hatchet( cls._instance = Hatchet(
debug=settings.HATCHET_DEBUG, debug=settings.HATCHET_DEBUG,
config=ClientConfig(logger=root_logger),
) )
return cls._instance return cls._instance
@classmethod @classmethod
async def start_workflow( async def start_workflow(
cls, workflow_name: str, input_data: dict, key: str | None = None cls,
workflow_name: str,
input_data: dict,
additional_metadata: dict | None = None,
) -> str: ) -> str:
"""Start a workflow and return the workflow run ID.""" """Start a workflow and return the workflow run ID.
Args:
workflow_name: Name of the workflow to trigger.
input_data: Input data for the workflow run.
additional_metadata: Optional metadata for filtering in dashboard
(e.g., transcript_id, recording_id).
"""
client = cls.get_client() client = cls.get_client()
result = await client.runs.aio_create( result = await client.runs.aio_create(
workflow_name, workflow_name,
input_data, input_data,
additional_metadata=additional_metadata,
) )
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id # SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id return result.run.metadata.id

View File

@@ -6,9 +6,12 @@ Orchestrates the full processing flow from recording metadata to final transcrip
""" """
import asyncio import asyncio
import functools
import tempfile import tempfile
from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable
import av import av
from hatchet_sdk import Context from hatchet_sdk import Context
@@ -16,6 +19,20 @@ from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.models import (
ConsentResult,
FinalizeResult,
MixdownResult,
ParticipantsResult,
ProcessTracksResult,
RecordingResult,
SummaryResult,
TitleResult,
TopicsResult,
WaveformResult,
WebhookResult,
ZulipResult,
)
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
from reflector.logger import logger from reflector.logger import logger
@@ -23,6 +40,7 @@ from reflector.logger import logger
OPUS_STANDARD_SAMPLE_RATE = 48000 OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000 OPUS_DEFAULT_BIT_RATE = 64000
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 PRESIGNED_URL_EXPIRATION_SECONDS = 7200
WAVEFORM_SEGMENTS = 255
class PipelineInput(BaseModel): class PipelineInput(BaseModel):
@@ -49,8 +67,9 @@ diarization_pipeline = hatchet.workflow(
# ============================================================================ # ============================================================================
async def _get_fresh_db_connection(): @asynccontextmanager
"""Create fresh database connection for subprocess.""" async def fresh_db_connection():
"""Context manager for database connections in Hatchet workers."""
import databases import databases
from reflector.db import _database_context from reflector.db import _database_context
@@ -60,22 +79,22 @@ async def _get_fresh_db_connection():
db = databases.Database(settings.DATABASE_URL) db = databases.Database(settings.DATABASE_URL)
_database_context.set(db) _database_context.set(db)
await db.connect() await db.connect()
return db try:
yield db
finally:
async def _close_db_connection(db):
"""Close database connection."""
from reflector.db import _database_context
await db.disconnect() await db.disconnect()
_database_context.set(None) _database_context.set(None)
async def _set_error_status(transcript_id: str): async def set_workflow_error_status(transcript_id: str) -> bool:
"""Set transcript status to 'error' on workflow failure (matches Celery line 790).""" """Set transcript status to 'error' on workflow failure.
try:
db = await _get_fresh_db_connection() Returns:
True if status was set successfully, False if failed.
Failure is logged as CRITICAL since it means transcript may be stuck.
"""
try: try:
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
await transcripts_controller.set_status(transcript_id, "error") await transcripts_controller.set_status(transcript_id, "error")
@@ -83,14 +102,15 @@ async def _set_error_status(transcript_id: str):
"[Hatchet] Set transcript status to error", "[Hatchet] Set transcript status to error",
transcript_id=transcript_id, transcript_id=transcript_id,
) )
finally: return True
await _close_db_connection(db)
except Exception as e: except Exception as e:
logger.error( logger.critical(
"[Hatchet] Failed to set error status", "[Hatchet] CRITICAL: Failed to set error status - transcript may be stuck in 'processing'",
transcript_id=transcript_id, transcript_id=transcript_id,
error=str(e), error=str(e),
exc_info=True,
) )
return False
def _get_storage(): def _get_storage():
@@ -106,13 +126,57 @@ def _get_storage():
) )
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns.
Hatchet SDK returns Pydantic models when tasks have typed return annotations,
but older code expects dicts. This helper normalizes the output.
"""
if isinstance(output, dict):
return output
return output.model_dump()
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
"""Decorator that handles task failures uniformly.
Args:
step_name: Name of the step for logging and progress tracking.
set_error_status: Whether to set transcript status to 'error' on failure.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context):
try:
return await func(input, ctx)
except Exception as e:
logger.error(
f"[Hatchet] {step_name} failed",
transcript_id=input.transcript_id,
error=str(e),
exc_info=True,
)
if set_error_status:
await set_workflow_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, step_name, "failed", ctx.workflow_run_id
)
raise
return wrapper
return decorator
# ============================================================================ # ============================================================================
# Pipeline Tasks # Pipeline Tasks
# ============================================================================ # ============================================================================
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3) @diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
async def get_recording(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("get_recording")
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
"""Fetch recording metadata from Daily.co API.""" """Fetch recording metadata from Daily.co API."""
logger.info("[Hatchet] get_recording", recording_id=input.recording_id) logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
@@ -120,9 +184,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
) )
# Set transcript status to "processing" at workflow start (matches Celery behavior) # Set transcript status to "processing" at workflow start
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -132,10 +195,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
"[Hatchet] Set transcript status to processing", "[Hatchet] Set transcript status to processing",
transcript_id=input.transcript_id, transcript_id=input.transcript_id,
) )
finally:
await _close_db_connection(db)
try:
from reflector.dailyco_api.client import DailyApiClient from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings from reflector.settings import settings
@@ -144,12 +204,12 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
) )
return { return RecordingResult(
"id": None, id=None,
"mtg_session_id": None, mtg_session_id=None,
"room_name": input.room_name, room_name=input.room_name,
"duration": 0, duration=0,
} )
if not settings.DAILY_API_KEY: if not settings.DAILY_API_KEY:
raise ValueError("DAILY_API_KEY not configured") raise ValueError("DAILY_API_KEY not configured")
@@ -168,38 +228,27 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
) )
return { return RecordingResult(
"id": recording.id, id=recording.id,
"mtg_session_id": recording.mtgSessionId, mtg_session_id=recording.mtgSessionId,
"room_name": recording.room_name, room_name=recording.room_name,
"duration": recording.duration, duration=recording.duration,
}
except Exception as e:
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
) )
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3 parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
) )
async def get_participants(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("get_participants")
"""Fetch participant list from Daily.co API and update transcript in database. async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
"""Fetch participant list from Daily.co API and update transcript in database."""
Matches Celery's update_participants_from_daily() behavior.
"""
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
) )
try: recording_data = _to_dict(ctx.task_output(get_recording))
recording_data = ctx.task_output(get_recording)
mtg_session_id = recording_data.get("mtg_session_id") mtg_session_id = recording_data.get("mtg_session_id")
from reflector.dailyco_api.client import DailyApiClient from reflector.dailyco_api.client import DailyApiClient
@@ -209,9 +258,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
parse_daily_recording_filename, parse_daily_recording_filename,
) )
# Get transcript and reset events/topics/participants (matches Celery line 599-607) # Get transcript and reset events/topics/participants
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.transcripts import ( from reflector.db.transcripts import (
TranscriptParticipant, TranscriptParticipant,
transcripts_controller, transcripts_controller,
@@ -219,8 +267,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
# Reset events/topics/participants (matches Celery line 599-607) # Reset events/topics/participants
# Note: title NOT cleared - Celery preserves existing titles # Note: title NOT cleared - preserves existing titles
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
@@ -237,16 +285,12 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
"completed", "completed",
ctx.workflow_run_id, ctx.workflow_run_id,
) )
return { return ParticipantsResult(
"participants": [], participants=[],
"num_tracks": len(input.tracks), num_tracks=len(input.tracks),
"source_language": transcript.source_language source_language=transcript.source_language if transcript else "en",
if transcript target_language=transcript.target_language if transcript else "en",
else "en", )
"target_language": transcript.target_language
if transcript
else "en",
}
# Fetch participants from Daily API # Fetch participants from Daily API
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
@@ -264,7 +308,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
track_keys = [t["s3_key"] for t in input.tracks] track_keys = [t["s3_key"] for t in input.tracks]
cam_audio_keys = filter_cam_audio_tracks(track_keys) cam_audio_keys = filter_cam_audio_tracks(track_keys)
# Update participants in database (matches Celery lines 568-590) # Update participants in database
participants_list = [] participants_list = []
for idx, key in enumerate(cam_audio_keys): for idx, key in enumerate(cam_audio_keys):
try: try:
@@ -299,46 +343,31 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
participant_count=len(participants_list), participant_count=len(participants_list),
) )
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
) )
return { return ParticipantsResult(
"participants": participants_list, participants=participants_list,
"num_tracks": len(input.tracks), num_tracks=len(input.tracks),
"source_language": transcript.source_language if transcript else "en", source_language=transcript.source_language if transcript else "en",
"target_language": transcript.target_language if transcript else "en", target_language=transcript.target_language if transcript else "en",
}
except Exception as e:
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
) )
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3 parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
) )
async def process_tracks(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("process_tracks")
"""Spawn child workflows for each track (dynamic fan-out). async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
"""Spawn child workflows for each track (dynamic fan-out)."""
Processes pad_track and transcribe_track for each audio track in parallel.
"""
logger.info( logger.info(
"[Hatchet] process_tracks", "[Hatchet] process_tracks",
num_tracks=len(input.tracks), num_tracks=len(input.tracks),
transcript_id=input.transcript_id, transcript_id=input.transcript_id,
) )
try: participants_data = _to_dict(ctx.task_output(get_participants))
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
participants_data = ctx.task_output(get_participants)
source_language = participants_data.get("source_language", "en") source_language = participants_data.get("source_language", "en")
# Spawn child workflows for each track with correct language # Spawn child workflows for each track with correct language
@@ -361,61 +390,49 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
# Get target_language for later use in detect_topics # Get target_language for later use in detect_topics
target_language = participants_data.get("target_language", "en") target_language = participants_data.get("target_language", "en")
# Collect all track results # Collect results from each track (don't mutate lists while iterating)
all_words = [] track_words = []
padded_urls = [] padded_urls = []
created_padded_files = set() created_padded_files = set()
for result in results: for result in results:
transcribe_result = result.get("transcribe_track", {}) transcribe_result = result.get("transcribe_track", {})
all_words.extend(transcribe_result.get("words", [])) track_words.append(transcribe_result.get("words", []))
pad_result = result.get("pad_track", {}) pad_result = result.get("pad_track", {})
padded_urls.append(pad_result.get("padded_url")) padded_urls.append(pad_result.get("padded_url"))
# Track padded files for cleanup (matches Celery line 636-637) # Track padded files for cleanup
track_index = pad_result.get("track_index") track_index = pad_result.get("track_index")
if pad_result.get("size", 0) > 0 and track_index is not None: if pad_result.get("size", 0) > 0 and track_index is not None:
# File was created (size > 0 means padding was applied)
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
created_padded_files.add(storage_path) created_padded_files.add(storage_path)
# Sort words by start time # Merge all words and sort by start time
all_words = [word for words in track_words for word in words]
all_words.sort(key=lambda w: w.get("start", 0)) all_words.sort(key=lambda w: w.get("start", 0))
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
# Mixdown needs the padded files, so we can't delete them here
logger.info( logger.info(
"[Hatchet] process_tracks complete", "[Hatchet] process_tracks complete",
num_tracks=len(input.tracks), num_tracks=len(input.tracks),
total_words=len(all_words), total_words=len(all_words),
) )
return { return ProcessTracksResult(
"all_words": all_words, all_words=all_words,
"padded_urls": padded_urls, padded_urls=padded_urls,
"word_count": len(all_words), word_count=len(all_words),
"num_tracks": len(input.tracks), num_tracks=len(input.tracks),
"target_language": target_language, target_language=target_language,
"created_padded_files": list( created_padded_files=list(created_padded_files),
created_padded_files
), # For cleanup after mixdown
}
except Exception as e:
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
) )
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3 parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
) )
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("mixdown_tracks")
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery).""" """Mix all padded tracks into single audio file using PyAV (same as Celery)."""
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
@@ -423,8 +440,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
) )
try: track_data = _to_dict(ctx.task_output(process_tracks))
track_data = ctx.task_output(process_tracks)
padded_urls = track_data.get("padded_urls", []) padded_urls = track_data.get("padded_urls", [])
if not padded_urls: if not padded_urls:
@@ -581,7 +597,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
except Exception: except Exception:
pass pass
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3) # Upload mixed file to storage
file_size = Path(output_path).stat().st_size file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3" storage_path = f"{input.transcript_id}/audio.mp3"
@@ -590,9 +606,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
Path(output_path).unlink(missing_ok=True) Path(output_path).unlink(missing_ok=True)
# Update transcript with audio_location (matches Celery line 661) # Update transcript with audio_location
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -600,8 +615,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
await transcripts_controller.update( await transcripts_controller.update(
transcript, {"audio_location": "storage"} transcript, {"audio_location": "storage"}
) )
finally:
await _close_db_connection(db)
logger.info( logger.info(
"[Hatchet] mixdown_tracks uploaded", "[Hatchet] mixdown_tracks uploaded",
@@ -613,27 +626,18 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
) )
return { return MixdownResult(
"audio_key": storage_path, audio_key=storage_path,
"duration": duration_ms[ duration=duration_ms[0],
0 tracks_mixed=len(valid_urls),
], # Duration in milliseconds from AudioFileWriterProcessor
"tracks_mixed": len(valid_urls),
}
except Exception as e:
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
) )
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3 parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
) )
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("generate_waveform")
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery).""" """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
@@ -641,15 +645,13 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
) )
try:
import httpx import httpx
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
from reflector.utils.audio_waveform import get_audio_waveform from reflector.utils.audio_waveform import get_audio_waveform
# Cleanup temporary padded S3 files (matches Celery lines 710-725) # Cleanup temporary padded S3 files (deferred until after mixdown)
# Moved here from process_tracks because mixdown_tracks needs the padded files track_data = _to_dict(ctx.task_output(process_tracks))
track_data = ctx.task_output(process_tracks)
created_padded_files = track_data.get("created_padded_files", []) created_padded_files = track_data.get("created_padded_files", [])
if created_padded_files: if created_padded_files:
logger.info( logger.info(
@@ -660,9 +662,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
for storage_path in created_padded_files: for storage_path in created_padded_files:
cleanup_tasks.append(storage.delete_file(storage_path)) cleanup_tasks.append(storage.delete_file(storage_path))
cleanup_results = await asyncio.gather( cleanup_results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
*cleanup_tasks, return_exceptions=True
)
for storage_path, result in zip(created_padded_files, cleanup_results): for storage_path, result in zip(created_padded_files, cleanup_results):
if isinstance(result, Exception): if isinstance(result, Exception):
logger.warning( logger.warning(
@@ -671,7 +671,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
error=str(result), error=str(result),
) )
mixdown_data = ctx.task_output(mixdown_tracks) mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
audio_key = mixdown_data.get("audio_key") audio_key = mixdown_data.get("audio_key")
storage = _get_storage() storage = _get_storage()
@@ -692,20 +692,19 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
with open(temp_path, "wb") as f: with open(temp_path, "wb") as f:
f.write(response.content) f.write(response.content)
# Generate waveform (matches Celery: get_audio_waveform with 255 segments) # Generate waveform
waveform = get_audio_waveform(path=Path(temp_path), segments_count=255) waveform = get_audio_waveform(
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
)
# Save waveform to database via event (matches Celery on_waveform callback) # Save waveform to database via event
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
waveform_data = TranscriptWaveform(waveform=waveform) waveform_data = TranscriptWaveform(waveform=waveform)
await transcripts_controller.append_event( await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform_data transcript=transcript, event="WAVEFORM", data=waveform_data
) )
finally:
await _close_db_connection(db)
finally: finally:
Path(temp_path).unlink(missing_ok=True) Path(temp_path).unlink(missing_ok=True)
@@ -716,21 +715,14 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
) )
return {"waveform_generated": True} return WaveformResult(waveform_generated=True)
except Exception as e:
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3 parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
) )
async def detect_topics(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("detect_topics")
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics using LLM and save to database (matches Celery on_topic callback).""" """Detect topics using LLM and save to database (matches Celery on_topic callback)."""
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
@@ -738,8 +730,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
) )
try: track_data = _to_dict(ctx.task_output(process_tracks))
track_data = ctx.task_output(process_tracks)
words = track_data.get("all_words", []) words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en") target_language = track_data.get("target_language", "en")
@@ -757,13 +748,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline = topic_processing.EmptyPipeline(logger=logger) empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
# Get DB connection for callbacks async with fresh_db_connection():
db = await _get_fresh_db_connection()
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that upserts topics to DB (matches Celery on_topic) # Callback that upserts topics to DB
async def on_topic_callback(data): async def on_topic_callback(data):
topic = TranscriptTopic( topic = TranscriptTopic(
title=data.title, title=data.title,
@@ -785,8 +773,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
on_topic_callback=on_topic_callback, on_topic_callback=on_topic_callback,
empty_pipeline=empty_pipeline, empty_pipeline=empty_pipeline,
) )
finally:
await _close_db_connection(db)
topics_list = [t.model_dump() for t in topics] topics_list = [t.model_dump() for t in topics]
@@ -796,21 +782,14 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
) )
return {"topics": topics_list} return TopicsResult(topics=topics_list)
except Exception as e:
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3 parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
) )
async def generate_title(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("generate_title")
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title using LLM and save to database (matches Celery on_title callback).""" """Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
@@ -818,8 +797,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
) )
try: topics_data = _to_dict(ctx.task_output(detect_topics))
topics_data = ctx.task_output(detect_topics)
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( from reflector.db.transcripts import (
@@ -834,11 +812,10 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline = topic_processing.EmptyPipeline(logger=logger) empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
title_result = None title_result = None
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates title in DB (matches Celery on_title) # Callback that updates title in DB
async def on_title_callback(data): async def on_title_callback(data):
nonlocal title_result nonlocal title_result
title_result = data.title title_result = data.title
@@ -858,8 +835,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline=empty_pipeline, empty_pipeline=empty_pipeline,
logger=logger, logger=logger,
) )
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_title complete", title=title_result) logger.info("[Hatchet] generate_title complete", title=title_result)
@@ -867,21 +842,14 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
) )
return {"title": title_result} return TitleResult(title=title_result)
except Exception as e:
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3 parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
) )
async def generate_summary(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("generate_summary")
async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
"""Generate meeting summary using LLM and save to database (matches Celery callbacks).""" """Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
@@ -889,8 +857,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
) )
try: topics_data = _to_dict(ctx.task_output(detect_topics))
topics_data = ctx.task_output(detect_topics)
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( from reflector.db.transcripts import (
@@ -907,11 +874,10 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
summary_result = None summary_result = None
short_summary_result = None short_summary_result = None
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates long_summary in DB (matches Celery on_long_summary) # Callback that updates long_summary in DB
async def on_long_summary_callback(data): async def on_long_summary_callback(data):
nonlocal summary_result nonlocal summary_result
summary_result = data.long_summary summary_result = data.long_summary
@@ -928,7 +894,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
data=final_long_summary, data=final_long_summary,
) )
# Callback that updates short_summary in DB (matches Celery on_short_summary) # Callback that updates short_summary in DB
async def on_short_summary_callback(data): async def on_short_summary_callback(data):
nonlocal short_summary_result nonlocal short_summary_result
short_summary_result = data.short_summary short_summary_result = data.short_summary
@@ -953,8 +919,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline=empty_pipeline, empty_pipeline=empty_pipeline,
logger=logger, logger=logger,
) )
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_summary complete") logger.info("[Hatchet] generate_summary complete")
@@ -962,15 +926,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
) )
return {"summary": summary_result, "short_summary": short_summary_result} return SummaryResult(summary=summary_result, short_summary=short_summary_result)
except Exception as e:
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
@@ -978,7 +934,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
execution_timeout=timedelta(seconds=60), execution_timeout=timedelta(seconds=60),
retries=3, retries=3,
) )
async def finalize(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("finalize")
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'. """Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
Matches Celery's on_transcript + set_status behavior. Matches Celery's on_transcript + set_status behavior.
@@ -990,33 +947,28 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
) )
try: mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
mixdown_data = ctx.task_output(mixdown_tracks) track_data = _to_dict(ctx.task_output(process_tracks))
track_data = ctx.task_output(process_tracks)
duration = mixdown_data.get("duration", 0) duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", []) all_words = track_data.get("all_words", [])
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.transcripts import TranscriptText, transcripts_controller from reflector.db.transcripts import TranscriptText, transcripts_controller
from reflector.processors.types import Transcript as TranscriptType from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word from reflector.processors.types import Word
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript is None: if transcript is None:
raise ValueError( raise ValueError(f"Transcript {input.transcript_id} not found in database")
f"Transcript {input.transcript_id} not found in database"
)
# Convert words back to Word objects for storage # Convert words back to Word objects for storage
word_objects = [Word(**w) for w in all_words] word_objects = [Word(**w) for w in all_words]
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736) # Create merged transcript for TRANSCRIPT event
merged_transcript = TranscriptType(words=word_objects, translation=None) merged_transcript = TranscriptType(words=word_objects, translation=None)
# Emit TRANSCRIPT event (matches Celery on_transcript callback) # Emit TRANSCRIPT event
await transcripts_controller.append_event( await transcripts_controller.append_event(
transcript=transcript, transcript=transcript,
event="TRANSCRIPT", event="TRANSCRIPT",
@@ -1036,35 +988,23 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
}, },
) )
# Set status to "ended" (matches Celery line 745) # Set status to "ended"
await transcripts_controller.set_status(input.transcript_id, "ended") await transcripts_controller.set_status(input.transcript_id, "ended")
logger.info( logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
"[Hatchet] finalize complete", transcript_id=input.transcript_id
)
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "finalize", "completed", ctx.workflow_run_id input.transcript_id, "finalize", "completed", ctx.workflow_run_id
) )
return {"status": "COMPLETED"} return FinalizeResult(status="COMPLETED")
except Exception as e:
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3 parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
) )
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("cleanup_consent", set_error_status=False)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
"""Check and handle consent requirements.""" """Check and handle consent requirements."""
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id) logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
@@ -1072,10 +1012,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
) )
try: async with fresh_db_connection():
db = await _get_fresh_db_connection()
try:
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
@@ -1091,27 +1028,18 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id "[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
) )
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
) )
return {"consent_checked": True} return ConsentResult(consent_checked=True)
except Exception as e:
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5 parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
) )
async def post_zulip(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("post_zulip", set_error_status=False)
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
"""Post notification to Zulip.""" """Post notification to Zulip."""
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id) logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
@@ -1119,7 +1047,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
) )
try:
from reflector.settings import settings from reflector.settings import settings
if not settings.ZULIP_REALM: if not settings.ZULIP_REALM:
@@ -1127,45 +1054,32 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
await emit_progress_async( await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
) )
return {"zulip_message_id": None, "skipped": True} return ZulipResult(zulip_message_id=None, skipped=True)
from reflector.zulip import post_transcript_notification from reflector.zulip import post_transcript_notification
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
message_id = await post_transcript_notification(transcript) message_id = await post_transcript_notification(transcript)
logger.info( logger.info("[Hatchet] post_zulip complete", zulip_message_id=message_id)
"[Hatchet] post_zulip complete", zulip_message_id=message_id
)
else: else:
message_id = None message_id = None
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
) )
return {"zulip_message_id": message_id} return ZulipResult(zulip_message_id=message_id)
except Exception as e:
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30 parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
) )
async def send_webhook(input: PipelineInput, ctx: Context) -> dict: @with_error_handling("send_webhook", set_error_status=False)
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
"""Send completion webhook to external service.""" """Send completion webhook to external service."""
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id) logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
@@ -1173,17 +1087,14 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
) )
try:
if not input.room_id: if not input.room_id:
logger.info("[Hatchet] send_webhook skipped (no room_id)") logger.info("[Hatchet] send_webhook skipped (no room_id)")
await emit_progress_async( await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
) )
return {"webhook_sent": False, "skipped": True} return WebhookResult(webhook_sent=False, skipped=True)
db = await _get_fresh_db_connection() async with fresh_db_connection():
try:
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
@@ -1217,20 +1128,10 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
ctx.workflow_run_id, ctx.workflow_run_id,
) )
return {"webhook_sent": True, "response_code": response.status_code} return WebhookResult(webhook_sent=True, response_code=response.status_code)
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
) )
return {"webhook_sent": False, "skipped": True} return WebhookResult(webhook_sent=False, skipped=True)
except Exception as e:
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
)
raise

View File

@@ -0,0 +1,123 @@
"""
Pydantic models for Hatchet workflow task return types.
Provides static typing for all task outputs, enabling type checking
and better IDE support.
"""
from typing import Any
from pydantic import BaseModel
# ============================================================================
# Track Processing Results (track_processing.py)
# ============================================================================
class PadTrackResult(BaseModel):
"""Result from pad_track task."""
padded_url: str
size: int
track_index: int
class TranscribeTrackResult(BaseModel):
"""Result from transcribe_track task."""
words: list[dict[str, Any]]
track_index: int
# ============================================================================
# Diarization Pipeline Results (diarization_pipeline.py)
# ============================================================================
class RecordingResult(BaseModel):
"""Result from get_recording task."""
id: str | None
mtg_session_id: str | None
room_name: str | None
duration: float
class ParticipantsResult(BaseModel):
"""Result from get_participants task."""
participants: list[dict[str, Any]]
num_tracks: int
source_language: str
target_language: str
class ProcessTracksResult(BaseModel):
"""Result from process_tracks task."""
all_words: list[dict[str, Any]]
padded_urls: list[str | None]
word_count: int
num_tracks: int
target_language: str
created_padded_files: list[str]
class MixdownResult(BaseModel):
"""Result from mixdown_tracks task."""
audio_key: str
duration: float
tracks_mixed: int
class WaveformResult(BaseModel):
"""Result from generate_waveform task."""
waveform_generated: bool
class TopicsResult(BaseModel):
"""Result from detect_topics task."""
topics: list[dict[str, Any]]
class TitleResult(BaseModel):
"""Result from generate_title task."""
title: str | None
class SummaryResult(BaseModel):
"""Result from generate_summary task."""
summary: str | None
short_summary: str | None
class FinalizeResult(BaseModel):
"""Result from finalize task."""
status: str
class ConsentResult(BaseModel):
"""Result from cleanup_consent task."""
consent_checked: bool
class ZulipResult(BaseModel):
"""Result from post_zulip task."""
zulip_message_id: int | None = None
skipped: bool = False
class WebhookResult(BaseModel):
"""Result from send_webhook task."""
webhook_sent: bool
skipped: bool = False
response_code: int | None = None

View File

@@ -18,8 +18,17 @@ from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger from reflector.logger import logger
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns."""
if isinstance(output, dict):
return output
return output.model_dump()
# Audio constants matching existing pipeline # Audio constants matching existing pipeline
OPUS_STANDARD_SAMPLE_RATE = 48000 OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000 OPUS_DEFAULT_BIT_RATE = 64000
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3) @track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> dict: async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment. """Pad single audio track with silence for alignment.
Extracts stream.start_time from WebM container metadata and applies Extracts stream.start_time from WebM container metadata and applies
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
await emit_progress_async( await emit_progress_async(
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
) )
return { return PadTrackResult(
"padded_url": source_url, padded_url=source_url,
"size": 0, size=0,
"track_index": input.track_index, track_index=input.track_index,
} )
# Create temp file for padded output # Create temp file for padded output
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
) )
return { return PadTrackResult(
"padded_url": padded_url, padded_url=padded_url,
"size": file_size, size=file_size,
"track_index": input.track_index, track_index=input.track_index,
} )
except Exception as e: except Exception as e:
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True) logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
@track_workflow.task( @track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3 parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
) )
async def transcribe_track(input: TrackInput, ctx: Context) -> dict: async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
"""Transcribe audio track using GPU (Modal.com) or local Whisper.""" """Transcribe audio track using GPU (Modal.com) or local Whisper."""
logger.info( logger.info(
"[Hatchet] transcribe_track", "[Hatchet] transcribe_track",
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
) )
try: try:
pad_result = ctx.task_output(pad_track) pad_result = _to_dict(ctx.task_output(pad_track))
audio_url = pad_result.get("padded_url") audio_url = pad_result.get("padded_url")
if not audio_url: if not audio_url:
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
) )
return { return TranscribeTrackResult(
"words": words, words=words,
"track_index": input.track_index, track_index=input.track_index,
} )
except Exception as e: except Exception as e:
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True) logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)

View File

@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
transcript, {"workflow_run_id": None} transcript, {"workflow_run_id": None}
) )
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
transcript = await transcripts_controller.get_by_id(
config.transcript_id
)
if transcript and transcript.workflow_run_id:
# Another process started a workflow between validation and now
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
)
if "RUNNING" in status or "QUEUED" in status:
logger.info(
"Concurrent workflow detected, skipping dispatch",
workflow_id=transcript.workflow_run_id,
)
return transcript.workflow_run_id
except Exception:
# If we can't get status, proceed with new workflow
pass
workflow_id = await HatchetClientManager.start_workflow( workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline", workflow_name="DiarizationPipeline",
input_data={ input_data={
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
"transcript_id": config.transcript_id, "transcript_id": config.transcript_id,
"room_id": config.room_id, "room_id": config.room_id,
}, },
additional_metadata={
"transcript_id": config.transcript_id,
"recording_id": config.recording_id,
"daily_recording_id": config.recording_id,
},
) )
if transcript: if transcript:

View File

@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
"transcript_id": transcript.id, "transcript_id": transcript.id,
"room_id": room.id, "room_id": room.id,
}, },
additional_metadata={
"transcript_id": transcript.id,
"recording_id": recording_id,
"daily_recording_id": recording_id,
},
) )
logger.info( logger.info(
"Started Hatchet workflow", "Started Hatchet workflow",

View File

@@ -527,6 +527,22 @@ def fake_mp3_upload():
yield yield
@pytest.fixture(autouse=True)
def reset_hatchet_client():
"""Reset HatchetClientManager singleton before and after each test.
This ensures test isolation - each test starts with a fresh client state.
The fixture is autouse=True so it applies to all tests automatically.
"""
from reflector.hatchet.client import HatchetClientManager
# Reset before test
HatchetClientManager.reset()
yield
# Reset after test to clean up
HatchetClientManager.reset()
@pytest.fixture @pytest.fixture
async def fake_transcript_with_topics(tmpdir, client): async def fake_transcript_with_topics(tmpdir, client):
import shutil import shutil

View File

@@ -2,6 +2,9 @@
Tests for HatchetClientManager error handling and validation. Tests for HatchetClientManager error handling and validation.
Only tests that catch real bugs - not mock verification tests. Only tests that catch real bugs - not mock verification tests.
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
automatically resets the singleton before and after each test.
""" """
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
""" """
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings: with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = "test-token" mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
mock_settings.HATCHET_DEBUG = False mock_settings.HATCHET_DEBUG = False
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
# Should return False on error (workflow might be gone) # Should return False on error (workflow might be gone)
assert can_replay is False assert can_replay is False
HatchetClientManager._instance = None
def test_hatchet_client_raises_without_token(): def test_hatchet_client_raises_without_token():
"""Test that get_client raises ValueError without token. """Test that get_client raises ValueError without token.
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
""" """
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings: with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = None mock_settings.HATCHET_CLIENT_TOKEN = None
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"): with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
HatchetClientManager.get_client() HatchetClientManager.get_client()
HatchetClientManager._instance = None