mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
fix: improve hatchet workflow reliability (#900)
* Increase max connections * Classify hard and transient hatchet errors * Fan out partial success * Force reprocessing of error transcripts * Stop retrying on 402 payment required * Avoid httpx/hatchet timeout race * Add retry wrapper to get_response for for transient errors * Add retry backoff * Return falsy results so get_response won't retry on empty string * Skip error status in on_workflow_failure when transcript already ended * Fix precommit issues * Fail step on first fan-out failure instead of skipping
This commit is contained in:
@@ -137,6 +137,7 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
image: postgres:17-alpine
|
image: postgres:17-alpine
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
command: ["postgres", "-c", "max_connections=200"]
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: reflector
|
POSTGRES_USER: reflector
|
||||||
POSTGRES_PASSWORD: reflector
|
POSTGRES_PASSWORD: reflector
|
||||||
|
|||||||
@@ -39,5 +39,12 @@ TIMEOUT_MEDIUM = (
|
|||||||
300 # Single LLM calls, waveform generation (5m for slow LLM responses)
|
300 # Single LLM calls, waveform generation (5m for slow LLM responses)
|
||||||
)
|
)
|
||||||
TIMEOUT_LONG = 180 # Action items (larger context LLM)
|
TIMEOUT_LONG = 180 # Action items (larger context LLM)
|
||||||
TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown
|
TIMEOUT_TITLE = 300 # generate_title (single LLM call; doc: reduce from 600s)
|
||||||
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks
|
TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown (Hatchet execution_timeout)
|
||||||
|
TIMEOUT_AUDIO_HTTP = (
|
||||||
|
660 # httpx timeout for pad_track — below 720 so Hatchet doesn't race
|
||||||
|
)
|
||||||
|
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks (Hatchet execution_timeout)
|
||||||
|
TIMEOUT_HEAVY_HTTP = (
|
||||||
|
540 # httpx timeout for transcribe_track — below 600 so Hatchet doesn't race
|
||||||
|
)
|
||||||
|
|||||||
74
server/reflector/hatchet/error_classification.py
Normal file
74
server/reflector/hatchet/error_classification.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Classify exceptions as non-retryable for Hatchet workflows.
|
||||||
|
|
||||||
|
When a task raises NonRetryableException (or an exception classified as
|
||||||
|
non-retryable and re-raised as such), Hatchet stops immediately — no further
|
||||||
|
retries. Used by with_error_handling to avoid wasting retries on config errors,
|
||||||
|
auth failures, corrupt data, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Optional dependencies: only classify if the exception type is available.
|
||||||
|
# This avoids hard dependency on openai/av/botocore for code paths that don't use them.
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
except ImportError:
|
||||||
|
openai = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import av
|
||||||
|
except ImportError:
|
||||||
|
av = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from botocore.exceptions import ClientError as BotoClientError
|
||||||
|
except ImportError:
|
||||||
|
BotoClientError = None # type: ignore[misc, assignment]
|
||||||
|
|
||||||
|
from hatchet_sdk import NonRetryableException
|
||||||
|
from httpx import HTTPStatusError
|
||||||
|
|
||||||
|
from reflector.llm import LLMParseError
|
||||||
|
|
||||||
|
# HTTP status codes that won't change on retry (auth, not found, payment, payload)
|
||||||
|
NON_RETRYABLE_HTTP_STATUSES = {401, 402, 403, 404, 413}
|
||||||
|
NON_RETRYABLE_S3_CODES = {"AccessDenied", "NoSuchBucket", "NoSuchKey"}
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_retryable(e: BaseException) -> bool:
|
||||||
|
"""Return True if the exception should stop Hatchet retries immediately.
|
||||||
|
|
||||||
|
Hard failures (config, auth, missing resource, corrupt data) return True.
|
||||||
|
Transient errors (timeouts, 5xx, 429, connection) return False.
|
||||||
|
"""
|
||||||
|
if isinstance(e, NonRetryableException):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Config/input errors
|
||||||
|
if isinstance(e, (ValueError, TypeError)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# HTTP status codes that won't change on retry
|
||||||
|
if isinstance(e, HTTPStatusError):
|
||||||
|
return e.response.status_code in NON_RETRYABLE_HTTP_STATUSES
|
||||||
|
|
||||||
|
# OpenAI auth errors
|
||||||
|
if openai is not None and isinstance(e, openai.AuthenticationError):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# LLM parse failures (already retried internally)
|
||||||
|
if isinstance(e, LLMParseError):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# S3 permission/existence errors
|
||||||
|
if BotoClientError is not None and isinstance(e, BotoClientError):
|
||||||
|
code = e.response.get("Error", {}).get("Code", "")
|
||||||
|
return code in NON_RETRYABLE_S3_CODES
|
||||||
|
|
||||||
|
# Corrupt audio (PyAV) — AVError in some versions; fallback to InvalidDataError
|
||||||
|
if av is not None:
|
||||||
|
av_error = getattr(av, "AVError", None) or getattr(
|
||||||
|
getattr(av, "error", None), "InvalidDataError", None
|
||||||
|
)
|
||||||
|
if av_error is not None and isinstance(e, av_error):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
@@ -27,6 +27,7 @@ from hatchet_sdk import (
|
|||||||
ConcurrencyExpression,
|
ConcurrencyExpression,
|
||||||
ConcurrencyLimitStrategy,
|
ConcurrencyLimitStrategy,
|
||||||
Context,
|
Context,
|
||||||
|
NonRetryableException,
|
||||||
)
|
)
|
||||||
from hatchet_sdk.labels import DesiredWorkerLabel
|
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -43,8 +44,10 @@ from reflector.hatchet.constants import (
|
|||||||
TIMEOUT_LONG,
|
TIMEOUT_LONG,
|
||||||
TIMEOUT_MEDIUM,
|
TIMEOUT_MEDIUM,
|
||||||
TIMEOUT_SHORT,
|
TIMEOUT_SHORT,
|
||||||
|
TIMEOUT_TITLE,
|
||||||
TaskName,
|
TaskName,
|
||||||
)
|
)
|
||||||
|
from reflector.hatchet.error_classification import is_non_retryable
|
||||||
from reflector.hatchet.workflows.models import (
|
from reflector.hatchet.workflows.models import (
|
||||||
ActionItemsResult,
|
ActionItemsResult,
|
||||||
ConsentResult,
|
ConsentResult,
|
||||||
@@ -216,6 +219,13 @@ def make_audio_progress_logger(
|
|||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
def _successful_run_results(
|
||||||
|
results: list[dict[str, Any] | BaseException],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return only successful (non-exception) results from aio_run_many(return_exceptions=True)."""
|
||||||
|
return [r for r in results if not isinstance(r, BaseException)]
|
||||||
|
|
||||||
|
|
||||||
def with_error_handling(
|
def with_error_handling(
|
||||||
step_name: TaskName, set_error_status: bool = True
|
step_name: TaskName, set_error_status: bool = True
|
||||||
) -> Callable[
|
) -> Callable[
|
||||||
@@ -243,8 +253,12 @@ def with_error_handling(
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
if set_error_status:
|
if is_non_retryable(e):
|
||||||
await set_workflow_error_status(input.transcript_id)
|
# Hard fail: stop retries, set error status, fail workflow
|
||||||
|
if set_error_status:
|
||||||
|
await set_workflow_error_status(input.transcript_id)
|
||||||
|
raise NonRetryableException(str(e)) from e
|
||||||
|
# Transient: do not set error status — Hatchet will retry
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return wrapper # type: ignore[return-value]
|
return wrapper # type: ignore[return-value]
|
||||||
@@ -253,7 +267,10 @@ def with_error_handling(
|
|||||||
|
|
||||||
|
|
||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.GET_RECORDING)
|
@with_error_handling(TaskName.GET_RECORDING)
|
||||||
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||||
@@ -309,6 +326,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
|||||||
parents=[get_recording],
|
parents=[get_recording],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.GET_PARTICIPANTS)
|
@with_error_handling(TaskName.GET_PARTICIPANTS)
|
||||||
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
|
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
|
||||||
@@ -412,6 +431,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
parents=[get_participants],
|
parents=[get_participants],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.PROCESS_TRACKS)
|
@with_error_handling(TaskName.PROCESS_TRACKS)
|
||||||
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
||||||
@@ -435,7 +456,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
for i, track in enumerate(input.tracks)
|
for i, track in enumerate(input.tracks)
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await track_workflow.aio_run_many(bulk_runs)
|
results = await track_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||||
|
|
||||||
target_language = participants_result.target_language
|
target_language = participants_result.target_language
|
||||||
|
|
||||||
@@ -443,7 +464,18 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
padded_tracks = []
|
padded_tracks = []
|
||||||
created_padded_files = set()
|
created_padded_files = set()
|
||||||
|
|
||||||
for result in results:
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] process_tracks: track workflow failed, failing step",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
track_index=i,
|
||||||
|
error=str(result),
|
||||||
|
)
|
||||||
|
ctx.log(f"process_tracks: track {i} failed ({result}), failing step")
|
||||||
|
raise ValueError(
|
||||||
|
f"Track {i} workflow failed after retries: {result!s}"
|
||||||
|
) from result
|
||||||
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
||||||
track_words.append(transcribe_result.words)
|
track_words.append(transcribe_result.words)
|
||||||
|
|
||||||
@@ -481,7 +513,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[process_tracks],
|
parents=[process_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
retries=3,
|
retries=2,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
desired_worker_labels={
|
desired_worker_labels={
|
||||||
"pool": DesiredWorkerLabel(
|
"pool": DesiredWorkerLabel(
|
||||||
value="cpu-heavy",
|
value="cpu-heavy",
|
||||||
@@ -593,6 +627,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
parents=[mixdown_tracks],
|
parents=[mixdown_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.GENERATE_WAVEFORM)
|
@with_error_handling(TaskName.GENERATE_WAVEFORM)
|
||||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
|
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
|
||||||
@@ -661,6 +697,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
parents=[process_tracks],
|
parents=[process_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.DETECT_TOPICS)
|
@with_error_handling(TaskName.DETECT_TOPICS)
|
||||||
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||||
@@ -722,11 +760,22 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await topic_chunk_workflow.aio_run_many(bulk_runs)
|
results = await topic_chunk_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||||
|
|
||||||
topic_chunks = [
|
topic_chunks: list[TopicChunkResult] = []
|
||||||
TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results
|
for i, result in enumerate(results):
|
||||||
]
|
if isinstance(result, BaseException):
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] detect_topics: chunk workflow failed, failing step",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
chunk_index=i,
|
||||||
|
error=str(result),
|
||||||
|
)
|
||||||
|
ctx.log(f"detect_topics: chunk {i} failed ({result}), failing step")
|
||||||
|
raise ValueError(
|
||||||
|
f"Topic chunk {i} workflow failed after retries: {result!s}"
|
||||||
|
) from result
|
||||||
|
topic_chunks.append(TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]))
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
@@ -764,8 +813,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
|
|
||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[detect_topics],
|
parents=[detect_topics],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.GENERATE_TITLE)
|
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||||
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||||
@@ -830,7 +881,9 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[detect_topics],
|
parents=[detect_topics],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
retries=3,
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.EXTRACT_SUBJECTS)
|
@with_error_handling(TaskName.EXTRACT_SUBJECTS)
|
||||||
async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult:
|
async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult:
|
||||||
@@ -909,6 +962,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
|
|||||||
parents=[extract_subjects],
|
parents=[extract_subjects],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.PROCESS_SUBJECTS)
|
@with_error_handling(TaskName.PROCESS_SUBJECTS)
|
||||||
async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult:
|
async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult:
|
||||||
@@ -935,12 +990,24 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
|
|||||||
for i, subject in enumerate(subjects)
|
for i, subject in enumerate(subjects)
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await subject_workflow.aio_run_many(bulk_runs)
|
results = await subject_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||||
|
|
||||||
subject_summaries = [
|
subject_summaries: list[SubjectSummaryResult] = []
|
||||||
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY])
|
for i, result in enumerate(results):
|
||||||
for result in results
|
if isinstance(result, BaseException):
|
||||||
]
|
logger.error(
|
||||||
|
"[Hatchet] process_subjects: subject workflow failed, failing step",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
subject_index=i,
|
||||||
|
error=str(result),
|
||||||
|
)
|
||||||
|
ctx.log(f"process_subjects: subject {i} failed ({result}), failing step")
|
||||||
|
raise ValueError(
|
||||||
|
f"Subject {i} workflow failed after retries: {result!s}"
|
||||||
|
) from result
|
||||||
|
subject_summaries.append(
|
||||||
|
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY])
|
||||||
|
)
|
||||||
|
|
||||||
ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries")
|
ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries")
|
||||||
|
|
||||||
@@ -951,6 +1018,8 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
|
|||||||
parents=[process_subjects],
|
parents=[process_subjects],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.GENERATE_RECAP)
|
@with_error_handling(TaskName.GENERATE_RECAP)
|
||||||
async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
|
async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
|
||||||
@@ -1040,6 +1109,8 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
|
|||||||
parents=[extract_subjects],
|
parents=[extract_subjects],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_LONG),
|
execution_timeout=timedelta(seconds=TIMEOUT_LONG),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS)
|
@with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS)
|
||||||
async def identify_action_items(
|
async def identify_action_items(
|
||||||
@@ -1108,6 +1179,8 @@ async def identify_action_items(
|
|||||||
parents=[process_tracks, generate_title, generate_recap, identify_action_items],
|
parents=[process_tracks, generate_title, generate_recap, identify_action_items],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=5,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.FINALIZE)
|
@with_error_handling(TaskName.FINALIZE)
|
||||||
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||||
@@ -1177,7 +1250,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
|
|
||||||
|
|
||||||
@daily_multitrack_pipeline.task(
|
@daily_multitrack_pipeline.task(
|
||||||
parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
|
parents=[finalize],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=10,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||||
@@ -1283,6 +1360,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
|||||||
parents=[cleanup_consent],
|
parents=[cleanup_consent],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||||
retries=5,
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||||
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||||
@@ -1310,6 +1389,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
|||||||
parents=[cleanup_consent],
|
parents=[cleanup_consent],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
retries=5,
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=15,
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||||
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||||
@@ -1378,3 +1459,32 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.log(f"send_webhook unexpected error, continuing anyway: {e}")
|
ctx.log(f"send_webhook unexpected error, continuing anyway: {e}")
|
||||||
return WebhookResult(webhook_sent=False)
|
return WebhookResult(webhook_sent=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
|
||||||
|
"""Run when the workflow is truly dead (all retries exhausted).
|
||||||
|
|
||||||
|
Sets transcript status to 'error' only if it is not already 'ended'.
|
||||||
|
Post-finalize tasks (cleanup_consent, post_zulip, send_webhook) use
|
||||||
|
set_error_status=False; if one of them fails, we must not overwrite
|
||||||
|
the 'ended' status that finalize already set.
|
||||||
|
"""
|
||||||
|
async with fresh_db_connection():
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript and transcript.status == "ended":
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] on_workflow_failure: transcript already ended, skipping error status (failure was post-finalize)",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
)
|
||||||
|
ctx.log(
|
||||||
|
"on_workflow_failure: transcript already ended, skipping error status"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await set_workflow_error_status(input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@daily_multitrack_pipeline.on_failure_task()
|
||||||
|
async def _register_on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
|
||||||
|
await on_workflow_failure(input, ctx)
|
||||||
|
|||||||
@@ -34,7 +34,12 @@ padding_workflow = hatchet.workflow(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
@padding_workflow.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||||
"""Pad audio track with silence based on WebM container start_time."""
|
"""Pad audio track with silence based on WebM container start_time."""
|
||||||
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from hatchet_sdk.rate_limit import RateLimit
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM
|
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_HEAVY
|
||||||
from reflector.hatchet.workflows.models import SubjectSummaryResult
|
from reflector.hatchet.workflows.models import SubjectSummaryResult
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors.summary.prompts import (
|
from reflector.processors.summary.prompts import (
|
||||||
@@ -41,8 +41,10 @@ subject_workflow = hatchet.workflow(
|
|||||||
|
|
||||||
|
|
||||||
@subject_workflow.task(
|
@subject_workflow.task(
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
retries=3,
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=60,
|
||||||
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)],
|
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)],
|
||||||
)
|
)
|
||||||
async def generate_detailed_summary(
|
async def generate_detailed_summary(
|
||||||
|
|||||||
@@ -50,7 +50,9 @@ topic_chunk_workflow = hatchet.workflow(
|
|||||||
|
|
||||||
@topic_chunk_workflow.task(
|
@topic_chunk_workflow.task(
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||||
retries=3,
|
retries=5,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=60,
|
||||||
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)],
|
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)],
|
||||||
)
|
)
|
||||||
async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult:
|
async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult:
|
||||||
|
|||||||
@@ -44,7 +44,12 @@ hatchet = HatchetClientManager.get_client()
|
|||||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||||
|
|
||||||
|
|
||||||
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
@track_workflow.task(
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
|
)
|
||||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||||
"""Pad single audio track with silence for alignment.
|
"""Pad single audio track with silence for alignment.
|
||||||
|
|
||||||
@@ -137,7 +142,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
|
|
||||||
|
|
||||||
@track_workflow.task(
|
@track_workflow.task(
|
||||||
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
parents=[pad_track],
|
||||||
|
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||||
|
retries=3,
|
||||||
|
backoff_factor=2.0,
|
||||||
|
backoff_max_seconds=30,
|
||||||
)
|
)
|
||||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||||
|
|||||||
@@ -65,10 +65,25 @@ class LLM:
|
|||||||
async def get_response(
|
async def get_response(
|
||||||
self, prompt: str, texts: list[str], tone_name: str | None = None
|
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get a text response using TreeSummarize for non-function-calling models"""
|
"""Get a text response using TreeSummarize for non-function-calling models.
|
||||||
summarizer = TreeSummarize(verbose=False)
|
|
||||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
Uses the same retry() wrapper as get_structured_response for transient
|
||||||
return str(response).strip()
|
network errors (connection, timeout, OSError) with exponential backoff.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _call():
|
||||||
|
summarizer = TreeSummarize(verbose=False)
|
||||||
|
response = await summarizer.aget_response(
|
||||||
|
prompt, texts, tone_name=tone_name
|
||||||
|
)
|
||||||
|
return str(response).strip()
|
||||||
|
|
||||||
|
return await retry(_call)(
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_backoff_interval=1.0,
|
||||||
|
retry_backoff_max=30.0,
|
||||||
|
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
|
||||||
|
)
|
||||||
|
|
||||||
async def get_structured_response(
|
async def get_structured_response(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
from reflector.hatchet.constants import TIMEOUT_AUDIO_HTTP
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||||
@@ -60,7 +60,7 @@ class AudioPaddingModalProcessor(AudioPaddingProcessor):
|
|||||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO) as client:
|
async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO_HTTP) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
url,
|
url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|||||||
@@ -55,7 +55,9 @@ class Settings(BaseSettings):
|
|||||||
WHISPER_FILE_MODEL: str = "tiny"
|
WHISPER_FILE_MODEL: str = "tiny"
|
||||||
TRANSCRIPT_URL: str | None = None
|
TRANSCRIPT_URL: str | None = None
|
||||||
TRANSCRIPT_TIMEOUT: int = 90
|
TRANSCRIPT_TIMEOUT: int = 90
|
||||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
TRANSCRIPT_FILE_TIMEOUT: int = (
|
||||||
|
540 # Below Hatchet TIMEOUT_HEAVY (600) to avoid timeout race
|
||||||
|
)
|
||||||
|
|
||||||
# Audio Transcription: modal backend
|
# Audio Transcription: modal backend
|
||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def retry(fn):
|
|||||||
"retry_httpx_status_stop",
|
"retry_httpx_status_stop",
|
||||||
(
|
(
|
||||||
401, # auth issue
|
401, # auth issue
|
||||||
|
402, # payment required / no credits — needs human action
|
||||||
404, # not found
|
404, # not found
|
||||||
413, # payload too large
|
413, # payload too large
|
||||||
418, # teapot
|
418, # teapot
|
||||||
@@ -58,8 +59,9 @@ def retry(fn):
|
|||||||
result = await fn(*args, **kwargs)
|
result = await fn(*args, **kwargs)
|
||||||
if isinstance(result, Response):
|
if isinstance(result, Response):
|
||||||
result.raise_for_status()
|
result.raise_for_status()
|
||||||
if result:
|
# Return any result including falsy (e.g. "" from get_response);
|
||||||
return result
|
# only retry on exception, not on empty string.
|
||||||
|
return result
|
||||||
except HTTPStatusError as e:
|
except HTTPStatusError as e:
|
||||||
retry_logger.exception(e)
|
retry_logger.exception(e)
|
||||||
status_code = e.response.status_code
|
status_code = e.response.status_code
|
||||||
|
|||||||
@@ -50,5 +50,8 @@ async def transcript_process(
|
|||||||
if isinstance(config, ProcessError):
|
if isinstance(config, ProcessError):
|
||||||
raise HTTPException(status_code=500, detail=config.detail)
|
raise HTTPException(status_code=500, detail=config.detail)
|
||||||
else:
|
else:
|
||||||
await dispatch_transcript_processing(config)
|
# When transcript is in error state, force a new workflow instead of replaying
|
||||||
|
# (replay would re-run from failure point with same conditions and likely fail again)
|
||||||
|
force = transcript.status == "error"
|
||||||
|
await dispatch_transcript_processing(config, force=force)
|
||||||
return ProcessStatus(status="ok")
|
return ProcessStatus(status="ok")
|
||||||
|
|||||||
303
server/tests/test_hatchet_error_handling.py
Normal file
303
server/tests/test_hatchet_error_handling.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
"""
|
||||||
|
Tests for Hatchet error handling: NonRetryable classification and error status.
|
||||||
|
|
||||||
|
These tests encode the desired behavior from the Hatchet Workflow Analysis doc:
|
||||||
|
- Transient exceptions: do NOT set error status (let Hatchet retry; user stays on "processing").
|
||||||
|
- Hard-fail exceptions: set error status and re-raise as NonRetryableException (stop retries).
|
||||||
|
- on_failure_task: sets error status when workflow is truly dead.
|
||||||
|
|
||||||
|
Run before the fix: some tests fail (reproducing the issues).
|
||||||
|
Run after the fix: all tests pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from hatchet_sdk import NonRetryableException
|
||||||
|
|
||||||
|
from reflector.hatchet.error_classification import is_non_retryable
|
||||||
|
from reflector.llm import LLMParseError
|
||||||
|
|
||||||
|
# --- Tests for is_non_retryable() (pass once error_classification exists) ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_value_error():
|
||||||
|
"""ValueError (e.g. missing config) should stop retries."""
|
||||||
|
assert is_non_retryable(ValueError("DAILY_API_KEY must be set")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_type_error():
|
||||||
|
"""TypeError (bad input) should stop retries."""
|
||||||
|
assert is_non_retryable(TypeError("expected str")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_http_401():
|
||||||
|
"""HTTP 401 auth error should stop retries."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 401
|
||||||
|
err = httpx.HTTPStatusError("Unauthorized", request=MagicMock(), response=resp)
|
||||||
|
assert is_non_retryable(err) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_http_402():
|
||||||
|
"""HTTP 402 (no credits) should stop retries."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 402
|
||||||
|
err = httpx.HTTPStatusError("Payment Required", request=MagicMock(), response=resp)
|
||||||
|
assert is_non_retryable(err) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_http_404():
|
||||||
|
"""HTTP 404 should stop retries."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 404
|
||||||
|
err = httpx.HTTPStatusError("Not Found", request=MagicMock(), response=resp)
|
||||||
|
assert is_non_retryable(err) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_false_for_http_503():
|
||||||
|
"""HTTP 503 is transient; retries are useful."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 503
|
||||||
|
err = httpx.HTTPStatusError(
|
||||||
|
"Service Unavailable", request=MagicMock(), response=resp
|
||||||
|
)
|
||||||
|
assert is_non_retryable(err) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_false_for_timeout():
|
||||||
|
"""Timeout is transient."""
|
||||||
|
assert is_non_retryable(httpx.TimeoutException("timed out")) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_llm_parse_error():
|
||||||
|
"""LLMParseError after internal retries should stop."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class _Dummy(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert is_non_retryable(LLMParseError(_Dummy, "Failed to parse", 3)) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_non_retryable_returns_true_for_non_retryable_exception():
|
||||||
|
"""Already-wrapped NonRetryableException should stay non-retryable."""
|
||||||
|
assert is_non_retryable(NonRetryableException("custom")) is True
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for with_error_handling (need pipeline module with patch) ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def pipeline_module():
|
||||||
|
"""Import daily_multitrack_pipeline with Hatchet client mocked."""
|
||||||
|
with patch("reflector.hatchet.client.settings") as s:
|
||||||
|
s.HATCHET_CLIENT_TOKEN = "test-token"
|
||||||
|
s.HATCHET_DEBUG = False
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.workflow.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
):
|
||||||
|
from reflector.hatchet.workflows import daily_multitrack_pipeline
|
||||||
|
|
||||||
|
return daily_multitrack_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_input():
|
||||||
|
"""Minimal PipelineInput for decorator tests."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import PipelineInput
|
||||||
|
|
||||||
|
return PipelineInput(
|
||||||
|
recording_id="rec-1",
|
||||||
|
tracks=[],
|
||||||
|
bucket_name="bucket",
|
||||||
|
transcript_id="ts-123",
|
||||||
|
room_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ctx():
|
||||||
|
"""Minimal Context-like object."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.log = MagicMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_error_handling_transient_does_not_set_error_status(
|
||||||
|
pipeline_module, mock_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Transient exception must NOT set error status (so user stays on 'processing' during retries).
|
||||||
|
|
||||||
|
Before fix: set_workflow_error_status is called on every exception → FAIL.
|
||||||
|
After fix: not called for transient → PASS.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise httpx.TimeoutException("timed out")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(httpx.TimeoutException):
|
||||||
|
await wrapped(mock_input, mock_ctx)
|
||||||
|
|
||||||
|
# Desired: do NOT set error status for transient (Hatchet will retry)
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_error_handling_hard_fail_raises_non_retryable_and_sets_status(
|
||||||
|
pipeline_module, mock_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""Hard-fail (e.g. ValueError) must set error status and re-raise NonRetryableException.
|
||||||
|
|
||||||
|
Before fix: raises ValueError, set_workflow_error_status called → test would need to expect ValueError.
|
||||||
|
After fix: raises NonRetryableException, set_workflow_error_status called → PASS.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise ValueError("PADDING_URL must be set")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises(NonRetryableException) as exc_info:
|
||||||
|
await wrapped(mock_input, mock_ctx)
|
||||||
|
|
||||||
|
assert "PADDING_URL" in str(exc_info.value)
|
||||||
|
mock_set_error.assert_called_once_with("ts-123")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_error_handling_set_error_status_false_never_sets_status(
|
||||||
|
pipeline_module, mock_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""When set_error_status=False, we must never set error status (e.g. cleanup_consent)."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
TaskName,
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_task(input, ctx):
|
||||||
|
raise ValueError("something went wrong")
|
||||||
|
|
||||||
|
wrapped = with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)(
|
||||||
|
failing_task
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
with pytest.raises((ValueError, NonRetryableException)):
|
||||||
|
await wrapped(mock_input, mock_ctx)
|
||||||
|
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _noop_db_context():
|
||||||
|
"""Async context manager that yields without touching the DB (for unit tests)."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_failure_task_sets_error_status(pipeline_module, mock_input, mock_ctx):
|
||||||
|
"""When workflow fails and transcript is not yet 'ended', on_failure sets status to 'error'."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
on_workflow_failure,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_processing = MagicMock()
|
||||||
|
transcript_processing.status = "processing"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_processing,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_input, mock_ctx)
|
||||||
|
mock_set_error.assert_called_once_with(mock_input.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_failure_task_does_not_overwrite_ended(
|
||||||
|
pipeline_module, mock_input, mock_ctx
|
||||||
|
):
|
||||||
|
"""When workflow fails after finalize (e.g. post_zulip), do not overwrite 'ended' with 'error'.
|
||||||
|
|
||||||
|
cleanup_consent, post_zulip, send_webhook use set_error_status=False; if one fails,
|
||||||
|
on_workflow_failure must not set status to 'error' when transcript is already 'ended'.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
on_workflow_failure,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_ended = MagicMock()
|
||||||
|
transcript_ended.status = "ended"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||||
|
_noop_db_context,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=transcript_ended,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error:
|
||||||
|
await on_workflow_failure(mock_input, mock_ctx)
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for fan-out helper (_successful_run_results) ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_successful_run_results_filters_exceptions():
|
||||||
|
"""_successful_run_results returns only non-exception items from aio_run_many(return_exceptions=True)."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
_successful_run_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{"key": "ok1"},
|
||||||
|
ValueError("child failed"),
|
||||||
|
{"key": "ok2"},
|
||||||
|
RuntimeError("another"),
|
||||||
|
]
|
||||||
|
successful = _successful_run_results(results)
|
||||||
|
assert len(successful) == 2
|
||||||
|
assert successful[0] == {"key": "ok1"}
|
||||||
|
assert successful[1] == {"key": "ok2"}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
@@ -252,6 +252,63 @@ class TestNetworkErrorRetries:
|
|||||||
assert mock_settings.llm.astructured_predict.call_count == 3
|
assert mock_settings.llm.astructured_predict.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetResponseRetries:
|
||||||
|
"""Test that get_response() uses the same retry() wrapper for transient errors."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_response_retries_on_connection_error(self, test_settings):
|
||||||
|
"""Test that get_response retries on ConnectionError and returns on success."""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.aget_response = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
ConnectionError("Connection refused"),
|
||||||
|
" Summary text ",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||||
|
result = await llm.get_response("Prompt", ["text"])
|
||||||
|
|
||||||
|
assert result == "Summary text"
|
||||||
|
assert mock_instance.aget_response.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_response_exhausts_retries(self, test_settings):
|
||||||
|
"""Test that get_response raises RetryException after retry attempts exceeded."""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.aget_response = AsyncMock(
|
||||||
|
side_effect=ConnectionError("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||||
|
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
||||||
|
await llm.get_response("Prompt", ["text"])
|
||||||
|
|
||||||
|
assert mock_instance.aget_response.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_response_returns_empty_string_without_retry(self, test_settings):
|
||||||
|
"""Empty or whitespace-only LLM response must return '' and not raise RetryException.
|
||||||
|
|
||||||
|
retry() must return falsy results (e.g. '' from get_response) instead of
|
||||||
|
treating them as 'no result' and retrying until RetryException.
|
||||||
|
"""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> ""
|
||||||
|
|
||||||
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||||
|
result = await llm.get_response("Prompt", ["text"])
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
assert mock_instance.aget_response.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
class TestTextsInclusion:
|
class TestTextsInclusion:
|
||||||
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,15 @@ async def test_retry_httpx(httpx_mock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_402_stops_by_default(httpx_mock):
|
||||||
|
"""402 (payment required / no credits) is in default retry_httpx_status_stop — do not retry."""
|
||||||
|
httpx_mock.add_response(status_code=402, json={"error": "insufficient_credits"})
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
with pytest.raises(RetryHTTPException):
|
||||||
|
await retry(client.get)("https://test_url", retry_timeout=5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retry_normal():
|
async def test_retry_normal():
|
||||||
left = 3
|
left = 3
|
||||||
|
|||||||
@@ -231,3 +231,81 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
|
|||||||
{"s3_key": k} for k in track_keys
|
{"s3_key": k} for k in track_keys
|
||||||
]
|
]
|
||||||
mock_file_pipeline.delay.assert_not_called()
|
mock_file_pipeline.delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reprocess_error_transcript_passes_force(client):
|
||||||
|
"""When transcript status is 'error', reprocess passes force=True to start fresh workflow."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
|
from reflector.db.rooms import rooms_controller
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
room = await rooms_controller.add(
|
||||||
|
name="test-room",
|
||||||
|
user_id="test-user",
|
||||||
|
zulip_auto_post=False,
|
||||||
|
zulip_stream="",
|
||||||
|
zulip_topic="",
|
||||||
|
is_locked=False,
|
||||||
|
room_mode="normal",
|
||||||
|
recording_type="cloud",
|
||||||
|
recording_trigger="automatic-2nd-participant",
|
||||||
|
is_shared=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
"",
|
||||||
|
source_kind="room",
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
user_id="test-user",
|
||||||
|
share_mode="public",
|
||||||
|
room_id=room.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
track_keys = ["recordings/test-room/track1.webm"]
|
||||||
|
recording = await recordings_controller.create(
|
||||||
|
Recording(
|
||||||
|
bucket_name="daily-bucket",
|
||||||
|
object_key="recordings/test-room",
|
||||||
|
meeting_id="test-meeting",
|
||||||
|
track_keys=track_keys,
|
||||||
|
recorded_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"recording_id": recording.id,
|
||||||
|
"status": "error",
|
||||||
|
"workflow_run_id": "old-failed-run",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
||||||
|
) as mock_celery,
|
||||||
|
patch(
|
||||||
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet,
|
||||||
|
patch(
|
||||||
|
"reflector.views.transcripts_process.dispatch_transcript_processing",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_dispatch,
|
||||||
|
):
|
||||||
|
mock_celery.return_value = False
|
||||||
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
|
return_value=V1TaskStatus.FAILED
|
||||||
|
)
|
||||||
|
response = await client.post(f"/transcripts/{transcript.id}/process")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_dispatch.assert_called_once()
|
||||||
|
assert mock_dispatch.call_args.kwargs["force"] is True
|
||||||
|
|||||||
Reference in New Issue
Block a user