hatchet no-mistake

This commit is contained in:
Igor Loskutov
2025-12-16 00:48:30 -05:00
parent 243ff2177c
commit c5498d26bf
18 changed files with 2189 additions and 1952 deletions

View File

@@ -14,6 +14,7 @@ from reflector.metrics import metrics_init
from reflector.settings import settings
from reflector.views.conductor import router as conductor_router
from reflector.views.daily import router as daily_router
from reflector.views.hatchet import router as hatchet_router
from reflector.views.meetings import router as meetings_router
from reflector.views.rooms import router as rooms_router
from reflector.views.rtc_offer import router as rtc_offer_router
@@ -100,6 +101,7 @@ app.include_router(zulip_router, prefix="/v1")
app.include_router(whereby_router, prefix="/v1")
app.include_router(daily_router, prefix="/v1/daily")
app.include_router(conductor_router, prefix="/v1")
app.include_router(hatchet_router, prefix="/v1")
add_pagination(app)
# prepare celery

View File

@@ -0,0 +1,6 @@
"""Hatchet workflow orchestration for Reflector."""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress, emit_progress_async
__all__ = ["HatchetClientManager", "emit_progress", "emit_progress_async"]

View File

@@ -0,0 +1,48 @@
"""Hatchet Python client wrapper."""
from hatchet_sdk import Hatchet
from reflector.settings import settings
class HatchetClientManager:
"""Singleton manager for Hatchet client connections."""
_instance: Hatchet | None = None
@classmethod
def get_client(cls) -> Hatchet:
"""Get or create the Hatchet client."""
if cls._instance is None:
if not settings.HATCHET_CLIENT_TOKEN:
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
cls._instance = Hatchet(
debug=settings.HATCHET_DEBUG,
)
return cls._instance
@classmethod
async def start_workflow(
cls, workflow_name: str, input_data: dict, key: str | None = None
) -> str:
"""Start a workflow and return the workflow run ID."""
client = cls.get_client()
result = await client.runs.aio_create(
workflow_name,
input_data,
)
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id
@classmethod
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
"""Get the current status of a workflow run."""
client = cls.get_client()
run = await client.runs.aio_get(workflow_run_id)
return run.to_dict()
@classmethod
def reset(cls) -> None:
"""Reset the client instance (for testing)."""
cls._instance = None

View File

@@ -0,0 +1,120 @@
"""Progress event emission for Hatchet workers."""
import asyncio
from typing import Literal
from reflector.db.transcripts import PipelineProgressData
from reflector.logger import logger
from reflector.ws_manager import get_ws_manager
# Step mapping for progress tracking (matches Conductor pipeline)
PIPELINE_STEPS = {
"get_recording": 1,
"get_participants": 2,
"pad_track": 3, # Fork tasks share same step
"mixdown_tracks": 4,
"generate_waveform": 5,
"transcribe_track": 6, # Fork tasks share same step
"merge_transcripts": 7,
"detect_topics": 8,
"generate_title": 9, # Fork tasks share same step
"generate_summary": 9, # Fork tasks share same step
"finalize": 10,
"cleanup_consent": 11,
"post_zulip": 12,
"send_webhook": 13,
}
TOTAL_STEPS = 13
async def _emit_progress_async(
transcript_id: str,
step: str,
status: Literal["pending", "in_progress", "completed", "failed"],
workflow_id: str | None = None,
) -> None:
"""Async implementation of progress emission."""
ws_manager = get_ws_manager()
step_index = PIPELINE_STEPS.get(step, 0)
data = PipelineProgressData(
workflow_id=workflow_id,
current_step=step,
step_index=step_index,
total_steps=TOTAL_STEPS,
step_status=status,
)
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={
"event": "PIPELINE_PROGRESS",
"data": data.model_dump(),
},
)
logger.debug(
"[Hatchet Progress] Emitted",
transcript_id=transcript_id,
step=step,
status=status,
step_index=step_index,
)
def emit_progress(
transcript_id: str,
step: str,
status: Literal["pending", "in_progress", "completed", "failed"],
workflow_id: str | None = None,
) -> None:
"""Emit a pipeline progress event (sync wrapper for Hatchet workers).
Args:
transcript_id: The transcript ID to emit progress for
step: The current step name (e.g., "transcribe_track")
status: The step status
workflow_id: Optional workflow run ID
"""
try:
# Get or create event loop for sync context
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
# Already in async context, schedule the coroutine
asyncio.create_task(
_emit_progress_async(transcript_id, step, status, workflow_id)
)
else:
# Not in async context, run synchronously
asyncio.run(_emit_progress_async(transcript_id, step, status, workflow_id))
except Exception as e:
# Progress emission should never break the pipeline
logger.warning(
"[Hatchet Progress] Failed to emit progress event",
error=str(e),
transcript_id=transcript_id,
step=step,
)
async def emit_progress_async(
transcript_id: str,
step: str,
status: Literal["pending", "in_progress", "completed", "failed"],
workflow_id: str | None = None,
) -> None:
"""Async version of emit_progress for use in async Hatchet tasks."""
try:
await _emit_progress_async(transcript_id, step, status, workflow_id)
except Exception as e:
logger.warning(
"[Hatchet Progress] Failed to emit progress event",
error=str(e),
transcript_id=transcript_id,
step=step,
)

