feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)

* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine

* fix: always force reprocessing

* fix: ci tests with live pipelines

* fix: ci tests with live pipelines
This commit is contained in:
Juan Diego García
2026-03-16 16:07:16 -05:00
committed by GitHub
parent 72dca7cacc
commit 37a1f01850
22 changed files with 2140 additions and 353 deletions

View File

@@ -51,6 +51,9 @@ services:
HF_TOKEN: ${HF_TOKEN:-} HF_TOKEN: ${HF_TOKEN:-}
# WebRTC: fixed UDP port range for ICE candidates (mapped above) # WebRTC: fixed UDP port range for ICE candidates (mapped above)
WEBRTC_PORT_RANGE: "51000-51100" WEBRTC_PORT_RANGE: "51000-51100"
# Hatchet workflow engine (always-on for processing pipelines)
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
@@ -75,6 +78,9 @@ services:
CELERY_RESULT_BACKEND: redis://redis:6379/1 CELERY_RESULT_BACKEND: redis://redis:6379/1
# ML backend config comes from env_file (server/.env), set per-mode by setup script # ML backend config comes from env_file (server/.env), set per-mode by setup script
HF_TOKEN: ${HF_TOKEN:-} HF_TOKEN: ${HF_TOKEN:-}
# Hatchet workflow engine (always-on for processing pipelines)
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
@@ -126,6 +132,8 @@ services:
redis: redis:
image: redis:7.2-alpine image: redis:7.2-alpine
restart: unless-stopped restart: unless-stopped
ports:
- "6379:6379"
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "ping"] test: ["CMD", "redis-cli", "ping"]
interval: 30s interval: 30s
@@ -301,20 +309,20 @@ services:
- server - server
# =========================================================== # ===========================================================
# Hatchet + Daily.co workers (optional — for Daily.co multitrack processing) # Hatchet workflow engine + workers
# Auto-enabled when DAILY_API_KEY is configured in server/r # Required for all processing pipelines (file, live, Daily.co multitrack).
# Always-on — every selfhosted deployment needs Hatchet.
# =========================================================== # ===========================================================
hatchet: hatchet:
image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest
profiles: [dailyco]
restart: on-failure restart: on-failure
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
ports: ports:
- "8888:8888" - "127.0.0.1:8888:8888"
- "7078:7077" - "127.0.0.1:7078:7077"
env_file: env_file:
- ./.env.hatchet - ./.env.hatchet
environment: environment:
@@ -363,7 +371,6 @@ services:
context: ./server context: ./server
dockerfile: Dockerfile dockerfile: Dockerfile
image: monadicalsas/reflector-backend:latest image: monadicalsas/reflector-backend:latest
profiles: [dailyco]
restart: unless-stopped restart: unless-stopped
env_file: env_file:
- ./server/.env - ./server/.env

View File

