mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
hatchet no-mistake
This commit is contained in:
@@ -14,6 +14,7 @@ from reflector.metrics import metrics_init
|
||||
from reflector.settings import settings
|
||||
from reflector.views.conductor import router as conductor_router
|
||||
from reflector.views.daily import router as daily_router
|
||||
from reflector.views.hatchet import router as hatchet_router
|
||||
from reflector.views.meetings import router as meetings_router
|
||||
from reflector.views.rooms import router as rooms_router
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
@@ -100,6 +101,7 @@ app.include_router(zulip_router, prefix="/v1")
|
||||
app.include_router(whereby_router, prefix="/v1")
|
||||
app.include_router(daily_router, prefix="/v1/daily")
|
||||
app.include_router(conductor_router, prefix="/v1")
|
||||
app.include_router(hatchet_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
|
||||
6
server/reflector/hatchet/__init__.py
Normal file
6
server/reflector/hatchet/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Hatchet workflow orchestration for Reflector."""
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress, emit_progress_async
|
||||
|
||||
__all__ = ["HatchetClientManager", "emit_progress", "emit_progress_async"]
|
||||
48
server/reflector/hatchet/client.py
Normal file
48
server/reflector/hatchet/client.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Hatchet Python client wrapper."""
|
||||
|
||||
from hatchet_sdk import Hatchet
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class HatchetClientManager:
|
||||
"""Singleton manager for Hatchet client connections."""
|
||||
|
||||
_instance: Hatchet | None = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls) -> Hatchet:
|
||||
"""Get or create the Hatchet client."""
|
||||
if cls._instance is None:
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
||||
|
||||
cls._instance = Hatchet(
|
||||
debug=settings.HATCHET_DEBUG,
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
async def start_workflow(
|
||||
cls, workflow_name: str, input_data: dict, key: str | None = None
|
||||
) -> str:
|
||||
"""Start a workflow and return the workflow run ID."""
|
||||
client = cls.get_client()
|
||||
result = await client.runs.aio_create(
|
||||
workflow_name,
|
||||
input_data,
|
||||
)
|
||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||
return result.run.metadata.id
|
||||
|
||||
@classmethod
|
||||
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
|
||||
"""Get the current status of a workflow run."""
|
||||
client = cls.get_client()
|
||||
run = await client.runs.aio_get(workflow_run_id)
|
||||
return run.to_dict()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Reset the client instance (for testing)."""
|
||||
cls._instance = None
|
||||
120
server/reflector/hatchet/progress.py
Normal file
120
server/reflector/hatchet/progress.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Progress event emission for Hatchet workers."""
|
||||
|
||||
import asyncio
|
||||
from typing import Literal
|
||||
|
||||
from reflector.db.transcripts import PipelineProgressData
|
||||
from reflector.logger import logger
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
# Step mapping for progress tracking (matches Conductor pipeline)
|
||||
PIPELINE_STEPS = {
|
||||
"get_recording": 1,
|
||||
"get_participants": 2,
|
||||
"pad_track": 3, # Fork tasks share same step
|
||||
"mixdown_tracks": 4,
|
||||
"generate_waveform": 5,
|
||||
"transcribe_track": 6, # Fork tasks share same step
|
||||
"merge_transcripts": 7,
|
||||
"detect_topics": 8,
|
||||
"generate_title": 9, # Fork tasks share same step
|
||||
"generate_summary": 9, # Fork tasks share same step
|
||||
"finalize": 10,
|
||||
"cleanup_consent": 11,
|
||||
"post_zulip": 12,
|
||||
"send_webhook": 13,
|
||||
}
|
||||
|
||||
TOTAL_STEPS = 13
|
||||
|
||||
|
||||
async def _emit_progress_async(
|
||||
transcript_id: str,
|
||||
step: str,
|
||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
||||
workflow_id: str | None = None,
|
||||
) -> None:
|
||||
"""Async implementation of progress emission."""
|
||||
ws_manager = get_ws_manager()
|
||||
step_index = PIPELINE_STEPS.get(step, 0)
|
||||
|
||||
data = PipelineProgressData(
|
||||
workflow_id=workflow_id,
|
||||
current_step=step,
|
||||
step_index=step_index,
|
||||
total_steps=TOTAL_STEPS,
|
||||
step_status=status,
|
||||
)
|
||||
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "PIPELINE_PROGRESS",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[Hatchet Progress] Emitted",
|
||||
transcript_id=transcript_id,
|
||||
step=step,
|
||||
status=status,
|
||||
step_index=step_index,
|
||||
)
|
||||
|
||||
|
||||
def emit_progress(
|
||||
transcript_id: str,
|
||||
step: str,
|
||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
||||
workflow_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a pipeline progress event (sync wrapper for Hatchet workers).
|
||||
|
||||
Args:
|
||||
transcript_id: The transcript ID to emit progress for
|
||||
step: The current step name (e.g., "transcribe_track")
|
||||
status: The step status
|
||||
workflow_id: Optional workflow run ID
|
||||
"""
|
||||
try:
|
||||
# Get or create event loop for sync context
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
# Already in async context, schedule the coroutine
|
||||
asyncio.create_task(
|
||||
_emit_progress_async(transcript_id, step, status, workflow_id)
|
||||
)
|
||||
else:
|
||||
# Not in async context, run synchronously
|
||||
asyncio.run(_emit_progress_async(transcript_id, step, status, workflow_id))
|
||||
except Exception as e:
|
||||
# Progress emission should never break the pipeline
|
||||
logger.warning(
|
||||
"[Hatchet Progress] Failed to emit progress event",
|
||||
error=str(e),
|
||||
transcript_id=transcript_id,
|
||||
step=step,
|
||||
)
|
||||
|
||||
|
||||
async def emit_progress_async(
|
||||
transcript_id: str,
|
||||
step: str,
|
||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
||||
workflow_id: str | None = None,
|
||||
) -> None:
|
||||
"""Async version of emit_progress for use in async Hatchet tasks."""
|
||||
try:
|
||||
await _emit_progress_async(transcript_id, step, status, workflow_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Hatchet Progress] Failed to emit progress event",
|
||||
error=str(e),
|
||||
transcript_id=transcript_id,
|
||||
step=step,
|
||||
)
|
||||
59
server/reflector/hatchet/run_workers.py
Normal file
59
server/reflector/hatchet/run_workers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Run Hatchet workers for the diarization pipeline.
|
||||
|
||||
Usage:
|
||||
uv run -m reflector.hatchet.run_workers
|
||||
|
||||
# Or via docker:
|
||||
docker compose exec server uv run -m reflector.hatchet.run_workers
|
||||
"""
|
||||
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start Hatchet worker polling."""
|
||||
if not settings.HATCHET_ENABLED:
|
||||
logger.error("HATCHET_ENABLED is False, not starting workers")
|
||||
sys.exit(1)
|
||||
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
logger.error("HATCHET_CLIENT_TOKEN is not set")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(
|
||||
"Starting Hatchet workers",
|
||||
debug=settings.HATCHET_DEBUG,
|
||||
)
|
||||
|
||||
# Import workflows to register them
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows import diarization_pipeline, track_workflow
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
# Create worker with both workflows
|
||||
worker = hatchet.worker(
|
||||
"reflector-diarization-worker",
|
||||
workflows=[diarization_pipeline, track_workflow],
|
||||
)
|
||||
|
||||
# Handle graceful shutdown
|
||||
def shutdown_handler(signum: int, frame) -> None:
|
||||
logger.info("Received shutdown signal, stopping workers...")
|
||||
# Worker cleanup happens automatically on exit
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown_handler)
|
||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||
|
||||
logger.info("Starting Hatchet worker polling...")
|
||||
worker.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
14
server/reflector/hatchet/workflows/__init__.py
Normal file
14
server/reflector/hatchet/workflows/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Hatchet workflow definitions."""
|
||||
|
||||
from reflector.hatchet.workflows.diarization_pipeline import (
|
||||
PipelineInput,
|
||||
diarization_pipeline,
|
||||
)
|
||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||
|
||||
__all__ = [
|
||||
"diarization_pipeline",
|
||||
"track_workflow",
|
||||
"PipelineInput",
|
||||
"TrackInput",
|
||||
]
|
||||
808
server/reflector/hatchet/workflows/diarization_pipeline.py
Normal file
808
server/reflector/hatchet/workflows/diarization_pipeline.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""
|
||||
Hatchet main workflow: DiarizationPipeline
|
||||
|
||||
Multitrack diarization pipeline for Daily.co recordings.
|
||||
Orchestrates the full processing flow from recording metadata to final transcript.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
|
||||
from reflector.logger import logger
|
||||
|
||||
# Audio constants
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||
|
||||
|
||||
class PipelineInput(BaseModel):
|
||||
"""Input to trigger the diarization pipeline."""
|
||||
|
||||
recording_id: str | None
|
||||
room_name: str | None
|
||||
tracks: list[dict] # List of {"s3_key": str}
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
room_id: str | None = None
|
||||
|
||||
|
||||
# Get hatchet client and define workflow
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
diarization_pipeline = hatchet.workflow(
|
||||
name="DiarizationPipeline", input_validator=PipelineInput
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _get_fresh_db_connection():
|
||||
"""Create fresh database connection for subprocess."""
|
||||
import databases
|
||||
|
||||
from reflector.db import _database_context
|
||||
from reflector.settings import settings
|
||||
|
||||
_database_context.set(None)
|
||||
db = databases.Database(settings.DATABASE_URL)
|
||||
_database_context.set(db)
|
||||
await db.connect()
|
||||
return db
|
||||
|
||||
|
||||
async def _close_db_connection(db):
|
||||
"""Close database connection."""
|
||||
from reflector.db import _database_context
|
||||
|
||||
await db.disconnect()
|
||||
_database_context.set(None)
|
||||
|
||||
|
||||
def _get_storage():
|
||||
"""Create fresh storage instance."""
|
||||
from reflector.settings import settings
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
|
||||
return AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Tasks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
||||
async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Fetch recording metadata from Daily.co API."""
|
||||
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
from reflector.settings import settings
|
||||
|
||||
if not input.recording_id:
|
||||
# No recording_id in reprocess path - return minimal data
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {
|
||||
"id": None,
|
||||
"mtg_session_id": None,
|
||||
"room_name": input.room_name,
|
||||
"duration": 0,
|
||||
}
|
||||
|
||||
if not settings.DAILY_API_KEY:
|
||||
raise ValueError("DAILY_API_KEY not configured")
|
||||
|
||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||
recording = await client.get_recording(input.recording_id)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] get_recording complete",
|
||||
recording_id=input.recording_id,
|
||||
room_name=recording.room_name,
|
||||
duration=recording.duration,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"id": recording.id,
|
||||
"mtg_session_id": recording.mtgSessionId,
|
||||
"room_name": recording.room_name,
|
||||
"duration": recording.duration,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
||||
)
|
||||
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Fetch participant list from Daily.co API."""
|
||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
recording_data = ctx.task_output(get_recording)
|
||||
mtg_session_id = recording_data.get("mtg_session_id")
|
||||
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
from reflector.settings import settings
|
||||
|
||||
if not mtg_session_id or not settings.DAILY_API_KEY:
|
||||
# Return empty participants if no session ID
|
||||
await emit_progress_async(
|
||||
input.transcript_id,
|
||||
"get_participants",
|
||||
"completed",
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
return {"participants": [], "num_tracks": len(input.tracks)}
|
||||
|
||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||
participants = await client.get_meeting_participants(mtg_session_id)
|
||||
|
||||
participants_list = [
|
||||
{"participant_id": p.participant_id, "user_name": p.user_name}
|
||||
for p in participants.data
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] get_participants complete",
|
||||
participant_count=len(participants_list),
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"participants": participants_list, "num_tracks": len(input.tracks)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
|
||||
)
|
||||
async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Spawn child workflows for each track (dynamic fan-out).
|
||||
|
||||
Processes pad_track and transcribe_track for each audio track in parallel.
|
||||
"""
|
||||
logger.info(
|
||||
"[Hatchet] process_tracks",
|
||||
num_tracks=len(input.tracks),
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
# Spawn child workflows for each track
|
||||
child_coroutines = [
|
||||
track_workflow.aio_run(
|
||||
TrackInput(
|
||||
track_index=i,
|
||||
s3_key=track["s3_key"],
|
||||
bucket_name=input.bucket_name,
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
)
|
||||
for i, track in enumerate(input.tracks)
|
||||
]
|
||||
|
||||
# Wait for all child workflows to complete
|
||||
results = await asyncio.gather(*child_coroutines)
|
||||
|
||||
# Collect all track results
|
||||
all_words = []
|
||||
padded_urls = []
|
||||
|
||||
for result in results:
|
||||
transcribe_result = result.get("transcribe_track", {})
|
||||
all_words.extend(transcribe_result.get("words", []))
|
||||
|
||||
pad_result = result.get("pad_track", {})
|
||||
padded_urls.append(pad_result.get("padded_url"))
|
||||
|
||||
# Sort words by start time
|
||||
all_words.sort(key=lambda w: w.get("start", 0))
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] process_tracks complete",
|
||||
num_tracks=len(input.tracks),
|
||||
total_words=len(all_words),
|
||||
)
|
||||
|
||||
return {
|
||||
"all_words": all_words,
|
||||
"padded_urls": padded_urls,
|
||||
"word_count": len(all_words),
|
||||
"num_tracks": len(input.tracks),
|
||||
}
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Mix all padded tracks into single audio file."""
|
||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
padded_urls = track_data.get("padded_urls", [])
|
||||
|
||||
if not padded_urls:
|
||||
raise ValueError("No padded tracks to mixdown")
|
||||
|
||||
storage = _get_storage()
|
||||
|
||||
# Download all tracks and mix
|
||||
temp_inputs = []
|
||||
try:
|
||||
for i, url in enumerate(padded_urls):
|
||||
if not url:
|
||||
continue
|
||||
temp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False)
|
||||
temp_inputs.append(temp_input.name)
|
||||
|
||||
# Download track
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
with open(temp_input.name, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Mix using PyAV amix filter
|
||||
if len(temp_inputs) == 0:
|
||||
raise ValueError("No valid tracks to mixdown")
|
||||
|
||||
output_path = tempfile.mktemp(suffix=".mp3")
|
||||
|
||||
try:
|
||||
# Use ffmpeg-style mixing via PyAV
|
||||
containers = [av.open(path) for path in temp_inputs]
|
||||
|
||||
# Get the longest duration
|
||||
max_duration = 0.0
|
||||
for container in containers:
|
||||
if container.duration:
|
||||
duration = float(container.duration * av.time_base)
|
||||
max_duration = max(max_duration, duration)
|
||||
|
||||
# Close containers for now
|
||||
for container in containers:
|
||||
container.close()
|
||||
|
||||
# Use subprocess for mixing (simpler than complex PyAV graph)
|
||||
import subprocess
|
||||
|
||||
# Build ffmpeg command
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
for path in temp_inputs:
|
||||
cmd.extend(["-i", path])
|
||||
|
||||
# Build filter for N inputs
|
||||
n = len(temp_inputs)
|
||||
filter_str = f"amix=inputs={n}:duration=longest:normalize=0"
|
||||
cmd.extend(["-filter_complex", filter_str])
|
||||
cmd.extend(["-ac", "2", "-ar", "48000", "-b:a", "128k", output_path])
|
||||
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
|
||||
# Upload mixed file
|
||||
file_size = Path(output_path).stat().st_size
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/mixed.mp3"
|
||||
|
||||
with open(output_path, "rb") as mixed_file:
|
||||
await storage.put_file(storage_path, mixed_file)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] mixdown_tracks uploaded",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
finally:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
finally:
|
||||
for path in temp_inputs:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"audio_key": storage_path,
|
||||
"duration": max_duration,
|
||||
"tracks_mixed": len(temp_inputs),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
||||
)
|
||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Generate audio waveform visualization."""
|
||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||
audio_key = mixdown_data.get("audio_key")
|
||||
|
||||
storage = _get_storage()
|
||||
audio_url = await storage.get_file_url(
|
||||
audio_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
from reflector.pipelines.waveform_helpers import generate_waveform_data
|
||||
|
||||
waveform = await generate_waveform_data(audio_url)
|
||||
|
||||
# Store waveform
|
||||
waveform_key = f"file_pipeline_hatchet/{input.transcript_id}/waveform.json"
|
||||
import json
|
||||
|
||||
waveform_bytes = json.dumps(waveform).encode()
|
||||
import io
|
||||
|
||||
await storage.put_file(waveform_key, io.BytesIO(waveform_bytes))
|
||||
|
||||
logger.info("[Hatchet] generate_waveform complete")
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"waveform_key": waveform_key}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Detect topics using LLM."""
|
||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
words = track_data.get("all_words", [])
|
||||
|
||||
from reflector.pipelines import topic_processing
|
||||
from reflector.processors.types import Transcript as TranscriptType
|
||||
from reflector.processors.types import Word
|
||||
|
||||
# Convert word dicts to Word objects
|
||||
word_objects = [Word(**w) for w in words]
|
||||
transcript = TranscriptType(words=word_objects)
|
||||
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
|
||||
async def noop_callback(t):
|
||||
pass
|
||||
|
||||
topics = await topic_processing.detect_topics(
|
||||
transcript,
|
||||
"en", # target_language
|
||||
on_topic_callback=noop_callback,
|
||||
empty_pipeline=empty_pipeline,
|
||||
)
|
||||
|
||||
topics_list = [t.model_dump() for t in topics]
|
||||
|
||||
logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list))
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"topics": topics_list}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
||||
)
|
||||
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Generate meeting title using LLM."""
|
||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
topics_data = ctx.task_output(detect_topics)
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
from reflector.pipelines import topic_processing
|
||||
from reflector.processors.types import Topic
|
||||
|
||||
topic_objects = [Topic(**t) for t in topics]
|
||||
|
||||
title = await topic_processing.generate_title(topic_objects)
|
||||
|
||||
logger.info("[Hatchet] generate_title complete", title=title)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"title": title}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
||||
)
|
||||
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Generate meeting summary using LLM."""
|
||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
topics_data = ctx.task_output(detect_topics)
|
||||
|
||||
words = track_data.get("all_words", [])
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
from reflector.pipelines import topic_processing
|
||||
from reflector.processors.types import Topic, Word
|
||||
from reflector.processors.types import Transcript as TranscriptType
|
||||
|
||||
word_objects = [Word(**w) for w in words]
|
||||
transcript = TranscriptType(words=word_objects)
|
||||
topic_objects = [Topic(**t) for t in topics]
|
||||
|
||||
summary, short_summary = await topic_processing.generate_summary(
|
||||
transcript, topic_objects
|
||||
)
|
||||
|
||||
logger.info("[Hatchet] generate_summary complete")
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"summary": summary, "short_summary": short_summary}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[generate_waveform, generate_title, generate_summary],
|
||||
execution_timeout=timedelta(seconds=60),
|
||||
retries=3,
|
||||
)
|
||||
async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Finalize transcript status and update database."""
|
||||
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
title_data = ctx.task_output(generate_title)
|
||||
summary_data = ctx.task_output(generate_summary)
|
||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||
track_data = ctx.task_output(process_tracks)
|
||||
|
||||
title = title_data.get("title", "")
|
||||
summary = summary_data.get("summary", "")
|
||||
short_summary = summary_data.get("short_summary", "")
|
||||
duration = mixdown_data.get("duration", 0)
|
||||
all_words = track_data.get("all_words", [])
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript is None:
|
||||
raise ValueError(
|
||||
f"Transcript {input.transcript_id} not found in database"
|
||||
)
|
||||
|
||||
# Convert words back to Word objects for storage
|
||||
word_objects = [Word(**w) for w in all_words]
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": "ended",
|
||||
"title": title,
|
||||
"long_summary": summary,
|
||||
"short_summary": short_summary,
|
||||
"duration": duration,
|
||||
"words": word_objects,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] finalize complete", transcript_id=input.transcript_id
|
||||
)
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"status": "COMPLETED"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
|
||||
)
|
||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Check and handle consent requirements."""
|
||||
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript and transcript.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
|
||||
if meeting:
|
||||
# Check consent logic here
|
||||
# For now just mark as checked
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
||||
)
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"consent_checked": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] cleanup_consent failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
|
||||
)
|
||||
async def post_zulip(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Post notification to Zulip."""
|
||||
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
from reflector.settings import settings
|
||||
|
||||
if not settings.ZULIP_REALM:
|
||||
logger.info("[Hatchet] post_zulip skipped (Zulip not configured)")
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {"zulip_message_id": None, "skipped": True}
|
||||
|
||||
from reflector.zulip import post_transcript_notification
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
message_id = await post_transcript_notification(transcript)
|
||||
logger.info(
|
||||
"[Hatchet] post_zulip complete", zulip_message_id=message_id
|
||||
)
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"zulip_message_id": message_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] post_zulip failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@diarization_pipeline.task(
|
||||
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
|
||||
)
|
||||
async def send_webhook(input: PipelineInput, ctx: Context) -> dict:
|
||||
"""Send completion webhook to external service."""
|
||||
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
if not input.room_id:
|
||||
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {"webhook_sent": False, "skipped": True}
|
||||
|
||||
db = await _get_fresh_db_connection()
|
||||
|
||||
try:
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
room = await rooms_controller.get_by_id(input.room_id)
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
if room and room.webhook_url and transcript:
|
||||
import httpx
|
||||
|
||||
webhook_payload = {
|
||||
"event": "transcript.completed",
|
||||
"transcript_id": input.transcript_id,
|
||||
"title": transcript.title,
|
||||
"duration": transcript.duration,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
room.webhook_url, json=webhook_payload, timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] send_webhook complete", status_code=response.status_code
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id,
|
||||
"send_webhook",
|
||||
"completed",
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
|
||||
return {"webhook_sent": True, "response_code": response.status_code}
|
||||
|
||||
finally:
|
||||
await _close_db_connection(db)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {"webhook_sent": False, "skipped": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] send_webhook failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
337
server/reflector/hatchet/workflows/track_processing.py
Normal file
337
server/reflector/hatchet/workflows/track_processing.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Hatchet child workflow: TrackProcessing
|
||||
|
||||
Handles individual audio track processing: padding and transcription.
|
||||
Spawned dynamically by the main diarization pipeline for each track.
|
||||
"""
|
||||
|
||||
import math
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
from av.audio.resampler import AudioResampler
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.logger import logger
|
||||
|
||||
# Audio constants matching existing pipeline
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||
|
||||
|
||||
class TrackInput(BaseModel):
|
||||
"""Input for individual track processing."""
|
||||
|
||||
track_index: int
|
||||
s3_key: str
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
language: str = "en"
|
||||
|
||||
|
||||
# Get hatchet client and define workflow
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||
|
||||
|
||||
def _extract_stream_start_time_from_container(container, track_idx: int) -> float:
|
||||
"""Extract meeting-relative start time from WebM stream metadata.
|
||||
|
||||
Uses PyAV to read stream.start_time from WebM container.
|
||||
More accurate than filename timestamps by ~209ms due to network/encoding delays.
|
||||
"""
|
||||
start_time_seconds = 0.0
|
||||
try:
|
||||
audio_streams = [s for s in container.streams if s.type == "audio"]
|
||||
stream = audio_streams[0] if audio_streams else container.streams[0]
|
||||
|
||||
# 1) Try stream-level start_time (most reliable for Daily.co tracks)
|
||||
if stream.start_time is not None and stream.time_base is not None:
|
||||
start_time_seconds = float(stream.start_time * stream.time_base)
|
||||
|
||||
# 2) Fallback to container-level start_time
|
||||
if (start_time_seconds <= 0) and (container.start_time is not None):
|
||||
start_time_seconds = float(container.start_time * av.time_base)
|
||||
|
||||
# 3) Fallback to first packet DTS
|
||||
if start_time_seconds <= 0:
|
||||
for packet in container.demux(stream):
|
||||
if packet.dts is not None:
|
||||
start_time_seconds = float(packet.dts * stream.time_base)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"PyAV metadata read failed; assuming 0 start_time",
|
||||
track_idx=track_idx,
|
||||
error=str(e),
|
||||
)
|
||||
start_time_seconds = 0.0
|
||||
|
||||
logger.info(
|
||||
f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s",
|
||||
track_idx=track_idx,
|
||||
)
|
||||
return start_time_seconds
|
||||
|
||||
|
||||
def _apply_audio_padding_to_file(
|
||||
in_container,
|
||||
output_path: str,
|
||||
start_time_seconds: float,
|
||||
track_idx: int,
|
||||
) -> None:
|
||||
"""Apply silence padding to audio track using PyAV filter graph."""
|
||||
delay_ms = math.floor(start_time_seconds * 1000)
|
||||
|
||||
logger.info(
|
||||
f"Padding track {track_idx} with {delay_ms}ms delay using PyAV",
|
||||
track_idx=track_idx,
|
||||
delay_ms=delay_ms,
|
||||
)
|
||||
|
||||
with av.open(output_path, "w", format="webm") as out_container:
|
||||
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
|
||||
if in_stream is None:
|
||||
raise Exception("No audio stream in input")
|
||||
|
||||
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
|
||||
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
|
||||
graph = av.filter.Graph()
|
||||
|
||||
abuf_args = (
|
||||
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_fmt=s16:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
src = graph.add("abuffer", args=abuf_args, name="src")
|
||||
aresample_f = graph.add("aresample", args="async=1", name="ares")
|
||||
delays_arg = f"{delay_ms}|{delay_ms}"
|
||||
adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay")
|
||||
sink = graph.add("abuffersink", name="sink")
|
||||
|
||||
src.link_to(aresample_f)
|
||||
aresample_f.link_to(adelay_f)
|
||||
adelay_f.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
resampler = AudioResampler(
|
||||
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
|
||||
)
|
||||
|
||||
for frame in in_container.decode(in_stream):
|
||||
out_frames = resampler.resample(frame) or []
|
||||
for rframe in out_frames:
|
||||
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
src.push(rframe)
|
||||
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush remaining frames
|
||||
src.push(None)
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
|
||||
|
||||
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
"""Pad single audio track with silence for alignment.
|
||||
|
||||
Extracts stream.start_time from WebM container metadata and applies
|
||||
silence padding using PyAV filter graph (adelay).
|
||||
"""
|
||||
logger.info(
|
||||
"[Hatchet] pad_track",
|
||||
track_index=input.track_index,
|
||||
s3_key=input.s3_key,
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
from reflector.settings import settings
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
# Get presigned URL for source file
|
||||
source_url = await storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
bucket=input.bucket_name,
|
||||
)
|
||||
|
||||
# Open container and extract start time
|
||||
with av.open(source_url) as in_container:
|
||||
start_time_seconds = _extract_stream_start_time_from_container(
|
||||
in_container, input.track_index
|
||||
)
|
||||
|
||||
# If no padding needed, return original URL
|
||||
if start_time_seconds <= 0:
|
||||
logger.info(
|
||||
f"Track {input.track_index} requires no padding",
|
||||
track_index=input.track_index,
|
||||
)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {
|
||||
"padded_url": source_url,
|
||||
"size": 0,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
|
||||
# Create temp file for padded output
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
_apply_audio_padding_to_file(
|
||||
in_container, temp_path, start_time_seconds, input.track_index
|
||||
)
|
||||
|
||||
file_size = Path(temp_path).stat().st_size
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||
|
||||
logger.info(
|
||||
f"About to upload padded track",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
with open(temp_path, "rb") as padded_file:
|
||||
await storage.put_file(storage_path, padded_file)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded padded track to S3",
|
||||
key=storage_path,
|
||||
size=file_size,
|
||||
)
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
# Get presigned URL for padded file
|
||||
padded_url = await storage.get_file_url(
|
||||
storage_path,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] pad_track complete",
|
||||
track_index=input.track_index,
|
||||
padded_url=padded_url[:50] + "...",
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"padded_url": padded_url,
|
||||
"size": file_size,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@track_workflow.task(
|
||||
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
|
||||
)
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track",
|
||||
track_index=input.track_index,
|
||||
language=input.language,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
pad_result = ctx.task_output(pad_track)
|
||||
audio_url = pad_result.get("padded_url")
|
||||
|
||||
if not audio_url:
|
||||
raise ValueError("Missing padded_url from pad_track")
|
||||
|
||||
from reflector.pipelines.transcription_helpers import (
|
||||
transcribe_file_with_processor,
|
||||
)
|
||||
|
||||
transcript = await transcribe_file_with_processor(audio_url, input.language)
|
||||
|
||||
# Tag all words with speaker index
|
||||
words = []
|
||||
for word in transcript.words:
|
||||
word_dict = word.model_dump()
|
||||
word_dict["speaker"] = input.track_index
|
||||
words.append(word_dict)
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track complete",
|
||||
track_index=input.track_index,
|
||||
word_count=len(words),
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
@@ -15,6 +15,7 @@ from celery.result import AsyncResult
|
||||
from reflector.conductor.client import ConductorClientManager
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.transcripts import Transcript
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_multitrack_pipeline import (
|
||||
@@ -156,8 +157,47 @@ async def prepare_transcript_processing(
|
||||
|
||||
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None:
|
||||
if isinstance(config, MultitrackProcessingConfig):
|
||||
# Start Conductor workflow if enabled
|
||||
if settings.CONDUCTOR_ENABLED:
|
||||
# Start durable workflow if enabled (Hatchet or Conductor)
|
||||
durable_started = False
|
||||
|
||||
if settings.HATCHET_ENABLED:
|
||||
import asyncio
|
||||
|
||||
async def _start_hatchet():
|
||||
return await HatchetClientManager.start_workflow(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
"recording_id": config.recording_id,
|
||||
"room_name": None, # Not available in reprocess path
|
||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||
"bucket_name": config.bucket_name,
|
||||
"transcript_id": config.transcript_id,
|
||||
"room_id": config.room_id,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Already in async context
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result()
|
||||
else:
|
||||
workflow_id = asyncio.run(_start_hatchet())
|
||||
|
||||
logger.info(
|
||||
"Started Hatchet workflow (reprocess)",
|
||||
workflow_id=workflow_id,
|
||||
transcript_id=config.transcript_id,
|
||||
)
|
||||
durable_started = True
|
||||
|
||||
elif settings.CONDUCTOR_ENABLED:
|
||||
workflow_id = ConductorClientManager.start_workflow(
|
||||
name="diarization_pipeline",
|
||||
version=1,
|
||||
@@ -175,11 +215,13 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
|
||||
workflow_id=workflow_id,
|
||||
transcript_id=config.transcript_id,
|
||||
)
|
||||
durable_started = True
|
||||
|
||||
if not settings.CONDUCTOR_SHADOW_MODE:
|
||||
return None # Conductor-only, no Celery result
|
||||
# If durable workflow started and not in shadow mode, skip Celery
|
||||
if durable_started and not settings.DURABLE_WORKFLOW_SHADOW_MODE:
|
||||
return None
|
||||
|
||||
# Celery pipeline (shadow mode or Conductor disabled)
|
||||
# Celery pipeline (shadow mode or durable workflows disabled)
|
||||
return task_pipeline_multitrack_process.delay(
|
||||
transcript_id=config.transcript_id,
|
||||
bucket_name=config.bucket_name,
|
||||
|
||||
@@ -150,11 +150,34 @@ class Settings(BaseSettings):
|
||||
ZULIP_API_KEY: str | None = None
|
||||
ZULIP_BOT_EMAIL: str | None = None
|
||||
|
||||
# Durable workflow orchestration
|
||||
# Provider: "hatchet" or "conductor" (or "none" to disable)
|
||||
DURABLE_WORKFLOW_PROVIDER: str = "none"
|
||||
DURABLE_WORKFLOW_SHADOW_MODE: bool = False # Run both provider + Celery
|
||||
|
||||
# Conductor workflow orchestration
|
||||
CONDUCTOR_SERVER_URL: str = "http://conductor:8080/api"
|
||||
CONDUCTOR_DEBUG: bool = False
|
||||
CONDUCTOR_ENABLED: bool = False
|
||||
CONDUCTOR_SHADOW_MODE: bool = False
|
||||
|
||||
# Hatchet workflow orchestration
|
||||
HATCHET_CLIENT_TOKEN: str | None = None
|
||||
HATCHET_CLIENT_TLS_STRATEGY: str = "none" # none, tls, mtls
|
||||
HATCHET_DEBUG: bool = False
|
||||
|
||||
@property
|
||||
def CONDUCTOR_ENABLED(self) -> bool:
|
||||
"""Legacy compatibility: True if Conductor is the active provider."""
|
||||
return self.DURABLE_WORKFLOW_PROVIDER == "conductor"
|
||||
|
||||
@property
|
||||
def HATCHET_ENABLED(self) -> bool:
|
||||
"""True if Hatchet is the active provider."""
|
||||
return self.DURABLE_WORKFLOW_PROVIDER == "hatchet"
|
||||
|
||||
@property
|
||||
def CONDUCTOR_SHADOW_MODE(self) -> bool:
|
||||
"""Legacy compatibility for shadow mode."""
|
||||
return self.DURABLE_WORKFLOW_SHADOW_MODE and self.CONDUCTOR_ENABLED
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
57
server/reflector/views/hatchet.py
Normal file
57
server/reflector/views/hatchet.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Hatchet health and status endpoints."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/hatchet", tags=["hatchet"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def hatchet_health():
|
||||
"""Check Hatchet connectivity and status."""
|
||||
if not settings.HATCHET_ENABLED:
|
||||
return {"status": "disabled", "connected": False}
|
||||
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"connected": False,
|
||||
"error": "HATCHET_CLIENT_TOKEN not configured",
|
||||
}
|
||||
|
||||
try:
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
# Get client to verify token is valid
|
||||
client = HatchetClientManager.get_client()
|
||||
|
||||
# Try to get the client's gRPC connection status
|
||||
# The SDK doesn't have a simple health check, so we just verify we can create the client
|
||||
if client is not None:
|
||||
return {"status": "healthy", "connected": True}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"connected": False,
|
||||
"error": "Failed to create client",
|
||||
}
|
||||
except ValueError as e:
|
||||
return {"status": "unhealthy", "connected": False, "error": str(e)}
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "connected": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/workflow/{workflow_run_id}")
|
||||
async def get_workflow_status(workflow_run_id: str):
|
||||
"""Get the status of a workflow run."""
|
||||
if not settings.HATCHET_ENABLED:
|
||||
return {"error": "Hatchet is disabled"}
|
||||
|
||||
try:
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
status = await HatchetClientManager.get_workflow_status(workflow_run_id)
|
||||
return status
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -286,8 +286,34 @@ async def _process_multitrack_recording_inner(
|
||||
room_id=room.id,
|
||||
)
|
||||
|
||||
# Start Conductor workflow if enabled
|
||||
if settings.CONDUCTOR_ENABLED:
|
||||
# Start durable workflow if enabled (Hatchet or Conductor)
|
||||
durable_started = False
|
||||
|
||||
if settings.HATCHET_ENABLED:
|
||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||
|
||||
workflow_id = await HatchetClientManager.start_workflow(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
"recording_id": recording_id,
|
||||
"room_name": daily_room_name,
|
||||
"tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)],
|
||||
"bucket_name": bucket_name,
|
||||
"transcript_id": transcript.id,
|
||||
"room_id": room.id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Started Hatchet workflow",
|
||||
workflow_id=workflow_id,
|
||||
transcript_id=transcript.id,
|
||||
)
|
||||
|
||||
# Store workflow_id on recording for status tracking
|
||||
await recordings_controller.update(recording, {"workflow_id": workflow_id})
|
||||
durable_started = True
|
||||
|
||||
elif settings.CONDUCTOR_ENABLED:
|
||||
from reflector.conductor.client import ConductorClientManager # noqa: PLC0415
|
||||
|
||||
workflow_id = ConductorClientManager.start_workflow(
|
||||
@@ -310,11 +336,13 @@ async def _process_multitrack_recording_inner(
|
||||
|
||||
# Store workflow_id on recording for status tracking
|
||||
await recordings_controller.update(recording, {"workflow_id": workflow_id})
|
||||
durable_started = True
|
||||
|
||||
if not settings.CONDUCTOR_SHADOW_MODE:
|
||||
return # Don't trigger Celery
|
||||
# If durable workflow started and not in shadow mode, skip Celery
|
||||
if durable_started and not settings.DURABLE_WORKFLOW_SHADOW_MODE:
|
||||
return
|
||||
|
||||
# Celery pipeline (runs when Conductor disabled OR in shadow mode)
|
||||
# Celery pipeline (runs when durable workflows disabled OR in shadow mode)
|
||||
task_pipeline_multitrack_process.delay(
|
||||
transcript_id=transcript.id,
|
||||
bucket_name=bucket_name,
|
||||
|
||||
Reference in New Issue
Block a user