View File

@@ -0,0 +1,59 @@
"""
Run Hatchet workers for the diarization pipeline.
Usage:
uv run -m reflector.hatchet.run_workers
# Or via docker:
docker compose exec server uv run -m reflector.hatchet.run_workers
"""
import signal
import sys
from reflector.logger import logger
from reflector.settings import settings
def main() -> None:
"""Start Hatchet worker polling."""
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting workers")
sys.exit(1)
if not settings.HATCHET_CLIENT_TOKEN:
logger.error("HATCHET_CLIENT_TOKEN is not set")
sys.exit(1)
logger.info(
"Starting Hatchet workers",
debug=settings.HATCHET_DEBUG,
)
# Import workflows to register them
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows import diarization_pipeline, track_workflow
hatchet = HatchetClientManager.get_client()
# Create worker with both workflows
worker = hatchet.worker(
"reflector-diarization-worker",
workflows=[diarization_pipeline, track_workflow],
)
# Handle graceful shutdown
def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit
sys.exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
logger.info("Starting Hatchet worker polling...")
worker.start()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,14 @@
"""Hatchet workflow definitions."""
from reflector.hatchet.workflows.diarization_pipeline import (
PipelineInput,
diarization_pipeline,
)
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
__all__ = [
"diarization_pipeline",
"track_workflow",
"PipelineInput",
"TrackInput",
]

View File

