self-review (no-mistakes)

This commit is contained in:
Igor Loskutov
2025-12-16 16:04:52 -05:00
parent e81e0cb5c3
commit fce0945564
9 changed files with 1034 additions and 1041 deletions

115
TASKS.md
View File

@@ -1,115 +0,0 @@
# Durable Workflow Migration Tasks
This document defines atomic, isolated work items for migrating the Daily.co multitrack diarization pipeline from Celery to durable workflow orchestration using **Hatchet**.
---
## Provider Selection
```bash
# .env
DURABLE_WORKFLOW_PROVIDER=none # Celery only (default)
DURABLE_WORKFLOW_PROVIDER=hatchet # Use Hatchet
DURABLE_WORKFLOW_SHADOW_MODE=true # Run both Hatchet + Celery (for comparison)
```
---
## Task Index
| ID | Title | Status |
|----|-------|--------|
| INFRA-001 | Add container to docker-compose | Done |
| INFRA-002 | Create Python client wrapper | Done |
| INFRA-003 | Add environment configuration | Done |
| TASK-001 | Create workflow definition | Done |
| TASK-002 | get_recording task | Done |
| TASK-003 | get_participants task | Done |
| TASK-004 | pad_track task | Done |
| TASK-005 | mixdown_tracks task | Done |
| TASK-006 | generate_waveform task | Done |
| TASK-007 | transcribe_track task | Done |
| TASK-008 | merge_transcripts task | Done (in process_tracks) |
| TASK-009 | detect_topics task | Done |
| TASK-010 | generate_title task | Done |
| TASK-011 | generate_summary task | Done |
| TASK-012 | finalize task | Done |
| TASK-013 | cleanup_consent task | Done |
| TASK-014 | post_zulip task | Done |
| TASK-015 | send_webhook task | Done |
| EVENT-001 | Progress WebSocket events | Done |
| INTEG-001 | Pipeline trigger integration | Done |
| SHADOW-001 | Shadow mode toggle | Done |
| TEST-001 | Integration tests | Pending |
| TEST-002 | E2E workflow test | Pending |
| CUTOVER-001 | Production cutover | Pending |
| CLEANUP-001 | Remove Celery code | Pending |
---
## File Structure
```
server/reflector/hatchet/
├── client.py # SDK wrapper
├── progress.py # WebSocket progress emission
├── run_workers.py # Worker startup
└── workflows/
├── diarization_pipeline.py # Main workflow with all tasks
└── track_processing.py # Child workflow (pad + transcribe)
```
---
## Remaining Work
### TEST-001: Integration Tests
- [ ] Test each task with mocked external services
- [ ] Test error handling and retries
### TEST-002: E2E Workflow Test
- [ ] Complete workflow run with real Daily.co recording
- [ ] Verify output matches Celery pipeline
- [ ] Performance comparison
### CUTOVER-001: Production Cutover
- [ ] Deploy with `DURABLE_WORKFLOW_PROVIDER=hatchet`
- [ ] Monitor for failures
- [ ] Compare results with shadow mode if needed
### CLEANUP-001: Remove Celery Code
- [ ] Remove `main_multitrack_pipeline.py`
- [ ] Remove Celery task triggers
- [ ] Update documentation
---
## Known Issues
### Hatchet
- See `HATCHET_LLM_OBSERVATIONS.md` for debugging notes
- SDK v1.21+ API changes (breaking)
- JWT token Docker networking issues
- Worker appears hung without debug mode
- Workflow replay is version-locked (use --force to run latest code)
---
## Quick Start
### Hatchet
```bash
# Start infrastructure
docker compose up -d hatchet hatchet-worker
# Workers auto-register on startup
```
### Trigger Workflow
```bash
# Set provider in .env
DURABLE_WORKFLOW_PROVIDER=hatchet
# Process a Daily.co recording via webhook or API
# The pipeline trigger automatically uses the configured provider
```

View File

