mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
self-review (no-mistakes)
This commit is contained in:
115
TASKS.md
115
TASKS.md
@@ -1,115 +0,0 @@
|
|||||||
# Durable Workflow Migration Tasks
|
|
||||||
|
|
||||||
This document defines atomic, isolated work items for migrating the Daily.co multitrack diarization pipeline from Celery to durable workflow orchestration using **Hatchet**.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Provider Selection
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# .env
|
|
||||||
DURABLE_WORKFLOW_PROVIDER=none # Celery only (default)
|
|
||||||
DURABLE_WORKFLOW_PROVIDER=hatchet # Use Hatchet
|
|
||||||
DURABLE_WORKFLOW_SHADOW_MODE=true # Run both Hatchet + Celery (for comparison)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task Index
|
|
||||||
|
|
||||||
| ID | Title | Status |
|
|
||||||
|----|-------|--------|
|
|
||||||
| INFRA-001 | Add container to docker-compose | Done |
|
|
||||||
| INFRA-002 | Create Python client wrapper | Done |
|
|
||||||
| INFRA-003 | Add environment configuration | Done |
|
|
||||||
| TASK-001 | Create workflow definition | Done |
|
|
||||||
| TASK-002 | get_recording task | Done |
|
|
||||||
| TASK-003 | get_participants task | Done |
|
|
||||||
| TASK-004 | pad_track task | Done |
|
|
||||||
| TASK-005 | mixdown_tracks task | Done |
|
|
||||||
| TASK-006 | generate_waveform task | Done |
|
|
||||||
| TASK-007 | transcribe_track task | Done |
|
|
||||||
| TASK-008 | merge_transcripts task | Done (in process_tracks) |
|
|
||||||
| TASK-009 | detect_topics task | Done |
|
|
||||||
| TASK-010 | generate_title task | Done |
|
|
||||||
| TASK-011 | generate_summary task | Done |
|
|
||||||
| TASK-012 | finalize task | Done |
|
|
||||||
| TASK-013 | cleanup_consent task | Done |
|
|
||||||
| TASK-014 | post_zulip task | Done |
|
|
||||||
| TASK-015 | send_webhook task | Done |
|
|
||||||
| EVENT-001 | Progress WebSocket events | Done |
|
|
||||||
| INTEG-001 | Pipeline trigger integration | Done |
|
|
||||||
| SHADOW-001 | Shadow mode toggle | Done |
|
|
||||||
| TEST-001 | Integration tests | Pending |
|
|
||||||
| TEST-002 | E2E workflow test | Pending |
|
|
||||||
| CUTOVER-001 | Production cutover | Pending |
|
|
||||||
| CLEANUP-001 | Remove Celery code | Pending |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## File Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
server/reflector/hatchet/
|
|
||||||
├── client.py # SDK wrapper
|
|
||||||
├── progress.py # WebSocket progress emission
|
|
||||||
├── run_workers.py # Worker startup
|
|
||||||
└── workflows/
|
|
||||||
├── diarization_pipeline.py # Main workflow with all tasks
|
|
||||||
└── track_processing.py # Child workflow (pad + transcribe)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Remaining Work
|
|
||||||
|
|
||||||
### TEST-001: Integration Tests
|
|
||||||
- [ ] Test each task with mocked external services
|
|
||||||
- [ ] Test error handling and retries
|
|
||||||
|
|
||||||
### TEST-002: E2E Workflow Test
|
|
||||||
- [ ] Complete workflow run with real Daily.co recording
|
|
||||||
- [ ] Verify output matches Celery pipeline
|
|
||||||
- [ ] Performance comparison
|
|
||||||
|
|
||||||
### CUTOVER-001: Production Cutover
|
|
||||||
- [ ] Deploy with `DURABLE_WORKFLOW_PROVIDER=hatchet`
|
|
||||||
- [ ] Monitor for failures
|
|
||||||
- [ ] Compare results with shadow mode if needed
|
|
||||||
|
|
||||||
### CLEANUP-001: Remove Celery Code
|
|
||||||
- [ ] Remove `main_multitrack_pipeline.py`
|
|
||||||
- [ ] Remove Celery task triggers
|
|
||||||
- [ ] Update documentation
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Known Issues
|
|
||||||
|
|
||||||
### Hatchet
|
|
||||||
- See `HATCHET_LLM_OBSERVATIONS.md` for debugging notes
|
|
||||||
- SDK v1.21+ API changes (breaking)
|
|
||||||
- JWT token Docker networking issues
|
|
||||||
- Worker appears hung without debug mode
|
|
||||||
- Workflow replay is version-locked (use --force to run latest code)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Hatchet
|
|
||||||
```bash
|
|
||||||
# Start infrastructure
|
|
||||||
docker compose up -d hatchet hatchet-worker
|
|
||||||
|
|
||||||
# Workers auto-register on startup
|
|
||||||
```
|
|
||||||
|
|
||||||
### Trigger Workflow
|
|
||||||
```bash
|
|
||||||
# Set provider in .env
|
|
||||||
DURABLE_WORKFLOW_PROVIDER=hatchet
|
|
||||||
|
|
||||||
# Process a Daily.co recording via webhook or API
|
|
||||||
# The pipeline trigger automatically uses the configured provider
|
|
||||||
```
|
|
||||||
@@ -1,37 +1,71 @@
|
|||||||
"""Hatchet Python client wrapper."""
|
"""Hatchet Python client wrapper.
|
||||||
|
|
||||||
from hatchet_sdk import Hatchet
|
Uses singleton pattern because:
|
||||||
|
1. Hatchet client maintains persistent gRPC connections for workflow registration
|
||||||
|
2. Creating multiple clients would cause registration conflicts and resource leaks
|
||||||
|
3. The SDK is designed for a single client instance per process
|
||||||
|
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from hatchet_sdk import ClientConfig, Hatchet
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class HatchetClientManager:
|
class HatchetClientManager:
|
||||||
"""Singleton manager for Hatchet client connections."""
|
"""Singleton manager for Hatchet client connections.
|
||||||
|
|
||||||
|
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
|
||||||
|
connections for workflow registration, and multiple clients would conflict.
|
||||||
|
|
||||||
|
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
|
||||||
|
to ensure test isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
_instance: Hatchet | None = None
|
_instance: Hatchet | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_client(cls) -> Hatchet:
|
def get_client(cls) -> Hatchet:
|
||||||
"""Get or create the Hatchet client."""
|
"""Get or create the Hatchet client.
|
||||||
|
|
||||||
|
Configures root logger so all logger.info() calls in workflows
|
||||||
|
appear in the Hatchet dashboard logs.
|
||||||
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
if not settings.HATCHET_CLIENT_TOKEN:
|
if not settings.HATCHET_CLIENT_TOKEN:
|
||||||
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
||||||
|
|
||||||
|
# Pass root logger to Hatchet so workflow logs appear in dashboard
|
||||||
|
root_logger = logging.getLogger()
|
||||||
cls._instance = Hatchet(
|
cls._instance = Hatchet(
|
||||||
debug=settings.HATCHET_DEBUG,
|
debug=settings.HATCHET_DEBUG,
|
||||||
|
config=ClientConfig(logger=root_logger),
|
||||||
)
|
)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def start_workflow(
|
async def start_workflow(
|
||||||
cls, workflow_name: str, input_data: dict, key: str | None = None
|
cls,
|
||||||
|
workflow_name: str,
|
||||||
|
input_data: dict,
|
||||||
|
additional_metadata: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Start a workflow and return the workflow run ID."""
|
"""Start a workflow and return the workflow run ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_name: Name of the workflow to trigger.
|
||||||
|
input_data: Input data for the workflow run.
|
||||||
|
additional_metadata: Optional metadata for filtering in dashboard
|
||||||
|
(e.g., transcript_id, recording_id).
|
||||||
|
"""
|
||||||
client = cls.get_client()
|
client = cls.get_client()
|
||||||
result = await client.runs.aio_create(
|
result = await client.runs.aio_create(
|
||||||
workflow_name,
|
workflow_name,
|
||||||
input_data,
|
input_data,
|
||||||
|
additional_metadata=additional_metadata,
|
||||||
)
|
)
|
||||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||||
return result.run.metadata.id
|
return result.run.metadata.id
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ Orchestrates the full processing flow from recording metadata to final transcrip
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from hatchet_sdk import Context
|
from hatchet_sdk import Context
|
||||||
@@ -16,6 +19,20 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
from reflector.hatchet.progress import emit_progress_async
|
||||||
|
from reflector.hatchet.workflows.models import (
|
||||||
|
ConsentResult,
|
||||||
|
FinalizeResult,
|
||||||
|
MixdownResult,
|
||||||
|
ParticipantsResult,
|
||||||
|
ProcessTracksResult,
|
||||||
|
RecordingResult,
|
||||||
|
SummaryResult,
|
||||||
|
TitleResult,
|
||||||
|
TopicsResult,
|
||||||
|
WaveformResult,
|
||||||
|
WebhookResult,
|
||||||
|
ZulipResult,
|
||||||
|
)
|
||||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
|
||||||
@@ -23,6 +40,7 @@ from reflector.logger import logger
|
|||||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||||
OPUS_DEFAULT_BIT_RATE = 64000
|
OPUS_DEFAULT_BIT_RATE = 64000
|
||||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||||
|
WAVEFORM_SEGMENTS = 255
|
||||||
|
|
||||||
|
|
||||||
class PipelineInput(BaseModel):
|
class PipelineInput(BaseModel):
|
||||||
@@ -49,8 +67,9 @@ diarization_pipeline = hatchet.workflow(
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
async def _get_fresh_db_connection():
|
@asynccontextmanager
|
||||||
"""Create fresh database connection for subprocess."""
|
async def fresh_db_connection():
|
||||||
|
"""Context manager for database connections in Hatchet workers."""
|
||||||
import databases
|
import databases
|
||||||
|
|
||||||
from reflector.db import _database_context
|
from reflector.db import _database_context
|
||||||
@@ -60,22 +79,22 @@ async def _get_fresh_db_connection():
|
|||||||
db = databases.Database(settings.DATABASE_URL)
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
_database_context.set(db)
|
_database_context.set(db)
|
||||||
await db.connect()
|
await db.connect()
|
||||||
return db
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
async def _close_db_connection(db):
|
|
||||||
"""Close database connection."""
|
|
||||||
from reflector.db import _database_context
|
|
||||||
|
|
||||||
await db.disconnect()
|
await db.disconnect()
|
||||||
_database_context.set(None)
|
_database_context.set(None)
|
||||||
|
|
||||||
|
|
||||||
async def _set_error_status(transcript_id: str):
|
async def set_workflow_error_status(transcript_id: str) -> bool:
|
||||||
"""Set transcript status to 'error' on workflow failure (matches Celery line 790)."""
|
"""Set transcript status to 'error' on workflow failure.
|
||||||
try:
|
|
||||||
db = await _get_fresh_db_connection()
|
Returns:
|
||||||
|
True if status was set successfully, False if failed.
|
||||||
|
Failure is logged as CRITICAL since it means transcript may be stuck.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
await transcripts_controller.set_status(transcript_id, "error")
|
await transcripts_controller.set_status(transcript_id, "error")
|
||||||
@@ -83,14 +102,15 @@ async def _set_error_status(transcript_id: str):
|
|||||||
"[Hatchet] Set transcript status to error",
|
"[Hatchet] Set transcript status to error",
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
)
|
)
|
||||||
finally:
|
return True
|
||||||
await _close_db_connection(db)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.critical(
|
||||||
"[Hatchet] Failed to set error status",
|
"[Hatchet] CRITICAL: Failed to set error status - transcript may be stuck in 'processing'",
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_storage():
|
def _get_storage():
|
||||||
@@ -106,13 +126,57 @@ def _get_storage():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict(output) -> dict:
|
||||||
|
"""Convert task output to dict, handling both dict and Pydantic model returns.
|
||||||
|
|
||||||
|
Hatchet SDK returns Pydantic models when tasks have typed return annotations,
|
||||||
|
but older code expects dicts. This helper normalizes the output.
|
||||||
|
"""
|
||||||
|
if isinstance(output, dict):
|
||||||
|
return output
|
||||||
|
return output.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
|
||||||
|
"""Decorator that handles task failures uniformly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_name: Name of the step for logging and progress tracking.
|
||||||
|
set_error_status: Whether to set transcript status to 'error' on failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(input: PipelineInput, ctx: Context):
|
||||||
|
try:
|
||||||
|
return await func(input, ctx)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Hatchet] {step_name} failed",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if set_error_status:
|
||||||
|
await set_workflow_error_status(input.transcript_id)
|
||||||
|
await emit_progress_async(
|
||||||
|
input.transcript_id, step_name, "failed", ctx.workflow_run_id
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Pipeline Tasks
|
# Pipeline Tasks
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
||||||
async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("get_recording")
|
||||||
|
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||||
"""Fetch recording metadata from Daily.co API."""
|
"""Fetch recording metadata from Daily.co API."""
|
||||||
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
||||||
|
|
||||||
@@ -120,9 +184,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set transcript status to "processing" at workflow start (matches Celery behavior)
|
# Set transcript status to "processing" at workflow start
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
@@ -132,10 +195,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
"[Hatchet] Set transcript status to processing",
|
"[Hatchet] Set transcript status to processing",
|
||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
@@ -144,12 +204,12 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
return {
|
return RecordingResult(
|
||||||
"id": None,
|
id=None,
|
||||||
"mtg_session_id": None,
|
mtg_session_id=None,
|
||||||
"room_name": input.room_name,
|
room_name=input.room_name,
|
||||||
"duration": 0,
|
duration=0,
|
||||||
}
|
)
|
||||||
|
|
||||||
if not settings.DAILY_API_KEY:
|
if not settings.DAILY_API_KEY:
|
||||||
raise ValueError("DAILY_API_KEY not configured")
|
raise ValueError("DAILY_API_KEY not configured")
|
||||||
@@ -168,38 +228,27 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return RecordingResult(
|
||||||
"id": recording.id,
|
id=recording.id,
|
||||||
"mtg_session_id": recording.mtgSessionId,
|
mtg_session_id=recording.mtgSessionId,
|
||||||
"room_name": recording.room_name,
|
room_name=recording.room_name,
|
||||||
"duration": recording.duration,
|
duration=recording.duration,
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
|
|
||||||
)
|
)
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
||||||
)
|
)
|
||||||
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("get_participants")
|
||||||
"""Fetch participant list from Daily.co API and update transcript in database.
|
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
|
||||||
|
"""Fetch participant list from Daily.co API and update transcript in database."""
|
||||||
Matches Celery's update_participants_from_daily() behavior.
|
|
||||||
"""
|
|
||||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
recording_data = _to_dict(ctx.task_output(get_recording))
|
||||||
recording_data = ctx.task_output(get_recording)
|
|
||||||
mtg_session_id = recording_data.get("mtg_session_id")
|
mtg_session_id = recording_data.get("mtg_session_id")
|
||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
@@ -209,9 +258,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parse_daily_recording_filename,
|
parse_daily_recording_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get transcript and reset events/topics/participants (matches Celery line 599-607)
|
# Get transcript and reset events/topics/participants
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
TranscriptParticipant,
|
TranscriptParticipant,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
@@ -219,8 +267,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
# Reset events/topics/participants (matches Celery line 599-607)
|
# Reset events/topics/participants
|
||||||
# Note: title NOT cleared - Celery preserves existing titles
|
# Note: title NOT cleared - preserves existing titles
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
@@ -237,16 +285,12 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
"completed",
|
"completed",
|
||||||
ctx.workflow_run_id,
|
ctx.workflow_run_id,
|
||||||
)
|
)
|
||||||
return {
|
return ParticipantsResult(
|
||||||
"participants": [],
|
participants=[],
|
||||||
"num_tracks": len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
"source_language": transcript.source_language
|
source_language=transcript.source_language if transcript else "en",
|
||||||
if transcript
|
target_language=transcript.target_language if transcript else "en",
|
||||||
else "en",
|
)
|
||||||
"target_language": transcript.target_language
|
|
||||||
if transcript
|
|
||||||
else "en",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fetch participants from Daily API
|
# Fetch participants from Daily API
|
||||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||||
@@ -264,7 +308,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
track_keys = [t["s3_key"] for t in input.tracks]
|
track_keys = [t["s3_key"] for t in input.tracks]
|
||||||
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
||||||
|
|
||||||
# Update participants in database (matches Celery lines 568-590)
|
# Update participants in database
|
||||||
participants_list = []
|
participants_list = []
|
||||||
for idx, key in enumerate(cam_audio_keys):
|
for idx, key in enumerate(cam_audio_keys):
|
||||||
try:
|
try:
|
||||||
@@ -299,46 +343,31 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
participant_count=len(participants_list),
|
participant_count=len(participants_list),
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return ParticipantsResult(
|
||||||
"participants": participants_list,
|
participants=participants_list,
|
||||||
"num_tracks": len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
"source_language": transcript.source_language if transcript else "en",
|
source_language=transcript.source_language if transcript else "en",
|
||||||
"target_language": transcript.target_language if transcript else "en",
|
target_language=transcript.target_language if transcript else "en",
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
|
|
||||||
)
|
)
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
|
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
|
||||||
)
|
)
|
||||||
async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("process_tracks")
|
||||||
"""Spawn child workflows for each track (dynamic fan-out).
|
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
||||||
|
"""Spawn child workflows for each track (dynamic fan-out)."""
|
||||||
Processes pad_track and transcribe_track for each audio track in parallel.
|
|
||||||
"""
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] process_tracks",
|
"[Hatchet] process_tracks",
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
participants_data = _to_dict(ctx.task_output(get_participants))
|
||||||
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
|
|
||||||
participants_data = ctx.task_output(get_participants)
|
|
||||||
source_language = participants_data.get("source_language", "en")
|
source_language = participants_data.get("source_language", "en")
|
||||||
|
|
||||||
# Spawn child workflows for each track with correct language
|
# Spawn child workflows for each track with correct language
|
||||||
@@ -361,61 +390,49 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
# Get target_language for later use in detect_topics
|
# Get target_language for later use in detect_topics
|
||||||
target_language = participants_data.get("target_language", "en")
|
target_language = participants_data.get("target_language", "en")
|
||||||
|
|
||||||
# Collect all track results
|
# Collect results from each track (don't mutate lists while iterating)
|
||||||
all_words = []
|
track_words = []
|
||||||
padded_urls = []
|
padded_urls = []
|
||||||
created_padded_files = set()
|
created_padded_files = set()
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
transcribe_result = result.get("transcribe_track", {})
|
transcribe_result = result.get("transcribe_track", {})
|
||||||
all_words.extend(transcribe_result.get("words", []))
|
track_words.append(transcribe_result.get("words", []))
|
||||||
|
|
||||||
pad_result = result.get("pad_track", {})
|
pad_result = result.get("pad_track", {})
|
||||||
padded_urls.append(pad_result.get("padded_url"))
|
padded_urls.append(pad_result.get("padded_url"))
|
||||||
|
|
||||||
# Track padded files for cleanup (matches Celery line 636-637)
|
# Track padded files for cleanup
|
||||||
track_index = pad_result.get("track_index")
|
track_index = pad_result.get("track_index")
|
||||||
if pad_result.get("size", 0) > 0 and track_index is not None:
|
if pad_result.get("size", 0) > 0 and track_index is not None:
|
||||||
# File was created (size > 0 means padding was applied)
|
|
||||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
||||||
created_padded_files.add(storage_path)
|
created_padded_files.add(storage_path)
|
||||||
|
|
||||||
# Sort words by start time
|
# Merge all words and sort by start time
|
||||||
|
all_words = [word for words in track_words for word in words]
|
||||||
all_words.sort(key=lambda w: w.get("start", 0))
|
all_words.sort(key=lambda w: w.get("start", 0))
|
||||||
|
|
||||||
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
|
|
||||||
# Mixdown needs the padded files, so we can't delete them here
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] process_tracks complete",
|
"[Hatchet] process_tracks complete",
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
total_words=len(all_words),
|
total_words=len(all_words),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return ProcessTracksResult(
|
||||||
"all_words": all_words,
|
all_words=all_words,
|
||||||
"padded_urls": padded_urls,
|
padded_urls=padded_urls,
|
||||||
"word_count": len(all_words),
|
word_count=len(all_words),
|
||||||
"num_tracks": len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
"target_language": target_language,
|
target_language=target_language,
|
||||||
"created_padded_files": list(
|
created_padded_files=list(created_padded_files),
|
||||||
created_padded_files
|
|
||||||
), # For cleanup after mixdown
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
|
|
||||||
)
|
)
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("mixdown_tracks")
|
||||||
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
||||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -423,8 +440,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
track_data = ctx.task_output(process_tracks)
|
|
||||||
padded_urls = track_data.get("padded_urls", [])
|
padded_urls = track_data.get("padded_urls", [])
|
||||||
|
|
||||||
if not padded_urls:
|
if not padded_urls:
|
||||||
@@ -581,7 +597,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3)
|
# Upload mixed file to storage
|
||||||
file_size = Path(output_path).stat().st_size
|
file_size = Path(output_path).stat().st_size
|
||||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||||
|
|
||||||
@@ -590,9 +606,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
Path(output_path).unlink(missing_ok=True)
|
Path(output_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
# Update transcript with audio_location (matches Celery line 661)
|
# Update transcript with audio_location
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
@@ -600,8 +615,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript, {"audio_location": "storage"}
|
transcript, {"audio_location": "storage"}
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] mixdown_tracks uploaded",
|
"[Hatchet] mixdown_tracks uploaded",
|
||||||
@@ -613,27 +626,18 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return MixdownResult(
|
||||||
"audio_key": storage_path,
|
audio_key=storage_path,
|
||||||
"duration": duration_ms[
|
duration=duration_ms[0],
|
||||||
0
|
tracks_mixed=len(valid_urls),
|
||||||
], # Duration in milliseconds from AudioFileWriterProcessor
|
|
||||||
"tracks_mixed": len(valid_urls),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
|
|
||||||
)
|
)
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
||||||
)
|
)
|
||||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("generate_waveform")
|
||||||
|
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
|
||||||
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
||||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -641,15 +645,13 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
|
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
|
||||||
# Cleanup temporary padded S3 files (matches Celery lines 710-725)
|
# Cleanup temporary padded S3 files (deferred until after mixdown)
|
||||||
# Moved here from process_tracks because mixdown_tracks needs the padded files
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
track_data = ctx.task_output(process_tracks)
|
|
||||||
created_padded_files = track_data.get("created_padded_files", [])
|
created_padded_files = track_data.get("created_padded_files", [])
|
||||||
if created_padded_files:
|
if created_padded_files:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -660,9 +662,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
for storage_path in created_padded_files:
|
for storage_path in created_padded_files:
|
||||||
cleanup_tasks.append(storage.delete_file(storage_path))
|
cleanup_tasks.append(storage.delete_file(storage_path))
|
||||||
|
|
||||||
cleanup_results = await asyncio.gather(
|
cleanup_results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||||
*cleanup_tasks, return_exceptions=True
|
|
||||||
)
|
|
||||||
for storage_path, result in zip(created_padded_files, cleanup_results):
|
for storage_path, result in zip(created_padded_files, cleanup_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -671,7 +671,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
error=str(result),
|
error=str(result),
|
||||||
)
|
)
|
||||||
|
|
||||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||||
audio_key = mixdown_data.get("audio_key")
|
audio_key = mixdown_data.get("audio_key")
|
||||||
|
|
||||||
storage = _get_storage()
|
storage = _get_storage()
|
||||||
@@ -692,20 +692,19 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
with open(temp_path, "wb") as f:
|
with open(temp_path, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
# Generate waveform (matches Celery: get_audio_waveform with 255 segments)
|
# Generate waveform
|
||||||
waveform = get_audio_waveform(path=Path(temp_path), segments_count=255)
|
waveform = get_audio_waveform(
|
||||||
|
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
|
||||||
|
)
|
||||||
|
|
||||||
# Save waveform to database via event (matches Celery on_waveform callback)
|
# Save waveform to database via event
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
waveform_data = TranscriptWaveform(waveform=waveform)
|
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||||
await transcripts_controller.append_event(
|
await transcripts_controller.append_event(
|
||||||
transcript=transcript, event="WAVEFORM", data=waveform_data
|
transcript=transcript, event="WAVEFORM", data=waveform_data
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
Path(temp_path).unlink(missing_ok=True)
|
Path(temp_path).unlink(missing_ok=True)
|
||||||
@@ -716,21 +715,14 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"waveform_generated": True}
|
return WaveformResult(waveform_generated=True)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("detect_topics")
|
||||||
|
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||||
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
||||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -738,8 +730,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
track_data = ctx.task_output(process_tracks)
|
|
||||||
words = track_data.get("all_words", [])
|
words = track_data.get("all_words", [])
|
||||||
target_language = track_data.get("target_language", "en")
|
target_language = track_data.get("target_language", "en")
|
||||||
|
|
||||||
@@ -757,13 +748,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
|
||||||
# Get DB connection for callbacks
|
async with fresh_db_connection():
|
||||||
db = await _get_fresh_db_connection()
|
|
||||||
|
|
||||||
try:
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that upserts topics to DB (matches Celery on_topic)
|
# Callback that upserts topics to DB
|
||||||
async def on_topic_callback(data):
|
async def on_topic_callback(data):
|
||||||
topic = TranscriptTopic(
|
topic = TranscriptTopic(
|
||||||
title=data.title,
|
title=data.title,
|
||||||
@@ -785,8 +773,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
on_topic_callback=on_topic_callback,
|
on_topic_callback=on_topic_callback,
|
||||||
empty_pipeline=empty_pipeline,
|
empty_pipeline=empty_pipeline,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
topics_list = [t.model_dump() for t in topics]
|
topics_list = [t.model_dump() for t in topics]
|
||||||
|
|
||||||
@@ -796,21 +782,14 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"topics": topics_list}
|
return TopicsResult(topics=topics_list)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
||||||
)
|
)
|
||||||
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("generate_title")
|
||||||
|
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||||
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
||||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -818,8 +797,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||||
topics_data = ctx.task_output(detect_topics)
|
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
@@ -834,11 +812,10 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
title_result = None
|
title_result = None
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that updates title in DB (matches Celery on_title)
|
# Callback that updates title in DB
|
||||||
async def on_title_callback(data):
|
async def on_title_callback(data):
|
||||||
nonlocal title_result
|
nonlocal title_result
|
||||||
title_result = data.title
|
title_result = data.title
|
||||||
@@ -858,8 +835,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
empty_pipeline=empty_pipeline,
|
empty_pipeline=empty_pipeline,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
logger.info("[Hatchet] generate_title complete", title=title_result)
|
logger.info("[Hatchet] generate_title complete", title=title_result)
|
||||||
|
|
||||||
@@ -867,21 +842,14 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"title": title_result}
|
return TitleResult(title=title_result)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("generate_summary")
|
||||||
|
async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||||
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
||||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -889,8 +857,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||||
topics_data = ctx.task_output(detect_topics)
|
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
@@ -907,11 +874,10 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
summary_result = None
|
summary_result = None
|
||||||
short_summary_result = None
|
short_summary_result = None
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
try:
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that updates long_summary in DB (matches Celery on_long_summary)
|
# Callback that updates long_summary in DB
|
||||||
async def on_long_summary_callback(data):
|
async def on_long_summary_callback(data):
|
||||||
nonlocal summary_result
|
nonlocal summary_result
|
||||||
summary_result = data.long_summary
|
summary_result = data.long_summary
|
||||||
@@ -928,7 +894,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
data=final_long_summary,
|
data=final_long_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callback that updates short_summary in DB (matches Celery on_short_summary)
|
# Callback that updates short_summary in DB
|
||||||
async def on_short_summary_callback(data):
|
async def on_short_summary_callback(data):
|
||||||
nonlocal short_summary_result
|
nonlocal short_summary_result
|
||||||
short_summary_result = data.short_summary
|
short_summary_result = data.short_summary
|
||||||
@@ -953,8 +919,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
empty_pipeline=empty_pipeline,
|
empty_pipeline=empty_pipeline,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
logger.info("[Hatchet] generate_summary complete")
|
logger.info("[Hatchet] generate_summary complete")
|
||||||
|
|
||||||
@@ -962,15 +926,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"summary": summary_result, "short_summary": short_summary_result}
|
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
@@ -978,7 +934,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
execution_timeout=timedelta(seconds=60),
|
execution_timeout=timedelta(seconds=60),
|
||||||
retries=3,
|
retries=3,
|
||||||
)
|
)
|
||||||
async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("finalize")
|
||||||
|
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||||
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
|
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
|
||||||
|
|
||||||
Matches Celery's on_transcript + set_status behavior.
|
Matches Celery's on_transcript + set_status behavior.
|
||||||
@@ -990,33 +947,28 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
track_data = ctx.task_output(process_tracks)
|
|
||||||
|
|
||||||
duration = mixdown_data.get("duration", 0)
|
duration = mixdown_data.get("duration", 0)
|
||||||
all_words = track_data.get("all_words", [])
|
all_words = track_data.get("all_words", [])
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.db.transcripts import TranscriptText, transcripts_controller
|
from reflector.db.transcripts import TranscriptText, transcripts_controller
|
||||||
from reflector.processors.types import Transcript as TranscriptType
|
from reflector.processors.types import Transcript as TranscriptType
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript is None:
|
if transcript is None:
|
||||||
raise ValueError(
|
raise ValueError(f"Transcript {input.transcript_id} not found in database")
|
||||||
f"Transcript {input.transcript_id} not found in database"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert words back to Word objects for storage
|
# Convert words back to Word objects for storage
|
||||||
word_objects = [Word(**w) for w in all_words]
|
word_objects = [Word(**w) for w in all_words]
|
||||||
|
|
||||||
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736)
|
# Create merged transcript for TRANSCRIPT event
|
||||||
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
||||||
|
|
||||||
# Emit TRANSCRIPT event (matches Celery on_transcript callback)
|
# Emit TRANSCRIPT event
|
||||||
await transcripts_controller.append_event(
|
await transcripts_controller.append_event(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TRANSCRIPT",
|
event="TRANSCRIPT",
|
||||||
@@ -1036,35 +988,23 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set status to "ended" (matches Celery line 745)
|
# Set status to "ended"
|
||||||
await transcripts_controller.set_status(input.transcript_id, "ended")
|
await transcripts_controller.set_status(input.transcript_id, "ended")
|
||||||
|
|
||||||
logger.info(
|
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
|
||||||
"[Hatchet] finalize complete", transcript_id=input.transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "COMPLETED"}
|
return FinalizeResult(status="COMPLETED")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
|
|
||||||
await _set_error_status(input.transcript_id)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
|
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
|
||||||
)
|
)
|
||||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("cleanup_consent", set_error_status=False)
|
||||||
|
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||||
"""Check and handle consent requirements."""
|
"""Check and handle consent requirements."""
|
||||||
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -1072,10 +1012,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
async with fresh_db_connection():
|
||||||
db = await _get_fresh_db_connection()
|
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
@@ -1091,27 +1028,18 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"consent_checked": True}
|
return ConsentResult(consent_checked=True)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
|
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
|
||||||
)
|
)
|
||||||
async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("post_zulip", set_error_status=False)
|
||||||
|
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||||
"""Post notification to Zulip."""
|
"""Post notification to Zulip."""
|
||||||
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -1119,7 +1047,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
if not settings.ZULIP_REALM:
|
if not settings.ZULIP_REALM:
|
||||||
@@ -1127,45 +1054,32 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
return {"zulip_message_id": None, "skipped": True}
|
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||||
|
|
||||||
from reflector.zulip import post_transcript_notification
|
from reflector.zulip import post_transcript_notification
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
message_id = await post_transcript_notification(transcript)
|
message_id = await post_transcript_notification(transcript)
|
||||||
logger.info(
|
logger.info("[Hatchet] post_zulip complete", zulip_message_id=message_id)
|
||||||
"[Hatchet] post_zulip complete", zulip_message_id=message_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
message_id = None
|
message_id = None
|
||||||
|
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"zulip_message_id": message_id}
|
return ZulipResult(zulip_message_id=message_id)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
|
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
|
||||||
)
|
)
|
||||||
async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
@with_error_handling("send_webhook", set_error_status=False)
|
||||||
|
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||||
"""Send completion webhook to external service."""
|
"""Send completion webhook to external service."""
|
||||||
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
@@ -1173,17 +1087,14 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
if not input.room_id:
|
if not input.room_id:
|
||||||
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
return {"webhook_sent": False, "skipped": True}
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
async with fresh_db_connection():
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
@@ -1217,20 +1128,10 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
ctx.workflow_run_id,
|
ctx.workflow_run_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"webhook_sent": True, "response_code": response.status_code}
|
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||||
|
|
||||||
finally:
|
|
||||||
await _close_db_connection(db)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"webhook_sent": False, "skipped": True}
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|||||||
123
server/reflector/hatchet/workflows/models.py
Normal file
123
server/reflector/hatchet/workflows/models.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Pydantic models for Hatchet workflow task return types.
|
||||||
|
|
||||||
|
Provides static typing for all task outputs, enabling type checking
|
||||||
|
and better IDE support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Track Processing Results (track_processing.py)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class PadTrackResult(BaseModel):
|
||||||
|
"""Result from pad_track task."""
|
||||||
|
|
||||||
|
padded_url: str
|
||||||
|
size: int
|
||||||
|
track_index: int
|
||||||
|
|
||||||
|
|
||||||
|
class TranscribeTrackResult(BaseModel):
|
||||||
|
"""Result from transcribe_track task."""
|
||||||
|
|
||||||
|
words: list[dict[str, Any]]
|
||||||
|
track_index: int
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Diarization Pipeline Results (diarization_pipeline.py)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingResult(BaseModel):
|
||||||
|
"""Result from get_recording task."""
|
||||||
|
|
||||||
|
id: str | None
|
||||||
|
mtg_session_id: str | None
|
||||||
|
room_name: str | None
|
||||||
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
|
class ParticipantsResult(BaseModel):
|
||||||
|
"""Result from get_participants task."""
|
||||||
|
|
||||||
|
participants: list[dict[str, Any]]
|
||||||
|
num_tracks: int
|
||||||
|
source_language: str
|
||||||
|
target_language: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessTracksResult(BaseModel):
|
||||||
|
"""Result from process_tracks task."""
|
||||||
|
|
||||||
|
all_words: list[dict[str, Any]]
|
||||||
|
padded_urls: list[str | None]
|
||||||
|
word_count: int
|
||||||
|
num_tracks: int
|
||||||
|
target_language: str
|
||||||
|
created_padded_files: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MixdownResult(BaseModel):
|
||||||
|
"""Result from mixdown_tracks task."""
|
||||||
|
|
||||||
|
audio_key: str
|
||||||
|
duration: float
|
||||||
|
tracks_mixed: int
|
||||||
|
|
||||||
|
|
||||||
|
class WaveformResult(BaseModel):
|
||||||
|
"""Result from generate_waveform task."""
|
||||||
|
|
||||||
|
waveform_generated: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TopicsResult(BaseModel):
|
||||||
|
"""Result from detect_topics task."""
|
||||||
|
|
||||||
|
topics: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class TitleResult(BaseModel):
|
||||||
|
"""Result from generate_title task."""
|
||||||
|
|
||||||
|
title: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryResult(BaseModel):
|
||||||
|
"""Result from generate_summary task."""
|
||||||
|
|
||||||
|
summary: str | None
|
||||||
|
short_summary: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class FinalizeResult(BaseModel):
|
||||||
|
"""Result from finalize task."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class ConsentResult(BaseModel):
|
||||||
|
"""Result from cleanup_consent task."""
|
||||||
|
|
||||||
|
consent_checked: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ZulipResult(BaseModel):
|
||||||
|
"""Result from post_zulip task."""
|
||||||
|
|
||||||
|
zulip_message_id: int | None = None
|
||||||
|
skipped: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookResult(BaseModel):
|
||||||
|
"""Result from send_webhook task."""
|
||||||
|
|
||||||
|
webhook_sent: bool
|
||||||
|
skipped: bool = False
|
||||||
|
response_code: int | None = None
|
||||||
@@ -18,8 +18,17 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
from reflector.hatchet.progress import emit_progress_async
|
||||||
|
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict(output) -> dict:
|
||||||
|
"""Convert task output to dict, handling both dict and Pydantic model returns."""
|
||||||
|
if isinstance(output, dict):
|
||||||
|
return output
|
||||||
|
return output.model_dump()
|
||||||
|
|
||||||
|
|
||||||
# Audio constants matching existing pipeline
|
# Audio constants matching existing pipeline
|
||||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||||
OPUS_DEFAULT_BIT_RATE = 64000
|
OPUS_DEFAULT_BIT_RATE = 64000
|
||||||
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
|
|||||||
|
|
||||||
|
|
||||||
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
|
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
|
||||||
async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||||
"""Pad single audio track with silence for alignment.
|
"""Pad single audio track with silence for alignment.
|
||||||
|
|
||||||
Extracts stream.start_time from WebM container metadata and applies
|
Extracts stream.start_time from WebM container metadata and applies
|
||||||
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
|||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
return {
|
return PadTrackResult(
|
||||||
"padded_url": source_url,
|
padded_url=source_url,
|
||||||
"size": 0,
|
size=0,
|
||||||
"track_index": input.track_index,
|
track_index=input.track_index,
|
||||||
}
|
)
|
||||||
|
|
||||||
# Create temp file for padded output
|
# Create temp file for padded output
|
||||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||||
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return PadTrackResult(
|
||||||
"padded_url": padded_url,
|
padded_url=padded_url,
|
||||||
"size": file_size,
|
size=file_size,
|
||||||
"track_index": input.track_index,
|
track_index=input.track_index,
|
||||||
}
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||||
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
|||||||
@track_workflow.task(
|
@track_workflow.task(
|
||||||
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
|
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
|
||||||
)
|
)
|
||||||
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] transcribe_track",
|
"[Hatchet] transcribe_track",
|
||||||
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pad_result = ctx.task_output(pad_track)
|
pad_result = _to_dict(ctx.task_output(pad_track))
|
||||||
audio_url = pad_result.get("padded_url")
|
audio_url = pad_result.get("padded_url")
|
||||||
|
|
||||||
if not audio_url:
|
if not audio_url:
|
||||||
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return TranscribeTrackResult(
|
||||||
"words": words,
|
words=words,
|
||||||
"track_index": input.track_index,
|
track_index=input.track_index,
|
||||||
}
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||||
|
|||||||
@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
|
|||||||
transcript, {"workflow_run_id": None}
|
transcript, {"workflow_run_id": None}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
config.transcript_id
|
||||||
|
)
|
||||||
|
if transcript and transcript.workflow_run_id:
|
||||||
|
# Another process started a workflow between validation and now
|
||||||
|
try:
|
||||||
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
if "RUNNING" in status or "QUEUED" in status:
|
||||||
|
logger.info(
|
||||||
|
"Concurrent workflow detected, skipping dispatch",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
return transcript.workflow_run_id
|
||||||
|
except Exception:
|
||||||
|
# If we can't get status, proceed with new workflow
|
||||||
|
pass
|
||||||
|
|
||||||
workflow_id = await HatchetClientManager.start_workflow(
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
workflow_name="DiarizationPipeline",
|
workflow_name="DiarizationPipeline",
|
||||||
input_data={
|
input_data={
|
||||||
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
|
|||||||
"transcript_id": config.transcript_id,
|
"transcript_id": config.transcript_id,
|
||||||
"room_id": config.room_id,
|
"room_id": config.room_id,
|
||||||
},
|
},
|
||||||
|
additional_metadata={
|
||||||
|
"transcript_id": config.transcript_id,
|
||||||
|
"recording_id": config.recording_id,
|
||||||
|
"daily_recording_id": config.recording_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript:
|
if transcript:
|
||||||
|
|||||||
@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
|
|||||||
"transcript_id": transcript.id,
|
"transcript_id": transcript.id,
|
||||||
"room_id": room.id,
|
"room_id": room.id,
|
||||||
},
|
},
|
||||||
|
additional_metadata={
|
||||||
|
"transcript_id": transcript.id,
|
||||||
|
"recording_id": recording_id,
|
||||||
|
"daily_recording_id": recording_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Started Hatchet workflow",
|
"Started Hatchet workflow",
|
||||||
|
|||||||
@@ -527,6 +527,22 @@ def fake_mp3_upload():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_hatchet_client():
|
||||||
|
"""Reset HatchetClientManager singleton before and after each test.
|
||||||
|
|
||||||
|
This ensures test isolation - each test starts with a fresh client state.
|
||||||
|
The fixture is autouse=True so it applies to all tests automatically.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
# Reset before test
|
||||||
|
HatchetClientManager.reset()
|
||||||
|
yield
|
||||||
|
# Reset after test to clean up
|
||||||
|
HatchetClientManager.reset()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript_with_topics(tmpdir, client):
|
async def fake_transcript_with_topics(tmpdir, client):
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
Tests for HatchetClientManager error handling and validation.
|
Tests for HatchetClientManager error handling and validation.
|
||||||
|
|
||||||
Only tests that catch real bugs - not mock verification tests.
|
Only tests that catch real bugs - not mock verification tests.
|
||||||
|
|
||||||
|
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
|
||||||
|
automatically resets the singleton before and after each test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
|||||||
"""
|
"""
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
HatchetClientManager._instance = None
|
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
|
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
|
||||||
mock_settings.HATCHET_DEBUG = False
|
mock_settings.HATCHET_DEBUG = False
|
||||||
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
|||||||
# Should return False on error (workflow might be gone)
|
# Should return False on error (workflow might be gone)
|
||||||
assert can_replay is False
|
assert can_replay is False
|
||||||
|
|
||||||
HatchetClientManager._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
def test_hatchet_client_raises_without_token():
|
def test_hatchet_client_raises_without_token():
|
||||||
"""Test that get_client raises ValueError without token.
|
"""Test that get_client raises ValueError without token.
|
||||||
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
|
|||||||
"""
|
"""
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
HatchetClientManager._instance = None
|
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_CLIENT_TOKEN = None
|
mock_settings.HATCHET_CLIENT_TOKEN = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
||||||
HatchetClientManager.get_client()
|
HatchetClientManager.get_client()
|
||||||
|
|
||||||
HatchetClientManager._instance = None
|
|
||||||
|
|||||||
Reference in New Issue
Block a user