@@ -0,0 +1,808 @@
"""
Hatchet main workflow: DiarizationPipeline
Multitrack diarization pipeline for Daily.co recordings.
Orchestrates the full processing flow from recording metadata to final transcript.
"""
import asyncio
import tempfile
from datetime import timedelta
from pathlib import Path
import av
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
from reflector.logger import logger
# Audio constants
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
class PipelineInput(BaseModel):
"""Input to trigger the diarization pipeline."""
recording_id: str | None
room_name: str | None
tracks: list[dict] # List of {"s3_key": str}
bucket_name: str
transcript_id: str
room_id: str | None = None
# Get hatchet client and define workflow
hatchet = HatchetClientManager.get_client()
diarization_pipeline = hatchet.workflow(
name="DiarizationPipeline", input_validator=PipelineInput
)
# ============================================================================
# Helper Functions
# ============================================================================
async def _get_fresh_db_connection():
"""Create fresh database connection for subprocess."""
import databases
from reflector.db import _database_context
from reflector.settings import settings
_database_context.set(None)
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
await db.connect()
return db
async def _close_db_connection(db):
"""Close database connection."""
from reflector.db import _database_context
await db.disconnect()
_database_context.set(None)
def _get_storage():
"""Create fresh storage instance."""
from reflector.settings import settings
from reflector.storage.storage_aws import AwsStorage
return AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
# ============================================================================
# Pipeline Tasks
# ============================================================================
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
async def get_recording(input: PipelineInput, ctx: Context) -> dict:
"""Fetch recording metadata from Daily.co API."""
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
await emit_progress_async(
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
)
try:
from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings
if not input.recording_id:
# No recording_id in reprocess path - return minimal data
await emit_progress_async(
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
)
return {
"id": None,
"mtg_session_id": None,
"room_name": input.room_name,
"duration": 0,
}
if not settings.DAILY_API_KEY:
raise ValueError("DAILY_API_KEY not configured")
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
recording = await client.get_recording(input.recording_id)
logger.info(
"[Hatchet] get_recording complete",
recording_id=input.recording_id,
room_name=recording.room_name,
duration=recording.duration,
)
await emit_progress_async(
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
)
return {
"id": recording.id,
"mtg_session_id": recording.mtgSessionId,
"room_name": recording.room_name,
"duration": recording.duration,
}
except Exception as e:
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
)
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
"""Fetch participant list from Daily.co API."""
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
)
try:
recording_data = ctx.task_output(get_recording)
mtg_session_id = recording_data.get("mtg_session_id")
from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings
if not mtg_session_id or not settings.DAILY_API_KEY:
# Return empty participants if no session ID
await emit_progress_async(
input.transcript_id,
"get_participants",
"completed",
ctx.workflow_run_id,
)
return {"participants": [], "num_tracks": len(input.tracks)}
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
participants = await client.get_meeting_participants(mtg_session_id)
participants_list = [
{"participant_id": p.participant_id, "user_name": p.user_name}
for p in participants.data
]
logger.info(
"[Hatchet] get_participants complete",
participant_count=len(participants_list),
)
await emit_progress_async(
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
)
return {"participants": participants_list, "num_tracks": len(input.tracks)}
except Exception as e:
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
)
async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
"""Spawn child workflows for each track (dynamic fan-out).
Processes pad_track and transcribe_track for each audio track in parallel.
"""
logger.info(
"[Hatchet] process_tracks",
num_tracks=len(input.tracks),
transcript_id=input.transcript_id,
)
# Spawn child workflows for each track
child_coroutines = [
track_workflow.aio_run(
TrackInput(
track_index=i,
s3_key=track["s3_key"],
bucket_name=input.bucket_name,
transcript_id=input.transcript_id,
)
)
for i, track in enumerate(input.tracks)
]
# Wait for all child workflows to complete
results = await asyncio.gather(*child_coroutines)
# Collect all track results
all_words = []
padded_urls = []
for result in results:
transcribe_result = result.get("transcribe_track", {})
all_words.extend(transcribe_result.get("words", []))
pad_result = result.get("pad_track", {})
padded_urls.append(pad_result.get("padded_url"))
# Sort words by start time
all_words.sort(key=lambda w: w.get("start", 0))
logger.info(
"[Hatchet] process_tracks complete",
num_tracks=len(input.tracks),
total_words=len(all_words),
)
return {
"all_words": all_words,
"padded_urls": padded_urls,
"word_count": len(all_words),
"num_tracks": len(input.tracks),
}
@diarization_pipeline.task(
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
"""Mix all padded tracks into single audio file."""
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
)
try:
track_data = ctx.task_output(process_tracks)
padded_urls = track_data.get("padded_urls", [])
if not padded_urls:
raise ValueError("No padded tracks to mixdown")
storage = _get_storage()
# Download all tracks and mix
temp_inputs = []
try:
for i, url in enumerate(padded_urls):
if not url:
continue
temp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False)
temp_inputs.append(temp_input.name)
# Download track
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
with open(temp_input.name, "wb") as f:
f.write(response.content)
# Mix using PyAV amix filter
if len(temp_inputs) == 0:
raise ValueError("No valid tracks to mixdown")
output_path = tempfile.mktemp(suffix=".mp3")
try:
# Use ffmpeg-style mixing via PyAV
containers = [av.open(path) for path in temp_inputs]
# Get the longest duration
max_duration = 0.0
for container in containers:
if container.duration:
duration = float(container.duration * av.time_base)
max_duration = max(max_duration, duration)
# Close containers for now
for container in containers:
container.close()
# Use subprocess for mixing (simpler than complex PyAV graph)
import subprocess
# Build ffmpeg command
cmd = ["ffmpeg", "-y"]
for path in temp_inputs:
cmd.extend(["-i", path])
# Build filter for N inputs
n = len(temp_inputs)
filter_str = f"amix=inputs={n}:duration=longest:normalize=0"
cmd.extend(["-filter_complex", filter_str])
cmd.extend(["-ac", "2", "-ar", "48000", "-b:a", "128k", output_path])
subprocess.run(cmd, check=True, capture_output=True)
# Upload mixed file
file_size = Path(output_path).stat().st_size
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/mixed.mp3"
with open(output_path, "rb") as mixed_file:
await storage.put_file(storage_path, mixed_file)
logger.info(
"[Hatchet] mixdown_tracks uploaded",
key=storage_path,
size=file_size,
)
finally:
Path(output_path).unlink(missing_ok=True)
finally:
for path in temp_inputs:
Path(path).unlink(missing_ok=True)
await emit_progress_async(
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
)
return {
"audio_key": storage_path,
"duration": max_duration,
"tracks_mixed": len(temp_inputs),
}
except Exception as e:
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
)
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
"""Generate audio waveform visualization."""
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
)
try:
mixdown_data = ctx.task_output(mixdown_tracks)
audio_key = mixdown_data.get("audio_key")
storage = _get_storage()
audio_url = await storage.get_file_url(
audio_key,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
from reflector.pipelines.waveform_helpers import generate_waveform_data
waveform = await generate_waveform_data(audio_url)
# Store waveform
waveform_key = f"file_pipeline_hatchet/{input.transcript_id}/waveform.json"
import json
waveform_bytes = json.dumps(waveform).encode()
import io
await storage.put_file(waveform_key, io.BytesIO(waveform_bytes))
logger.info("[Hatchet] generate_waveform complete")
await emit_progress_async(
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
)
return {"waveform_key": waveform_key}
except Exception as e:
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
)
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
"""Detect topics using LLM."""
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
)
try:
track_data = ctx.task_output(process_tracks)
words = track_data.get("all_words", [])
from reflector.pipelines import topic_processing
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word
# Convert word dicts to Word objects
word_objects = [Word(**w) for w in words]
transcript = TranscriptType(words=word_objects)
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
async def noop_callback(t):
pass
topics = await topic_processing.detect_topics(
transcript,
"en", # target_language
on_topic_callback=noop_callback,
empty_pipeline=empty_pipeline,
)
topics_list = [t.model_dump() for t in topics]
logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list))
await emit_progress_async(
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
)
return {"topics": topics_list}
except Exception as e:
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
)
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
"""Generate meeting title using LLM."""
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
)
try:
topics_data = ctx.task_output(detect_topics)
topics = topics_data.get("topics", [])
from reflector.pipelines import topic_processing
from reflector.processors.types import Topic
topic_objects = [Topic(**t) for t in topics]
title = await topic_processing.generate_title(topic_objects)
logger.info("[Hatchet] generate_title complete", title=title)
await emit_progress_async(
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
)
return {"title": title}
except Exception as e:
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
)
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
"""Generate meeting summary using LLM."""
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
)
try:
track_data = ctx.task_output(process_tracks)
topics_data = ctx.task_output(detect_topics)
words = track_data.get("all_words", [])
topics = topics_data.get("topics", [])
from reflector.pipelines import topic_processing
from reflector.processors.types import Topic, Word
from reflector.processors.types import Transcript as TranscriptType
word_objects = [Word(**w) for w in words]
transcript = TranscriptType(words=word_objects)
topic_objects = [Topic(**t) for t in topics]
summary, short_summary = await topic_processing.generate_summary(
transcript, topic_objects
)
logger.info("[Hatchet] generate_summary complete")
await emit_progress_async(
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
)
return {"summary": summary, "short_summary": short_summary}
except Exception as e:
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[generate_waveform, generate_title, generate_summary],
execution_timeout=timedelta(seconds=60),
retries=3,
)
async def finalize(input: PipelineInput, ctx: Context) -> dict:
"""Finalize transcript status and update database."""
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
)
try:
title_data = ctx.task_output(generate_title)
summary_data = ctx.task_output(generate_summary)
mixdown_data = ctx.task_output(mixdown_tracks)
track_data = ctx.task_output(process_tracks)
title = title_data.get("title", "")
summary = summary_data.get("summary", "")
short_summary = summary_data.get("short_summary", "")
duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", [])
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import transcripts_controller
from reflector.processors.types import Word
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript is None:
raise ValueError(
f"Transcript {input.transcript_id} not found in database"
)
# Convert words back to Word objects for storage
word_objects = [Word(**w) for w in all_words]
await transcripts_controller.update(
transcript,
{
"status": "ended",
"title": title,
"long_summary": summary,
"short_summary": short_summary,
"duration": duration,
"words": word_objects,
},
)
logger.info(
"[Hatchet] finalize complete", transcript_id=input.transcript_id
)
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
)
return {"status": "COMPLETED"}
except Exception as e:
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
"""Check and handle consent requirements."""
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
)
try:
db = await _get_fresh_db_connection()
try:
from reflector.db.meetings import meetings_controller
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.meeting_id:
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
if meeting:
# Check consent logic here
# For now just mark as checked
pass
logger.info(
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
)
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
)
return {"consent_checked": True}
except Exception as e:
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
)
async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
"""Post notification to Zulip."""
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
)
try:
from reflector.settings import settings
if not settings.ZULIP_REALM:
logger.info("[Hatchet] post_zulip skipped (Zulip not configured)")
await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
)
return {"zulip_message_id": None, "skipped": True}
from reflector.zulip import post_transcript_notification
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
message_id = await post_transcript_notification(transcript)
logger.info(
"[Hatchet] post_zulip complete", zulip_message_id=message_id
)
else:
message_id = None
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
)
return {"zulip_message_id": message_id}
except Exception as e:
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task(
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
)
async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
"""Send completion webhook to external service."""
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
await emit_progress_async(
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
)
try:
if not input.room_id:
logger.info("[Hatchet] send_webhook skipped (no room_id)")
await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
)
return {"webhook_sent": False, "skipped": True}
db = await _get_fresh_db_connection()
try:
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
room = await rooms_controller.get_by_id(input.room_id)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if room and room.webhook_url and transcript:
import httpx
webhook_payload = {
"event": "transcript.completed",
"transcript_id": input.transcript_id,
"title": transcript.title,
"duration": transcript.duration,
}
async with httpx.AsyncClient() as client:
response = await client.post(
room.webhook_url, json=webhook_payload, timeout=30
)
response.raise_for_status()
logger.info(
"[Hatchet] send_webhook complete", status_code=response.status_code
)
await emit_progress_async(
input.transcript_id,
"send_webhook",
"completed",
ctx.workflow_run_id,
)
return {"webhook_sent": True, "response_code": response.status_code}
finally:
await _close_db_connection(db)
await emit_progress_async(
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
)
return {"webhook_sent": False, "skipped": True}
except Exception as e:
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
)
raise