@@ -1,37 +1,71 @@
"""Hatchet Python client wrapper."""
"""Hatchet Python client wrapper.
from hatchet_sdk import Hatchet
Uses singleton pattern because:
1. Hatchet client maintains persistent gRPC connections for workflow registration
2. Creating multiple clients would cause registration conflicts and resource leaks
3. The SDK is designed for a single client instance per process
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
"""
import logging
from hatchet_sdk import ClientConfig, Hatchet
from reflector.logger import logger
from reflector.settings import settings
class HatchetClientManager:
"""Singleton manager for Hatchet client connections."""
"""Singleton manager for Hatchet client connections.
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
connections for workflow registration, and multiple clients would conflict.
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
to ensure test isolation.
"""
_instance: Hatchet | None = None
@classmethod
def get_client(cls) -> Hatchet:
"""Get or create the Hatchet client."""
"""Get or create the Hatchet client.
Configures root logger so all logger.info() calls in workflows
appear in the Hatchet dashboard logs.
"""
if cls._instance is None:
if not settings.HATCHET_CLIENT_TOKEN:
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
# Pass root logger to Hatchet so workflow logs appear in dashboard
root_logger = logging.getLogger()
cls._instance = Hatchet(
debug=settings.HATCHET_DEBUG,
config=ClientConfig(logger=root_logger),
)
return cls._instance
@classmethod
async def start_workflow(
cls, workflow_name: str, input_data: dict, key: str | None = None
cls,
workflow_name: str,
input_data: dict,
additional_metadata: dict | None = None,
) -> str:
"""Start a workflow and return the workflow run ID."""
"""Start a workflow and return the workflow run ID.
Args:
workflow_name: Name of the workflow to trigger.
input_data: Input data for the workflow run.
additional_metadata: Optional metadata for filtering in dashboard
(e.g., transcript_id, recording_id).
"""
client = cls.get_client()
result = await client.runs.aio_create(
workflow_name,
input_data,
additional_metadata=additional_metadata,
)
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id

View File