@@ -261,9 +261,11 @@ if [[ -z "$MODEL_MODE" ]]; then
fi fi
# Build profiles list — one profile per feature # Build profiles list — one profile per feature
# Only --gpu needs a compose profile; --cpu and --hosted use in-process/remote backends # Hatchet + hatchet-worker-llm are always-on (no profile needed).
# gpu/cpu profiles only control the ML container (transcription service).
COMPOSE_PROFILES=() COMPOSE_PROFILES=()
[[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu") [[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu")
[[ "$MODEL_MODE" == "cpu" ]] && COMPOSE_PROFILES+=("cpu")
[[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE") [[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE")
[[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage") [[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage")
[[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy") [[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy")
@@ -557,12 +559,10 @@ step_server_env() {
ok "CPU mode — file processing timeouts set to 3600s (1 hour)" ok "CPU mode — file processing timeouts set to 3600s (1 hour)"
fi fi
# If Daily.co is manually configured, ensure Hatchet connectivity vars are set # Hatchet is always required (file, live, and multitrack pipelines all use it)
if env_has_key "$SERVER_ENV" "DAILY_API_KEY" && [[ -n "$(env_get "$SERVER_ENV" "DAILY_API_KEY")" ]]; then
env_set "$SERVER_ENV" "HATCHET_CLIENT_SERVER_URL" "http://hatchet:8888" env_set "$SERVER_ENV" "HATCHET_CLIENT_SERVER_URL" "http://hatchet:8888"
env_set "$SERVER_ENV" "HATCHET_CLIENT_HOST_PORT" "hatchet:7077" env_set "$SERVER_ENV" "HATCHET_CLIENT_HOST_PORT" "hatchet:7077"
ok "Daily.co detected — Hatchet connectivity configured" ok "Hatchet connectivity configured (workflow engine for processing pipelines)"
fi
ok "server/.env ready" ok "server/.env ready"
} }
@@ -886,15 +886,22 @@ step_services() {
compose_cmd pull server web || warn "Pull failed — using cached images" compose_cmd pull server web || warn "Pull failed — using cached images"
fi fi
# Build hatchet workers if Daily.co is configured (same backend image) # Hatchet is always needed (all processing pipelines use it)
if [[ "$DAILY_DETECTED" == "true" ]] && [[ "$BUILD_IMAGES" == "true" ]]; then local NEEDS_HATCHET=true
# Build hatchet workers if Hatchet is needed (same backend image)
if [[ "$NEEDS_HATCHET" == "true" ]] && [[ "$BUILD_IMAGES" == "true" ]]; then
info "Building Hatchet worker images..." info "Building Hatchet worker images..."
if [[ "$DAILY_DETECTED" == "true" ]]; then
compose_cmd build hatchet-worker-cpu hatchet-worker-llm compose_cmd build hatchet-worker-cpu hatchet-worker-llm
else
compose_cmd build hatchet-worker-llm
fi
ok "Hatchet worker images built" ok "Hatchet worker images built"
fi fi
# Ensure hatchet database exists before starting hatchet (init-hatchet-db.sql only runs on fresh postgres volumes) # Ensure hatchet database exists before starting hatchet (init-hatchet-db.sql only runs on fresh postgres volumes)
if [[ "$DAILY_DETECTED" == "true" ]]; then if [[ "$NEEDS_HATCHET" == "true" ]]; then
info "Ensuring postgres is running for Hatchet database setup..." info "Ensuring postgres is running for Hatchet database setup..."
compose_cmd up -d postgres compose_cmd up -d postgres
local pg_ready=false local pg_ready=false
@@ -1049,8 +1056,7 @@ step_health() {
fi fi
fi fi
# Hatchet (if Daily.co detected) # Hatchet (always-on)
if [[ "$DAILY_DETECTED" == "true" ]]; then
info "Waiting for Hatchet workflow engine..." info "Waiting for Hatchet workflow engine..."
local hatchet_ok=false local hatchet_ok=false
for i in $(seq 1 60); do for i in $(seq 1 60); do
@@ -1067,7 +1073,6 @@ step_health() {
else else
warn "Hatchet not ready yet. Check: docker compose logs hatchet" warn "Hatchet not ready yet. Check: docker compose logs hatchet"
fi fi
fi
# LLM warning for non-Ollama modes # LLM warning for non-Ollama modes
if [[ "$USES_OLLAMA" == "false" ]]; then if [[ "$USES_OLLAMA" == "false" ]]; then
@@ -1087,12 +1092,10 @@ step_health() {
} }
# ========================================================= # =========================================================
# Step 8: Hatchet token generation (Daily.co only) # Step 8: Hatchet token generation (gpu/cpu/Daily.co)
# ========================================================= # =========================================================
step_hatchet_token() { step_hatchet_token() {
if [[ "$DAILY_DETECTED" != "true" ]]; then # Hatchet is always required — no gating needed
return
fi
# Skip if token already set # Skip if token already set
if env_has_key "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" && [[ -n "$(env_get "$SERVER_ENV" "HATCHET_CLIENT_TOKEN")" ]]; then if env_has_key "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" && [[ -n "$(env_get "$SERVER_ENV" "HATCHET_CLIENT_TOKEN")" ]]; then
@@ -1147,7 +1150,9 @@ step_hatchet_token() {
# Restart services that need the token # Restart services that need the token
info "Restarting services with new Hatchet token..." info "Restarting services with new Hatchet token..."
compose_cmd restart server worker hatchet-worker-cpu hatchet-worker-llm local restart_services="server worker hatchet-worker-llm"
[[ "$DAILY_DETECTED" == "true" ]] && restart_services="$restart_services hatchet-worker-cpu"
compose_cmd restart $restart_services
ok "Services restarted with Hatchet token" ok "Services restarted with Hatchet token"
} }
@@ -1216,8 +1221,7 @@ main() {
ok "Daily.co detected — enabling Hatchet workflow services" ok "Daily.co detected — enabling Hatchet workflow services"
fi fi
# Generate .env.hatchet for hatchet dashboard config # Generate .env.hatchet for hatchet dashboard config (always needed)
if [[ "$DAILY_DETECTED" == "true" ]]; then
local hatchet_server_url hatchet_cookie_domain local hatchet_server_url hatchet_cookie_domain
if [[ -n "$CUSTOM_DOMAIN" ]]; then if [[ -n "$CUSTOM_DOMAIN" ]]; then
hatchet_server_url="https://${CUSTOM_DOMAIN}:8888" hatchet_server_url="https://${CUSTOM_DOMAIN}:8888"
@@ -1234,10 +1238,6 @@ SERVER_URL=$hatchet_server_url
SERVER_AUTH_COOKIE_DOMAIN=$hatchet_cookie_domain SERVER_AUTH_COOKIE_DOMAIN=$hatchet_cookie_domain
EOF EOF
ok "Generated .env.hatchet (dashboard URL=$hatchet_server_url)" ok "Generated .env.hatchet (dashboard URL=$hatchet_server_url)"
else
# Create empty .env.hatchet so compose doesn't fail if dailyco profile is ever activated manually
touch "$ROOT_DIR/.env.hatchet"
fi
step_www_env step_www_env
echo "" echo ""

View File

@@ -116,6 +116,7 @@ source = ["reflector"]
ENVIRONMENT = "pytest" ENVIRONMENT = "pytest"
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test" DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
AUTH_BACKEND = "jwt" AUTH_BACKEND = "jwt"
HATCHET_CLIENT_TOKEN = "test-dummy-token"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"

View File

@@ -26,6 +26,21 @@ class TaskName(StrEnum):
DETECT_CHUNK_TOPIC = "detect_chunk_topic" DETECT_CHUNK_TOPIC = "detect_chunk_topic"
GENERATE_DETAILED_SUMMARY = "generate_detailed_summary" GENERATE_DETAILED_SUMMARY = "generate_detailed_summary"
# File pipeline tasks
EXTRACT_AUDIO = "extract_audio"
UPLOAD_AUDIO = "upload_audio"
TRANSCRIBE = "transcribe"
DIARIZE = "diarize"
ASSEMBLE_TRANSCRIPT = "assemble_transcript"
GENERATE_SUMMARIES = "generate_summaries"
# Live post-processing pipeline tasks
WAVEFORM = "waveform"
CONVERT_MP3 = "convert_mp3"
UPLOAD_MP3 = "upload_mp3"
REMOVE_UPLOAD = "remove_upload"
FINAL_SUMMARIES = "final_summaries"
# Rate limit key for LLM API calls (shared across all LLM-calling tasks) # Rate limit key for LLM API calls (shared across all LLM-calling tasks)
LLM_RATE_LIMIT_KEY = "llm" LLM_RATE_LIMIT_KEY = "llm"

View File

@@ -10,6 +10,8 @@ from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import ( from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline, daily_multitrack_pipeline,
) )
from reflector.hatchet.workflows.file_pipeline import file_pipeline
from reflector.hatchet.workflows.live_post_pipeline import live_post_pipeline
from reflector.hatchet.workflows.subject_processing import subject_workflow from reflector.hatchet.workflows.subject_processing import subject_workflow
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
from reflector.hatchet.workflows.track_processing import track_workflow from reflector.hatchet.workflows.track_processing import track_workflow
@@ -47,6 +49,8 @@ def main():
}, },
workflows=[ workflows=[
daily_multitrack_pipeline, daily_multitrack_pipeline,
file_pipeline,
live_post_pipeline,
topic_chunk_workflow, topic_chunk_workflow,
subject_workflow, subject_workflow,
track_workflow, track_workflow,

View File

@@ -0,0 +1,885 @@
"""
Hatchet workflow: FilePipeline
Processing pipeline for file uploads and Whereby recordings.
Orchestrates: extract audio → upload → transcribe/diarize/waveform (parallel)
→ assemble → detect topics → title/summaries (parallel) → finalize
→ cleanup consent → post zulip / send webhook.
Note: This file uses deferred imports (inside functions/tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
are not shared across forks, avoiding connection pooling issues.
"""
import json
from datetime import timedelta
from pathlib import Path
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.broadcast import (
append_event_and_broadcast,
set_status_and_broadcast,
)
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import (
TIMEOUT_HEAVY,
TIMEOUT_MEDIUM,
TIMEOUT_SHORT,
TIMEOUT_TITLE,
TaskName,
)
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
fresh_db_connection,
set_workflow_error_status,
with_error_handling,
)
from reflector.hatchet.workflows.models import (
ConsentResult,
TitleResult,
TopicsResult,
WaveformResult,
WebhookResult,
ZulipResult,
)
from reflector.logger import logger
from reflector.pipelines import topic_processing
from reflector.settings import settings
from reflector.utils.audio_constants import WAVEFORM_SEGMENTS
from reflector.utils.audio_waveform import get_audio_waveform
class FilePipelineInput(BaseModel):
transcript_id: str
room_id: str | None = None
# --- Result models specific to file pipeline ---
class ExtractAudioResult(BaseModel):
audio_path: str
duration_ms: float = 0.0
class UploadAudioResult(BaseModel):
audio_url: str
audio_path: str
class TranscribeResult(BaseModel):
words: list[dict]
translation: str | None = None
class DiarizeResult(BaseModel):
diarization: list[dict] | None = None
class AssembleTranscriptResult(BaseModel):
assembled: bool
class SummariesResult(BaseModel):
generated: bool
class FinalizeResult(BaseModel):
status: str
hatchet = HatchetClientManager.get_client()
file_pipeline = hatchet.workflow(name="FilePipeline", input_validator=FilePipelineInput)
@file_pipeline.task(
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.EXTRACT_AUDIO)
async def extract_audio(input: FilePipelineInput, ctx: Context) -> ExtractAudioResult:
"""Extract audio from upload file, convert to MP3."""
ctx.log(f"extract_audio: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
await set_status_and_broadcast(input.transcript_id, "processing", logger=logger)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
# Clear transcript as we're going to regenerate everything
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
},
)
# Find upload file
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)
if not audio_file:
raise ValueError("No audio file found to process")
ctx.log(f"extract_audio: processing {audio_file}")
# Extract audio and write as MP3
import av # noqa: PLC0415
from reflector.processors import AudioFileWriterProcessor # noqa: PLC0415
duration_ms_container = [0.0]
async def capture_duration(d):
duration_ms_container[0] = d
mp3_writer = AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
on_duration=capture_duration,
)
input_container = av.open(str(audio_file))
for frame in input_container.decode(audio=0):
await mp3_writer.push(frame)
await mp3_writer.flush()
input_container.close()
duration_ms = duration_ms_container[0]
audio_path = str(transcript.audio_mp3_filename)
# Persist duration to database and broadcast to websocket clients
from reflector.db.transcripts import TranscriptDuration # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller as tc
await tc.update(transcript, {"duration": duration_ms})
await append_event_and_broadcast(
input.transcript_id,
transcript,
"DURATION",
TranscriptDuration(duration=duration_ms),
logger=logger,
)
ctx.log(f"extract_audio complete: {audio_path}, duration={duration_ms}ms")
return ExtractAudioResult(audio_path=audio_path, duration_ms=duration_ms)
@file_pipeline.task(
parents=[extract_audio],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.UPLOAD_AUDIO)
async def upload_audio(input: FilePipelineInput, ctx: Context) -> UploadAudioResult:
"""Upload audio to S3/storage, return audio_url."""
ctx.log(f"upload_audio: starting for transcript_id={input.transcript_id}")
extract_result = ctx.task_output(extract_audio)
audio_path = extract_result.audio_path
from reflector.storage import get_transcripts_storage # noqa: PLC0415
storage = get_transcripts_storage()
if not storage:
raise ValueError(
"Storage backend required for file processing. "
"Configure TRANSCRIPT_STORAGE_* settings."
)
with open(audio_path, "rb") as f:
audio_data = f.read()
storage_path = f"file_pipeline/{input.transcript_id}/audio.mp3"
await storage.put_file(storage_path, audio_data)
audio_url = await storage.get_file_url(storage_path)
ctx.log(f"upload_audio complete: {audio_url}")
return UploadAudioResult(audio_url=audio_url, audio_path=audio_path)
@file_pipeline.task(
parents=[upload_audio],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.TRANSCRIBE)
async def transcribe(input: FilePipelineInput, ctx: Context) -> TranscribeResult:
"""Transcribe the audio file using the configured backend."""
ctx.log(f"transcribe: starting for transcript_id={input.transcript_id}")
upload_result = ctx.task_output(upload_audio)
audio_url = upload_result.audio_url
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
source_language = transcript.source_language
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
transcribe_file_with_processor,
)
result = await transcribe_file_with_processor(audio_url, source_language)
ctx.log(f"transcribe complete: {len(result.words)} words")
return TranscribeResult(
words=[w.model_dump() for w in result.words],
translation=result.translation,
)
@file_pipeline.task(
parents=[upload_audio],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.DIARIZE)
async def diarize(input: FilePipelineInput, ctx: Context) -> DiarizeResult:
"""Diarize the audio file (speaker identification)."""
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
if not settings.DIARIZATION_BACKEND:
ctx.log("diarize: diarization disabled, skipping")
return DiarizeResult(diarization=None)
upload_result = ctx.task_output(upload_audio)
audio_url = upload_result.audio_url
from reflector.processors.file_diarization import ( # noqa: PLC0415
FileDiarizationInput,
)
from reflector.processors.file_diarization_auto import ( # noqa: PLC0415
FileDiarizationAutoProcessor,
)
processor = FileDiarizationAutoProcessor()
input_data = FileDiarizationInput(audio_url=audio_url)
result = None
async def capture_result(diarization_output):
nonlocal result
result = diarization_output.diarization
try:
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
except Exception as e:
logger.error(f"Diarization failed: {e}")
return DiarizeResult(diarization=None)
ctx.log(f"diarize complete: {len(result) if result else 0} segments")
return DiarizeResult(diarization=list(result) if result else None)
@file_pipeline.task(
parents=[upload_audio],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.GENERATE_WAVEFORM)
async def generate_waveform(input: FilePipelineInput, ctx: Context) -> WaveformResult:
"""Generate audio waveform visualization."""
ctx.log(f"generate_waveform: starting for transcript_id={input.transcript_id}")
upload_result = ctx.task_output(upload_audio)
audio_path = upload_result.audio_path
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptWaveform,
transcripts_controller,
)
waveform = get_audio_waveform(
path=Path(audio_path), segments_count=WAVEFORM_SEGMENTS
)
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
transcript.data_path.mkdir(parents=True, exist_ok=True)
with open(transcript.audio_waveform_filename, "w") as f:
json.dump(waveform, f)
waveform_data = TranscriptWaveform(waveform=waveform)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"WAVEFORM",
waveform_data,
logger=logger,
)
ctx.log("generate_waveform complete")
return WaveformResult(waveform_generated=True)
@file_pipeline.task(
parents=[transcribe, diarize, generate_waveform],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.ASSEMBLE_TRANSCRIPT)
async def assemble_transcript(
input: FilePipelineInput, ctx: Context
) -> AssembleTranscriptResult:
"""Merge transcription + diarization results."""
ctx.log(f"assemble_transcript: starting for transcript_id={input.transcript_id}")
transcribe_result = ctx.task_output(transcribe)
diarize_result = ctx.task_output(diarize)
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
TranscriptDiarizationAssemblerInput,
TranscriptDiarizationAssemblerProcessor,
)
from reflector.processors.types import ( # noqa: PLC0415
DiarizationSegment,
Word,
)
from reflector.processors.types import ( # noqa: PLC0415
Transcript as TranscriptType,
)
words = [Word(**w) for w in transcribe_result.words]
transcript_data = TranscriptType(
words=words, translation=transcribe_result.translation
)
diarization = None
if diarize_result.diarization:
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
processor = TranscriptDiarizationAssemblerProcessor()
assembler_input = TranscriptDiarizationAssemblerInput(
transcript=transcript_data, diarization=diarization or []
)
diarized_transcript = None
async def capture_result(transcript):
nonlocal diarized_transcript
diarized_transcript = transcript
processor.on(capture_result)
await processor.push(assembler_input)
await processor.flush()
if not diarized_transcript:
raise ValueError("No diarized transcript captured")
# Save the assembled transcript events to the database
async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptText,
transcripts_controller,
)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
assembled_text = diarized_transcript.text if diarized_transcript else ""
assembled_translation = (
diarized_transcript.translation if diarized_transcript else None
)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"TRANSCRIPT",
TranscriptText(text=assembled_text, translation=assembled_translation),
logger=logger,
)
ctx.log("assemble_transcript complete")
return AssembleTranscriptResult(assembled=True)
@file_pipeline.task(
parents=[assemble_transcript],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.DETECT_TOPICS)
async def detect_topics(input: FilePipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics from the assembled transcript."""
ctx.log(f"detect_topics: starting for transcript_id={input.transcript_id}")
# Re-read the transcript to get the diarized words
transcribe_result = ctx.task_output(transcribe)
diarize_result = ctx.task_output(diarize)
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
TranscriptDiarizationAssemblerInput,
TranscriptDiarizationAssemblerProcessor,
)
from reflector.processors.types import ( # noqa: PLC0415
DiarizationSegment,
Word,
)
from reflector.processors.types import ( # noqa: PLC0415
Transcript as TranscriptType,
)
words = [Word(**w) for w in transcribe_result.words]
transcript_data = TranscriptType(
words=words, translation=transcribe_result.translation
)
diarization = None
if diarize_result.diarization:
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
# Re-assemble to get the diarized transcript for topic detection
processor = TranscriptDiarizationAssemblerProcessor()
assembler_input = TranscriptDiarizationAssemblerInput(
transcript=transcript_data, diarization=diarization or []
)
diarized_transcript = None
async def capture_result(transcript):
nonlocal diarized_transcript
diarized_transcript = transcript
processor.on(capture_result)
await processor.push(assembler_input)
await processor.flush()
if not diarized_transcript:
raise ValueError("No diarized transcript for topic detection")
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
target_language = transcript.target_language
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
async def on_topic_callback(data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
transcript=data.transcript.text
if hasattr(data.transcript, "text")
else "",
words=data.transcript.words
if hasattr(data.transcript, "words")
else [],
)
await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast(
input.transcript_id, transcript, "TOPIC", topic, logger=logger
)
topics = await topic_processing.detect_topics(
diarized_transcript,
target_language,
on_topic_callback=on_topic_callback,
empty_pipeline=empty_pipeline,
)
ctx.log(f"detect_topics complete: {len(topics)} topics")
return TopicsResult(topics=topics)
@file_pipeline.task(
parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.GENERATE_TITLE)
async def generate_title(input: FilePipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title using LLM."""
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalTitle,
transcripts_controller,
)
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
title_result = None
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
async def on_title_callback(data):
nonlocal title_result
title_result = data.title
final_title = TranscriptFinalTitle(title=data.title)
if not transcript.title:
await transcripts_controller.update(
transcript, {"title": final_title.title}
)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"FINAL_TITLE",
final_title,
logger=logger,
)
await topic_processing.generate_title(
topics,
on_title_callback=on_title_callback,
empty_pipeline=empty_pipeline,
logger=logger,
)
ctx.log(f"generate_title complete: '{title_result}'")
return TitleResult(title=title_result)
@file_pipeline.task(
parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.GENERATE_SUMMARIES)
async def generate_summaries(input: FilePipelineInput, ctx: Context) -> SummariesResult:
"""Generate long/short summaries and action items."""
ctx.log(f"generate_summaries: starting for transcript_id={input.transcript_id}")
topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptActionItems,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
transcripts_controller,
)
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
async def on_long_summary_callback(data):
final_long = TranscriptFinalLongSummary(long_summary=data.long_summary)
await transcripts_controller.update(
transcript, {"long_summary": final_long.long_summary}
)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"FINAL_LONG_SUMMARY",
final_long,
logger=logger,
)
async def on_short_summary_callback(data):
final_short = TranscriptFinalShortSummary(short_summary=data.short_summary)
await transcripts_controller.update(
transcript, {"short_summary": final_short.short_summary}
)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"FINAL_SHORT_SUMMARY",
final_short,
logger=logger,
)
async def on_action_items_callback(data):
action_items = TranscriptActionItems(action_items=data.action_items)
await transcripts_controller.update(
transcript, {"action_items": action_items.action_items}
)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"ACTION_ITEMS",
action_items,
logger=logger,
)
await topic_processing.generate_summaries(
topics,
transcript,
on_long_summary_callback=on_long_summary_callback,
on_short_summary_callback=on_short_summary_callback,
on_action_items_callback=on_action_items_callback,
empty_pipeline=empty_pipeline,
logger=logger,
)
ctx.log("generate_summaries complete")
return SummariesResult(generated=True)
@file_pipeline.task(
parents=[generate_title, generate_summaries],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=5,
)
@with_error_handling(TaskName.FINALIZE)
async def finalize(input: FilePipelineInput, ctx: Context) -> FinalizeResult:
"""Set transcript status to 'ended' and broadcast."""
ctx.log("finalize: setting status to 'ended'")
async with fresh_db_connection():
await set_status_and_broadcast(input.transcript_id, "ended", logger=logger)
ctx.log("finalize complete")
return FinalizeResult(status="COMPLETED")
@file_pipeline.task(
parents=[finalize],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
async def cleanup_consent(input: FilePipelineInput, ctx: Context) -> ConsentResult:
"""Check consent and delete audio files if any participant denied."""
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.db.meetings import ( # noqa: PLC0415
meeting_consent_controller,
meetings_controller,
)
from reflector.db.recordings import recordings_controller # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.storage import get_transcripts_storage # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
ctx.log("cleanup_consent: transcript not found")
return ConsentResult()
consent_denied = False
recording = None
if transcript.recording_id:
recording = await recordings_controller.get_by_id(transcript.recording_id)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if meeting:
consent_denied = await meeting_consent_controller.has_any_denial(
meeting.id
)
if not consent_denied:
ctx.log("cleanup_consent: consent approved, keeping all files")
return ConsentResult()
ctx.log("cleanup_consent: consent denied, deleting audio files")
deletion_errors = []
if recording and recording.bucket_name:
keys_to_delete = []
if recording.track_keys:
keys_to_delete = recording.track_keys
elif recording.object_key:
keys_to_delete = [recording.object_key]
master_storage = get_transcripts_storage()
for key in keys_to_delete:
try:
await master_storage.delete_file(key, bucket=recording.bucket_name)
ctx.log(f"Deleted recording file: {recording.bucket_name}/{key}")
except Exception as e:
error_msg = f"Failed to delete {key}: {e}"
logger.error(error_msg, exc_info=True)
deletion_errors.append(error_msg)
if transcript.audio_location == "storage":
storage = get_transcripts_storage()
try:
await storage.delete_file(transcript.storage_audio_path)
ctx.log(f"Deleted processed audio: {transcript.storage_audio_path}")
except Exception as e:
error_msg = f"Failed to delete processed audio: {e}"
logger.error(error_msg, exc_info=True)
deletion_errors.append(error_msg)
try:
if (
hasattr(transcript, "audio_mp3_filename")
and transcript.audio_mp3_filename
):
transcript.audio_mp3_filename.unlink(missing_ok=True)
if (
hasattr(transcript, "audio_wav_filename")
and transcript.audio_wav_filename
):
transcript.audio_wav_filename.unlink(missing_ok=True)
except Exception as e:
error_msg = f"Failed to delete local audio files: {e}"
logger.error(error_msg, exc_info=True)
deletion_errors.append(error_msg)
if deletion_errors:
logger.warning(
"[Hatchet] cleanup_consent completed with errors",
transcript_id=input.transcript_id,
error_count=len(deletion_errors),
)
else:
await transcripts_controller.update(transcript, {"audio_deleted": True})
ctx.log("cleanup_consent: all audio deleted successfully")
return ConsentResult()
@file_pipeline.task(
parents=[cleanup_consent],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
async def post_zulip(input: FilePipelineInput, ctx: Context) -> ZulipResult:
"""Post notification to Zulip."""
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
if not settings.ZULIP_REALM:
ctx.log("post_zulip skipped (Zulip not configured)")
return ZulipResult(zulip_message_id=None, skipped=True)
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.zulip import post_transcript_notification # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
message_id = await post_transcript_notification(transcript)
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
else:
message_id = None
return ZulipResult(zulip_message_id=message_id)
@file_pipeline.task(
parents=[cleanup_consent],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
async def send_webhook(input: FilePipelineInput, ctx: Context) -> WebhookResult:
"""Send completion webhook to external service."""
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
if not input.room_id:
ctx.log("send_webhook skipped (no room_id)")
return WebhookResult(webhook_sent=False, skipped=True)
async with fresh_db_connection():
from reflector.db.rooms import rooms_controller # noqa: PLC0415
from reflector.utils.webhook import ( # noqa: PLC0415
fetch_transcript_webhook_payload,
send_webhook_request,
)
room = await rooms_controller.get_by_id(input.room_id)
if not room or not room.webhook_url:
ctx.log("send_webhook skipped (no webhook_url configured)")
return WebhookResult(webhook_sent=False, skipped=True)
payload = await fetch_transcript_webhook_payload(
transcript_id=input.transcript_id,
room_id=input.room_id,
)
if isinstance(payload, str):
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
return WebhookResult(webhook_sent=False, skipped=True)
import httpx # noqa: PLC0415
try:
response = await send_webhook_request(
url=room.webhook_url,
payload=payload,
event_type="transcript.completed",
webhook_secret=room.webhook_secret,
timeout=30.0,
)
ctx.log(f"send_webhook complete: status_code={response.status_code}")
return WebhookResult(webhook_sent=True, response_code=response.status_code)
except httpx.HTTPStatusError as e:
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
return WebhookResult(
webhook_sent=False, response_code=e.response.status_code
)
except (httpx.ConnectError, httpx.TimeoutException) as e:
ctx.log(f"send_webhook failed ({e}), continuing")
return WebhookResult(webhook_sent=False)
except Exception as e:
ctx.log(f"send_webhook unexpected error: {e}")
return WebhookResult(webhook_sent=False)
# --- On failure handler ---
async def on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
"""Set transcript status to 'error' only if not already 'ended'."""
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.status == "ended":
logger.info(
"[Hatchet] FilePipeline on_workflow_failure: transcript already ended, skipping error status",
transcript_id=input.transcript_id,
)
ctx.log(
"on_workflow_failure: transcript already ended, skipping error status"
)
return
await set_workflow_error_status(input.transcript_id)
@file_pipeline.on_failure_task()
async def _register_on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
await on_workflow_failure(input, ctx)

View File

@@ -0,0 +1,389 @@
"""
Hatchet workflow: LivePostProcessingPipeline
Post-processing pipeline for live WebRTC meetings.
Triggered after a live meeting ends. Orchestrates:
Left branch: waveform → convert_mp3 → upload_mp3 → remove_upload → diarize → cleanup_consent
Right branch: generate_title (parallel with left branch)
Fan-in: final_summaries → post_zulip → send_webhook
Note: This file uses deferred imports (inside functions/tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
are not shared across forks, avoiding connection pooling issues.
"""
from datetime import timedelta
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import (
TIMEOUT_HEAVY,
TIMEOUT_MEDIUM,
TIMEOUT_SHORT,
TIMEOUT_TITLE,
TaskName,
)
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
fresh_db_connection,
set_workflow_error_status,
with_error_handling,
)
from reflector.hatchet.workflows.models import (
ConsentResult,
TitleResult,
WaveformResult,
WebhookResult,
ZulipResult,
)
from reflector.logger import logger
from reflector.settings import settings
class LivePostPipelineInput(BaseModel):
transcript_id: str
room_id: str | None = None
# --- Result models specific to live post pipeline ---
class ConvertMp3Result(BaseModel):
converted: bool
class UploadMp3Result(BaseModel):
uploaded: bool
class RemoveUploadResult(BaseModel):
removed: bool
class DiarizeResult(BaseModel):
diarized: bool
class FinalSummariesResult(BaseModel):
generated: bool
hatchet = HatchetClientManager.get_client()
live_post_pipeline = hatchet.workflow(
name="LivePostProcessingPipeline", input_validator=LivePostPipelineInput
)
@live_post_pipeline.task(
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.WAVEFORM)
async def waveform(input: LivePostPipelineInput, ctx: Context) -> WaveformResult:
"""Generate waveform visualization from recorded audio."""
ctx.log(f"waveform: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
PipelineMainWaveform,
)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run()
ctx.log("waveform complete")
return WaveformResult(waveform_generated=True)
@live_post_pipeline.task(
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.GENERATE_TITLE)
async def generate_title(input: LivePostPipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title from topics (runs in parallel with audio chain)."""
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
PipelineMainTitle,
)
runner = PipelineMainTitle(transcript_id=input.transcript_id)
await runner.run()
ctx.log("generate_title complete")
return TitleResult(title=None)
@live_post_pipeline.task(
parents=[waveform],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.CONVERT_MP3)
async def convert_mp3(input: LivePostPipelineInput, ctx: Context) -> ConvertMp3Result:
"""Convert WAV recording to MP3."""
ctx.log(f"convert_mp3: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
pipeline_convert_to_mp3,
)
await pipeline_convert_to_mp3(transcript_id=input.transcript_id)
ctx.log("convert_mp3 complete")
return ConvertMp3Result(converted=True)
@live_post_pipeline.task(
parents=[convert_mp3],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.UPLOAD_MP3)
async def upload_mp3(input: LivePostPipelineInput, ctx: Context) -> UploadMp3Result:
"""Upload MP3 to external storage."""
ctx.log(f"upload_mp3: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
pipeline_upload_mp3,
)
await pipeline_upload_mp3(transcript_id=input.transcript_id)
ctx.log("upload_mp3 complete")
return UploadMp3Result(uploaded=True)
@live_post_pipeline.task(
parents=[upload_mp3],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=5,
)
@with_error_handling(TaskName.REMOVE_UPLOAD)
async def remove_upload(
input: LivePostPipelineInput, ctx: Context
) -> RemoveUploadResult:
"""Remove the original upload file."""
ctx.log(f"remove_upload: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
pipeline_remove_upload,
)
await pipeline_remove_upload(transcript_id=input.transcript_id)
ctx.log("remove_upload complete")
return RemoveUploadResult(removed=True)
@live_post_pipeline.task(
parents=[remove_upload],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.DIARIZE)
async def diarize(input: LivePostPipelineInput, ctx: Context) -> DiarizeResult:
"""Run diarization on the recorded audio."""
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
pipeline_diarization,
)
await pipeline_diarization(transcript_id=input.transcript_id)
ctx.log("diarize complete")
return DiarizeResult(diarized=True)
@live_post_pipeline.task(
parents=[diarize],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
)
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
async def cleanup_consent(input: LivePostPipelineInput, ctx: Context) -> ConsentResult:
"""Check consent and delete audio files if any participant denied."""
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
cleanup_consent as _cleanup_consent,
)
await _cleanup_consent(transcript_id=input.transcript_id)
ctx.log("cleanup_consent complete")
return ConsentResult()
@live_post_pipeline.task(
parents=[cleanup_consent, generate_title],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
@with_error_handling(TaskName.FINAL_SUMMARIES)
async def final_summaries(
input: LivePostPipelineInput, ctx: Context
) -> FinalSummariesResult:
"""Generate final summaries (fan-in after audio chain + title)."""
ctx.log(f"final_summaries: starting for transcript_id={input.transcript_id}")
async with fresh_db_connection():
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
pipeline_summaries,
)
await pipeline_summaries(transcript_id=input.transcript_id)
ctx.log("final_summaries complete")
return FinalSummariesResult(generated=True)
@live_post_pipeline.task(
parents=[final_summaries],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
async def post_zulip(input: LivePostPipelineInput, ctx: Context) -> ZulipResult:
"""Post notification to Zulip."""
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
if not settings.ZULIP_REALM:
ctx.log("post_zulip skipped (Zulip not configured)")
return ZulipResult(zulip_message_id=None, skipped=True)
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.zulip import post_transcript_notification # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
message_id = await post_transcript_notification(transcript)
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
else:
message_id = None
return ZulipResult(zulip_message_id=message_id)
@live_post_pipeline.task(
parents=[final_summaries],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
)
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
async def send_webhook(input: LivePostPipelineInput, ctx: Context) -> WebhookResult:
"""Send completion webhook to external service."""
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
if not input.room_id:
ctx.log("send_webhook skipped (no room_id)")
return WebhookResult(webhook_sent=False, skipped=True)
async with fresh_db_connection():
from reflector.db.rooms import rooms_controller # noqa: PLC0415
from reflector.utils.webhook import ( # noqa: PLC0415
fetch_transcript_webhook_payload,
send_webhook_request,
)
room = await rooms_controller.get_by_id(input.room_id)
if not room or not room.webhook_url:
ctx.log("send_webhook skipped (no webhook_url configured)")
return WebhookResult(webhook_sent=False, skipped=True)
payload = await fetch_transcript_webhook_payload(
transcript_id=input.transcript_id,
room_id=input.room_id,
)
if isinstance(payload, str):
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
return WebhookResult(webhook_sent=False, skipped=True)
import httpx # noqa: PLC0415
try:
response = await send_webhook_request(
url=room.webhook_url,
payload=payload,
event_type="transcript.completed",
webhook_secret=room.webhook_secret,
timeout=30.0,
)
ctx.log(f"send_webhook complete: status_code={response.status_code}")
return WebhookResult(webhook_sent=True, response_code=response.status_code)
except httpx.HTTPStatusError as e:
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
return WebhookResult(
webhook_sent=False, response_code=e.response.status_code
)
except (httpx.ConnectError, httpx.TimeoutException) as e:
ctx.log(f"send_webhook failed ({e}), continuing")
return WebhookResult(webhook_sent=False)
except Exception as e:
ctx.log(f"send_webhook unexpected error: {e}")
return WebhookResult(webhook_sent=False)
# --- On failure handler ---
async def on_workflow_failure(input: LivePostPipelineInput, ctx: Context) -> None:
"""Set transcript status to 'error' only if not already 'ended'."""
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.status == "ended":
logger.info(
"[Hatchet] LivePostProcessingPipeline on_workflow_failure: transcript already ended",
transcript_id=input.transcript_id,
)
ctx.log(
"on_workflow_failure: transcript already ended, skipping error status"
)
return
await set_workflow_error_status(input.transcript_id)
@live_post_pipeline.on_failure_task()
async def _register_on_workflow_failure(
input: LivePostPipelineInput, ctx: Context
) -> None:
await on_workflow_failure(input, ctx)

View File

@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from typing import Generic from typing import Generic
import av import av
from celery import chord, current_task, group, shared_task from celery import current_task, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
@@ -397,7 +397,9 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline # when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
pipeline_post(transcript_id=self.transcript_id) transcript = await transcripts_controller.get_by_id(self.transcript_id)
room_id = transcript.room_id if transcript else None
await pipeline_post(transcript_id=self.transcript_id, room_id=room_id)
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]): class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
@@ -792,29 +794,20 @@ async def task_pipeline_post_to_zulip(*, transcript_id: str):
await pipeline_post_to_zulip(transcript_id=transcript_id) await pipeline_post_to_zulip(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str): async def pipeline_post(*, transcript_id: str, room_id: str | None = None):
""" """
Run the post pipeline Run the post pipeline via Hatchet.
""" """
chain_mp3_and_diarize = ( from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
| task_cleanup_consent.si(transcript_id=transcript_id)
)
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord( await HatchetClientManager.start_workflow(
group(chain_mp3_and_diarize, chain_title_preview), "LivePostProcessingPipeline",
chain_final_summaries, {
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id) "transcript_id": str(transcript_id),
"room_id": str(room_id) if room_id else None,
return chain.delay() },
additional_metadata={"transcript_id": str(transcript_id)},
)
@get_transcript @get_transcript

View File

@@ -10,7 +10,6 @@ from dataclasses import dataclass
from typing import Literal, Union, assert_never from typing import Literal, Union, assert_never
import celery import celery
from celery.result import AsyncResult
from hatchet_sdk.clients.rest.exceptions import ApiException, NotFoundException from hatchet_sdk.clients.rest.exceptions import ApiException, NotFoundException
from hatchet_sdk.clients.rest.models import V1TaskStatus from hatchet_sdk.clients.rest.models import V1TaskStatus
@@ -18,7 +17,6 @@ from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import Transcript, transcripts_controller from reflector.db.transcripts import Transcript, transcripts_controller
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.utils.string import NonEmptyString from reflector.utils.string import NonEmptyString
@@ -105,11 +103,8 @@ async def validate_transcript_for_processing(
): ):
return ValidationNotReady(detail="Recording is not ready for processing") return ValidationNotReady(detail="Recording is not ready for processing")
# Check Celery tasks # Check Celery tasks (multitrack still uses Celery for some paths)
if task_is_scheduled_or_active( if task_is_scheduled_or_active(
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
transcript_id=transcript.id,
) or task_is_scheduled_or_active(
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process", "reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
transcript_id=transcript.id, transcript_id=transcript.id,
): ):
@@ -175,11 +170,8 @@ async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResu
async def dispatch_transcript_processing( async def dispatch_transcript_processing(
config: ProcessingConfig, force: bool = False config: ProcessingConfig, force: bool = False
) -> AsyncResult | None: ) -> None:
"""Dispatch transcript processing to appropriate backend (Hatchet or Celery). """Dispatch transcript processing to Hatchet workflow engine."""
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
"""
if isinstance(config, MultitrackProcessingConfig): if isinstance(config, MultitrackProcessingConfig):
# Multitrack processing always uses Hatchet (no Celery fallback) # Multitrack processing always uses Hatchet (no Celery fallback)
# First check if we can replay (outside transaction since it's read-only) # First check if we can replay (outside transaction since it's read-only)
@@ -275,7 +267,21 @@ async def dispatch_transcript_processing(
return None return None
elif isinstance(config, FileProcessingConfig): elif isinstance(config, FileProcessingConfig):
return task_pipeline_file_process.delay(transcript_id=config.transcript_id) # File processing uses Hatchet workflow
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="FilePipeline",
input_data={"transcript_id": config.transcript_id},
additional_metadata={"transcript_id": config.transcript_id},
)
transcript = await transcripts_controller.get_by_id(config.transcript_id)
if transcript:
await transcripts_controller.update(
transcript, {"workflow_run_id": workflow_id}
)
logger.info("File pipeline dispatched via Hatchet", workflow_id=workflow_id)
return None
else: else:
assert_never(config) assert_never(config)

View File

@@ -7,7 +7,6 @@ import asyncio
import json import json
import shutil import shutil
import sys import sys
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Tuple from typing import Any, Dict, List, Literal, Tuple
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
@@ -15,10 +14,8 @@ from urllib.parse import unquote, urlparse
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.hatchet.client import HatchetClientManager
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import (
task_pipeline_file_process as task_pipeline_file_process,
)
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
from reflector.pipelines.main_live_pipeline import ( from reflector.pipelines.main_live_pipeline import (
pipeline_process as live_pipeline_process, pipeline_process as live_pipeline_process,
@@ -237,29 +234,22 @@ async def process_live_pipeline(
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post # assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
assert pre_final_transcript.status != "ended" assert pre_final_transcript.status != "ended"
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling # Trigger post-processing via Hatchet (fire-and-forget)
result = live_pipeline_post(transcript_id=transcript_id) await live_pipeline_post(transcript_id=transcript_id)
print("Live post-processing pipeline triggered via Hatchet", file=sys.stderr)
# result.ready() blocks even without await; it mutates result also
while not result.ready():
print(f"Status: {result.state}")
time.sleep(2)
async def process_file_pipeline( async def process_file_pipeline(
transcript_id: TranscriptId, transcript_id: TranscriptId,
): ):
"""Process audio/video file using the optimized file pipeline""" """Process audio/video file using the optimized file pipeline via Hatchet"""
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution await HatchetClientManager.start_workflow(
result = task_pipeline_file_process.delay(transcript_id=transcript_id) "FilePipeline",
{"transcript_id": str(transcript_id)},
# Wait for the Celery task to complete additional_metadata={"transcript_id": str(transcript_id)},
while not result.ready(): )
print(f"File pipeline status: {result.state}", file=sys.stderr) print("File pipeline triggered via Hatchet", file=sys.stderr)
time.sleep(2)
logger.info("File pipeline processing complete")
async def process( async def process(
@@ -293,6 +283,15 @@ async def process(
await handler(transcript_id) await handler(transcript_id)
if pipeline == "file":
# File pipeline is async via Hatchet — results not available immediately.
# Use reflector.tools.process_transcript with --sync for polling.
print(
f"File pipeline dispatched for transcript {transcript_id}. "
f"Results will be available once the Hatchet workflow completes.",
file=sys.stderr,
)
else:
await extract_result_from_entry(transcript_id, output_path) await extract_result_from_entry(transcript_id, output_path)
finally: finally:
await database.disconnect() await database.disconnect()

View File

@@ -11,10 +11,8 @@ Usage:
import argparse import argparse
import asyncio import asyncio
import sys import sys
import time
from typing import Callable from typing import Callable
from celery.result import AsyncResult
from hatchet_sdk.clients.rest.models import V1TaskStatus from hatchet_sdk.clients.rest.models import V1TaskStatus
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
@@ -39,7 +37,7 @@ async def process_transcript_inner(
on_validation: Callable[[ValidationResult], None], on_validation: Callable[[ValidationResult], None],
on_preprocess: Callable[[PrepareResult], None], on_preprocess: Callable[[PrepareResult], None],
force: bool = False, force: bool = False,
) -> AsyncResult | None: ) -> None:
validation = await validate_transcript_for_processing(transcript) validation = await validate_transcript_for_processing(transcript)
on_validation(validation) on_validation(validation)
config = await prepare_transcript_processing(validation) config = await prepare_transcript_processing(validation)
@@ -87,15 +85,13 @@ async def process_transcript(
elif isinstance(config, FileProcessingConfig): elif isinstance(config, FileProcessingConfig):
print(f"Dispatching file pipeline", file=sys.stderr) print(f"Dispatching file pipeline", file=sys.stderr)
result = await process_transcript_inner( await process_transcript_inner(
transcript, transcript,
on_validation=on_validation, on_validation=on_validation,
on_preprocess=on_preprocess, on_preprocess=on_preprocess,
force=force, force=force,
) )
if result is None:
# Hatchet workflow dispatched
if sync: if sync:
# Re-fetch transcript to get workflow_run_id # Re-fetch transcript to get workflow_run_id
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
@@ -123,21 +119,6 @@ async def process_transcript(
"Task dispatched (use --sync to wait for completion)", "Task dispatched (use --sync to wait for completion)",
file=sys.stderr, file=sys.stderr,
) )
elif sync:
print("Waiting for task completion...", file=sys.stderr)
while not result.ready():
print(f" Status: {result.state}", file=sys.stderr)
time.sleep(5)
if result.successful():
print("Task completed successfully", file=sys.stderr)
else:
print(f"Task failed: {result.result}", file=sys.stderr)
sys.exit(1)
else:
print(
"Task dispatched (use --sync to wait for completion)", file=sys.stderr
)
finally: finally:
await database.disconnect() await database.disconnect()

View File

@@ -52,8 +52,5 @@ async def transcript_process(
if isinstance(config, ProcessError): if isinstance(config, ProcessError):
raise HTTPException(status_code=500, detail=config.detail) raise HTTPException(status_code=500, detail=config.detail)
else: else:
# When transcript is in error state, force a new workflow instead of replaying await dispatch_transcript_processing(config, force=True)
# (replay would re-run from failure point with same conditions and likely fail again)
force = transcript.status == "error"
await dispatch_transcript_processing(config, force=force)
return ProcessStatus(status="ok") return ProcessStatus(status="ok")

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
import reflector.auth as auth import reflector.auth as auth
from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.hatchet.client import HatchetClientManager
router = APIRouter() router = APIRouter()
@@ -95,7 +95,14 @@ async def transcript_record_upload(
transcript, {"status": "uploaded", "source_kind": SourceKind.FILE} transcript, {"status": "uploaded", "source_kind": SourceKind.FILE}
) )
# launch a background task to process the file # launch Hatchet workflow to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id) workflow_id = await HatchetClientManager.start_workflow(
"FilePipeline",
{"transcript_id": str(transcript_id)},
additional_metadata={"transcript_id": str(transcript_id)},
)
# Save workflow_run_id for duplicate detection and status polling
await transcripts_controller.update(transcript, {"workflow_run_id": workflow_id})
return UploadStatus(status="ok") return UploadStatus(status="ok")

View File

@@ -25,7 +25,6 @@ from reflector.db.transcripts import (
transcripts_controller, transcripts_controller,
) )
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_live_pipeline import asynctask from reflector.pipelines.main_live_pipeline import asynctask
from reflector.pipelines.topic_processing import EmptyPipeline from reflector.pipelines.topic_processing import EmptyPipeline
from reflector.processors import AudioFileWriterProcessor from reflector.processors import AudioFileWriterProcessor
@@ -163,7 +162,14 @@ async def process_recording(bucket_name: str, object_key: str):
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id) await HatchetClientManager.start_workflow(
"FilePipeline",
{
"transcript_id": str(transcript.id),
"room_id": str(room.id) if room else None,
},
additional_metadata={"transcript_id": str(transcript.id)},
)
@shared_task @shared_task

View File

@@ -1,6 +1,6 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -538,18 +538,59 @@ def fake_mp3_upload():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_hatchet_client(): def mock_hatchet_client():
"""Reset HatchetClientManager singleton before and after each test. """Mock HatchetClientManager for all tests.
This ensures test isolation - each test starts with a fresh client state. Prevents tests from connecting to a real Hatchet server. The dummy token
The fixture is autouse=True so it applies to all tests automatically. in [tool.pytest_env] prevents the import-time ValueError, but the SDK
would still try to connect when get_client() is called. This fixture
mocks get_client to return a MagicMock and start_workflow to return a
dummy workflow ID.
""" """
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
# Reset before test
HatchetClientManager.reset() HatchetClientManager.reset()
yield
# Reset after test to clean up mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with (
patch.object(
HatchetClientManager,
"get_client",
return_value=mock_client,
),
patch.object(
HatchetClientManager,
"start_workflow",
new_callable=AsyncMock,
return_value="mock-workflow-id",
),
patch.object(
HatchetClientManager,
"get_workflow_run_status",
new_callable=AsyncMock,
return_value=None,
),
patch.object(
HatchetClientManager,
"can_replay",
new_callable=AsyncMock,
return_value=False,
),
patch.object(
HatchetClientManager,
"cancel_workflow",
new_callable=AsyncMock,
),
patch.object(
HatchetClientManager,
"replay_workflow",
new_callable=AsyncMock,
),
):
yield mock_client
HatchetClientManager.reset() HatchetClientManager.reset()

View File

@@ -37,18 +37,3 @@ async def test_hatchet_client_can_replay_handles_exception():
# Should return False on error (workflow might be gone) # Should return False on error (workflow might be gone)
assert can_replay is False assert can_replay is False
def test_hatchet_client_raises_without_token():
"""Test that get_client raises ValueError without token.
Useful: Catches if someone removes the token validation,
which would cause cryptic errors later.
"""
from reflector.hatchet.client import HatchetClientManager
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = None
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
HatchetClientManager.get_client()

View File

@@ -0,0 +1,233 @@
"""
Tests for the FilePipeline Hatchet workflow.
Tests verify:
1. with_error_handling behavior for file pipeline input model
2. on_workflow_failure logic (don't overwrite 'ended' status)
3. Input model validation
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB."""
yield None
@pytest.fixture(scope="module")
def file_pipeline_module():
"""Import file_pipeline with Hatchet client mocked."""
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import file_pipeline
return file_pipeline
@pytest.fixture
def mock_file_input():
"""Minimal FilePipelineInput for tests."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
return FilePipelineInput(
transcript_id="ts-file-123",
room_id="room-456",
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
def test_file_pipeline_input_model():
"""Test FilePipelineInput validation."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
# Valid input with room_id
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
assert input_with_room.transcript_id == "ts-123"
assert input_with_room.room_id == "room-456"
# Valid input without room_id
input_no_room = FilePipelineInput(transcript_id="ts-123")
assert input_no_room.room_id is None
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_transient(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Transient exception must NOT set error status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_hard_fail(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("No audio file found")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_file_input, mock_ctx)
assert "No audio file found" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-file-123")
def test_diarize_result_uses_plain_dicts():
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
The diarize task must serialize segments as plain dicts (not call .model_dump()),
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
This was a real bug: 'dict' object has no attribute 'model_dump'.
"""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
from reflector.processors.types import DiarizationSegment
# DiarizationSegment is a TypedDict — instances are plain dicts
segments = [
DiarizationSegment(start=0.0, end=1.5, speaker=0),
DiarizationSegment(start=1.5, end=3.0, speaker=1),
]
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
# DiarizeResult should accept list[dict] directly (no model_dump needed)
result = DiarizeResult(diarization=segments)
assert result.diarization is not None
assert len(result.diarization) == 2
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
assert reconstructed[0]["start"] == 0.0
assert reconstructed[0]["speaker"] == 0
assert reconstructed[1]["end"] == 3.0
assert reconstructed[1]["speaker"] == 1
def test_diarize_result_handles_none():
"""DiarizeResult with no diarization data (diarization disabled)."""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
result = DiarizeResult(diarization=None)
assert result.diarization is None
result_default = DiarizeResult()
assert result_default.diarization is None
def test_transcribe_result_words_are_pydantic():
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
from reflector.processors.types import Word
words = [
Word(text="hello", start=0.0, end=0.5),
Word(text="world", start=0.5, end=1.0),
]
# Words are Pydantic models, so model_dump() works
word_dicts = [w.model_dump() for w in words]
result = TranscribeResult(words=word_dicts)
# Consumer reconstructs via Word(**w)
reconstructed = [Word(**w) for w in result.words]
assert reconstructed[0].text == "hello"
assert reconstructed[1].start == 0.5
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_sets_error_status(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure sets error status when transcript is processing."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure must NOT overwrite 'ended' status."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()

View File

@@ -0,0 +1,218 @@
"""
Tests for the LivePostProcessingPipeline Hatchet workflow.
Tests verify:
1. with_error_handling behavior for live post pipeline input model
2. on_workflow_failure logic (don't overwrite 'ended' status)
3. Input model validation
4. pipeline_post() now triggers Hatchet instead of Celery chord
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB."""
yield None
@pytest.fixture(scope="module")
def live_pipeline_module():
"""Import live_post_pipeline with Hatchet client mocked."""
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import live_post_pipeline
return live_post_pipeline
@pytest.fixture
def mock_live_input():
"""Minimal LivePostPipelineInput for tests."""
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
return LivePostPipelineInput(
transcript_id="ts-live-789",
room_id="room-abc",
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
def test_live_post_pipeline_input_model():
"""Test LivePostPipelineInput validation."""
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
# Valid input with room_id
input_with_room = LivePostPipelineInput(transcript_id="ts-123", room_id="room-456")
assert input_with_room.transcript_id == "ts-123"
assert input_with_room.room_id == "room-456"
# Valid input without room_id
input_no_room = LivePostPipelineInput(transcript_id="ts-123")
assert input_no_room.room_id is None
@pytest.mark.asyncio
async def test_live_pipeline_error_handling_transient(
live_pipeline_module, mock_live_input, mock_ctx
):
"""Transient exception must NOT set error status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_live_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_live_pipeline_error_handling_hard_fail(
live_pipeline_module, mock_live_input, mock_ctx
):
"""Hard-fail must set error status and raise NonRetryableException."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("Transcript not found")
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_live_input, mock_ctx)
assert "Transcript not found" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-live-789")
@pytest.mark.asyncio
async def test_live_pipeline_on_failure_sets_error_status(
live_pipeline_module, mock_live_input, mock_ctx
):
"""on_workflow_failure sets error status when transcript is processing."""
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_live_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_live_input.transcript_id)
@pytest.mark.asyncio
async def test_live_pipeline_on_failure_does_not_overwrite_ended(
live_pipeline_module, mock_live_input, mock_ctx
):
"""on_workflow_failure must NOT overwrite 'ended' status."""
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_live_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_pipeline_post_triggers_hatchet():
"""pipeline_post() should trigger Hatchet LivePostProcessingPipeline workflow."""
with patch(
"reflector.hatchet.client.HatchetClientManager.start_workflow",
new_callable=AsyncMock,
return_value="workflow-run-id",
) as mock_start:
from reflector.pipelines.main_live_pipeline import pipeline_post
await pipeline_post(transcript_id="ts-test-123", room_id="room-test")
mock_start.assert_called_once_with(
"LivePostProcessingPipeline",
{
"transcript_id": "ts-test-123",
"room_id": "room-test",
},
additional_metadata={"transcript_id": "ts-test-123"},
)
@pytest.mark.asyncio
async def test_pipeline_post_triggers_hatchet_without_room_id():
"""pipeline_post() should handle None room_id."""
with patch(
"reflector.hatchet.client.HatchetClientManager.start_workflow",
new_callable=AsyncMock,
return_value="workflow-run-id",
) as mock_start:
from reflector.pipelines.main_live_pipeline import pipeline_post
await pipeline_post(transcript_id="ts-test-456")
mock_start.assert_called_once_with(
"LivePostProcessingPipeline",
{
"transcript_id": "ts-test-456",
"room_id": None,
},
additional_metadata={"transcript_id": "ts-test-456"},
)

View File

@@ -0,0 +1,90 @@
"""
Tests verifying Celery-to-Hatchet trigger migration.
Ensures that:
1. process_recording triggers FilePipeline via Hatchet (not Celery)
2. transcript_record_upload triggers FilePipeline via Hatchet (not Celery)
3. Old Celery task references are no longer in active call sites
"""
def test_process_recording_does_not_import_celery_file_task():
"""Verify process.py no longer imports task_pipeline_file_process."""
import inspect
from reflector.worker import process
source = inspect.getsource(process)
# Should not contain the old Celery task import
assert "task_pipeline_file_process" not in source
def test_transcripts_upload_does_not_import_celery_file_task():
"""Verify transcripts_upload.py no longer imports task_pipeline_file_process."""
import inspect
from reflector.views import transcripts_upload
source = inspect.getsource(transcripts_upload)
# Should not contain the old Celery task import
assert "task_pipeline_file_process" not in source
def test_transcripts_upload_imports_hatchet():
"""Verify transcripts_upload.py imports HatchetClientManager."""
import inspect
from reflector.views import transcripts_upload
source = inspect.getsource(transcripts_upload)
assert "HatchetClientManager" in source
def test_pipeline_post_is_async():
"""Verify pipeline_post is now async (Hatchet trigger)."""
import asyncio
from reflector.pipelines.main_live_pipeline import pipeline_post
assert asyncio.iscoroutinefunction(pipeline_post)
def test_transcript_process_service_does_not_import_celery_file_task():
"""Verify transcript_process.py service no longer imports task_pipeline_file_process."""
import inspect
from reflector.services import transcript_process
source = inspect.getsource(transcript_process)
assert "task_pipeline_file_process" not in source
def test_transcript_process_service_dispatch_uses_hatchet():
"""Verify dispatch_transcript_processing uses HatchetClientManager for file processing."""
import inspect
from reflector.services import transcript_process
source = inspect.getsource(transcript_process.dispatch_transcript_processing)
assert "HatchetClientManager" in source
assert "FilePipeline" in source
def test_new_task_names_exist():
"""Verify new TaskName constants were added for file and live pipelines."""
from reflector.hatchet.constants import TaskName
# File pipeline tasks
assert TaskName.EXTRACT_AUDIO == "extract_audio"
assert TaskName.UPLOAD_AUDIO == "upload_audio"
assert TaskName.TRANSCRIBE == "transcribe"
assert TaskName.DIARIZE == "diarize"
assert TaskName.ASSEMBLE_TRANSCRIPT == "assemble_transcript"
assert TaskName.GENERATE_SUMMARIES == "generate_summaries"
# Live post-processing pipeline tasks
assert TaskName.WAVEFORM == "waveform"
assert TaskName.CONVERT_MP3 == "convert_mp3"
assert TaskName.UPLOAD_MP3 == "upload_mp3"
assert TaskName.REMOVE_UPLOAD == "remove_upload"
assert TaskName.FINAL_SUMMARIES == "final_summaries"

View File

@@ -1,5 +1,3 @@
import asyncio
import time
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@@ -27,8 +25,6 @@ async def client(app_lifespan):
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_process( async def test_transcript_process(
tmpdir, tmpdir,
@@ -39,8 +35,13 @@ async def test_transcript_process(
dummy_storage, dummy_storage,
client, client,
monkeypatch, monkeypatch,
mock_hatchet_client,
): ):
# public mode: this test uses an anonymous client; allow anonymous transcript creation """Test upload + process dispatch via Hatchet.
The file pipeline is now dispatched to Hatchet (fire-and-forget),
so we verify the workflow was triggered rather than polling for completion.
"""
monkeypatch.setattr(settings, "PUBLIC_MODE", True) monkeypatch.setattr(settings, "PUBLIC_MODE", True)
# create a transcript # create a transcript
@@ -63,51 +64,43 @@ async def test_transcript_process(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait for processing to finish (max 1 minute) # Verify Hatchet workflow was dispatched (from upload endpoint)
timeout_seconds = 60 from reflector.hatchet.client import HatchetClientManager
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: HatchetClientManager.start_workflow.assert_called_once_with(
# fetch the transcript and check if it is ended "FilePipeline",
{"transcript_id": tid},
additional_metadata={"transcript_id": tid},
)
# Verify transcript status was set to "uploaded"
resp = await client.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): assert resp.json()["status"] == "uploaded"
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing # Reset mock for reprocess test
response = await client.post( HatchetClientManager.start_workflow.reset_mock()
f"/transcripts/{tid}/process",
) # Clear workflow_run_id so /process endpoint can dispatch again
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(tid)
await transcripts_controller.update(transcript, {"workflow_run_id": None})
# Reprocess via /process endpoint
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active",
return_value=False,
):
response = await client.post(f"/transcripts/{tid}/process")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
await asyncio.sleep(2)
# wait for processing to finish (max 1 minute) # Verify second Hatchet dispatch (from /process endpoint)
timeout_seconds = 60 HatchetClientManager.start_workflow.assert_called_once()
start_time = time.monotonic() call_kwargs = HatchetClientManager.start_workflow.call_args.kwargs
while (time.monotonic() - start_time) < timeout_seconds: assert call_kwargs["workflow_name"] == "FilePipeline"
# fetch the transcript and check if it is ended assert call_kwargs["input_data"]["transcript_id"] == tid
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -150,20 +143,25 @@ async def test_whereby_recording_uses_file_pipeline(monkeypatch, client):
with ( with (
patch( patch(
"reflector.services.transcript_process.task_pipeline_file_process" "reflector.services.transcript_process.task_is_scheduled_or_active",
) as mock_file_pipeline, return_value=False,
),
patch( patch(
"reflector.services.transcript_process.HatchetClientManager" "reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet, ) as mock_hatchet,
): ):
mock_hatchet.start_workflow = AsyncMock(return_value="test-workflow-id")
response = await client.post(f"/transcripts/{transcript.id}/process") response = await client.post(f"/transcripts/{transcript.id}/process")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# Whereby recordings should use file pipeline, not Hatchet # Whereby recordings should use Hatchet FilePipeline
mock_file_pipeline.delay.assert_called_once_with(transcript_id=transcript.id) mock_hatchet.start_workflow.assert_called_once()
mock_hatchet.start_workflow.assert_not_called() call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
assert call_kwargs["workflow_name"] == "FilePipeline"
assert call_kwargs["input_data"]["transcript_id"] == transcript.id
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -224,8 +222,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
with ( with (
patch( patch(
"reflector.services.transcript_process.task_pipeline_file_process" "reflector.services.transcript_process.task_is_scheduled_or_active",
) as mock_file_pipeline, return_value=False,
),
patch( patch(
"reflector.services.transcript_process.HatchetClientManager" "reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet, ) as mock_hatchet,
@@ -237,7 +236,7 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# Daily.co multitrack recordings should use Hatchet workflow # Daily.co multitrack recordings should use Hatchet DiarizationPipeline
mock_hatchet.start_workflow.assert_called_once() mock_hatchet.start_workflow.assert_called_once()
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
assert call_kwargs["workflow_name"] == "DiarizationPipeline" assert call_kwargs["workflow_name"] == "DiarizationPipeline"
@@ -246,7 +245,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
assert call_kwargs["input_data"]["tracks"] == [ assert call_kwargs["input_data"]["tracks"] == [
{"s3_key": k} for k in track_keys {"s3_key": k} for k in track_keys
] ]
mock_file_pipeline.delay.assert_not_called()
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")

View File

@@ -2,6 +2,10 @@
# FIXME test status of transcript # FIXME test status of transcript
# FIXME test websocket connection after RTC is finished still send the full events # FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work # FIXME try with locked session, RTC should not work
# TODO: add integration tests for post-processing (LivePostPipeline) with a real
# Hatchet instance. These tests currently only cover the live pipeline.
# Post-processing events (WAVEFORM, FINAL_*, DURATION, STATUS=ended, mp3)
# are now dispatched via Hatchet and tested in test_hatchet_live_post_pipeline.py.
import asyncio import asyncio
import json import json
@@ -49,7 +53,7 @@ class ThreadedUvicorn:
@pytest.fixture @pytest.fixture
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker): def appserver(tmpdir, setup_database):
import threading import threading
from reflector.app import app from reflector.app import app
@@ -119,8 +123,6 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket( async def test_transcript_rtc_and_websocket(
tmpdir, tmpdir,
@@ -134,6 +136,7 @@ async def test_transcript_rtc_and_websocket(
appserver, appserver,
client, client,
monkeypatch, monkeypatch,
mock_hatchet_client,
): ):
# goal: start the server, exchange RTC, receive websocket events # goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread # because of that, we need to start the server in a thread
@@ -208,35 +211,30 @@ async def test_transcript_rtc_and_websocket(
stream_client.channel.send(json.dumps({"cmd": "STOP"})) stream_client.channel.send(json.dumps({"cmd": "STOP"}))
await stream_client.stop() await stream_client.stop()
# wait the processing to finish # Wait for live pipeline to flush (it dispatches post-processing to Hatchet)
timeout = 120 timeout = 30
while True: while True:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("processing", "ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1 timeout -= 1
if timeout < 0: if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended") raise TimeoutError("Timeout waiting for live pipeline to finish")
if resp.json()["status"] != "ended":
raise TimeoutError("Transcript processing failed")
# stop websocket task # stop websocket task
websocket_task.cancel() websocket_task.cancel()
# check events # check live pipeline events
assert len(events) > 0 assert len(events) > 0
from pprint import pprint from pprint import pprint
pprint(events) pprint(events)
# get events list
eventnames = [e["event"] for e in events] eventnames = [e["event"] for e in events]
# check events # Live pipeline produces TRANSCRIPT and TOPIC events during RTC
assert "TRANSCRIPT" in eventnames assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")] ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["text"].startswith("Hello world.")
@@ -249,50 +247,18 @@ async def test_transcript_rtc_and_websocket(
assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames # Live pipeline status progression
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "Llm Title"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert "recording" in statuses
assert "processing" in statuses
assert statuses.index("recording") < statuses.index("processing") assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended # Post-processing (WAVEFORM, FINAL_*, DURATION, mp3, STATUS=ended) is now
assert events[-1]["event"] == "STATUS" # dispatched to Hatchet via LivePostPipeline — not tested here.
assert events[-1]["data"]["value"] == "ended" # See test_hatchet_live_post_pipeline.py for post-processing tests.
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr( async def test_transcript_rtc_and_websocket_and_fr(
tmpdir, tmpdir,
@@ -306,6 +272,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
appserver, appserver,
client, client,
monkeypatch, monkeypatch,
mock_hatchet_client,
): ):
# goal: start the server, exchange RTC, receive websocket events # goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread # because of that, we need to start the server in a thread
@@ -382,42 +349,34 @@ async def test_transcript_rtc_and_websocket_and_fr(
# instead of waiting a long time, we just send a STOP # instead of waiting a long time, we just send a STOP
stream_client.channel.send(json.dumps({"cmd": "STOP"})) stream_client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2) await asyncio.sleep(2)
await stream_client.stop() await stream_client.stop()
# wait the processing to finish # Wait for live pipeline to flush
timeout = 120 timeout = 30
while True: while True:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] == "ended": if resp.json()["status"] in ("processing", "ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1 timeout -= 1
if timeout < 0: if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended") raise TimeoutError("Timeout waiting for live pipeline to finish")
if resp.json()["status"] != "ended":
raise TimeoutError("Transcript processing failed")
await asyncio.sleep(2)
# stop websocket task # stop websocket task
websocket_task.cancel() websocket_task.cancel()
# check events # check live pipeline events
assert len(events) > 0 assert len(events) > 0
from pprint import pprint from pprint import pprint
pprint(events) pprint(events)
# get events list
eventnames = [e["event"] for e in events] eventnames = [e["event"] for e in events]
# check events # Live pipeline produces TRANSCRIPT with translation
assert "TRANSCRIPT" in eventnames assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")] ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["text"].startswith("Hello world.")
@@ -430,23 +389,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames # Live pipeline status progression
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "Llm Title"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert "recording" in statuses
assert "processing" in statuses
assert statuses.index("recording") < statuses.index("processing") assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended # Post-processing (FINAL_*, STATUS=ended) is now dispatched to Hatchet
assert events[-1]["event"] == "STATUS" # via LivePostPipeline — not tested here.
assert events[-1]["data"]["value"] == "ended"

View File

@@ -1,12 +1,7 @@
import asyncio
import time
import pytest import pytest
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_upload_file( async def test_transcript_upload_file(
tmpdir, tmpdir,
@@ -17,6 +12,7 @@ async def test_transcript_upload_file(
dummy_storage, dummy_storage,
client, client,
monkeypatch, monkeypatch,
mock_hatchet_client,
): ):
from reflector.settings import settings from reflector.settings import settings
@@ -43,27 +39,16 @@ async def test_transcript_upload_file(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait the processing to finish (max 1 minute) # Verify Hatchet workflow was dispatched for file processing
timeout_seconds = 60 from reflector.hatchet.client import HatchetClientManager
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: HatchetClientManager.start_workflow.assert_called_once_with(
# fetch the transcript and check if it is ended "FilePipeline",
{"transcript_id": tid},
additional_metadata={"transcript_id": tid},
)
# Verify transcript status was updated to "uploaded"
resp = await client.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): assert resp.json()["status"] == "uploaded"
break
await asyncio.sleep(1)
else:
return pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "Hello world. How are you today?" in response.json()[0]["transcript"]