View File

@@ -0,0 +1,337 @@
"""
Hatchet child workflow: TrackProcessing
Handles individual audio track processing: padding and transcription.
Spawned dynamically by the main diarization pipeline for each track.
"""
import math
import tempfile
from datetime import timedelta
from fractions import Fraction
from pathlib import Path
import av
from av.audio.resampler import AudioResampler
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async
from reflector.logger import logger
# Audio constants matching existing pipeline
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
class TrackInput(BaseModel):
"""Input for individual track processing."""
track_index: int
s3_key: str
bucket_name: str
transcript_id: str
language: str = "en"
# Get hatchet client and define workflow
hatchet = HatchetClientManager.get_client()
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
def _extract_stream_start_time_from_container(container, track_idx: int) -> float:
"""Extract meeting-relative start time from WebM stream metadata.
Uses PyAV to read stream.start_time from WebM container.
More accurate than filename timestamps by ~209ms due to network/encoding delays.
"""
start_time_seconds = 0.0
try:
audio_streams = [s for s in container.streams if s.type == "audio"]
stream = audio_streams[0] if audio_streams else container.streams[0]
# 1) Try stream-level start_time (most reliable for Daily.co tracks)
if stream.start_time is not None and stream.time_base is not None:
start_time_seconds = float(stream.start_time * stream.time_base)
# 2) Fallback to container-level start_time
if (start_time_seconds <= 0) and (container.start_time is not None):
start_time_seconds = float(container.start_time * av.time_base)
# 3) Fallback to first packet DTS
if start_time_seconds <= 0:
for packet in container.demux(stream):
if packet.dts is not None:
start_time_seconds = float(packet.dts * stream.time_base)
break
except Exception as e:
logger.warning(
"PyAV metadata read failed; assuming 0 start_time",
track_idx=track_idx,
error=str(e),
)
start_time_seconds = 0.0
logger.info(
f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s",
track_idx=track_idx,
)
return start_time_seconds
def _apply_audio_padding_to_file(
in_container,
output_path: str,
start_time_seconds: float,
track_idx: int,
) -> None:
"""Apply silence padding to audio track using PyAV filter graph."""
delay_ms = math.floor(start_time_seconds * 1000)
logger.info(
f"Padding track {track_idx} with {delay_ms}ms delay using PyAV",
track_idx=track_idx,
delay_ms=delay_ms,
)
with av.open(output_path, "w", format="webm") as out_container:
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
if in_stream is None:
raise Exception("No audio stream in input")
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
graph = av.filter.Graph()
abuf_args = (
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_fmt=s16:"
f"channel_layout=stereo"
)
src = graph.add("abuffer", args=abuf_args, name="src")
aresample_f = graph.add("aresample", args="async=1", name="ares")
delays_arg = f"{delay_ms}|{delay_ms}"
adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay")
sink = graph.add("abuffersink", name="sink")
src.link_to(aresample_f)
aresample_f.link_to(adelay_f)
adelay_f.link_to(sink)
graph.configure()
resampler = AudioResampler(
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
)
for frame in in_container.decode(in_stream):
out_frames = resampler.resample(frame) or []
for rframe in out_frames:
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
src.push(rframe)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
# Flush remaining frames
src.push(None)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> dict:
"""Pad single audio track with silence for alignment.
Extracts stream.start_time from WebM container metadata and applies
silence padding using PyAV filter graph (adelay).
"""
logger.info(
"[Hatchet] pad_track",
track_index=input.track_index,
s3_key=input.s3_key,
transcript_id=input.transcript_id,
)
await emit_progress_async(
input.transcript_id, "pad_track", "in_progress", ctx.workflow_run_id
)
try:
# Create fresh storage instance to avoid aioboto3 fork issues
from reflector.settings import settings
from reflector.storage.storage_aws import AwsStorage
storage = AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
# Get presigned URL for source file
source_url = await storage.get_file_url(
input.s3_key,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
bucket=input.bucket_name,
)
# Open container and extract start time
with av.open(source_url) as in_container:
start_time_seconds = _extract_stream_start_time_from_container(
in_container, input.track_index
)
# If no padding needed, return original URL
if start_time_seconds <= 0:
logger.info(
f"Track {input.track_index} requires no padding",
track_index=input.track_index,
)
await emit_progress_async(
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": source_url,
"size": 0,
"track_index": input.track_index,
}
# Create temp file for padded output
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
temp_path = temp_file.name
try:
_apply_audio_padding_to_file(
in_container, temp_path, start_time_seconds, input.track_index
)
file_size = Path(temp_path).stat().st_size
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
logger.info(
f"About to upload padded track",
key=storage_path,
size=file_size,
)
with open(temp_path, "rb") as padded_file:
await storage.put_file(storage_path, padded_file)
logger.info(
f"Uploaded padded track to S3",
key=storage_path,
size=file_size,
)
finally:
Path(temp_path).unlink(missing_ok=True)
# Get presigned URL for padded file
padded_url = await storage.get_file_url(
storage_path,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
logger.info(
"[Hatchet] pad_track complete",
track_index=input.track_index,
padded_url=padded_url[:50] + "...",
)
await emit_progress_async(
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": padded_url,
"size": file_size,
"track_index": input.track_index,
}
except Exception as e:
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "pad_track", "failed", ctx.workflow_run_id
)
raise
@track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
)
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
logger.info(
"[Hatchet] transcribe_track",
track_index=input.track_index,
language=input.language,
)
await emit_progress_async(
input.transcript_id, "transcribe_track", "in_progress", ctx.workflow_run_id
)
try:
pad_result = ctx.task_output(pad_track)
audio_url = pad_result.get("padded_url")
if not audio_url:
raise ValueError("Missing padded_url from pad_track")
from reflector.pipelines.transcription_helpers import (
transcribe_file_with_processor,
)
transcript = await transcribe_file_with_processor(audio_url, input.language)
# Tag all words with speaker index
words = []
for word in transcript.words:
word_dict = word.model_dump()
word_dict["speaker"] = input.track_index
words.append(word_dict)
logger.info(
"[Hatchet] transcribe_track complete",
track_index=input.track_index,
word_count=len(words),
)
await emit_progress_async(
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
)
return {
"words": words,
"track_index": input.track_index,
}
except Exception as e:
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
await emit_progress_async(
input.transcript_id, "transcribe_track", "failed", ctx.workflow_run_id
)
raise