@@ -6,9 +6,12 @@ Orchestrates the full processing flow from recording metadata to final transcrip
"""
import asyncio
import functools
import tempfile
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from typing import Callable
import av
from hatchet_sdk import Context
@@ -16,6 +19,20 @@ from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.models import (
ConsentResult,
FinalizeResult,
MixdownResult,
ParticipantsResult,
ProcessTracksResult,
RecordingResult,
SummaryResult,
TitleResult,
TopicsResult,
WaveformResult,
WebhookResult,
ZulipResult,
)
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
from reflector.logger import logger
@@ -23,6 +40,7 @@ from reflector.logger import logger
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
WAVEFORM_SEGMENTS = 255
class PipelineInput(BaseModel):
@@ -49,8 +67,9 @@ diarization_pipeline = hatchet.workflow(
# ============================================================================
async def _get_fresh_db_connection():
"""Create fresh database connection for subprocess."""
@asynccontextmanager
async def fresh_db_connection():
"""Context manager for database connections in Hatchet workers."""
import databases
from reflector.db import _database_context
@@ -60,22 +79,22 @@ async def _get_fresh_db_connection():
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
await db.connect()
return db
async def _close_db_connection(db):
"""Close database connection."""
from reflector.db import _database_context
try:
yield db
finally:
await db.disconnect()
_database_context.set(None)
async def _set_error_status(transcript_id: str):
"""Set transcript status to 'error' on workflow failure (matches Celery line 790)."""
try:
db = await _get_fresh_db_connection()
async def set_workflow_error_status(transcript_id: str) -> bool:
"""Set transcript status to 'error' on workflow failure.
Returns:
True if status was set successfully, False if failed.
Failure is logged as CRITICAL since it means transcript may be stuck.
"""
try:
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
await transcripts_controller.set_status(transcript_id, "error")
@@ -83,14 +102,15 @@ async def _set_error_status(transcript_id: str):
"[Hatchet] Set transcript status to error",
transcript_id=transcript_id,
)
finally:
await _close_db_connection(db)
return True
except Exception as e:
logger.error(
"[Hatchet] Failed to set error status",
logger.critical(
"[Hatchet] CRITICAL: Failed to set error status - transcript may be stuck in 'processing'",
transcript_id=transcript_id,
error=str(e),
exc_info=True,
)
return False
def _get_storage():
@@ -106,13 +126,57 @@ def _get_storage():
)
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns.
Hatchet SDK returns Pydantic models when tasks have typed return annotations,
but older code expects dicts. This helper normalizes the output.
"""
if isinstance(output, dict):
return output
return output.model_dump()
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
"""Decorator that handles task failures uniformly.
Args:
step_name: Name of the step for logging and progress tracking.
set_error_status: Whether to set transcript status to 'error' on failure.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context):
try:
return await func(input, ctx)
except Exception as e:
logger.error(
f"[Hatchet] {step_name} failed",
transcript_id=input.transcript_id,
error=str(e),
exc_info=True,
)
if set_error_status:
await set_workflow_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, step_name, "failed", ctx.workflow_run_id
)
raise
return wrapper
return decorator
# ============================================================================
# Pipeline Tasks
# ============================================================================
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
async def get_recording(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("get_recording")
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
"""Fetch recording metadata from Daily.co API."""
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
@@ -120,9 +184,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
)
# Set transcript status to "processing" at workflow start (matches Celery behavior)
db = await _get_fresh_db_connection()
try:
# Set transcript status to "processing" at workflow start
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -132,10 +195,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
"[Hatchet] Set transcript status to processing",
transcript_id=input.transcript_id,
)
finally:
await _close_db_connection(db)
try:
from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings
@@ -144,12 +204,12 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
await emit_progress_async(
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
)
return {
"id": None,
"mtg_session_id": None,
"room_name": input.room_name,
"duration": 0,
}
return RecordingResult(
id=None,
mtg_session_id=None,
room_name=input.room_name,
duration=0,
)
if not settings.DAILY_API_KEY:
raise ValueError("DAILY_API_KEY not configured")
@@ -168,38 +228,27 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
)
return {
"id": recording.id,
"mtg_session_id": recording.mtgSessionId,
"room_name": recording.room_name,
"duration": recording.duration,
}
except Exception as e:
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
return RecordingResult(
id=recording.id,
mtg_session_id=recording.mtgSessionId,
room_name=recording.room_name,
duration=recording.duration,
)
raise
@diarization_pipeline.task(
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
)
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
"""Fetch participant list from Daily.co API and update transcript in database.
Matches Celery's update_participants_from_daily() behavior.
"""
@with_error_handling("get_participants")
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
"""Fetch participant list from Daily.co API and update transcript in database."""
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
)
try:
recording_data = ctx.task_output(get_recording)
recording_data = _to_dict(ctx.task_output(get_recording))
mtg_session_id = recording_data.get("mtg_session_id")
from reflector.dailyco_api.client import DailyApiClient
@@ -209,9 +258,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
parse_daily_recording_filename,
)
# Get transcript and reset events/topics/participants (matches Celery line 599-607)
db = await _get_fresh_db_connection()
try:
# Get transcript and reset events/topics/participants
async with fresh_db_connection():
from reflector.db.transcripts import (
TranscriptParticipant,
transcripts_controller,
@@ -219,8 +267,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
# Reset events/topics/participants (matches Celery line 599-607)
# Note: title NOT cleared - Celery preserves existing titles
# Reset events/topics/participants
# Note: title NOT cleared - preserves existing titles
await transcripts_controller.update(
transcript,
{
@@ -237,16 +285,12 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
"completed",
ctx.workflow_run_id,
)
return {
"participants": [],
"num_tracks": len(input.tracks),
"source_language": transcript.source_language
if transcript
else "en",
"target_language": transcript.target_language
if transcript
else "en",
}
return ParticipantsResult(
participants=[],
num_tracks=len(input.tracks),
source_language=transcript.source_language if transcript else "en",
target_language=transcript.target_language if transcript else "en",
)
# Fetch participants from Daily API
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
@@ -264,7 +308,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
track_keys = [t["s3_key"] for t in input.tracks]
cam_audio_keys = filter_cam_audio_tracks(track_keys)
# Update participants in database (matches Celery lines 568-590)
# Update participants in database
participants_list = []
for idx, key in enumerate(cam_audio_keys):
try:
@@ -299,46 +343,31 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
participant_count=len(participants_list),
)
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
)
return {
"participants": participants_list,
"num_tracks": len(input.tracks),
"source_language": transcript.source_language if transcript else "en",
"target_language": transcript.target_language if transcript else "en",
}
except Exception as e:
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
return ParticipantsResult(
participants=participants_list,
num_tracks=len(input.tracks),
source_language=transcript.source_language if transcript else "en",
target_language=transcript.target_language if transcript else "en",
)
raise
@diarization_pipeline.task(
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
)
async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
"""Spawn child workflows for each track (dynamic fan-out).
Processes pad_track and transcribe_track for each audio track in parallel.
"""
@with_error_handling("process_tracks")
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
"""Spawn child workflows for each track (dynamic fan-out)."""
logger.info(
"[Hatchet] process_tracks",
num_tracks=len(input.tracks),
transcript_id=input.transcript_id,
)
try:
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
participants_data = ctx.task_output(get_participants)
participants_data = _to_dict(ctx.task_output(get_participants))
source_language = participants_data.get("source_language", "en")
# Spawn child workflows for each track with correct language
@@ -361,61 +390,49 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
# Get target_language for later use in detect_topics
target_language = participants_data.get("target_language", "en")
# Collect all track results
all_words = []
# Collect results from each track (don't mutate lists while iterating)
track_words = []
padded_urls = []
created_padded_files = set()
for result in results:
transcribe_result = result.get("transcribe_track", {})
all_words.extend(transcribe_result.get("words", []))
track_words.append(transcribe_result.get("words", []))
pad_result = result.get("pad_track", {})
padded_urls.append(pad_result.get("padded_url"))
# Track padded files for cleanup (matches Celery line 636-637)
# Track padded files for cleanup
track_index = pad_result.get("track_index")
if pad_result.get("size", 0) > 0 and track_index is not None:
# File was created (size > 0 means padding was applied)
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
created_padded_files.add(storage_path)
# Sort words by start time
# Merge all words and sort by start time
all_words = [word for words in track_words for word in words]
all_words.sort(key=lambda w: w.get("start", 0))
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
# Mixdown needs the padded files, so we can't delete them here
logger.info(
"[Hatchet] process_tracks complete",
num_tracks=len(input.tracks),
total_words=len(all_words),
)
return {
"all_words": all_words,
"padded_urls": padded_urls,
"word_count": len(all_words),
"num_tracks": len(input.tracks),
"target_language": target_language,
"created_padded_files": list(
created_padded_files
), # For cleanup after mixdown
}
except Exception as e:
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
return ProcessTracksResult(
all_words=all_words,
padded_urls=padded_urls,
word_count=len(all_words),
num_tracks=len(input.tracks),
target_language=target_language,
created_padded_files=list(created_padded_files),
)
raise
@diarization_pipeline.task(
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("mixdown_tracks")
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
@@ -423,8 +440,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
)
try:
track_data = ctx.task_output(process_tracks)
track_data = _to_dict(ctx.task_output(process_tracks))
padded_urls = track_data.get("padded_urls", [])
if not padded_urls:
@@ -581,7 +597,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
except Exception:
pass
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3)
# Upload mixed file to storage
file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3"
@@ -590,9 +606,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
Path(output_path).unlink(missing_ok=True)
# Update transcript with audio_location (matches Celery line 661)
db = await _get_fresh_db_connection()
try:
# Update transcript with audio_location
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -600,8 +615,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
await transcripts_controller.update(
transcript, {"audio_location": "storage"}
)
finally:
await _close_db_connection(db)
logger.info(
"[Hatchet] mixdown_tracks uploaded",
@@ -613,27 +626,18 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
)
return {
"audio_key": storage_path,
"duration": duration_ms[
0
], # Duration in milliseconds from AudioFileWriterProcessor
"tracks_mixed": len(valid_urls),
}
except Exception as e:
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
return MixdownResult(
audio_key=storage_path,
duration=duration_ms[0],
tracks_mixed=len(valid_urls),
)
raise
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
)
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("generate_waveform")
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
@@ -641,15 +645,13 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
)
try:
import httpx
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
from reflector.utils.audio_waveform import get_audio_waveform
# Cleanup temporary padded S3 files (matches Celery lines 710-725)
# Moved here from process_tracks because mixdown_tracks needs the padded files
track_data = ctx.task_output(process_tracks)
# Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = _to_dict(ctx.task_output(process_tracks))
created_padded_files = track_data.get("created_padded_files", [])
if created_padded_files:
logger.info(
@@ -660,9 +662,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
for storage_path in created_padded_files:
cleanup_tasks.append(storage.delete_file(storage_path))
cleanup_results = await asyncio.gather(
*cleanup_tasks, return_exceptions=True
)
cleanup_results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
for storage_path, result in zip(created_padded_files, cleanup_results):
if isinstance(result, Exception):
logger.warning(
@@ -671,7 +671,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
error=str(result),
)
mixdown_data = ctx.task_output(mixdown_tracks)
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
audio_key = mixdown_data.get("audio_key")
storage = _get_storage()
@@ -692,20 +692,19 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
with open(temp_path, "wb") as f:
f.write(response.content)
# Generate waveform (matches Celery: get_audio_waveform with 255 segments)
waveform = get_audio_waveform(path=Path(temp_path), segments_count=255)
# Generate waveform
waveform = get_audio_waveform(
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
)
# Save waveform to database via event (matches Celery on_waveform callback)
db = await _get_fresh_db_connection()
try:
# Save waveform to database via event
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
waveform_data = TranscriptWaveform(waveform=waveform)
await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform_data
)
finally:
await _close_db_connection(db)
finally:
Path(temp_path).unlink(missing_ok=True)
@@ -716,21 +715,14 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
)
return {"waveform_generated": True}
except Exception as e:
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
)
raise
return WaveformResult(waveform_generated=True)
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
)
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("detect_topics")
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
@@ -738,8 +730,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
)
try:
track_data = ctx.task_output(process_tracks)
track_data = _to_dict(ctx.task_output(process_tracks))
words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en")
@@ -757,13 +748,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
# Get DB connection for callbacks
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that upserts topics to DB (matches Celery on_topic)
# Callback that upserts topics to DB
async def on_topic_callback(data):
topic = TranscriptTopic(
title=data.title,
@@ -785,8 +773,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
on_topic_callback=on_topic_callback,
empty_pipeline=empty_pipeline,
)
finally:
await _close_db_connection(db)
topics_list = [t.model_dump() for t in topics]
@@ -796,21 +782,14 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
)
return {"topics": topics_list}
except Exception as e:
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
)
raise
return TopicsResult(topics=topics_list)
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
)
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("generate_title")
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
@@ -818,8 +797,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
)
try:
topics_data = ctx.task_output(detect_topics)
topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import (
@@ -834,11 +812,10 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
title_result = None
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates title in DB (matches Celery on_title)
# Callback that updates title in DB
async def on_title_callback(data):
nonlocal title_result
title_result = data.title
@@ -858,8 +835,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline=empty_pipeline,
logger=logger,
)
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_title complete", title=title_result)
@@ -867,21 +842,14 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
)
return {"title": title_result}
except Exception as e:
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
)
raise
return TitleResult(title=title_result)
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
)
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("generate_summary")
async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
@@ -889,8 +857,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
)
try:
topics_data = ctx.task_output(detect_topics)
topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import (
@@ -907,11 +874,10 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
summary_result = None
short_summary_result = None
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates long_summary in DB (matches Celery on_long_summary)
# Callback that updates long_summary in DB
async def on_long_summary_callback(data):
nonlocal summary_result
summary_result = data.long_summary
@@ -928,7 +894,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
data=final_long_summary,
)
# Callback that updates short_summary in DB (matches Celery on_short_summary)
# Callback that updates short_summary in DB
async def on_short_summary_callback(data):
nonlocal short_summary_result
short_summary_result = data.short_summary
@@ -953,8 +919,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
empty_pipeline=empty_pipeline,
logger=logger,
)
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_summary complete")
@@ -962,15 +926,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
)
return {"summary": summary_result, "short_summary": short_summary_result}
except Exception as e:
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
)
raise
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
@diarization_pipeline.task(
@@ -978,7 +934,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
execution_timeout=timedelta(seconds=60),
retries=3,
)
async def finalize(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("finalize")
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
Matches Celery's on_transcript + set_status behavior.
@@ -990,33 +947,28 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
)
try:
mixdown_data = ctx.task_output(mixdown_tracks)
track_data = ctx.task_output(process_tracks)
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
track_data = _to_dict(ctx.task_output(process_tracks))
duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", [])
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
from reflector.db.transcripts import TranscriptText, transcripts_controller
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript is None:
raise ValueError(
f"Transcript {input.transcript_id} not found in database"
)
raise ValueError(f"Transcript {input.transcript_id} not found in database")
# Convert words back to Word objects for storage
word_objects = [Word(**w) for w in all_words]
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736)
# Create merged transcript for TRANSCRIPT event
merged_transcript = TranscriptType(words=word_objects, translation=None)
# Emit TRANSCRIPT event (matches Celery on_transcript callback)
# Emit TRANSCRIPT event
await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
@@ -1036,35 +988,23 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
},
)
# Set status to "ended" (matches Celery line 745)
# Set status to "ended"
await transcripts_controller.set_status(input.transcript_id, "ended")
logger.info(
"[Hatchet] finalize complete", transcript_id=input.transcript_id
)
finally:
await _close_db_connection(db)
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
)
return {"status": "COMPLETED"}
except Exception as e:
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
)
raise
return FinalizeResult(status="COMPLETED")
@diarization_pipeline.task(
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("cleanup_consent", set_error_status=False)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
"""Check and handle consent requirements."""
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
@@ -1072,10 +1012,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
)
try:
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
from reflector.db.meetings import meetings_controller
from reflector.db.transcripts import transcripts_controller
@@ -1091,27 +1028,18 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
)
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
)
return {"consent_checked": True}
except Exception as e:
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
)
raise
return ConsentResult(consent_checked=True)
@diarization_pipeline.task(
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
)
async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("post_zulip", set_error_status=False)
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
"""Post notification to Zulip."""
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
@@ -1119,7 +1047,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
)
try:
from reflector.settings import settings
if not settings.ZULIP_REALM:
@@ -1127,45 +1054,32 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
)
return {"zulip_message_id": None, "skipped": True}
return ZulipResult(zulip_message_id=None, skipped=True)
from reflector.zulip import post_transcript_notification
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
message_id = await post_transcript_notification(transcript)
logger.info(
"[Hatchet] post_zulip complete", zulip_message_id=message_id
)
logger.info("[Hatchet] post_zulip complete", zulip_message_id=message_id)
else:
message_id = None
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
)
return {"zulip_message_id": message_id}
except Exception as e:
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
)
raise
return ZulipResult(zulip_message_id=message_id)
@diarization_pipeline.task(
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
)
async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
@with_error_handling("send_webhook", set_error_status=False)
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
"""Send completion webhook to external service."""
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
@@ -1173,17 +1087,14 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
)
try:
if not input.room_id:
logger.info("[Hatchet] send_webhook skipped (no room_id)")
await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
)
return {"webhook_sent": False, "skipped": True}
return WebhookResult(webhook_sent=False, skipped=True)
db = await _get_fresh_db_connection()
try:
async with fresh_db_connection():
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
@@ -1217,20 +1128,10 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
ctx.workflow_run_id,
)
return {"webhook_sent": True, "response_code": response.status_code}
finally:
await _close_db_connection(db)
return WebhookResult(webhook_sent=True, response_code=response.status_code)
await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
)
return {"webhook_sent": False, "skipped": True}
except Exception as e:
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
)
raise
return WebhookResult(webhook_sent=False, skipped=True)

View File

@@ -0,0 +1,123 @@
"""
Pydantic models for Hatchet workflow task return types.
Provides static typing for all task outputs, enabling type checking
and better IDE support.
"""
from typing import Any
from pydantic import BaseModel
# ============================================================================
# Track Processing Results (track_processing.py)
# ============================================================================
class PadTrackResult(BaseModel):
"""Result from pad_track task."""
padded_url: str
size: int
track_index: int
class TranscribeTrackResult(BaseModel):
"""Result from transcribe_track task."""
words: list[dict[str, Any]]
track_index: int
# ============================================================================
# Diarization Pipeline Results (diarization_pipeline.py)
# ============================================================================
class RecordingResult(BaseModel):
"""Result from get_recording task."""
id: str | None
mtg_session_id: str | None
room_name: str | None
duration: float
class ParticipantsResult(BaseModel):
"""Result from get_participants task."""
participants: list[dict[str, Any]]
num_tracks: int
source_language: str
target_language: str
class ProcessTracksResult(BaseModel):
"""Result from process_tracks task."""
all_words: list[dict[str, Any]]
padded_urls: list[str | None]
word_count: int
num_tracks: int
target_language: str
created_padded_files: list[str]
class MixdownResult(BaseModel):
"""Result from mixdown_tracks task."""
audio_key: str
duration: float
tracks_mixed: int
class WaveformResult(BaseModel):
"""Result from generate_waveform task."""
waveform_generated: bool
class TopicsResult(BaseModel):
"""Result from detect_topics task."""
topics: list[dict[str, Any]]
class TitleResult(BaseModel):
"""Result from generate_title task."""
title: str | None
class SummaryResult(BaseModel):
"""Result from generate_summary task."""
summary: str | None
short_summary: str | None
class FinalizeResult(BaseModel):
"""Result from finalize task."""
status: str
class ConsentResult(BaseModel):
"""Result from cleanup_consent task."""
consent_checked: bool
class ZulipResult(BaseModel):
"""Result from post_zulip task."""
zulip_message_id: int | None = None
skipped: bool = False
class WebhookResult(BaseModel):
"""Result from send_webhook task."""
webhook_sent: bool
skipped: bool = False
response_code: int | None = None

View File

@@ -18,8 +18,17 @@ from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns."""
if isinstance(output, dict):
return output
return output.model_dump()
# Audio constants matching existing pipeline
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> dict:
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment.
Extracts stream.start_time from WebM container metadata and applies
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
await emit_progress_async(
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": source_url,
"size": 0,
"track_index": input.track_index,
}
return PadTrackResult(
padded_url=source_url,
size=0,
track_index=input.track_index,
)
# Create temp file for padded output
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": padded_url,
"size": file_size,
"track_index": input.track_index,
}
return PadTrackResult(
padded_url=padded_url,
size=file_size,
track_index=input.track_index,
)
except Exception as e:
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
@track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
)
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
logger.info(
"[Hatchet] transcribe_track",
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
)
try:
pad_result = ctx.task_output(pad_track)
pad_result = _to_dict(ctx.task_output(pad_track))
audio_url = pad_result.get("padded_url")
if not audio_url:
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
)
return {
"words": words,
"track_index": input.track_index,
}
return TranscribeTrackResult(
words=words,
track_index=input.track_index,
)
except Exception as e:
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)

