mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)
* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine * fix: always force reprocessing * fix: ci tests with live pipelines * fix: ci tests with live pipelines
This commit is contained in:
committed by
GitHub
parent
72dca7cacc
commit
37a1f01850
@@ -51,6 +51,9 @@ services:
|
|||||||
HF_TOKEN: ${HF_TOKEN:-}
|
HF_TOKEN: ${HF_TOKEN:-}
|
||||||
# WebRTC: fixed UDP port range for ICE candidates (mapped above)
|
# WebRTC: fixed UDP port range for ICE candidates (mapped above)
|
||||||
WEBRTC_PORT_RANGE: "51000-51100"
|
WEBRTC_PORT_RANGE: "51000-51100"
|
||||||
|
# Hatchet workflow engine (always-on for processing pipelines)
|
||||||
|
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
|
||||||
|
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -75,6 +78,9 @@ services:
|
|||||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||||
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
||||||
HF_TOKEN: ${HF_TOKEN:-}
|
HF_TOKEN: ${HF_TOKEN:-}
|
||||||
|
# Hatchet workflow engine (always-on for processing pipelines)
|
||||||
|
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
|
||||||
|
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -126,6 +132,8 @@ services:
|
|||||||
redis:
|
redis:
|
||||||
image: redis:7.2-alpine
|
image: redis:7.2-alpine
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
@@ -301,20 +309,20 @@ services:
|
|||||||
- server
|
- server
|
||||||
|
|
||||||
# ===========================================================
|
# ===========================================================
|
||||||
# Hatchet + Daily.co workers (optional — for Daily.co multitrack processing)
|
# Hatchet workflow engine + workers
|
||||||
# Auto-enabled when DAILY_API_KEY is configured in server/r
|
# Required for all processing pipelines (file, live, Daily.co multitrack).
|
||||||
|
# Always-on — every selfhosted deployment needs Hatchet.
|
||||||
# ===========================================================
|
# ===========================================================
|
||||||
|
|
||||||
hatchet:
|
hatchet:
|
||||||
image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest
|
image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest
|
||||||
profiles: [dailyco]
|
|
||||||
restart: on-failure
|
restart: on-failure
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
ports:
|
ports:
|
||||||
- "8888:8888"
|
- "127.0.0.1:8888:8888"
|
||||||
- "7078:7077"
|
- "127.0.0.1:7078:7077"
|
||||||
env_file:
|
env_file:
|
||||||
- ./.env.hatchet
|
- ./.env.hatchet
|
||||||
environment:
|
environment:
|
||||||
@@ -363,7 +371,6 @@ services:
|
|||||||
context: ./server
|
context: ./server
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
image: monadicalsas/reflector-backend:latest
|
image: monadicalsas/reflector-backend:latest
|
||||||
profiles: [dailyco]
|
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
|
|||||||
@@ -261,9 +261,11 @@ if [[ -z "$MODEL_MODE" ]]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Build profiles list — one profile per feature
|
# Build profiles list — one profile per feature
|
||||||
# Only --gpu needs a compose profile; --cpu and --hosted use in-process/remote backends
|
# Hatchet + hatchet-worker-llm are always-on (no profile needed).
|
||||||
|
# gpu/cpu profiles only control the ML container (transcription service).
|
||||||
COMPOSE_PROFILES=()
|
COMPOSE_PROFILES=()
|
||||||
[[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu")
|
[[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu")
|
||||||
|
[[ "$MODEL_MODE" == "cpu" ]] && COMPOSE_PROFILES+=("cpu")
|
||||||
[[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE")
|
[[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE")
|
||||||
[[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage")
|
[[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage")
|
||||||
[[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy")
|
[[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy")
|
||||||
@@ -557,12 +559,10 @@ step_server_env() {
|
|||||||
ok "CPU mode — file processing timeouts set to 3600s (1 hour)"
|
ok "CPU mode — file processing timeouts set to 3600s (1 hour)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# If Daily.co is manually configured, ensure Hatchet connectivity vars are set
|
# Hatchet is always required (file, live, and multitrack pipelines all use it)
|
||||||
if env_has_key "$SERVER_ENV" "DAILY_API_KEY" && [[ -n "$(env_get "$SERVER_ENV" "DAILY_API_KEY")" ]]; then
|
|
||||||
env_set "$SERVER_ENV" "HATCHET_CLIENT_SERVER_URL" "http://hatchet:8888"
|
env_set "$SERVER_ENV" "HATCHET_CLIENT_SERVER_URL" "http://hatchet:8888"
|
||||||
env_set "$SERVER_ENV" "HATCHET_CLIENT_HOST_PORT" "hatchet:7077"
|
env_set "$SERVER_ENV" "HATCHET_CLIENT_HOST_PORT" "hatchet:7077"
|
||||||
ok "Daily.co detected — Hatchet connectivity configured"
|
ok "Hatchet connectivity configured (workflow engine for processing pipelines)"
|
||||||
fi
|
|
||||||
|
|
||||||
ok "server/.env ready"
|
ok "server/.env ready"
|
||||||
}
|
}
|
||||||
@@ -886,15 +886,22 @@ step_services() {
|
|||||||
compose_cmd pull server web || warn "Pull failed — using cached images"
|
compose_cmd pull server web || warn "Pull failed — using cached images"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Build hatchet workers if Daily.co is configured (same backend image)
|
# Hatchet is always needed (all processing pipelines use it)
|
||||||
if [[ "$DAILY_DETECTED" == "true" ]] && [[ "$BUILD_IMAGES" == "true" ]]; then
|
local NEEDS_HATCHET=true
|
||||||
|
|
||||||
|
# Build hatchet workers if Hatchet is needed (same backend image)
|
||||||
|
if [[ "$NEEDS_HATCHET" == "true" ]] && [[ "$BUILD_IMAGES" == "true" ]]; then
|
||||||
info "Building Hatchet worker images..."
|
info "Building Hatchet worker images..."
|
||||||
|
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
||||||
compose_cmd build hatchet-worker-cpu hatchet-worker-llm
|
compose_cmd build hatchet-worker-cpu hatchet-worker-llm
|
||||||
|
else
|
||||||
|
compose_cmd build hatchet-worker-llm
|
||||||
|
fi
|
||||||
ok "Hatchet worker images built"
|
ok "Hatchet worker images built"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Ensure hatchet database exists before starting hatchet (init-hatchet-db.sql only runs on fresh postgres volumes)
|
# Ensure hatchet database exists before starting hatchet (init-hatchet-db.sql only runs on fresh postgres volumes)
|
||||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
if [[ "$NEEDS_HATCHET" == "true" ]]; then
|
||||||
info "Ensuring postgres is running for Hatchet database setup..."
|
info "Ensuring postgres is running for Hatchet database setup..."
|
||||||
compose_cmd up -d postgres
|
compose_cmd up -d postgres
|
||||||
local pg_ready=false
|
local pg_ready=false
|
||||||
@@ -1049,8 +1056,7 @@ step_health() {
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Hatchet (if Daily.co detected)
|
# Hatchet (always-on)
|
||||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
|
||||||
info "Waiting for Hatchet workflow engine..."
|
info "Waiting for Hatchet workflow engine..."
|
||||||
local hatchet_ok=false
|
local hatchet_ok=false
|
||||||
for i in $(seq 1 60); do
|
for i in $(seq 1 60); do
|
||||||
@@ -1067,7 +1073,6 @@ step_health() {
|
|||||||
else
|
else
|
||||||
warn "Hatchet not ready yet. Check: docker compose logs hatchet"
|
warn "Hatchet not ready yet. Check: docker compose logs hatchet"
|
||||||
fi
|
fi
|
||||||
fi
|
|
||||||
|
|
||||||
# LLM warning for non-Ollama modes
|
# LLM warning for non-Ollama modes
|
||||||
if [[ "$USES_OLLAMA" == "false" ]]; then
|
if [[ "$USES_OLLAMA" == "false" ]]; then
|
||||||
@@ -1087,12 +1092,10 @@ step_health() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# =========================================================
|
# =========================================================
|
||||||
# Step 8: Hatchet token generation (Daily.co only)
|
# Step 8: Hatchet token generation (gpu/cpu/Daily.co)
|
||||||
# =========================================================
|
# =========================================================
|
||||||
step_hatchet_token() {
|
step_hatchet_token() {
|
||||||
if [[ "$DAILY_DETECTED" != "true" ]]; then
|
# Hatchet is always required — no gating needed
|
||||||
return
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Skip if token already set
|
# Skip if token already set
|
||||||
if env_has_key "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" && [[ -n "$(env_get "$SERVER_ENV" "HATCHET_CLIENT_TOKEN")" ]]; then
|
if env_has_key "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" && [[ -n "$(env_get "$SERVER_ENV" "HATCHET_CLIENT_TOKEN")" ]]; then
|
||||||
@@ -1147,7 +1150,9 @@ step_hatchet_token() {
|
|||||||
|
|
||||||
# Restart services that need the token
|
# Restart services that need the token
|
||||||
info "Restarting services with new Hatchet token..."
|
info "Restarting services with new Hatchet token..."
|
||||||
compose_cmd restart server worker hatchet-worker-cpu hatchet-worker-llm
|
local restart_services="server worker hatchet-worker-llm"
|
||||||
|
[[ "$DAILY_DETECTED" == "true" ]] && restart_services="$restart_services hatchet-worker-cpu"
|
||||||
|
compose_cmd restart $restart_services
|
||||||
ok "Services restarted with Hatchet token"
|
ok "Services restarted with Hatchet token"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1216,8 +1221,7 @@ main() {
|
|||||||
ok "Daily.co detected — enabling Hatchet workflow services"
|
ok "Daily.co detected — enabling Hatchet workflow services"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Generate .env.hatchet for hatchet dashboard config
|
# Generate .env.hatchet for hatchet dashboard config (always needed)
|
||||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
|
||||||
local hatchet_server_url hatchet_cookie_domain
|
local hatchet_server_url hatchet_cookie_domain
|
||||||
if [[ -n "$CUSTOM_DOMAIN" ]]; then
|
if [[ -n "$CUSTOM_DOMAIN" ]]; then
|
||||||
hatchet_server_url="https://${CUSTOM_DOMAIN}:8888"
|
hatchet_server_url="https://${CUSTOM_DOMAIN}:8888"
|
||||||
@@ -1234,10 +1238,6 @@ SERVER_URL=$hatchet_server_url
|
|||||||
SERVER_AUTH_COOKIE_DOMAIN=$hatchet_cookie_domain
|
SERVER_AUTH_COOKIE_DOMAIN=$hatchet_cookie_domain
|
||||||
EOF
|
EOF
|
||||||
ok "Generated .env.hatchet (dashboard URL=$hatchet_server_url)"
|
ok "Generated .env.hatchet (dashboard URL=$hatchet_server_url)"
|
||||||
else
|
|
||||||
# Create empty .env.hatchet so compose doesn't fail if dailyco profile is ever activated manually
|
|
||||||
touch "$ROOT_DIR/.env.hatchet"
|
|
||||||
fi
|
|
||||||
|
|
||||||
step_www_env
|
step_www_env
|
||||||
echo ""
|
echo ""
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ source = ["reflector"]
|
|||||||
ENVIRONMENT = "pytest"
|
ENVIRONMENT = "pytest"
|
||||||
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||||
AUTH_BACKEND = "jwt"
|
AUTH_BACKEND = "jwt"
|
||||||
|
HATCHET_CLIENT_TOKEN = "test-dummy-token"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
|
|||||||
@@ -26,6 +26,21 @@ class TaskName(StrEnum):
|
|||||||
DETECT_CHUNK_TOPIC = "detect_chunk_topic"
|
DETECT_CHUNK_TOPIC = "detect_chunk_topic"
|
||||||
GENERATE_DETAILED_SUMMARY = "generate_detailed_summary"
|
GENERATE_DETAILED_SUMMARY = "generate_detailed_summary"
|
||||||
|
|
||||||
|
# File pipeline tasks
|
||||||
|
EXTRACT_AUDIO = "extract_audio"
|
||||||
|
UPLOAD_AUDIO = "upload_audio"
|
||||||
|
TRANSCRIBE = "transcribe"
|
||||||
|
DIARIZE = "diarize"
|
||||||
|
ASSEMBLE_TRANSCRIPT = "assemble_transcript"
|
||||||
|
GENERATE_SUMMARIES = "generate_summaries"
|
||||||
|
|
||||||
|
# Live post-processing pipeline tasks
|
||||||
|
WAVEFORM = "waveform"
|
||||||
|
CONVERT_MP3 = "convert_mp3"
|
||||||
|
UPLOAD_MP3 = "upload_mp3"
|
||||||
|
REMOVE_UPLOAD = "remove_upload"
|
||||||
|
FINAL_SUMMARIES = "final_summaries"
|
||||||
|
|
||||||
|
|
||||||
# Rate limit key for LLM API calls (shared across all LLM-calling tasks)
|
# Rate limit key for LLM API calls (shared across all LLM-calling tasks)
|
||||||
LLM_RATE_LIMIT_KEY = "llm"
|
LLM_RATE_LIMIT_KEY = "llm"
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from reflector.hatchet.client import HatchetClientManager
|
|||||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
daily_multitrack_pipeline,
|
daily_multitrack_pipeline,
|
||||||
)
|
)
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import file_pipeline
|
||||||
|
from reflector.hatchet.workflows.live_post_pipeline import live_post_pipeline
|
||||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||||
@@ -47,6 +49,8 @@ def main():
|
|||||||
},
|
},
|
||||||
workflows=[
|
workflows=[
|
||||||
daily_multitrack_pipeline,
|
daily_multitrack_pipeline,
|
||||||
|
file_pipeline,
|
||||||
|
live_post_pipeline,
|
||||||
topic_chunk_workflow,
|
topic_chunk_workflow,
|
||||||
subject_workflow,
|
subject_workflow,
|
||||||
track_workflow,
|
track_workflow,
|
||||||
|
|||||||
885
server/reflector/hatchet/workflows/file_pipeline.py
Normal file
885
server/reflector/hatchet/workflows/file_pipeline.py
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
"""
|
||||||
|
Hatchet workflow: FilePipeline
|
||||||
|
|
||||||
|
Processing pipeline for file uploads and Whereby recordings.
|
||||||
|
Orchestrates: extract audio → upload → transcribe/diarize/waveform (parallel)
|
||||||
|
→ assemble → detect topics → title/summaries (parallel) → finalize
|
||||||
|
→ cleanup consent → post zulip / send webhook.
|
||||||
|
|
||||||
|
Note: This file uses deferred imports (inside functions/tasks) intentionally.
|
||||||
|
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
|
||||||
|
are not shared across forks, avoiding connection pooling issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from hatchet_sdk import Context
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.hatchet.broadcast import (
|
||||||
|
append_event_and_broadcast,
|
||||||
|
set_status_and_broadcast,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.constants import (
|
||||||
|
TIMEOUT_HEAVY,
|
||||||
|
TIMEOUT_MEDIUM,
|
||||||
|
TIMEOUT_SHORT,
|
||||||
|
TIMEOUT_TITLE,
|
||||||
|
TaskName,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
fresh_db_connection,
|
||||||
|
set_workflow_error_status,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.models import (
|
||||||
|
ConsentResult,
|
||||||
|
TitleResult,
|
||||||
|
TopicsResult,
|
||||||
|
WaveformResult,
|
||||||
|
WebhookResult,
|
||||||
|
ZulipResult,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.pipelines import topic_processing
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.audio_constants import WAVEFORM_SEGMENTS
|
||||||
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class FilePipelineInput(BaseModel):
|
||||||
|
transcript_id: str
|
||||||
|
room_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Result models specific to file pipeline ---
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractAudioResult(BaseModel):
|
||||||
|
audio_path: str
|
||||||
|
duration_ms: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class UploadAudioResult(BaseModel):
|
||||||
|
audio_url: str
|
||||||
|
audio_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscribeResult(BaseModel):
|
||||||
|
words: list[dict]
|
||||||
|
translation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizeResult(BaseModel):
|
||||||
|
diarization: list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AssembleTranscriptResult(BaseModel):
|
||||||
|
assembled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SummariesResult(BaseModel):
|
||||||
|
generated: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FinalizeResult(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
file_pipeline = hatchet.workflow(name="FilePipeline", input_validator=FilePipelineInput)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.EXTRACT_AUDIO)
|
||||||
|
async def extract_audio(input: FilePipelineInput, ctx: Context) -> ExtractAudioResult:
|
||||||
|
"""Extract audio from upload file, convert to MP3."""
|
||||||
|
ctx.log(f"extract_audio: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
await set_status_and_broadcast(input.transcript_id, "processing", logger=logger)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
|
||||||
|
# Clear transcript as we're going to regenerate everything
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"events": [],
|
||||||
|
"topics": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find upload file
|
||||||
|
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||||
|
if not audio_file:
|
||||||
|
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||||
|
if not audio_file:
|
||||||
|
raise ValueError("No audio file found to process")
|
||||||
|
|
||||||
|
ctx.log(f"extract_audio: processing {audio_file}")
|
||||||
|
|
||||||
|
# Extract audio and write as MP3
|
||||||
|
import av # noqa: PLC0415
|
||||||
|
|
||||||
|
from reflector.processors import AudioFileWriterProcessor # noqa: PLC0415
|
||||||
|
|
||||||
|
duration_ms_container = [0.0]
|
||||||
|
|
||||||
|
async def capture_duration(d):
|
||||||
|
duration_ms_container[0] = d
|
||||||
|
|
||||||
|
mp3_writer = AudioFileWriterProcessor(
|
||||||
|
path=transcript.audio_mp3_filename,
|
||||||
|
on_duration=capture_duration,
|
||||||
|
)
|
||||||
|
input_container = av.open(str(audio_file))
|
||||||
|
for frame in input_container.decode(audio=0):
|
||||||
|
await mp3_writer.push(frame)
|
||||||
|
await mp3_writer.flush()
|
||||||
|
input_container.close()
|
||||||
|
|
||||||
|
duration_ms = duration_ms_container[0]
|
||||||
|
audio_path = str(transcript.audio_mp3_filename)
|
||||||
|
|
||||||
|
# Persist duration to database and broadcast to websocket clients
|
||||||
|
from reflector.db.transcripts import TranscriptDuration # noqa: PLC0415
|
||||||
|
from reflector.db.transcripts import transcripts_controller as tc
|
||||||
|
|
||||||
|
await tc.update(transcript, {"duration": duration_ms})
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"DURATION",
|
||||||
|
TranscriptDuration(duration=duration_ms),
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log(f"extract_audio complete: {audio_path}, duration={duration_ms}ms")
|
||||||
|
return ExtractAudioResult(audio_path=audio_path, duration_ms=duration_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[extract_audio],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.UPLOAD_AUDIO)
|
||||||
|
async def upload_audio(input: FilePipelineInput, ctx: Context) -> UploadAudioResult:
|
||||||
|
"""Upload audio to S3/storage, return audio_url."""
|
||||||
|
ctx.log(f"upload_audio: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
extract_result = ctx.task_output(extract_audio)
|
||||||
|
audio_path = extract_result.audio_path
|
||||||
|
|
||||||
|
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
if not storage:
|
||||||
|
raise ValueError(
|
||||||
|
"Storage backend required for file processing. "
|
||||||
|
"Configure TRANSCRIPT_STORAGE_* settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(audio_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
storage_path = f"file_pipeline/{input.transcript_id}/audio.mp3"
|
||||||
|
await storage.put_file(storage_path, audio_data)
|
||||||
|
audio_url = await storage.get_file_url(storage_path)
|
||||||
|
|
||||||
|
ctx.log(f"upload_audio complete: {audio_url}")
|
||||||
|
return UploadAudioResult(audio_url=audio_url, audio_path=audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[upload_audio],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.TRANSCRIBE)
|
||||||
|
async def transcribe(input: FilePipelineInput, ctx: Context) -> TranscribeResult:
|
||||||
|
"""Transcribe the audio file using the configured backend."""
|
||||||
|
ctx.log(f"transcribe: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
upload_result = ctx.task_output(upload_audio)
|
||||||
|
audio_url = upload_result.audio_url
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
source_language = transcript.source_language
|
||||||
|
|
||||||
|
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
||||||
|
transcribe_file_with_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await transcribe_file_with_processor(audio_url, source_language)
|
||||||
|
|
||||||
|
ctx.log(f"transcribe complete: {len(result.words)} words")
|
||||||
|
return TranscribeResult(
|
||||||
|
words=[w.model_dump() for w in result.words],
|
||||||
|
translation=result.translation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[upload_audio],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.DIARIZE)
|
||||||
|
async def diarize(input: FilePipelineInput, ctx: Context) -> DiarizeResult:
|
||||||
|
"""Diarize the audio file (speaker identification)."""
|
||||||
|
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
if not settings.DIARIZATION_BACKEND:
|
||||||
|
ctx.log("diarize: diarization disabled, skipping")
|
||||||
|
return DiarizeResult(diarization=None)
|
||||||
|
|
||||||
|
upload_result = ctx.task_output(upload_audio)
|
||||||
|
audio_url = upload_result.audio_url
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import ( # noqa: PLC0415
|
||||||
|
FileDiarizationInput,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_diarization_auto import ( # noqa: PLC0415
|
||||||
|
FileDiarizationAutoProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = FileDiarizationAutoProcessor()
|
||||||
|
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||||
|
|
||||||
|
result = None
|
||||||
|
|
||||||
|
async def capture_result(diarization_output):
|
||||||
|
nonlocal result
|
||||||
|
result = diarization_output.diarization
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Diarization failed: {e}")
|
||||||
|
return DiarizeResult(diarization=None)
|
||||||
|
|
||||||
|
ctx.log(f"diarize complete: {len(result) if result else 0} segments")
|
||||||
|
return DiarizeResult(diarization=list(result) if result else None)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[upload_audio],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.GENERATE_WAVEFORM)
|
||||||
|
async def generate_waveform(input: FilePipelineInput, ctx: Context) -> WaveformResult:
|
||||||
|
"""Generate audio waveform visualization."""
|
||||||
|
ctx.log(f"generate_waveform: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
upload_result = ctx.task_output(upload_audio)
|
||||||
|
audio_path = upload_result.audio_path
|
||||||
|
|
||||||
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptWaveform,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform = get_audio_waveform(
|
||||||
|
path=Path(audio_path), segments_count=WAVEFORM_SEGMENTS
|
||||||
|
)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
transcript.data_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(transcript.audio_waveform_filename, "w") as f:
|
||||||
|
json.dump(waveform, f)
|
||||||
|
|
||||||
|
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"WAVEFORM",
|
||||||
|
waveform_data,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log("generate_waveform complete")
|
||||||
|
return WaveformResult(waveform_generated=True)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[transcribe, diarize, generate_waveform],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.ASSEMBLE_TRANSCRIPT)
|
||||||
|
async def assemble_transcript(
|
||||||
|
input: FilePipelineInput, ctx: Context
|
||||||
|
) -> AssembleTranscriptResult:
|
||||||
|
"""Merge transcription + diarization results."""
|
||||||
|
ctx.log(f"assemble_transcript: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
transcribe_result = ctx.task_output(transcribe)
|
||||||
|
diarize_result = ctx.task_output(diarize)
|
||||||
|
|
||||||
|
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import ( # noqa: PLC0415
|
||||||
|
DiarizationSegment,
|
||||||
|
Word,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import ( # noqa: PLC0415
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
|
||||||
|
words = [Word(**w) for w in transcribe_result.words]
|
||||||
|
transcript_data = TranscriptType(
|
||||||
|
words=words, translation=transcribe_result.translation
|
||||||
|
)
|
||||||
|
|
||||||
|
diarization = None
|
||||||
|
if diarize_result.diarization:
|
||||||
|
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
|
||||||
|
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
assembler_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript_data, diarization=diarization or []
|
||||||
|
)
|
||||||
|
|
||||||
|
diarized_transcript = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal diarized_transcript
|
||||||
|
diarized_transcript = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(assembler_input)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not diarized_transcript:
|
||||||
|
raise ValueError("No diarized transcript captured")
|
||||||
|
|
||||||
|
# Save the assembled transcript events to the database
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptText,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
assembled_text = diarized_transcript.text if diarized_transcript else ""
|
||||||
|
assembled_translation = (
|
||||||
|
diarized_transcript.translation if diarized_transcript else None
|
||||||
|
)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"TRANSCRIPT",
|
||||||
|
TranscriptText(text=assembled_text, translation=assembled_translation),
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log("assemble_transcript complete")
|
||||||
|
return AssembleTranscriptResult(assembled=True)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[assemble_transcript],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.DETECT_TOPICS)
|
||||||
|
async def detect_topics(input: FilePipelineInput, ctx: Context) -> TopicsResult:
|
||||||
|
"""Detect topics from the assembled transcript."""
|
||||||
|
ctx.log(f"detect_topics: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
# Re-read the transcript to get the diarized words
|
||||||
|
transcribe_result = ctx.task_output(transcribe)
|
||||||
|
diarize_result = ctx.task_output(diarize)
|
||||||
|
|
||||||
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import ( # noqa: PLC0415
|
||||||
|
DiarizationSegment,
|
||||||
|
Word,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import ( # noqa: PLC0415
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
|
||||||
|
words = [Word(**w) for w in transcribe_result.words]
|
||||||
|
transcript_data = TranscriptType(
|
||||||
|
words=words, translation=transcribe_result.translation
|
||||||
|
)
|
||||||
|
|
||||||
|
diarization = None
|
||||||
|
if diarize_result.diarization:
|
||||||
|
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
|
||||||
|
|
||||||
|
# Re-assemble to get the diarized transcript for topic detection
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
assembler_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript_data, diarization=diarization or []
|
||||||
|
)
|
||||||
|
|
||||||
|
diarized_transcript = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal diarized_transcript
|
||||||
|
diarized_transcript = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(assembler_input)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not diarized_transcript:
|
||||||
|
raise ValueError("No diarized transcript for topic detection")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
target_language = transcript.target_language
|
||||||
|
|
||||||
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
|
||||||
|
async def on_topic_callback(data):
|
||||||
|
topic = TranscriptTopic(
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
transcript=data.transcript.text
|
||||||
|
if hasattr(data.transcript, "text")
|
||||||
|
else "",
|
||||||
|
words=data.transcript.words
|
||||||
|
if hasattr(data.transcript, "words")
|
||||||
|
else [],
|
||||||
|
)
|
||||||
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id, transcript, "TOPIC", topic, logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
topics = await topic_processing.detect_topics(
|
||||||
|
diarized_transcript,
|
||||||
|
target_language,
|
||||||
|
on_topic_callback=on_topic_callback,
|
||||||
|
empty_pipeline=empty_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log(f"detect_topics complete: {len(topics)} topics")
|
||||||
|
return TopicsResult(topics=topics)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[detect_topics],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||||
|
async def generate_title(input: FilePipelineInput, ctx: Context) -> TitleResult:
|
||||||
|
"""Generate meeting title using LLM."""
|
||||||
|
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
topics_result = ctx.task_output(detect_topics)
|
||||||
|
topics = topics_result.topics
|
||||||
|
|
||||||
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptFinalTitle,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
title_result = None
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
|
||||||
|
async def on_title_callback(data):
|
||||||
|
nonlocal title_result
|
||||||
|
title_result = data.title
|
||||||
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
|
if not transcript.title:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"title": final_title.title}
|
||||||
|
)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"FINAL_TITLE",
|
||||||
|
final_title,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_processing.generate_title(
|
||||||
|
topics,
|
||||||
|
on_title_callback=on_title_callback,
|
||||||
|
empty_pipeline=empty_pipeline,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log(f"generate_title complete: '{title_result}'")
|
||||||
|
return TitleResult(title=title_result)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[detect_topics],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.GENERATE_SUMMARIES)
|
||||||
|
async def generate_summaries(input: FilePipelineInput, ctx: Context) -> SummariesResult:
|
||||||
|
"""Generate long/short summaries and action items."""
|
||||||
|
ctx.log(f"generate_summaries: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
topics_result = ctx.task_output(detect_topics)
|
||||||
|
topics = topics_result.topics
|
||||||
|
|
||||||
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptActionItems,
|
||||||
|
TranscriptFinalLongSummary,
|
||||||
|
TranscriptFinalShortSummary,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
|
||||||
|
async def on_long_summary_callback(data):
|
||||||
|
final_long = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"long_summary": final_long.long_summary}
|
||||||
|
)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"FINAL_LONG_SUMMARY",
|
||||||
|
final_long,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_short_summary_callback(data):
|
||||||
|
final_short = TranscriptFinalShortSummary(short_summary=data.short_summary)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"short_summary": final_short.short_summary}
|
||||||
|
)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"FINAL_SHORT_SUMMARY",
|
||||||
|
final_short,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_action_items_callback(data):
|
||||||
|
action_items = TranscriptActionItems(action_items=data.action_items)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"action_items": action_items.action_items}
|
||||||
|
)
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id,
|
||||||
|
transcript,
|
||||||
|
"ACTION_ITEMS",
|
||||||
|
action_items,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_processing.generate_summaries(
|
||||||
|
topics,
|
||||||
|
transcript,
|
||||||
|
on_long_summary_callback=on_long_summary_callback,
|
||||||
|
on_short_summary_callback=on_short_summary_callback,
|
||||||
|
on_action_items_callback=on_action_items_callback,
|
||||||
|
empty_pipeline=empty_pipeline,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.log("generate_summaries complete")
|
||||||
|
return SummariesResult(generated=True)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[generate_title, generate_summaries],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=5,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.FINALIZE)
|
||||||
|
async def finalize(input: FilePipelineInput, ctx: Context) -> FinalizeResult:
|
||||||
|
"""Set transcript status to 'ended' and broadcast."""
|
||||||
|
ctx.log("finalize: setting status to 'ended'")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
await set_status_and_broadcast(input.transcript_id, "ended", logger=logger)
|
||||||
|
|
||||||
|
ctx.log("finalize complete")
|
||||||
|
return FinalizeResult(status="COMPLETED")
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[finalize],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||||
|
async def cleanup_consent(input: FilePipelineInput, ctx: Context) -> ConsentResult:
|
||||||
|
"""Check consent and delete audio files if any participant denied."""
|
||||||
|
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.meetings import ( # noqa: PLC0415
|
||||||
|
meeting_consent_controller,
|
||||||
|
meetings_controller,
|
||||||
|
)
|
||||||
|
from reflector.db.recordings import recordings_controller # noqa: PLC0415
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
ctx.log("cleanup_consent: transcript not found")
|
||||||
|
return ConsentResult()
|
||||||
|
|
||||||
|
consent_denied = False
|
||||||
|
recording = None
|
||||||
|
if transcript.recording_id:
|
||||||
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
|
if recording and recording.meeting_id:
|
||||||
|
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||||
|
if meeting:
|
||||||
|
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||||
|
meeting.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not consent_denied:
|
||||||
|
ctx.log("cleanup_consent: consent approved, keeping all files")
|
||||||
|
return ConsentResult()
|
||||||
|
|
||||||
|
ctx.log("cleanup_consent: consent denied, deleting audio files")
|
||||||
|
|
||||||
|
deletion_errors = []
|
||||||
|
if recording and recording.bucket_name:
|
||||||
|
keys_to_delete = []
|
||||||
|
if recording.track_keys:
|
||||||
|
keys_to_delete = recording.track_keys
|
||||||
|
elif recording.object_key:
|
||||||
|
keys_to_delete = [recording.object_key]
|
||||||
|
|
||||||
|
master_storage = get_transcripts_storage()
|
||||||
|
for key in keys_to_delete:
|
||||||
|
try:
|
||||||
|
await master_storage.delete_file(key, bucket=recording.bucket_name)
|
||||||
|
ctx.log(f"Deleted recording file: {recording.bucket_name}/{key}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to delete {key}: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
deletion_errors.append(error_msg)
|
||||||
|
|
||||||
|
if transcript.audio_location == "storage":
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
try:
|
||||||
|
await storage.delete_file(transcript.storage_audio_path)
|
||||||
|
ctx.log(f"Deleted processed audio: {transcript.storage_audio_path}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to delete processed audio: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
deletion_errors.append(error_msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
hasattr(transcript, "audio_mp3_filename")
|
||||||
|
and transcript.audio_mp3_filename
|
||||||
|
):
|
||||||
|
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||||
|
if (
|
||||||
|
hasattr(transcript, "audio_wav_filename")
|
||||||
|
and transcript.audio_wav_filename
|
||||||
|
):
|
||||||
|
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to delete local audio files: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
deletion_errors.append(error_msg)
|
||||||
|
|
||||||
|
if deletion_errors:
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet] cleanup_consent completed with errors",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
error_count=len(deletion_errors),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||||
|
ctx.log("cleanup_consent: all audio deleted successfully")
|
||||||
|
|
||||||
|
return ConsentResult()
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[cleanup_consent],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||||
|
async def post_zulip(input: FilePipelineInput, ctx: Context) -> ZulipResult:
|
||||||
|
"""Post notification to Zulip."""
|
||||||
|
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
if not settings.ZULIP_REALM:
|
||||||
|
ctx.log("post_zulip skipped (Zulip not configured)")
|
||||||
|
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
from reflector.zulip import post_transcript_notification # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
message_id = await post_transcript_notification(transcript)
|
||||||
|
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
|
||||||
|
else:
|
||||||
|
message_id = None
|
||||||
|
|
||||||
|
return ZulipResult(zulip_message_id=message_id)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.task(
|
||||||
|
parents=[cleanup_consent],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||||
|
async def send_webhook(input: FilePipelineInput, ctx: Context) -> WebhookResult:
|
||||||
|
"""Send completion webhook to external service."""
|
||||||
|
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
if not input.room_id:
|
||||||
|
ctx.log("send_webhook skipped (no room_id)")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.rooms import rooms_controller # noqa: PLC0415
|
||||||
|
from reflector.utils.webhook import ( # noqa: PLC0415
|
||||||
|
fetch_transcript_webhook_payload,
|
||||||
|
send_webhook_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
room = await rooms_controller.get_by_id(input.room_id)
|
||||||
|
if not room or not room.webhook_url:
|
||||||
|
ctx.log("send_webhook skipped (no webhook_url configured)")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
payload = await fetch_transcript_webhook_payload(
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
room_id=input.room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(payload, str):
|
||||||
|
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
import httpx # noqa: PLC0415
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await send_webhook_request(
|
||||||
|
url=room.webhook_url,
|
||||||
|
payload=payload,
|
||||||
|
event_type="transcript.completed",
|
||||||
|
webhook_secret=room.webhook_secret,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
ctx.log(f"send_webhook complete: status_code={response.status_code}")
|
||||||
|
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
|
||||||
|
return WebhookResult(
|
||||||
|
webhook_sent=False, response_code=e.response.status_code
|
||||||
|
)
|
||||||
|
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||||
|
ctx.log(f"send_webhook failed ({e}), continuing")
|
||||||
|
return WebhookResult(webhook_sent=False)
|
||||||
|
except Exception as e:
|
||||||
|
ctx.log(f"send_webhook unexpected error: {e}")
|
||||||
|
return WebhookResult(webhook_sent=False)
|
||||||
|
|
||||||
|
|
||||||
|
# --- On failure handler ---
|
||||||
|
|
||||||
|
|
||||||
|
async def on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
|
||||||
|
"""Set transcript status to 'error' only if not already 'ended'."""
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript and transcript.status == "ended":
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] FilePipeline on_workflow_failure: transcript already ended, skipping error status",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
)
|
||||||
|
ctx.log(
|
||||||
|
"on_workflow_failure: transcript already ended, skipping error status"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await set_workflow_error_status(input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@file_pipeline.on_failure_task()
|
||||||
|
async def _register_on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
|
||||||
|
await on_workflow_failure(input, ctx)
|
||||||
389
server/reflector/hatchet/workflows/live_post_pipeline.py
Normal file
389
server/reflector/hatchet/workflows/live_post_pipeline.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""
|
||||||
|
Hatchet workflow: LivePostProcessingPipeline
|
||||||
|
|
||||||
|
Post-processing pipeline for live WebRTC meetings.
|
||||||
|
Triggered after a live meeting ends. Orchestrates:
|
||||||
|
Left branch: waveform → convert_mp3 → upload_mp3 → remove_upload → diarize → cleanup_consent
|
||||||
|
Right branch: generate_title (parallel with left branch)
|
||||||
|
Fan-in: final_summaries → post_zulip → send_webhook
|
||||||
|
|
||||||
|
Note: This file uses deferred imports (inside functions/tasks) intentionally.
|
||||||
|
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
|
||||||
|
are not shared across forks, avoiding connection pooling issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from hatchet_sdk import Context
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.constants import (
|
||||||
|
TIMEOUT_HEAVY,
|
||||||
|
TIMEOUT_MEDIUM,
|
||||||
|
TIMEOUT_SHORT,
|
||||||
|
TIMEOUT_TITLE,
|
||||||
|
TaskName,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
fresh_db_connection,
|
||||||
|
set_workflow_error_status,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.models import (
|
||||||
|
ConsentResult,
|
||||||
|
TitleResult,
|
||||||
|
WaveformResult,
|
||||||
|
WebhookResult,
|
||||||
|
ZulipResult,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class LivePostPipelineInput(BaseModel):
|
||||||
|
transcript_id: str
|
||||||
|
room_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Result models specific to live post pipeline ---
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertMp3Result(BaseModel):
|
||||||
|
converted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class UploadMp3Result(BaseModel):
|
||||||
|
uploaded: bool
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveUploadResult(BaseModel):
|
||||||
|
removed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizeResult(BaseModel):
|
||||||
|
diarized: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FinalSummariesResult(BaseModel):
|
||||||
|
generated: bool
|
||||||
|
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
live_post_pipeline = hatchet.workflow(
|
||||||
|
name="LivePostProcessingPipeline", input_validator=LivePostPipelineInput
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.WAVEFORM)
|
||||||
|
async def waveform(input: LivePostPipelineInput, ctx: Context) -> WaveformResult:
|
||||||
|
"""Generate waveform visualization from recorded audio."""
|
||||||
|
ctx.log(f"waveform: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
PipelineMainWaveform,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
|
||||||
|
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||||
|
await runner.run()
|
||||||
|
|
||||||
|
ctx.log("waveform complete")
|
||||||
|
return WaveformResult(waveform_generated=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||||
|
async def generate_title(input: LivePostPipelineInput, ctx: Context) -> TitleResult:
|
||||||
|
"""Generate meeting title from topics (runs in parallel with audio chain)."""
|
||||||
|
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
PipelineMainTitle,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = PipelineMainTitle(transcript_id=input.transcript_id)
|
||||||
|
await runner.run()
|
||||||
|
|
||||||
|
ctx.log("generate_title complete")
|
||||||
|
return TitleResult(title=None)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[waveform],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.CONVERT_MP3)
|
||||||
|
async def convert_mp3(input: LivePostPipelineInput, ctx: Context) -> ConvertMp3Result:
|
||||||
|
"""Convert WAV recording to MP3."""
|
||||||
|
ctx.log(f"convert_mp3: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
pipeline_convert_to_mp3,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_convert_to_mp3(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("convert_mp3 complete")
|
||||||
|
return ConvertMp3Result(converted=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[convert_mp3],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.UPLOAD_MP3)
|
||||||
|
async def upload_mp3(input: LivePostPipelineInput, ctx: Context) -> UploadMp3Result:
|
||||||
|
"""Upload MP3 to external storage."""
|
||||||
|
ctx.log(f"upload_mp3: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
pipeline_upload_mp3,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_upload_mp3(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("upload_mp3 complete")
|
||||||
|
return UploadMp3Result(uploaded=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[upload_mp3],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=5,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.REMOVE_UPLOAD)
|
||||||
|
async def remove_upload(
|
||||||
|
input: LivePostPipelineInput, ctx: Context
|
||||||
|
) -> RemoveUploadResult:
|
||||||
|
"""Remove the original upload file."""
|
||||||
|
ctx.log(f"remove_upload: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
pipeline_remove_upload,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_remove_upload(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("remove_upload complete")
|
||||||
|
return RemoveUploadResult(removed=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[remove_upload],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.DIARIZE)
|
||||||
|
async def diarize(input: LivePostPipelineInput, ctx: Context) -> DiarizeResult:
|
||||||
|
"""Run diarization on the recorded audio."""
|
||||||
|
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
pipeline_diarization,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_diarization(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("diarize complete")
|
||||||
|
return DiarizeResult(diarized=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[diarize],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||||
|
async def cleanup_consent(input: LivePostPipelineInput, ctx: Context) -> ConsentResult:
|
||||||
|
"""Check consent and delete audio files if any participant denied."""
|
||||||
|
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
cleanup_consent as _cleanup_consent,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _cleanup_consent(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("cleanup_consent complete")
|
||||||
|
return ConsentResult()
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[cleanup_consent, generate_title],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.FINAL_SUMMARIES)
|
||||||
|
async def final_summaries(
|
||||||
|
input: LivePostPipelineInput, ctx: Context
|
||||||
|
) -> FinalSummariesResult:
|
||||||
|
"""Generate final summaries (fan-in after audio chain + title)."""
|
||||||
|
ctx.log(f"final_summaries: starting for transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||||
|
pipeline_summaries,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_summaries(transcript_id=input.transcript_id)
|
||||||
|
|
||||||
|
ctx.log("final_summaries complete")
|
||||||
|
return FinalSummariesResult(generated=True)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[final_summaries],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||||
|
async def post_zulip(input: LivePostPipelineInput, ctx: Context) -> ZulipResult:
|
||||||
|
"""Post notification to Zulip."""
|
||||||
|
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
if not settings.ZULIP_REALM:
|
||||||
|
ctx.log("post_zulip skipped (Zulip not configured)")
|
||||||
|
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
from reflector.zulip import post_transcript_notification # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
message_id = await post_transcript_notification(transcript)
|
||||||
|
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
|
||||||
|
else:
|
||||||
|
message_id = None
|
||||||
|
|
||||||
|
return ZulipResult(zulip_message_id=message_id)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.task(
|
||||||
|
parents=[final_summaries],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
|
)
|
||||||
|
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||||
|
async def send_webhook(input: LivePostPipelineInput, ctx: Context) -> WebhookResult:
|
||||||
|
"""Send completion webhook to external service."""
|
||||||
|
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
|
if not input.room_id:
|
||||||
|
ctx.log("send_webhook skipped (no room_id)")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.rooms import rooms_controller # noqa: PLC0415
|
||||||
|
from reflector.utils.webhook import ( # noqa: PLC0415
|
||||||
|
fetch_transcript_webhook_payload,
|
||||||
|
send_webhook_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
room = await rooms_controller.get_by_id(input.room_id)
|
||||||
|
if not room or not room.webhook_url:
|
||||||
|
ctx.log("send_webhook skipped (no webhook_url configured)")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
payload = await fetch_transcript_webhook_payload(
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
room_id=input.room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(payload, str):
|
||||||
|
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
|
||||||
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
|
import httpx # noqa: PLC0415
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await send_webhook_request(
|
||||||
|
url=room.webhook_url,
|
||||||
|
payload=payload,
|
||||||
|
event_type="transcript.completed",
|
||||||
|
webhook_secret=room.webhook_secret,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
ctx.log(f"send_webhook complete: status_code={response.status_code}")
|
||||||
|
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
|
||||||
|
return WebhookResult(
|
||||||
|
webhook_sent=False, response_code=e.response.status_code
|
||||||
|
)
|
||||||
|
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||||
|
ctx.log(f"send_webhook failed ({e}), continuing")
|
||||||
|
return WebhookResult(webhook_sent=False)
|
||||||
|
except Exception as e:
|
||||||
|
ctx.log(f"send_webhook unexpected error: {e}")
|
||||||
|
return WebhookResult(webhook_sent=False)
|
||||||
|
|
||||||
|
|
||||||
|
# --- On failure handler ---
|
||||||
|
|
||||||
|
|
||||||
|
async def on_workflow_failure(input: LivePostPipelineInput, ctx: Context) -> None:
|
||||||
|
"""Set transcript status to 'error' only if not already 'ended'."""
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript and transcript.status == "ended":
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] LivePostProcessingPipeline on_workflow_failure: transcript already ended",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
)
|
||||||
|
ctx.log(
|
||||||
|
"on_workflow_failure: transcript already ended, skipping error status"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await set_workflow_error_status(input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@live_post_pipeline.on_failure_task()
|
||||||
|
async def _register_on_workflow_failure(
|
||||||
|
input: LivePostPipelineInput, ctx: Context
|
||||||
|
) -> None:
|
||||||
|
await on_workflow_failure(input, ctx)
|
||||||
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import Generic
|
from typing import Generic
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from celery import chord, current_task, group, shared_task
|
from celery import current_task, shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from structlog import BoundLogger as Logger
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
@@ -397,7 +397,9 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
# when the pipeline ends, connect to the post pipeline
|
# when the pipeline ends, connect to the post pipeline
|
||||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||||
pipeline_post(transcript_id=self.transcript_id)
|
transcript = await transcripts_controller.get_by_id(self.transcript_id)
|
||||||
|
room_id = transcript.room_id if transcript else None
|
||||||
|
await pipeline_post(transcript_id=self.transcript_id, room_id=room_id)
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||||
@@ -792,29 +794,20 @@ async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
|||||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_post(*, transcript_id: str):
|
async def pipeline_post(*, transcript_id: str, room_id: str | None = None):
|
||||||
"""
|
"""
|
||||||
Run the post pipeline
|
Run the post pipeline via Hatchet.
|
||||||
"""
|
"""
|
||||||
chain_mp3_and_diarize = (
|
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
|
||||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
|
||||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
|
||||||
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
|
||||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
|
||||||
| task_cleanup_consent.si(transcript_id=transcript_id)
|
|
||||||
)
|
|
||||||
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
|
|
||||||
chain_final_summaries = task_pipeline_final_summaries.si(
|
|
||||||
transcript_id=transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
chain = chord(
|
await HatchetClientManager.start_workflow(
|
||||||
group(chain_mp3_and_diarize, chain_title_preview),
|
"LivePostProcessingPipeline",
|
||||||
chain_final_summaries,
|
{
|
||||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
"transcript_id": str(transcript_id),
|
||||||
|
"room_id": str(room_id) if room_id else None,
|
||||||
return chain.delay()
|
},
|
||||||
|
additional_metadata={"transcript_id": str(transcript_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
@get_transcript
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Literal, Union, assert_never
|
from typing import Literal, Union, assert_never
|
||||||
|
|
||||||
import celery
|
import celery
|
||||||
from celery.result import AsyncResult
|
|
||||||
from hatchet_sdk.clients.rest.exceptions import ApiException, NotFoundException
|
from hatchet_sdk.clients.rest.exceptions import ApiException, NotFoundException
|
||||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
@@ -18,7 +17,6 @@ from reflector.db.recordings import recordings_controller
|
|||||||
from reflector.db.transcripts import Transcript, transcripts_controller
|
from reflector.db.transcripts import Transcript, transcripts_controller
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
|
||||||
from reflector.utils.string import NonEmptyString
|
from reflector.utils.string import NonEmptyString
|
||||||
|
|
||||||
|
|
||||||
@@ -105,11 +103,8 @@ async def validate_transcript_for_processing(
|
|||||||
):
|
):
|
||||||
return ValidationNotReady(detail="Recording is not ready for processing")
|
return ValidationNotReady(detail="Recording is not ready for processing")
|
||||||
|
|
||||||
# Check Celery tasks
|
# Check Celery tasks (multitrack still uses Celery for some paths)
|
||||||
if task_is_scheduled_or_active(
|
if task_is_scheduled_or_active(
|
||||||
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
|
||||||
transcript_id=transcript.id,
|
|
||||||
) or task_is_scheduled_or_active(
|
|
||||||
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
|
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
|
||||||
transcript_id=transcript.id,
|
transcript_id=transcript.id,
|
||||||
):
|
):
|
||||||
@@ -175,11 +170,8 @@ async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResu
|
|||||||
|
|
||||||
async def dispatch_transcript_processing(
|
async def dispatch_transcript_processing(
|
||||||
config: ProcessingConfig, force: bool = False
|
config: ProcessingConfig, force: bool = False
|
||||||
) -> AsyncResult | None:
|
) -> None:
|
||||||
"""Dispatch transcript processing to appropriate backend (Hatchet or Celery).
|
"""Dispatch transcript processing to Hatchet workflow engine."""
|
||||||
|
|
||||||
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
|
|
||||||
"""
|
|
||||||
if isinstance(config, MultitrackProcessingConfig):
|
if isinstance(config, MultitrackProcessingConfig):
|
||||||
# Multitrack processing always uses Hatchet (no Celery fallback)
|
# Multitrack processing always uses Hatchet (no Celery fallback)
|
||||||
# First check if we can replay (outside transaction since it's read-only)
|
# First check if we can replay (outside transaction since it's read-only)
|
||||||
@@ -275,7 +267,21 @@ async def dispatch_transcript_processing(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
elif isinstance(config, FileProcessingConfig):
|
elif isinstance(config, FileProcessingConfig):
|
||||||
return task_pipeline_file_process.delay(transcript_id=config.transcript_id)
|
# File processing uses Hatchet workflow
|
||||||
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
|
workflow_name="FilePipeline",
|
||||||
|
input_data={"transcript_id": config.transcript_id},
|
||||||
|
additional_metadata={"transcript_id": config.transcript_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"workflow_run_id": workflow_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("File pipeline dispatched via Hatchet", workflow_id=workflow_id)
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
assert_never(config)
|
assert_never(config)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Tuple
|
from typing import Any, Dict, List, Literal, Tuple
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urlparse
|
||||||
@@ -15,10 +14,8 @@ from urllib.parse import unquote, urlparse
|
|||||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||||
|
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.main_file_pipeline import (
|
|
||||||
task_pipeline_file_process as task_pipeline_file_process,
|
|
||||||
)
|
|
||||||
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||||
from reflector.pipelines.main_live_pipeline import (
|
from reflector.pipelines.main_live_pipeline import (
|
||||||
pipeline_process as live_pipeline_process,
|
pipeline_process as live_pipeline_process,
|
||||||
@@ -237,29 +234,22 @@ async def process_live_pipeline(
|
|||||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||||
assert pre_final_transcript.status != "ended"
|
assert pre_final_transcript.status != "ended"
|
||||||
|
|
||||||
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
# Trigger post-processing via Hatchet (fire-and-forget)
|
||||||
result = live_pipeline_post(transcript_id=transcript_id)
|
await live_pipeline_post(transcript_id=transcript_id)
|
||||||
|
print("Live post-processing pipeline triggered via Hatchet", file=sys.stderr)
|
||||||
# result.ready() blocks even without await; it mutates result also
|
|
||||||
while not result.ready():
|
|
||||||
print(f"Status: {result.state}")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_file_pipeline(
|
async def process_file_pipeline(
|
||||||
transcript_id: TranscriptId,
|
transcript_id: TranscriptId,
|
||||||
):
|
):
|
||||||
"""Process audio/video file using the optimized file pipeline"""
|
"""Process audio/video file using the optimized file pipeline via Hatchet"""
|
||||||
|
|
||||||
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
await HatchetClientManager.start_workflow(
|
||||||
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
"FilePipeline",
|
||||||
|
{"transcript_id": str(transcript_id)},
|
||||||
# Wait for the Celery task to complete
|
additional_metadata={"transcript_id": str(transcript_id)},
|
||||||
while not result.ready():
|
)
|
||||||
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
print("File pipeline triggered via Hatchet", file=sys.stderr)
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
logger.info("File pipeline processing complete")
|
|
||||||
|
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
@@ -293,6 +283,15 @@ async def process(
|
|||||||
|
|
||||||
await handler(transcript_id)
|
await handler(transcript_id)
|
||||||
|
|
||||||
|
if pipeline == "file":
|
||||||
|
# File pipeline is async via Hatchet — results not available immediately.
|
||||||
|
# Use reflector.tools.process_transcript with --sync for polling.
|
||||||
|
print(
|
||||||
|
f"File pipeline dispatched for transcript {transcript_id}. "
|
||||||
|
f"Results will be available once the Hatchet workflow completes.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
await extract_result_from_entry(transcript_id, output_path)
|
await extract_result_from_entry(transcript_id, output_path)
|
||||||
finally:
|
finally:
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
|
|||||||
@@ -11,10 +11,8 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from celery.result import AsyncResult
|
|
||||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
|
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
|
||||||
@@ -39,7 +37,7 @@ async def process_transcript_inner(
|
|||||||
on_validation: Callable[[ValidationResult], None],
|
on_validation: Callable[[ValidationResult], None],
|
||||||
on_preprocess: Callable[[PrepareResult], None],
|
on_preprocess: Callable[[PrepareResult], None],
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> AsyncResult | None:
|
) -> None:
|
||||||
validation = await validate_transcript_for_processing(transcript)
|
validation = await validate_transcript_for_processing(transcript)
|
||||||
on_validation(validation)
|
on_validation(validation)
|
||||||
config = await prepare_transcript_processing(validation)
|
config = await prepare_transcript_processing(validation)
|
||||||
@@ -87,15 +85,13 @@ async def process_transcript(
|
|||||||
elif isinstance(config, FileProcessingConfig):
|
elif isinstance(config, FileProcessingConfig):
|
||||||
print(f"Dispatching file pipeline", file=sys.stderr)
|
print(f"Dispatching file pipeline", file=sys.stderr)
|
||||||
|
|
||||||
result = await process_transcript_inner(
|
await process_transcript_inner(
|
||||||
transcript,
|
transcript,
|
||||||
on_validation=on_validation,
|
on_validation=on_validation,
|
||||||
on_preprocess=on_preprocess,
|
on_preprocess=on_preprocess,
|
||||||
force=force,
|
force=force,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is None:
|
|
||||||
# Hatchet workflow dispatched
|
|
||||||
if sync:
|
if sync:
|
||||||
# Re-fetch transcript to get workflow_run_id
|
# Re-fetch transcript to get workflow_run_id
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
@@ -123,21 +119,6 @@ async def process_transcript(
|
|||||||
"Task dispatched (use --sync to wait for completion)",
|
"Task dispatched (use --sync to wait for completion)",
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
elif sync:
|
|
||||||
print("Waiting for task completion...", file=sys.stderr)
|
|
||||||
while not result.ready():
|
|
||||||
print(f" Status: {result.state}", file=sys.stderr)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
if result.successful():
|
|
||||||
print("Task completed successfully", file=sys.stderr)
|
|
||||||
else:
|
|
||||||
print(f"Task failed: {result.result}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"Task dispatched (use --sync to wait for completion)", file=sys.stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
|
|||||||
@@ -52,8 +52,5 @@ async def transcript_process(
|
|||||||
if isinstance(config, ProcessError):
|
if isinstance(config, ProcessError):
|
||||||
raise HTTPException(status_code=500, detail=config.detail)
|
raise HTTPException(status_code=500, detail=config.detail)
|
||||||
else:
|
else:
|
||||||
# When transcript is in error state, force a new workflow instead of replaying
|
await dispatch_transcript_processing(config, force=True)
|
||||||
# (replay would re-run from failure point with same conditions and likely fail again)
|
|
||||||
force = transcript.status == "error"
|
|
||||||
await dispatch_transcript_processing(config, force=force)
|
|
||||||
return ProcessStatus(status="ok")
|
return ProcessStatus(status="ok")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -95,7 +95,14 @@ async def transcript_record_upload(
|
|||||||
transcript, {"status": "uploaded", "source_kind": SourceKind.FILE}
|
transcript, {"status": "uploaded", "source_kind": SourceKind.FILE}
|
||||||
)
|
)
|
||||||
|
|
||||||
# launch a background task to process the file
|
# launch Hatchet workflow to process the file
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
|
"FilePipeline",
|
||||||
|
{"transcript_id": str(transcript_id)},
|
||||||
|
additional_metadata={"transcript_id": str(transcript_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save workflow_run_id for duplicate detection and status polling
|
||||||
|
await transcripts_controller.update(transcript, {"workflow_run_id": workflow_id})
|
||||||
|
|
||||||
return UploadStatus(status="ok")
|
return UploadStatus(status="ok")
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from reflector.db.transcripts import (
|
|||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
from reflector.pipelines.topic_processing import EmptyPipeline
|
from reflector.pipelines.topic_processing import EmptyPipeline
|
||||||
from reflector.processors import AudioFileWriterProcessor
|
from reflector.processors import AudioFileWriterProcessor
|
||||||
@@ -163,7 +162,14 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
await HatchetClientManager.start_workflow(
|
||||||
|
"FilePipeline",
|
||||||
|
{
|
||||||
|
"transcript_id": str(transcript.id),
|
||||||
|
"room_id": str(room.id) if room else None,
|
||||||
|
},
|
||||||
|
additional_metadata={"transcript_id": str(transcript.id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -538,18 +538,59 @@ def fake_mp3_upload():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_hatchet_client():
|
def mock_hatchet_client():
|
||||||
"""Reset HatchetClientManager singleton before and after each test.
|
"""Mock HatchetClientManager for all tests.
|
||||||
|
|
||||||
This ensures test isolation - each test starts with a fresh client state.
|
Prevents tests from connecting to a real Hatchet server. The dummy token
|
||||||
The fixture is autouse=True so it applies to all tests automatically.
|
in [tool.pytest_env] prevents the import-time ValueError, but the SDK
|
||||||
|
would still try to connect when get_client() is called. This fixture
|
||||||
|
mocks get_client to return a MagicMock and start_workflow to return a
|
||||||
|
dummy workflow ID.
|
||||||
"""
|
"""
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
# Reset before test
|
|
||||||
HatchetClientManager.reset()
|
HatchetClientManager.reset()
|
||||||
yield
|
|
||||||
# Reset after test to clean up
|
mock_client = MagicMock()
|
||||||
|
mock_client.workflow.return_value = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"start_workflow",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="mock-workflow-id",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"get_workflow_run_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"can_replay",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=False,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"cancel_workflow",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
HatchetClientManager,
|
||||||
|
"replay_workflow",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
HatchetClientManager.reset()
|
HatchetClientManager.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,18 +37,3 @@ async def test_hatchet_client_can_replay_handles_exception():
|
|||||||
|
|
||||||
# Should return False on error (workflow might be gone)
|
# Should return False on error (workflow might be gone)
|
||||||
assert can_replay is False
|
assert can_replay is False
|
||||||
|
|
||||||
|
|
||||||
def test_hatchet_client_raises_without_token():
|
|
||||||
"""Test that get_client raises ValueError without token.
|
|
||||||
|
|
||||||
Useful: Catches if someone removes the token validation,
|
|
||||||
which would cause cryptic errors later.
|
|
||||||
"""
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
|
||||||
mock_settings.HATCHET_CLIENT_TOKEN = None
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
|
||||||
HatchetClientManager.get_client()
|
|
||||||
|
|||||||
233
server/tests/test_hatchet_file_pipeline.py
Normal file
233
server/tests/test_hatchet_file_pipeline.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""
|
||||||
|
Tests for the FilePipeline Hatchet workflow.
|
||||||
|
|
||||||
|
Tests verify:
|
||||||
|
1. with_error_handling behavior for file pipeline input model
|
||||||
|
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
||||||
|
3. Input model validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from hatchet_sdk import NonRetryableException
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _noop_db_context():
|
||||||
|
"""Async context manager that yields without touching the DB."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def file_pipeline_module():
|
||||||
|
"""Import file_pipeline with Hatchet client mocked."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.workflow.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
):
|
||||||
|
from reflector.hatchet.workflows import file_pipeline
|
||||||
|
|
||||||
|
return file_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_file_input():
|
||||||
|
"""Minimal FilePipelineInput for tests."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||||
|
|
||||||
|
return FilePipelineInput(
|
||||||
|
transcript_id="ts-file-123",
|
||||||
|
room_id="room-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ctx():
|
||||||
|
"""Minimal Context-like object."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.log = MagicMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_pipeline_input_model():
|
||||||
|
"""Test FilePipelineInput validation."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||||
|
|
||||||
|
# Valid input with room_id
|
||||||
|
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
|
||||||
|
assert input_with_room.transcript_id == "ts-123"
|
||||||
|
assert input_with_room.room_id == "room-456"
|
||||||
|
|
||||||
|
# Valid input without room_id
|
||||||
|
input_no_room = FilePipelineInput(transcript_id="ts-123")
|
||||||
|
assert input_no_room.room_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_pipeline_error_handling_transient(
|
||||||
|
file_pipeline_module, mock_file_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Transient exception must NOT set error status."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise httpx.TimeoutException("timed out")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(httpx.TimeoutException):
|
||||||
|
await wrapped(mock_file_input, mock_ctx)
|
||||||
|
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_pipeline_error_handling_hard_fail(
|
||||||
|
file_pipeline_module, mock_file_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise ValueError("No audio file found")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(NonRetryableException) as exc_info:
|
||||||
|
await wrapped(mock_file_input, mock_ctx)
|
||||||
|
|
||||||
|
assert "No audio file found" in str(exc_info.value)
|
||||||
|
mock_set_error.assert_called_once_with("ts-file-123")
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarize_result_uses_plain_dicts():
|
||||||
|
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
|
||||||
|
|
||||||
|
The diarize task must serialize segments as plain dicts (not call .model_dump()),
|
||||||
|
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
|
||||||
|
This was a real bug: 'dict' object has no attribute 'model_dump'.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||||
|
from reflector.processors.types import DiarizationSegment
|
||||||
|
|
||||||
|
# DiarizationSegment is a TypedDict — instances are plain dicts
|
||||||
|
segments = [
|
||||||
|
DiarizationSegment(start=0.0, end=1.5, speaker=0),
|
||||||
|
DiarizationSegment(start=1.5, end=3.0, speaker=1),
|
||||||
|
]
|
||||||
|
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
|
||||||
|
|
||||||
|
# DiarizeResult should accept list[dict] directly (no model_dump needed)
|
||||||
|
result = DiarizeResult(diarization=segments)
|
||||||
|
assert result.diarization is not None
|
||||||
|
assert len(result.diarization) == 2
|
||||||
|
|
||||||
|
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
|
||||||
|
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
|
||||||
|
assert reconstructed[0]["start"] == 0.0
|
||||||
|
assert reconstructed[0]["speaker"] == 0
|
||||||
|
assert reconstructed[1]["end"] == 3.0
|
||||||
|
assert reconstructed[1]["speaker"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarize_result_handles_none():
|
||||||
|
"""DiarizeResult with no diarization data (diarization disabled)."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||||
|
|
||||||
|
result = DiarizeResult(diarization=None)
|
||||||
|
assert result.diarization is None
|
||||||
|
|
||||||
|
result_default = DiarizeResult()
|
||||||
|
assert result_default.diarization is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_result_words_are_pydantic():
|
||||||
|
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
|
||||||
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
|
words = [
|
||||||
|
Word(text="hello", start=0.0, end=0.5),
|
||||||
|
Word(text="world", start=0.5, end=1.0),
|
||||||
|
]
|
||||||
|
# Words are Pydantic models, so model_dump() works
|
||||||
|
word_dicts = [w.model_dump() for w in words]
|
||||||
|
result = TranscribeResult(words=word_dicts)
|
||||||
|
|
||||||
|
# Consumer reconstructs via Word(**w)
|
||||||
|
reconstructed = [Word(**w) for w in result.words]
|
||||||
|
assert reconstructed[0].text == "hello"
|
||||||
|
assert reconstructed[1].start == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_pipeline_on_failure_sets_error_status(
|
||||||
|
file_pipeline_module, mock_file_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""on_workflow_failure sets error status when transcript is processing."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||||
|
|
||||||
|
transcript_processing = MagicMock()
|
||||||
|
transcript_processing.status = "processing"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_processing,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||||
|
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
|
||||||
|
file_pipeline_module, mock_file_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
||||||
|
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||||
|
|
||||||
|
transcript_ended = MagicMock()
|
||||||
|
transcript_ended.status = "ended"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_ended,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
218
server/tests/test_hatchet_live_post_pipeline.py
Normal file
218
server/tests/test_hatchet_live_post_pipeline.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""
|
||||||
|
Tests for the LivePostProcessingPipeline Hatchet workflow.
|
||||||
|
|
||||||
|
Tests verify:
|
||||||
|
1. with_error_handling behavior for live post pipeline input model
|
||||||
|
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
||||||
|
3. Input model validation
|
||||||
|
4. pipeline_post() now triggers Hatchet instead of Celery chord
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from hatchet_sdk import NonRetryableException
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _noop_db_context():
|
||||||
|
"""Async context manager that yields without touching the DB."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def live_pipeline_module():
|
||||||
|
"""Import live_post_pipeline with Hatchet client mocked."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.workflow.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
):
|
||||||
|
from reflector.hatchet.workflows import live_post_pipeline
|
||||||
|
|
||||||
|
return live_post_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_live_input():
|
||||||
|
"""Minimal LivePostPipelineInput for tests."""
|
||||||
|
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
|
||||||
|
|
||||||
|
return LivePostPipelineInput(
|
||||||
|
transcript_id="ts-live-789",
|
||||||
|
room_id="room-abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ctx():
|
||||||
|
"""Minimal Context-like object."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.log = MagicMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
def test_live_post_pipeline_input_model():
|
||||||
|
"""Test LivePostPipelineInput validation."""
|
||||||
|
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
|
||||||
|
|
||||||
|
# Valid input with room_id
|
||||||
|
input_with_room = LivePostPipelineInput(transcript_id="ts-123", room_id="room-456")
|
||||||
|
assert input_with_room.transcript_id == "ts-123"
|
||||||
|
assert input_with_room.room_id == "room-456"
|
||||||
|
|
||||||
|
# Valid input without room_id
|
||||||
|
input_no_room = LivePostPipelineInput(transcript_id="ts-123")
|
||||||
|
assert input_no_room.room_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_pipeline_error_handling_transient(
|
||||||
|
live_pipeline_module, mock_live_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Transient exception must NOT set error status."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise httpx.TimeoutException("timed out")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(httpx.TimeoutException):
|
||||||
|
await wrapped(mock_live_input, mock_ctx)
|
||||||
|
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_pipeline_error_handling_hard_fail(
|
||||||
|
live_pipeline_module, mock_live_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Hard-fail must set error status and raise NonRetryableException."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise ValueError("Transcript not found")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(NonRetryableException) as exc_info:
|
||||||
|
await wrapped(mock_live_input, mock_ctx)
|
||||||
|
|
||||||
|
assert "Transcript not found" in str(exc_info.value)
|
||||||
|
mock_set_error.assert_called_once_with("ts-live-789")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_pipeline_on_failure_sets_error_status(
|
||||||
|
live_pipeline_module, mock_live_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""on_workflow_failure sets error status when transcript is processing."""
|
||||||
|
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
|
||||||
|
|
||||||
|
transcript_processing = MagicMock()
|
||||||
|
transcript_processing.status = "processing"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_processing,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_live_input, mock_ctx)
|
||||||
|
mock_set_error.assert_called_once_with(mock_live_input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_pipeline_on_failure_does_not_overwrite_ended(
|
||||||
|
live_pipeline_module, mock_live_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
||||||
|
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
|
||||||
|
|
||||||
|
transcript_ended = MagicMock()
|
||||||
|
transcript_ended.status = "ended"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_ended,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_live_input, mock_ctx)
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_post_triggers_hatchet():
|
||||||
|
"""pipeline_post() should trigger Hatchet LivePostProcessingPipeline workflow."""
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.start_workflow",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="workflow-run-id",
|
||||||
|
) as mock_start:
|
||||||
|
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||||
|
|
||||||
|
await pipeline_post(transcript_id="ts-test-123", room_id="room-test")
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
"LivePostProcessingPipeline",
|
||||||
|
{
|
||||||
|
"transcript_id": "ts-test-123",
|
||||||
|
"room_id": "room-test",
|
||||||
|
},
|
||||||
|
additional_metadata={"transcript_id": "ts-test-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_post_triggers_hatchet_without_room_id():
|
||||||
|
"""pipeline_post() should handle None room_id."""
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.start_workflow",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="workflow-run-id",
|
||||||
|
) as mock_start:
|
||||||
|
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||||
|
|
||||||
|
await pipeline_post(transcript_id="ts-test-456")
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
"LivePostProcessingPipeline",
|
||||||
|
{
|
||||||
|
"transcript_id": "ts-test-456",
|
||||||
|
"room_id": None,
|
||||||
|
},
|
||||||
|
additional_metadata={"transcript_id": "ts-test-456"},
|
||||||
|
)
|
||||||
90
server/tests/test_hatchet_trigger_migration.py
Normal file
90
server/tests/test_hatchet_trigger_migration.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
Tests verifying Celery-to-Hatchet trigger migration.
|
||||||
|
|
||||||
|
Ensures that:
|
||||||
|
1. process_recording triggers FilePipeline via Hatchet (not Celery)
|
||||||
|
2. transcript_record_upload triggers FilePipeline via Hatchet (not Celery)
|
||||||
|
3. Old Celery task references are no longer in active call sites
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_recording_does_not_import_celery_file_task():
|
||||||
|
"""Verify process.py no longer imports task_pipeline_file_process."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from reflector.worker import process
|
||||||
|
|
||||||
|
source = inspect.getsource(process)
|
||||||
|
# Should not contain the old Celery task import
|
||||||
|
assert "task_pipeline_file_process" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcripts_upload_does_not_import_celery_file_task():
|
||||||
|
"""Verify transcripts_upload.py no longer imports task_pipeline_file_process."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from reflector.views import transcripts_upload
|
||||||
|
|
||||||
|
source = inspect.getsource(transcripts_upload)
|
||||||
|
# Should not contain the old Celery task import
|
||||||
|
assert "task_pipeline_file_process" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcripts_upload_imports_hatchet():
|
||||||
|
"""Verify transcripts_upload.py imports HatchetClientManager."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from reflector.views import transcripts_upload
|
||||||
|
|
||||||
|
source = inspect.getsource(transcripts_upload)
|
||||||
|
assert "HatchetClientManager" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_post_is_async():
|
||||||
|
"""Verify pipeline_post is now async (Hatchet trigger)."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||||
|
|
||||||
|
assert asyncio.iscoroutinefunction(pipeline_post)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcript_process_service_does_not_import_celery_file_task():
|
||||||
|
"""Verify transcript_process.py service no longer imports task_pipeline_file_process."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from reflector.services import transcript_process
|
||||||
|
|
||||||
|
source = inspect.getsource(transcript_process)
|
||||||
|
assert "task_pipeline_file_process" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcript_process_service_dispatch_uses_hatchet():
|
||||||
|
"""Verify dispatch_transcript_processing uses HatchetClientManager for file processing."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from reflector.services import transcript_process
|
||||||
|
|
||||||
|
source = inspect.getsource(transcript_process.dispatch_transcript_processing)
|
||||||
|
assert "HatchetClientManager" in source
|
||||||
|
assert "FilePipeline" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_task_names_exist():
|
||||||
|
"""Verify new TaskName constants were added for file and live pipelines."""
|
||||||
|
from reflector.hatchet.constants import TaskName
|
||||||
|
|
||||||
|
# File pipeline tasks
|
||||||
|
assert TaskName.EXTRACT_AUDIO == "extract_audio"
|
||||||
|
assert TaskName.UPLOAD_AUDIO == "upload_audio"
|
||||||
|
assert TaskName.TRANSCRIBE == "transcribe"
|
||||||
|
assert TaskName.DIARIZE == "diarize"
|
||||||
|
assert TaskName.ASSEMBLE_TRANSCRIPT == "assemble_transcript"
|
||||||
|
assert TaskName.GENERATE_SUMMARIES == "generate_summaries"
|
||||||
|
|
||||||
|
# Live post-processing pipeline tasks
|
||||||
|
assert TaskName.WAVEFORM == "waveform"
|
||||||
|
assert TaskName.CONVERT_MP3 == "convert_mp3"
|
||||||
|
assert TaskName.UPLOAD_MP3 == "upload_mp3"
|
||||||
|
assert TaskName.REMOVE_UPLOAD == "remove_upload"
|
||||||
|
assert TaskName.FINAL_SUMMARIES == "final_summaries"
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -27,8 +25,6 @@ async def client(app_lifespan):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_process(
|
async def test_transcript_process(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -39,8 +35,13 @@ async def test_transcript_process(
|
|||||||
dummy_storage,
|
dummy_storage,
|
||||||
client,
|
client,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
mock_hatchet_client,
|
||||||
):
|
):
|
||||||
# public mode: this test uses an anonymous client; allow anonymous transcript creation
|
"""Test upload + process dispatch via Hatchet.
|
||||||
|
|
||||||
|
The file pipeline is now dispatched to Hatchet (fire-and-forget),
|
||||||
|
so we verify the workflow was triggered rather than polling for completion.
|
||||||
|
"""
|
||||||
monkeypatch.setattr(settings, "PUBLIC_MODE", True)
|
monkeypatch.setattr(settings, "PUBLIC_MODE", True)
|
||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
@@ -63,51 +64,43 @@ async def test_transcript_process(
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
# wait for processing to finish (max 1 minute)
|
# Verify Hatchet workflow was dispatched (from upload endpoint)
|
||||||
timeout_seconds = 60
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
start_time = time.monotonic()
|
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
HatchetClientManager.start_workflow.assert_called_once_with(
|
||||||
# fetch the transcript and check if it is ended
|
"FilePipeline",
|
||||||
|
{"transcript_id": tid},
|
||||||
|
additional_metadata={"transcript_id": tid},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify transcript status was set to "uploaded"
|
||||||
resp = await client.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
assert resp.json()["status"] == "uploaded"
|
||||||
break
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
else:
|
|
||||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
|
||||||
|
|
||||||
# restart the processing
|
# Reset mock for reprocess test
|
||||||
response = await client.post(
|
HatchetClientManager.start_workflow.reset_mock()
|
||||||
f"/transcripts/{tid}/process",
|
|
||||||
)
|
# Clear workflow_run_id so /process endpoint can dispatch again
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(tid)
|
||||||
|
await transcripts_controller.update(transcript, {"workflow_run_id": None})
|
||||||
|
|
||||||
|
# Reprocess via /process endpoint
|
||||||
|
with patch(
|
||||||
|
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
response = await client.post(f"/transcripts/{tid}/process")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
# wait for processing to finish (max 1 minute)
|
# Verify second Hatchet dispatch (from /process endpoint)
|
||||||
timeout_seconds = 60
|
HatchetClientManager.start_workflow.assert_called_once()
|
||||||
start_time = time.monotonic()
|
call_kwargs = HatchetClientManager.start_workflow.call_args.kwargs
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
assert call_kwargs["workflow_name"] == "FilePipeline"
|
||||||
# fetch the transcript and check if it is ended
|
assert call_kwargs["input_data"]["transcript_id"] == tid
|
||||||
resp = await client.get(f"/transcripts/{tid}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
|
||||||
break
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
else:
|
|
||||||
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
|
|
||||||
|
|
||||||
# check the transcript is ended
|
|
||||||
transcript = resp.json()
|
|
||||||
assert transcript["status"] == "ended"
|
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
|
||||||
assert transcript["title"] == "Llm Title"
|
|
||||||
|
|
||||||
# check topics and transcript
|
|
||||||
response = await client.get(f"/transcripts/{tid}/topics")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()) == 1
|
|
||||||
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@@ -150,20 +143,25 @@ async def test_whereby_recording_uses_file_pipeline(monkeypatch, client):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"reflector.services.transcript_process.task_pipeline_file_process"
|
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||||
) as mock_file_pipeline,
|
return_value=False,
|
||||||
|
),
|
||||||
patch(
|
patch(
|
||||||
"reflector.services.transcript_process.HatchetClientManager"
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
) as mock_hatchet,
|
) as mock_hatchet,
|
||||||
):
|
):
|
||||||
|
mock_hatchet.start_workflow = AsyncMock(return_value="test-workflow-id")
|
||||||
|
|
||||||
response = await client.post(f"/transcripts/{transcript.id}/process")
|
response = await client.post(f"/transcripts/{transcript.id}/process")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
# Whereby recordings should use file pipeline, not Hatchet
|
# Whereby recordings should use Hatchet FilePipeline
|
||||||
mock_file_pipeline.delay.assert_called_once_with(transcript_id=transcript.id)
|
mock_hatchet.start_workflow.assert_called_once()
|
||||||
mock_hatchet.start_workflow.assert_not_called()
|
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
|
||||||
|
assert call_kwargs["workflow_name"] == "FilePipeline"
|
||||||
|
assert call_kwargs["input_data"]["transcript_id"] == transcript.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@@ -224,8 +222,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"reflector.services.transcript_process.task_pipeline_file_process"
|
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||||
) as mock_file_pipeline,
|
return_value=False,
|
||||||
|
),
|
||||||
patch(
|
patch(
|
||||||
"reflector.services.transcript_process.HatchetClientManager"
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
) as mock_hatchet,
|
) as mock_hatchet,
|
||||||
@@ -237,7 +236,7 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
# Daily.co multitrack recordings should use Hatchet workflow
|
# Daily.co multitrack recordings should use Hatchet DiarizationPipeline
|
||||||
mock_hatchet.start_workflow.assert_called_once()
|
mock_hatchet.start_workflow.assert_called_once()
|
||||||
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
|
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
|
||||||
assert call_kwargs["workflow_name"] == "DiarizationPipeline"
|
assert call_kwargs["workflow_name"] == "DiarizationPipeline"
|
||||||
@@ -246,7 +245,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
|||||||
assert call_kwargs["input_data"]["tracks"] == [
|
assert call_kwargs["input_data"]["tracks"] == [
|
||||||
{"s3_key": k} for k in track_keys
|
{"s3_key": k} for k in track_keys
|
||||||
]
|
]
|
||||||
mock_file_pipeline.delay.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
|
|||||||
@@ -2,6 +2,10 @@
|
|||||||
# FIXME test status of transcript
|
# FIXME test status of transcript
|
||||||
# FIXME test websocket connection after RTC is finished still send the full events
|
# FIXME test websocket connection after RTC is finished still send the full events
|
||||||
# FIXME try with locked session, RTC should not work
|
# FIXME try with locked session, RTC should not work
|
||||||
|
# TODO: add integration tests for post-processing (LivePostPipeline) with a real
|
||||||
|
# Hatchet instance. These tests currently only cover the live pipeline.
|
||||||
|
# Post-processing events (WAVEFORM, FINAL_*, DURATION, STATUS=ended, mp3)
|
||||||
|
# are now dispatched via Hatchet and tested in test_hatchet_live_post_pipeline.py.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
@@ -49,7 +53,7 @@ class ThreadedUvicorn:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
def appserver(tmpdir, setup_database):
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
@@ -119,8 +123,6 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_rtc_and_websocket(
|
async def test_transcript_rtc_and_websocket(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -134,6 +136,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
appserver,
|
appserver,
|
||||||
client,
|
client,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
mock_hatchet_client,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
# because of that, we need to start the server in a thread
|
# because of that, we need to start the server in a thread
|
||||||
@@ -208,35 +211,30 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
await stream_client.stop()
|
await stream_client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# Wait for live pipeline to flush (it dispatches post-processing to Hatchet)
|
||||||
timeout = 120
|
timeout = 30
|
||||||
while True:
|
while True:
|
||||||
# fetch the transcript and check if it is ended
|
|
||||||
resp = await client.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
if resp.json()["status"] in ("processing", "ended", "error"):
|
||||||
break
|
break
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
timeout -= 1
|
timeout -= 1
|
||||||
if timeout < 0:
|
if timeout < 0:
|
||||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
raise TimeoutError("Timeout waiting for live pipeline to finish")
|
||||||
|
|
||||||
if resp.json()["status"] != "ended":
|
|
||||||
raise TimeoutError("Transcript processing failed")
|
|
||||||
|
|
||||||
# stop websocket task
|
# stop websocket task
|
||||||
websocket_task.cancel()
|
websocket_task.cancel()
|
||||||
|
|
||||||
# check events
|
# check live pipeline events
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
pprint(events)
|
pprint(events)
|
||||||
|
|
||||||
# get events list
|
|
||||||
eventnames = [e["event"] for e in events]
|
eventnames = [e["event"] for e in events]
|
||||||
|
|
||||||
# check events
|
# Live pipeline produces TRANSCRIPT and TOPIC events during RTC
|
||||||
assert "TRANSCRIPT" in eventnames
|
assert "TRANSCRIPT" in eventnames
|
||||||
ev = events[eventnames.index("TRANSCRIPT")]
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
assert ev["data"]["text"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
@@ -249,50 +247,18 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert ev["data"]["transcript"].startswith("Hello world.")
|
assert ev["data"]["transcript"].startswith("Hello world.")
|
||||||
assert ev["data"]["timestamp"] == 0.0
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
# Live pipeline status progression
|
||||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
|
||||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
|
||||||
|
|
||||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
|
||||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
|
||||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
|
||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
|
||||||
assert ev["data"]["title"] == "Llm Title"
|
|
||||||
|
|
||||||
assert "WAVEFORM" in eventnames
|
|
||||||
ev = events[eventnames.index("WAVEFORM")]
|
|
||||||
assert isinstance(ev["data"]["waveform"], list)
|
|
||||||
assert len(ev["data"]["waveform"]) >= 250
|
|
||||||
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
|
|
||||||
assert waveform_resp.status_code == 200
|
|
||||||
assert waveform_resp.headers["content-type"] == "application/json"
|
|
||||||
assert isinstance(waveform_resp.json()["data"], list)
|
|
||||||
assert len(waveform_resp.json()["data"]) >= 250
|
|
||||||
|
|
||||||
# check status order
|
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
assert "recording" in statuses
|
||||||
|
assert "processing" in statuses
|
||||||
assert statuses.index("recording") < statuses.index("processing")
|
assert statuses.index("recording") < statuses.index("processing")
|
||||||
assert statuses.index("processing") < statuses.index("ended")
|
|
||||||
|
|
||||||
# ensure the last event received is ended
|
# Post-processing (WAVEFORM, FINAL_*, DURATION, mp3, STATUS=ended) is now
|
||||||
assert events[-1]["event"] == "STATUS"
|
# dispatched to Hatchet via LivePostPipeline — not tested here.
|
||||||
assert events[-1]["data"]["value"] == "ended"
|
# See test_hatchet_live_post_pipeline.py for post-processing tests.
|
||||||
|
|
||||||
# check on the latest response that the audio duration is > 0
|
|
||||||
assert resp.json()["duration"] > 0
|
|
||||||
assert "DURATION" in eventnames
|
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
|
||||||
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
|
|
||||||
assert audio_resp.status_code == 200
|
|
||||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_rtc_and_websocket_and_fr(
|
async def test_transcript_rtc_and_websocket_and_fr(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -306,6 +272,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
appserver,
|
appserver,
|
||||||
client,
|
client,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
mock_hatchet_client,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
# because of that, we need to start the server in a thread
|
# because of that, we need to start the server in a thread
|
||||||
@@ -382,42 +349,34 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
# instead of waiting a long time, we just send a STOP
|
# instead of waiting a long time, we just send a STOP
|
||||||
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
|
|
||||||
# wait the processing to finish
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
await stream_client.stop()
|
await stream_client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# Wait for live pipeline to flush
|
||||||
timeout = 120
|
timeout = 30
|
||||||
while True:
|
while True:
|
||||||
# fetch the transcript and check if it is ended
|
|
||||||
resp = await client.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] == "ended":
|
if resp.json()["status"] in ("processing", "ended", "error"):
|
||||||
break
|
break
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
timeout -= 1
|
timeout -= 1
|
||||||
if timeout < 0:
|
if timeout < 0:
|
||||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
raise TimeoutError("Timeout waiting for live pipeline to finish")
|
||||||
|
|
||||||
if resp.json()["status"] != "ended":
|
|
||||||
raise TimeoutError("Transcript processing failed")
|
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
# stop websocket task
|
# stop websocket task
|
||||||
websocket_task.cancel()
|
websocket_task.cancel()
|
||||||
|
|
||||||
# check events
|
# check live pipeline events
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
pprint(events)
|
pprint(events)
|
||||||
|
|
||||||
# get events list
|
|
||||||
eventnames = [e["event"] for e in events]
|
eventnames = [e["event"] for e in events]
|
||||||
|
|
||||||
# check events
|
# Live pipeline produces TRANSCRIPT with translation
|
||||||
assert "TRANSCRIPT" in eventnames
|
assert "TRANSCRIPT" in eventnames
|
||||||
ev = events[eventnames.index("TRANSCRIPT")]
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
assert ev["data"]["text"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
@@ -430,23 +389,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
assert ev["data"]["transcript"].startswith("Hello world.")
|
assert ev["data"]["transcript"].startswith("Hello world.")
|
||||||
assert ev["data"]["timestamp"] == 0.0
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
# Live pipeline status progression
|
||||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
|
||||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
|
||||||
|
|
||||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
|
||||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
|
||||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
|
||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
|
||||||
assert ev["data"]["title"] == "Llm Title"
|
|
||||||
|
|
||||||
# check status order
|
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
assert "recording" in statuses
|
||||||
|
assert "processing" in statuses
|
||||||
assert statuses.index("recording") < statuses.index("processing")
|
assert statuses.index("recording") < statuses.index("processing")
|
||||||
assert statuses.index("processing") < statuses.index("ended")
|
|
||||||
|
|
||||||
# ensure the last event received is ended
|
# Post-processing (FINAL_*, STATUS=ended) is now dispatched to Hatchet
|
||||||
assert events[-1]["event"] == "STATUS"
|
# via LivePostPipeline — not tested here.
|
||||||
assert events[-1]["data"]["value"] == "ended"
|
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_upload_file(
|
async def test_transcript_upload_file(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -17,6 +12,7 @@ async def test_transcript_upload_file(
|
|||||||
dummy_storage,
|
dummy_storage,
|
||||||
client,
|
client,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
mock_hatchet_client,
|
||||||
):
|
):
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
@@ -43,27 +39,16 @@ async def test_transcript_upload_file(
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
# wait the processing to finish (max 1 minute)
|
# Verify Hatchet workflow was dispatched for file processing
|
||||||
timeout_seconds = 60
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
start_time = time.monotonic()
|
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
HatchetClientManager.start_workflow.assert_called_once_with(
|
||||||
# fetch the transcript and check if it is ended
|
"FilePipeline",
|
||||||
|
{"transcript_id": tid},
|
||||||
|
additional_metadata={"transcript_id": tid},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify transcript status was updated to "uploaded"
|
||||||
resp = await client.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
assert resp.json()["status"] == "uploaded"
|
||||||
break
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
else:
|
|
||||||
return pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
|
|
||||||
|
|
||||||
# check the transcript is ended
|
|
||||||
transcript = resp.json()
|
|
||||||
assert transcript["status"] == "ended"
|
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
|
||||||
assert transcript["title"] == "Llm Title"
|
|
||||||
|
|
||||||
# check topics and transcript
|
|
||||||
response = await client.get(f"/transcripts/{tid}/topics")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()) == 1
|
|
||||||
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user