View File

@@ -15,6 +15,7 @@ from celery.result import AsyncResult
from reflector.conductor.client import ConductorClientManager
from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import Transcript
from reflector.hatchet.client import HatchetClientManager
from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_multitrack_pipeline import (
@@ -156,8 +157,47 @@ async def prepare_transcript_processing(
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None:
if isinstance(config, MultitrackProcessingConfig):
# Start Conductor workflow if enabled
if settings.CONDUCTOR_ENABLED:
# Start durable workflow if enabled (Hatchet or Conductor)
durable_started = False
if settings.HATCHET_ENABLED:
import asyncio
async def _start_hatchet():
return await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline",
input_data={
"recording_id": config.recording_id,
"room_name": None, # Not available in reprocess path
"tracks": [{"s3_key": k} for k in config.track_keys],
"bucket_name": config.bucket_name,
"transcript_id": config.transcript_id,
"room_id": config.room_id,
},
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# Already in async context
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result()
else:
workflow_id = asyncio.run(_start_hatchet())
logger.info(
"Started Hatchet workflow (reprocess)",
workflow_id=workflow_id,
transcript_id=config.transcript_id,
)
durable_started = True
elif settings.CONDUCTOR_ENABLED:
workflow_id = ConductorClientManager.start_workflow(
name="diarization_pipeline",
version=1,
@@ -175,11 +215,13 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
workflow_id=workflow_id,
transcript_id=config.transcript_id,
)
durable_started = True
if not settings.CONDUCTOR_SHADOW_MODE:
return None # Conductor-only, no Celery result
# If durable workflow started and not in shadow mode, skip Celery
if durable_started and not settings.DURABLE_WORKFLOW_SHADOW_MODE:
return None
# Celery pipeline (shadow mode or Conductor disabled)
# Celery pipeline (shadow mode or durable workflows disabled)
return task_pipeline_multitrack_process.delay(
transcript_id=config.transcript_id,
bucket_name=config.bucket_name,

View File

@@ -150,11 +150,34 @@ class Settings(BaseSettings):
ZULIP_API_KEY: str | None = None
ZULIP_BOT_EMAIL: str | None = None
# Durable workflow orchestration
# Provider: "hatchet" or "conductor" (or "none" to disable)
DURABLE_WORKFLOW_PROVIDER: str = "none"
DURABLE_WORKFLOW_SHADOW_MODE: bool = False # Run both provider + Celery
# Conductor workflow orchestration
CONDUCTOR_SERVER_URL: str = "http://conductor:8080/api"
CONDUCTOR_DEBUG: bool = False
CONDUCTOR_ENABLED: bool = False
CONDUCTOR_SHADOW_MODE: bool = False
# Hatchet workflow orchestration
HATCHET_CLIENT_TOKEN: str | None = None
HATCHET_CLIENT_TLS_STRATEGY: str = "none" # none, tls, mtls
HATCHET_DEBUG: bool = False
@property
def CONDUCTOR_ENABLED(self) -> bool:
"""Legacy compatibility: True if Conductor is the active provider."""
return self.DURABLE_WORKFLOW_PROVIDER == "conductor"
@property
def HATCHET_ENABLED(self) -> bool:
"""True if Hatchet is the active provider."""
return self.DURABLE_WORKFLOW_PROVIDER == "hatchet"
@property
def CONDUCTOR_SHADOW_MODE(self) -> bool:
"""Legacy compatibility for shadow mode."""
return self.DURABLE_WORKFLOW_SHADOW_MODE and self.CONDUCTOR_ENABLED
settings = Settings()

View File

@@ -0,0 +1,57 @@
"""Hatchet health and status endpoints."""
from fastapi import APIRouter
from reflector.settings import settings
router = APIRouter(prefix="/hatchet", tags=["hatchet"])
@router.get("/health")
async def hatchet_health():
"""Check Hatchet connectivity and status."""
if not settings.HATCHET_ENABLED:
return {"status": "disabled", "connected": False}
if not settings.HATCHET_CLIENT_TOKEN:
return {
"status": "unhealthy",
"connected": False,
"error": "HATCHET_CLIENT_TOKEN not configured",
}
try:
from reflector.hatchet.client import HatchetClientManager
# Get client to verify token is valid
client = HatchetClientManager.get_client()
# Try to get the client's gRPC connection status
# The SDK doesn't have a simple health check, so we just verify we can create the client
if client is not None:
return {"status": "healthy", "connected": True}
else:
return {
"status": "unhealthy",
"connected": False,
"error": "Failed to create client",
}
except ValueError as e:
return {"status": "unhealthy", "connected": False, "error": str(e)}
except Exception as e:
return {"status": "unhealthy", "connected": False, "error": str(e)}
@router.get("/workflow/{workflow_run_id}")
async def get_workflow_status(workflow_run_id: str):
"""Get the status of a workflow run."""
if not settings.HATCHET_ENABLED:
return {"error": "Hatchet is disabled"}
try:
from reflector.hatchet.client import HatchetClientManager
status = await HatchetClientManager.get_workflow_status(workflow_run_id)
return status
except Exception as e:
return {"error": str(e)}

View File

@@ -286,8 +286,34 @@ async def _process_multitrack_recording_inner(
room_id=room.id,
)
# Start Conductor workflow if enabled
if settings.CONDUCTOR_ENABLED:
# Start durable workflow if enabled (Hatchet or Conductor)
durable_started = False
if settings.HATCHET_ENABLED:
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline",
input_data={
"recording_id": recording_id,
"room_name": daily_room_name,
"tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)],
"bucket_name": bucket_name,
"transcript_id": transcript.id,
"room_id": room.id,
},
)
logger.info(
"Started Hatchet workflow",
workflow_id=workflow_id,
transcript_id=transcript.id,
)
# Store workflow_id on recording for status tracking
await recordings_controller.update(recording, {"workflow_id": workflow_id})
durable_started = True
elif settings.CONDUCTOR_ENABLED:
from reflector.conductor.client import ConductorClientManager # noqa: PLC0415
workflow_id = ConductorClientManager.start_workflow(
@@ -310,11 +336,13 @@ async def _process_multitrack_recording_inner(
# Store workflow_id on recording for status tracking
await recordings_controller.update(recording, {"workflow_id": workflow_id})
durable_started = True
if not settings.CONDUCTOR_SHADOW_MODE:
return # Don't trigger Celery
# If durable workflow started and not in shadow mode, skip Celery
if durable_started and not settings.DURABLE_WORKFLOW_SHADOW_MODE:
return
# Celery pipeline (runs when Conductor disabled OR in shadow mode)
# Celery pipeline (runs when durable workflows disabled OR in shadow mode)
task_pipeline_multitrack_process.delay(
transcript_id=transcript.id,
bucket_name=bucket_name,