View File

@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
transcript, {"workflow_run_id": None}
)
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
transcript = await transcripts_controller.get_by_id(
config.transcript_id
)
if transcript and transcript.workflow_run_id:
# Another process started a workflow between validation and now
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
)
if "RUNNING" in status or "QUEUED" in status:
logger.info(
"Concurrent workflow detected, skipping dispatch",
workflow_id=transcript.workflow_run_id,
)
return transcript.workflow_run_id
except Exception:
# If we can't get status, proceed with new workflow
pass
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline",
input_data={
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
"transcript_id": config.transcript_id,
"room_id": config.room_id,
},
additional_metadata={
"transcript_id": config.transcript_id,
"recording_id": config.recording_id,
"daily_recording_id": config.recording_id,
},
)
if transcript:

View File

@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
"transcript_id": transcript.id,
"room_id": room.id,
},
additional_metadata={
"transcript_id": transcript.id,
"recording_id": recording_id,
"daily_recording_id": recording_id,
},
)
logger.info(
"Started Hatchet workflow",

View File

@@ -527,6 +527,22 @@ def fake_mp3_upload():
yield
@pytest.fixture(autouse=True)
def reset_hatchet_client():
"""Reset HatchetClientManager singleton before and after each test.
This ensures test isolation - each test starts with a fresh client state.
The fixture is autouse=True so it applies to all tests automatically.
"""
from reflector.hatchet.client import HatchetClientManager
# Reset before test
HatchetClientManager.reset()
yield
# Reset after test to clean up
HatchetClientManager.reset()
@pytest.fixture
async def fake_transcript_with_topics(tmpdir, client):
import shutil

View File

@@ -2,6 +2,9 @@
Tests for HatchetClientManager error handling and validation.
Only tests that catch real bugs - not mock verification tests.
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
automatically resets the singleton before and after each test.
"""
from unittest.mock import AsyncMock, MagicMock, patch
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
"""
from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
mock_settings.HATCHET_DEBUG = False
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
# Should return False on error (workflow might be gone)
assert can_replay is False
HatchetClientManager._instance = None
def test_hatchet_client_raises_without_token():
"""Test that get_client raises ValueError without token.
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
"""
from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = None
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
HatchetClientManager.get_client()
HatchetClientManager._instance = None