mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
self-review (no-mistakes)
This commit is contained in:
115
TASKS.md
115
TASKS.md
@@ -1,115 +0,0 @@
|
||||
# Durable Workflow Migration Tasks
|
||||
|
||||
This document defines atomic, isolated work items for migrating the Daily.co multitrack diarization pipeline from Celery to durable workflow orchestration using **Hatchet**.
|
||||
|
||||
---
|
||||
|
||||
## Provider Selection
|
||||
|
||||
```bash
|
||||
# .env
|
||||
DURABLE_WORKFLOW_PROVIDER=none # Celery only (default)
|
||||
DURABLE_WORKFLOW_PROVIDER=hatchet # Use Hatchet
|
||||
DURABLE_WORKFLOW_SHADOW_MODE=true # Run both Hatchet + Celery (for comparison)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task Index
|
||||
|
||||
| ID | Title | Status |
|
||||
|----|-------|--------|
|
||||
| INFRA-001 | Add container to docker-compose | Done |
|
||||
| INFRA-002 | Create Python client wrapper | Done |
|
||||
| INFRA-003 | Add environment configuration | Done |
|
||||
| TASK-001 | Create workflow definition | Done |
|
||||
| TASK-002 | get_recording task | Done |
|
||||
| TASK-003 | get_participants task | Done |
|
||||
| TASK-004 | pad_track task | Done |
|
||||
| TASK-005 | mixdown_tracks task | Done |
|
||||
| TASK-006 | generate_waveform task | Done |
|
||||
| TASK-007 | transcribe_track task | Done |
|
||||
| TASK-008 | merge_transcripts task | Done (in process_tracks) |
|
||||
| TASK-009 | detect_topics task | Done |
|
||||
| TASK-010 | generate_title task | Done |
|
||||
| TASK-011 | generate_summary task | Done |
|
||||
| TASK-012 | finalize task | Done |
|
||||
| TASK-013 | cleanup_consent task | Done |
|
||||
| TASK-014 | post_zulip task | Done |
|
||||
| TASK-015 | send_webhook task | Done |
|
||||
| EVENT-001 | Progress WebSocket events | Done |
|
||||
| INTEG-001 | Pipeline trigger integration | Done |
|
||||
| SHADOW-001 | Shadow mode toggle | Done |
|
||||
| TEST-001 | Integration tests | Pending |
|
||||
| TEST-002 | E2E workflow test | Pending |
|
||||
| CUTOVER-001 | Production cutover | Pending |
|
||||
| CLEANUP-001 | Remove Celery code | Pending |
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
server/reflector/hatchet/
|
||||
├── client.py # SDK wrapper
|
||||
├── progress.py # WebSocket progress emission
|
||||
├── run_workers.py # Worker startup
|
||||
└── workflows/
|
||||
├── diarization_pipeline.py # Main workflow with all tasks
|
||||
└── track_processing.py # Child workflow (pad + transcribe)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Remaining Work
|
||||
|
||||
### TEST-001: Integration Tests
|
||||
- [ ] Test each task with mocked external services
|
||||
- [ ] Test error handling and retries
|
||||
|
||||
### TEST-002: E2E Workflow Test
|
||||
- [ ] Complete workflow run with real Daily.co recording
|
||||
- [ ] Verify output matches Celery pipeline
|
||||
- [ ] Performance comparison
|
||||
|
||||
### CUTOVER-001: Production Cutover
|
||||
- [ ] Deploy with `DURABLE_WORKFLOW_PROVIDER=hatchet`
|
||||
- [ ] Monitor for failures
|
||||
- [ ] Compare results with shadow mode if needed
|
||||
|
||||
### CLEANUP-001: Remove Celery Code
|
||||
- [ ] Remove `main_multitrack_pipeline.py`
|
||||
- [ ] Remove Celery task triggers
|
||||
- [ ] Update documentation
|
||||
|
||||
---
|
||||
|
||||
## Known Issues
|
||||
|
||||
### Hatchet
|
||||
- See `HATCHET_LLM_OBSERVATIONS.md` for debugging notes
|
||||
- SDK v1.21+ API changes (breaking)
|
||||
- JWT token Docker networking issues
|
||||
- Worker appears hung without debug mode
|
||||
- Workflow replay is version-locked (use --force to run latest code)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Hatchet
|
||||
```bash
|
||||
# Start infrastructure
|
||||
docker compose up -d hatchet hatchet-worker
|
||||
|
||||
# Workers auto-register on startup
|
||||
```
|
||||
|
||||
### Trigger Workflow
|
||||
```bash
|
||||
# Set provider in .env
|
||||
DURABLE_WORKFLOW_PROVIDER=hatchet
|
||||
|
||||
# Process a Daily.co recording via webhook or API
|
||||
# The pipeline trigger automatically uses the configured provider
|
||||
```
|
||||
@@ -1,37 +1,71 @@
|
||||
"""Hatchet Python client wrapper."""
|
||||
"""Hatchet Python client wrapper.
|
||||
|
||||
from hatchet_sdk import Hatchet
|
||||
Uses singleton pattern because:
|
||||
1. Hatchet client maintains persistent gRPC connections for workflow registration
|
||||
2. Creating multiple clients would cause registration conflicts and resource leaks
|
||||
3. The SDK is designed for a single client instance per process
|
||||
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from hatchet_sdk import ClientConfig, Hatchet
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class HatchetClientManager:
|
||||
"""Singleton manager for Hatchet client connections."""
|
||||
"""Singleton manager for Hatchet client connections.
|
||||
|
||||
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
|
||||
connections for workflow registration, and multiple clients would conflict.
|
||||
|
||||
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
|
||||
to ensure test isolation.
|
||||
"""
|
||||
|
||||
_instance: Hatchet | None = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls) -> Hatchet:
|
||||
"""Get or create the Hatchet client."""
|
||||
"""Get or create the Hatchet client.
|
||||
|
||||
Configures root logger so all logger.info() calls in workflows
|
||||
appear in the Hatchet dashboard logs.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
||||
|
||||
# Pass root logger to Hatchet so workflow logs appear in dashboard
|
||||
root_logger = logging.getLogger()
|
||||
cls._instance = Hatchet(
|
||||
debug=settings.HATCHET_DEBUG,
|
||||
config=ClientConfig(logger=root_logger),
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
async def start_workflow(
|
||||
cls, workflow_name: str, input_data: dict, key: str | None = None
|
||||
cls,
|
||||
workflow_name: str,
|
||||
input_data: dict,
|
||||
additional_metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""Start a workflow and return the workflow run ID."""
|
||||
"""Start a workflow and return the workflow run ID.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to trigger.
|
||||
input_data: Input data for the workflow run.
|
||||
additional_metadata: Optional metadata for filtering in dashboard
|
||||
(e.g., transcript_id, recording_id).
|
||||
"""
|
||||
client = cls.get_client()
|
||||
result = await client.runs.aio_create(
|
||||
workflow_name,
|
||||
input_data,
|
||||
additional_metadata=additional_metadata,
|
||||
)
|
||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||
return result.run.metadata.id
|
||||
|
||||
@@ -6,9 +6,12 @@ Orchestrates the full processing flow from recording metadata to final transcrip
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import av
|
||||
from hatchet_sdk import Context
|
||||
@@ -16,6 +19,20 @@ from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.models import (
|
||||
ConsentResult,
|
||||
FinalizeResult,
|
||||
MixdownResult,
|
||||
ParticipantsResult,
|
||||
ProcessTracksResult,
|
||||
RecordingResult,
|
||||
SummaryResult,
|
||||
TitleResult,
|
||||
TopicsResult,
|
||||
WaveformResult,
|
||||
WebhookResult,
|
||||
ZulipResult,
|
||||
)
|
||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||
from reflector.logger import logger
|
||||
|
||||
@@ -23,6 +40,7 @@ from reflector.logger import logger
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||
WAVEFORM_SEGMENTS = 255
|
||||
|
||||
|
||||
class PipelineInput(BaseModel):
|
||||
@@ -49,8 +67,9 @@ diarization_pipeline = hatchet.workflow(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _get_fresh_db_connection():
|
||||
"""Create fresh database connection for subprocess."""
|
||||
@asynccontextmanager
|
||||
async def fresh_db_connection():
|
||||
"""Context manager for database connections in Hatchet workers."""
|
||||
import databases
|
||||
|
||||
from reflector.db import _database_context
|
||||
@@ -60,22 +79,22 @@ async def _get_fresh_db_connection():
|
||||
db = databases.Database(settings.DATABASE_URL)
|
||||
_database_context.set(db)
|
||||
await db.connect()
|
||||
return db
|
||||
|
||||
|
||||
async def _close_db_connection(db):
|
||||
"""Close database connection."""
|
||||
from reflector.db import _database_context
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
await db.disconnect()
|
||||
_database_context.set(None)
|
||||
|
||||
|
||||
async def _set_error_status(transcript_id: str):
|
||||
"""Set transcript status to 'error' on workflow failure (matches Celery line 790)."""
|
||||
try:
|
||||
db = await _get_fresh_db_connection()
|
||||
async def set_workflow_error_status(transcript_id: str) -> bool:
|
||||
"""Set transcript status to 'error' on workflow failure.
|
||||
|
||||
Returns:
|
||||
True if status was set successfully, False if failed.
|
||||
Failure is logged as CRITICAL since it means transcript may be stuck.
|
||||
"""
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
await transcripts_controller.set_status(transcript_id, "error")
|
||||
@@ -83,14 +102,15 @@ async def _set_error_status(transcript_id: str):
|
||||
"[Hatchet] Set transcript status to error",
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[Hatchet] Failed to set error status",
|
||||
logger.critical(
|
||||
"[Hatchet] CRITICAL: Failed to set error status - transcript may be stuck in 'processing'",
|
||||
transcript_id=transcript_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _get_storage():
|
||||
@@ -106,13 +126,57 @@ def _get_storage():
|
||||
)
|
||||
|
||||
|
||||
def _to_dict(output) -> dict:
|
||||
"""Convert task output to dict, handling both dict and Pydantic model returns.
|
||||
|
||||
Hatchet SDK returns Pydantic models when tasks have typed return annotations,
|
||||
but older code expects dicts. This helper normalizes the output.
|
||||
"""
|
||||
if isinstance(output, dict):
|
||||
return output
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
|
||||
"""Decorator that handles task failures uniformly.
|
||||
|
||||
Args:
|
||||
step_name: Name of the step for logging and progress tracking.
|
||||
set_error_status: Whether to set transcript status to 'error' on failure.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(input: PipelineInput, ctx: Context):
|
||||
try:
|
||||
return await func(input, ctx)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Hatchet] {step_name} failed",
|
||||
transcript_id=input.transcript_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, step_name, "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Tasks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
||||
async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("get_recording")
|
||||
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
"""Fetch recording metadata from Daily.co API."""
|
||||
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
||||
|
||||
@@ -120,9 +184,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
# Set transcript status to "processing" at workflow start (matches Celery behavior)
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
# Set transcript status to "processing" at workflow start
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
@@ -132,10 +195,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
"[Hatchet] Set transcript status to processing",
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
try:
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
from reflector.settings import settings
|
||||
|
||||
@@ -144,12 +204,12 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {
|
||||
"id": None,
|
||||
"mtg_session_id": None,
|
||||
"room_name": input.room_name,
|
||||
"duration": 0,
|
||||
}
|
||||
return RecordingResult(
|
||||
id=None,
|
||||
mtg_session_id=None,
|
||||
room_name=input.room_name,
|
||||
duration=0,
|
||||
)
|
||||
|
||||
if not settings.DAILY_API_KEY:
|
||||
raise ValueError("DAILY_API_KEY not configured")
|
||||
@@ -168,38 +228,27 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"id": recording.id,
|
||||
"mtg_session_id": recording.mtgSessionId,
|
||||
"room_name": recording.room_name,
|
||||
"duration": recording.duration,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
|
||||
return RecordingResult(
|
||||
id=recording.id,
|
||||
mtg_session_id=recording.mtgSessionId,
|
||||
room_name=recording.room_name,
|
||||
duration=recording.duration,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
||||
)
|
||||
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Fetch participant list from Daily.co API and update transcript in database.
|
||||
|
||||
Matches Celery's update_participants_from_daily() behavior.
|
||||
"""
|
||||
@with_error_handling("get_participants")
|
||||
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
|
||||
"""Fetch participant list from Daily.co API and update transcript in database."""
|
||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
recording_data = ctx.task_output(get_recording)
|
||||
recording_data = _to_dict(ctx.task_output(get_recording))
|
||||
mtg_session_id = recording_data.get("mtg_session_id")
|
||||
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
@@ -209,9 +258,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
parse_daily_recording_filename,
|
||||
)
|
||||
|
||||
# Get transcript and reset events/topics/participants (matches Celery line 599-607)
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
# Get transcript and reset events/topics/participants
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
transcripts_controller,
|
||||
@@ -219,8 +267,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
# Reset events/topics/participants (matches Celery line 599-607)
|
||||
# Note: title NOT cleared - Celery preserves existing titles
|
||||
# Reset events/topics/participants
|
||||
# Note: title NOT cleared - preserves existing titles
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
@@ -237,16 +285,12 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
"completed",
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
return {
|
||||
"participants": [],
|
||||
"num_tracks": len(input.tracks),
|
||||
"source_language": transcript.source_language
|
||||
if transcript
|
||||
else "en",
|
||||
"target_language": transcript.target_language
|
||||
if transcript
|
||||
else "en",
|
||||
}
|
||||
return ParticipantsResult(
|
||||
participants=[],
|
||||
num_tracks=len(input.tracks),
|
||||
source_language=transcript.source_language if transcript else "en",
|
||||
target_language=transcript.target_language if transcript else "en",
|
||||
)
|
||||
|
||||
# Fetch participants from Daily API
|
||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||
@@ -264,7 +308,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
track_keys = [t["s3_key"] for t in input.tracks]
|
||||
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
||||
|
||||
# Update participants in database (matches Celery lines 568-590)
|
||||
# Update participants in database
|
||||
participants_list = []
|
||||
for idx, key in enumerate(cam_audio_keys):
|
||||
try:
|
||||
@@ -299,46 +343,31 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
participant_count=len(participants_list),
|
||||
)
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"participants": participants_list,
|
||||
"num_tracks": len(input.tracks),
|
||||
"source_language": transcript.source_language if transcript else "en",
|
||||
"target_language": transcript.target_language if transcript else "en",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
|
||||
return ParticipantsResult(
|
||||
participants=participants_list,
|
||||
num_tracks=len(input.tracks),
|
||||
source_language=transcript.source_language if transcript else "en",
|
||||
target_language=transcript.target_language if transcript else "en",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
|
||||
)
|
||||
async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Spawn child workflows for each track (dynamic fan-out).
|
||||
|
||||
Processes pad_track and transcribe_track for each audio track in parallel.
|
||||
"""
|
||||
@with_error_handling("process_tracks")
|
||||
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
||||
"""Spawn child workflows for each track (dynamic fan-out)."""
|
||||
logger.info(
|
||||
"[Hatchet] process_tracks",
|
||||
num_tracks=len(input.tracks),
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
|
||||
participants_data = ctx.task_output(get_participants)
|
||||
participants_data = _to_dict(ctx.task_output(get_participants))
|
||||
source_language = participants_data.get("source_language", "en")
|
||||
|
||||
# Spawn child workflows for each track with correct language
|
||||
@@ -361,61 +390,49 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
# Get target_language for later use in detect_topics
|
||||
target_language = participants_data.get("target_language", "en")
|
||||
|
||||
# Collect all track results
|
||||
all_words = []
|
||||
# Collect results from each track (don't mutate lists while iterating)
|
||||
track_words = []
|
||||
padded_urls = []
|
||||
created_padded_files = set()
|
||||
|
||||
for result in results:
|
||||
transcribe_result = result.get("transcribe_track", {})
|
||||
all_words.extend(transcribe_result.get("words", []))
|
||||
track_words.append(transcribe_result.get("words", []))
|
||||
|
||||
pad_result = result.get("pad_track", {})
|
||||
padded_urls.append(pad_result.get("padded_url"))
|
||||
|
||||
# Track padded files for cleanup (matches Celery line 636-637)
|
||||
# Track padded files for cleanup
|
||||
track_index = pad_result.get("track_index")
|
||||
if pad_result.get("size", 0) > 0 and track_index is not None:
|
||||
# File was created (size > 0 means padding was applied)
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
||||
created_padded_files.add(storage_path)
|
||||
|
||||
# Sort words by start time
|
||||
# Merge all words and sort by start time
|
||||
all_words = [word for words in track_words for word in words]
|
||||
all_words.sort(key=lambda w: w.get("start", 0))
|
||||
|
||||
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
|
||||
# Mixdown needs the padded files, so we can't delete them here
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] process_tracks complete",
|
||||
num_tracks=len(input.tracks),
|
||||
total_words=len(all_words),
|
||||
)
|
||||
|
||||
return {
|
||||
"all_words": all_words,
|
||||
"padded_urls": padded_urls,
|
||||
"word_count": len(all_words),
|
||||
"num_tracks": len(input.tracks),
|
||||
"target_language": target_language,
|
||||
"created_padded_files": list(
|
||||
created_padded_files
|
||||
), # For cleanup after mixdown
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
|
||||
return ProcessTracksResult(
|
||||
all_words=all_words,
|
||||
padded_urls=padded_urls,
|
||||
word_count=len(all_words),
|
||||
num_tracks=len(input.tracks),
|
||||
target_language=target_language,
|
||||
created_padded_files=list(created_padded_files),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("mixdown_tracks")
|
||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -423,8 +440,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
padded_urls = track_data.get("padded_urls", [])
|
||||
|
||||
if not padded_urls:
|
||||
@@ -581,7 +597,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3)
|
||||
# Upload mixed file to storage
|
||||
file_size = Path(output_path).stat().st_size
|
||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||
|
||||
@@ -590,9 +606,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
# Update transcript with audio_location (matches Celery line 661)
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
# Update transcript with audio_location
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
@@ -600,8 +615,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
await transcripts_controller.update(
|
||||
transcript, {"audio_location": "storage"}
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] mixdown_tracks uploaded",
|
||||
@@ -613,27 +626,18 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"audio_key": storage_path,
|
||||
"duration": duration_ms[
|
||||
0
|
||||
], # Duration in milliseconds from AudioFileWriterProcessor
|
||||
"tracks_mixed": len(valid_urls),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
|
||||
return MixdownResult(
|
||||
audio_key=storage_path,
|
||||
duration=duration_ms[0],
|
||||
tracks_mixed=len(valid_urls),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
||||
)
|
||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("generate_waveform")
|
||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
|
||||
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -641,15 +645,13 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
# Cleanup temporary padded S3 files (matches Celery lines 710-725)
|
||||
# Moved here from process_tracks because mixdown_tracks needs the padded files
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
# Cleanup temporary padded S3 files (deferred until after mixdown)
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
created_padded_files = track_data.get("created_padded_files", [])
|
||||
if created_padded_files:
|
||||
logger.info(
|
||||
@@ -660,9 +662,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
for storage_path in created_padded_files:
|
||||
cleanup_tasks.append(storage.delete_file(storage_path))
|
||||
|
||||
cleanup_results = await asyncio.gather(
|
||||
*cleanup_tasks, return_exceptions=True
|
||||
)
|
||||
cleanup_results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||
for storage_path, result in zip(created_padded_files, cleanup_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(
|
||||
@@ -671,7 +671,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
error=str(result),
|
||||
)
|
||||
|
||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||
audio_key = mixdown_data.get("audio_key")
|
||||
|
||||
storage = _get_storage()
|
||||
@@ -692,20 +692,19 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Generate waveform (matches Celery: get_audio_waveform with 255 segments)
|
||||
waveform = get_audio_waveform(path=Path(temp_path), segments_count=255)
|
||||
# Generate waveform
|
||||
waveform = get_audio_waveform(
|
||||
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
|
||||
)
|
||||
|
||||
# Save waveform to database via event (matches Celery on_waveform callback)
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
# Save waveform to database via event
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform_data
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
@@ -716,21 +715,14 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"waveform_generated": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return WaveformResult(waveform_generated=True)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("detect_topics")
|
||||
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -738,8 +730,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
words = track_data.get("all_words", [])
|
||||
target_language = track_data.get("target_language", "en")
|
||||
|
||||
@@ -757,13 +748,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
|
||||
# Get DB connection for callbacks
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that upserts topics to DB (matches Celery on_topic)
|
||||
# Callback that upserts topics to DB
|
||||
async def on_topic_callback(data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
@@ -785,8 +773,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
on_topic_callback=on_topic_callback,
|
||||
empty_pipeline=empty_pipeline,
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
topics_list = [t.model_dump() for t in topics]
|
||||
|
||||
@@ -796,21 +782,14 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"topics": topics_list}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return TopicsResult(topics=topics_list)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
||||
)
|
||||
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("generate_title")
|
||||
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -818,8 +797,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
topics_data = ctx.task_output(detect_topics)
|
||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
@@ -834,11 +812,10 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
title_result = None
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that updates title in DB (matches Celery on_title)
|
||||
# Callback that updates title in DB
|
||||
async def on_title_callback(data):
|
||||
nonlocal title_result
|
||||
title_result = data.title
|
||||
@@ -858,8 +835,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
empty_pipeline=empty_pipeline,
|
||||
logger=logger,
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
logger.info("[Hatchet] generate_title complete", title=title_result)
|
||||
|
||||
@@ -867,21 +842,14 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"title": title_result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return TitleResult(title=title_result)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("generate_summary")
|
||||
async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -889,8 +857,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
topics_data = ctx.task_output(detect_topics)
|
||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
@@ -907,11 +874,10 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
summary_result = None
|
||||
short_summary_result = None
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that updates long_summary in DB (matches Celery on_long_summary)
|
||||
# Callback that updates long_summary in DB
|
||||
async def on_long_summary_callback(data):
|
||||
nonlocal summary_result
|
||||
summary_result = data.long_summary
|
||||
@@ -928,7 +894,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
data=final_long_summary,
|
||||
)
|
||||
|
||||
# Callback that updates short_summary in DB (matches Celery on_short_summary)
|
||||
# Callback that updates short_summary in DB
|
||||
async def on_short_summary_callback(data):
|
||||
nonlocal short_summary_result
|
||||
short_summary_result = data.short_summary
|
||||
@@ -953,8 +919,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
empty_pipeline=empty_pipeline,
|
||||
logger=logger,
|
||||
)
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
logger.info("[Hatchet] generate_summary complete")
|
||||
|
||||
@@ -962,15 +926,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"summary": summary_result, "short_summary": short_summary_result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
@@ -978,7 +934,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
execution_timeout=timedelta(seconds=60),
|
||||
retries=3,
|
||||
)
|
||||
async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("finalize")
|
||||
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
|
||||
|
||||
Matches Celery's on_transcript + set_status behavior.
|
||||
@@ -990,33 +947,28 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
|
||||
duration = mixdown_data.get("duration", 0)
|
||||
all_words = track_data.get("all_words", [])
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import TranscriptText, transcripts_controller
|
||||
from reflector.processors.types import Transcript as TranscriptType
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript is None:
|
||||
raise ValueError(
|
||||
f"Transcript {input.transcript_id} not found in database"
|
||||
)
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found in database")
|
||||
|
||||
# Convert words back to Word objects for storage
|
||||
word_objects = [Word(**w) for w in all_words]
|
||||
|
||||
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736)
|
||||
# Create merged transcript for TRANSCRIPT event
|
||||
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
||||
|
||||
# Emit TRANSCRIPT event (matches Celery on_transcript callback)
|
||||
# Emit TRANSCRIPT event
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
@@ -1036,35 +988,23 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
||||
},
|
||||
)
|
||||
|
||||
# Set status to "ended" (matches Celery line 745)
|
||||
# Set status to "ended"
|
||||
await transcripts_controller.set_status(input.transcript_id, "ended")
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] finalize complete", transcript_id=input.transcript_id
|
||||
)
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"status": "COMPLETED"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
|
||||
await _set_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return FinalizeResult(status="COMPLETED")
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
|
||||
)
|
||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("cleanup_consent", set_error_status=False)
|
||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
"""Check and handle consent requirements."""
|
||||
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -1072,10 +1012,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
@@ -1091,27 +1028,18 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
||||
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
||||
)
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"consent_checked": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return ConsentResult(consent_checked=True)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
|
||||
)
|
||||
async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("post_zulip", set_error_status=False)
|
||||
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||
"""Post notification to Zulip."""
|
||||
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -1119,7 +1047,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
from reflector.settings import settings
|
||||
|
||||
if not settings.ZULIP_REALM:
|
||||
@@ -1127,45 +1054,32 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {"zulip_message_id": None, "skipped": True}
|
||||
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||
|
||||
from reflector.zulip import post_transcript_notification
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
message_id = await post_transcript_notification(transcript)
|
||||
logger.info(
|
||||
"[Hatchet] post_zulip complete", zulip_message_id=message_id
|
||||
)
|
||||
logger.info("[Hatchet] post_zulip complete", zulip_message_id=message_id)
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"zulip_message_id": message_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return ZulipResult(zulip_message_id=message_id)
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
|
||||
)
|
||||
async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
||||
@with_error_handling("send_webhook", set_error_status=False)
|
||||
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||
"""Send completion webhook to external service."""
|
||||
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
||||
|
||||
@@ -1173,17 +1087,14 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
if not input.room_id:
|
||||
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {"webhook_sent": False, "skipped": True}
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
@@ -1217,20 +1128,10 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
|
||||
return {"webhook_sent": True, "response_code": response.status_code}
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"webhook_sent": False, "skipped": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
123
server/reflector/hatchet/workflows/models.py
Normal file
123
server/reflector/hatchet/workflows/models.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Pydantic models for Hatchet workflow task return types.
|
||||
|
||||
Provides static typing for all task outputs, enabling type checking
|
||||
and better IDE support.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# ============================================================================
|
||||
# Track Processing Results (track_processing.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PadTrackResult(BaseModel):
|
||||
"""Result from pad_track task."""
|
||||
|
||||
padded_url: str
|
||||
size: int
|
||||
track_index: int
|
||||
|
||||
|
||||
class TranscribeTrackResult(BaseModel):
|
||||
"""Result from transcribe_track task."""
|
||||
|
||||
words: list[dict[str, Any]]
|
||||
track_index: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Diarization Pipeline Results (diarization_pipeline.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RecordingResult(BaseModel):
|
||||
"""Result from get_recording task."""
|
||||
|
||||
id: str | None
|
||||
mtg_session_id: str | None
|
||||
room_name: str | None
|
||||
duration: float
|
||||
|
||||
|
||||
class ParticipantsResult(BaseModel):
|
||||
"""Result from get_participants task."""
|
||||
|
||||
participants: list[dict[str, Any]]
|
||||
num_tracks: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
|
||||
class ProcessTracksResult(BaseModel):
|
||||
"""Result from process_tracks task."""
|
||||
|
||||
all_words: list[dict[str, Any]]
|
||||
padded_urls: list[str | None]
|
||||
word_count: int
|
||||
num_tracks: int
|
||||
target_language: str
|
||||
created_padded_files: list[str]
|
||||
|
||||
|
||||
class MixdownResult(BaseModel):
|
||||
"""Result from mixdown_tracks task."""
|
||||
|
||||
audio_key: str
|
||||
duration: float
|
||||
tracks_mixed: int
|
||||
|
||||
|
||||
class WaveformResult(BaseModel):
|
||||
"""Result from generate_waveform task."""
|
||||
|
||||
waveform_generated: bool
|
||||
|
||||
|
||||
class TopicsResult(BaseModel):
|
||||
"""Result from detect_topics task."""
|
||||
|
||||
topics: list[dict[str, Any]]
|
||||
|
||||
|
||||
class TitleResult(BaseModel):
|
||||
"""Result from generate_title task."""
|
||||
|
||||
title: str | None
|
||||
|
||||
|
||||
class SummaryResult(BaseModel):
|
||||
"""Result from generate_summary task."""
|
||||
|
||||
summary: str | None
|
||||
short_summary: str | None
|
||||
|
||||
|
||||
class FinalizeResult(BaseModel):
|
||||
"""Result from finalize task."""
|
||||
|
||||
status: str
|
||||
|
||||
|
||||
class ConsentResult(BaseModel):
|
||||
"""Result from cleanup_consent task."""
|
||||
|
||||
consent_checked: bool
|
||||
|
||||
|
||||
class ZulipResult(BaseModel):
|
||||
"""Result from post_zulip task."""
|
||||
|
||||
zulip_message_id: int | None = None
|
||||
skipped: bool = False
|
||||
|
||||
|
||||
class WebhookResult(BaseModel):
|
||||
"""Result from send_webhook task."""
|
||||
|
||||
webhook_sent: bool
|
||||
skipped: bool = False
|
||||
response_code: int | None = None
|
||||
@@ -18,8 +18,17 @@ from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
def _to_dict(output) -> dict:
|
||||
"""Convert task output to dict, handling both dict and Pydantic model returns."""
|
||||
if isinstance(output, dict):
|
||||
return output
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
# Audio constants matching existing pipeline
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
|
||||
|
||||
|
||||
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad single audio track with silence for alignment.
|
||||
|
||||
Extracts stream.start_time from WebM container metadata and applies
|
||||
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {
|
||||
"padded_url": source_url,
|
||||
"size": 0,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return PadTrackResult(
|
||||
padded_url=source_url,
|
||||
size=0,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
# Create temp file for padded output
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"padded_url": padded_url,
|
||||
"size": file_size,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return PadTrackResult(
|
||||
padded_url=padded_url,
|
||||
size=file_size,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
@track_workflow.task(
|
||||
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
|
||||
)
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track",
|
||||
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
)
|
||||
|
||||
try:
|
||||
pad_result = ctx.task_output(pad_track)
|
||||
pad_result = _to_dict(ctx.task_output(pad_track))
|
||||
audio_url = pad_result.get("padded_url")
|
||||
|
||||
if not audio_url:
|
||||
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return TranscribeTrackResult(
|
||||
words=words,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||
|
||||
@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
|
||||
transcript, {"workflow_run_id": None}
|
||||
)
|
||||
|
||||
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
config.transcript_id
|
||||
)
|
||||
if transcript and transcript.workflow_run_id:
|
||||
# Another process started a workflow between validation and now
|
||||
try:
|
||||
status = await HatchetClientManager.get_workflow_run_status(
|
||||
transcript.workflow_run_id
|
||||
)
|
||||
if "RUNNING" in status or "QUEUED" in status:
|
||||
logger.info(
|
||||
"Concurrent workflow detected, skipping dispatch",
|
||||
workflow_id=transcript.workflow_run_id,
|
||||
)
|
||||
return transcript.workflow_run_id
|
||||
except Exception:
|
||||
# If we can't get status, proceed with new workflow
|
||||
pass
|
||||
|
||||
workflow_id = await HatchetClientManager.start_workflow(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
|
||||
"transcript_id": config.transcript_id,
|
||||
"room_id": config.room_id,
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": config.transcript_id,
|
||||
"recording_id": config.recording_id,
|
||||
"daily_recording_id": config.recording_id,
|
||||
},
|
||||
)
|
||||
|
||||
if transcript:
|
||||
|
||||
@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
|
||||
"transcript_id": transcript.id,
|
||||
"room_id": room.id,
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": transcript.id,
|
||||
"recording_id": recording_id,
|
||||
"daily_recording_id": recording_id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Started Hatchet workflow",
|
||||
|
||||
@@ -527,6 +527,22 @@ def fake_mp3_upload():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_hatchet_client():
|
||||
"""Reset HatchetClientManager singleton before and after each test.
|
||||
|
||||
This ensures test isolation - each test starts with a fresh client state.
|
||||
The fixture is autouse=True so it applies to all tests automatically.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
# Reset before test
|
||||
HatchetClientManager.reset()
|
||||
yield
|
||||
# Reset after test to clean up
|
||||
HatchetClientManager.reset()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir, client):
|
||||
import shutil
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
Tests for HatchetClientManager error handling and validation.
|
||||
|
||||
Only tests that catch real bugs - not mock verification tests.
|
||||
|
||||
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
|
||||
automatically resets the singleton before and after each test.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
|
||||
mock_settings.HATCHET_DEBUG = False
|
||||
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
||||
# Should return False on error (workflow might be gone)
|
||||
assert can_replay is False
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
|
||||
def test_hatchet_client_raises_without_token():
|
||||
"""Test that get_client raises ValueError without token.
|
||||
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||
mock_settings.HATCHET_CLIENT_TOKEN = None
|
||||
|
||||
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
||||
HatchetClientManager.get_client()
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
Reference in New Issue
Block a user