fix: improve hatchet workflow reliability (#900)

* Increase max connections

* Classify hard and transient hatchet errors

* Fan out partial success

* Force reprocessing of error transcripts

* Stop retrying on 402 payment required

* Avoid httpx/hatchet timeout race

* Add retry wrapper to get_response for for transient errors

* Add retry backoff

* Return falsy results so get_response won't retry on empty string

* Skip error status in on_workflow_failure when transcript already ended

* Fix precommit issues

* Fail step on first fan-out failure instead of skipping
This commit is contained in:
Sergey Mankovsky
2026-03-06 17:07:26 +01:00
committed by GitHub
parent a682846645
commit c155f66982
17 changed files with 717 additions and 38 deletions

View File

@@ -137,6 +137,7 @@ services:
postgres: postgres:
image: postgres:17-alpine image: postgres:17-alpine
restart: unless-stopped restart: unless-stopped
command: ["postgres", "-c", "max_connections=200"]
environment: environment:
POSTGRES_USER: reflector POSTGRES_USER: reflector
POSTGRES_PASSWORD: reflector POSTGRES_PASSWORD: reflector

View File

@@ -39,5 +39,12 @@ TIMEOUT_MEDIUM = (
300 # Single LLM calls, waveform generation (5m for slow LLM responses) 300 # Single LLM calls, waveform generation (5m for slow LLM responses)
) )
TIMEOUT_LONG = 180 # Action items (larger context LLM) TIMEOUT_LONG = 180 # Action items (larger context LLM)
TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown TIMEOUT_TITLE = 300 # generate_title (single LLM call; doc: reduce from 600s)
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown (Hatchet execution_timeout)
TIMEOUT_AUDIO_HTTP = (
660 # httpx timeout for pad_track — below 720 so Hatchet doesn't race
)
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks (Hatchet execution_timeout)
TIMEOUT_HEAVY_HTTP = (
540 # httpx timeout for transcribe_track — below 600 so Hatchet doesn't race
)

View File

@@ -0,0 +1,74 @@
"""Classify exceptions as non-retryable for Hatchet workflows.
When a task raises NonRetryableException (or an exception classified as
non-retryable and re-raised as such), Hatchet stops immediately — no further
retries. Used by with_error_handling to avoid wasting retries on config errors,
auth failures, corrupt data, etc.
"""
# Optional dependencies: only classify if the exception type is available.
# This avoids hard dependency on openai/av/botocore for code paths that don't use them.
try:
import openai
except ImportError:
openai = None # type: ignore[assignment]
try:
import av
except ImportError:
av = None # type: ignore[assignment]
try:
from botocore.exceptions import ClientError as BotoClientError
except ImportError:
BotoClientError = None # type: ignore[misc, assignment]
from hatchet_sdk import NonRetryableException
from httpx import HTTPStatusError
from reflector.llm import LLMParseError
# HTTP status codes that won't change on retry (auth, not found, payment, payload)
NON_RETRYABLE_HTTP_STATUSES = {401, 402, 403, 404, 413}
NON_RETRYABLE_S3_CODES = {"AccessDenied", "NoSuchBucket", "NoSuchKey"}
def is_non_retryable(e: BaseException) -> bool:
"""Return True if the exception should stop Hatchet retries immediately.
Hard failures (config, auth, missing resource, corrupt data) return True.
Transient errors (timeouts, 5xx, 429, connection) return False.
"""
if isinstance(e, NonRetryableException):
return True
# Config/input errors
if isinstance(e, (ValueError, TypeError)):
return True
# HTTP status codes that won't change on retry
if isinstance(e, HTTPStatusError):
return e.response.status_code in NON_RETRYABLE_HTTP_STATUSES
# OpenAI auth errors
if openai is not None and isinstance(e, openai.AuthenticationError):
return True
# LLM parse failures (already retried internally)
if isinstance(e, LLMParseError):
return True
# S3 permission/existence errors
if BotoClientError is not None and isinstance(e, BotoClientError):
code = e.response.get("Error", {}).get("Code", "")
return code in NON_RETRYABLE_S3_CODES
# Corrupt audio (PyAV) — AVError in some versions; fallback to InvalidDataError
if av is not None:
av_error = getattr(av, "AVError", None) or getattr(
getattr(av, "error", None), "InvalidDataError", None
)
if av_error is not None and isinstance(e, av_error):
return True
return False

View File

@@ -27,6 +27,7 @@ from hatchet_sdk import (
ConcurrencyExpression, ConcurrencyExpression,
ConcurrencyLimitStrategy, ConcurrencyLimitStrategy,
Context, Context,
NonRetryableException,
) )
from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.labels import DesiredWorkerLabel
from pydantic import BaseModel from pydantic import BaseModel
@@ -43,8 +44,10 @@ from reflector.hatchet.constants import (
TIMEOUT_LONG, TIMEOUT_LONG,
TIMEOUT_MEDIUM, TIMEOUT_MEDIUM,
TIMEOUT_SHORT, TIMEOUT_SHORT,
TIMEOUT_TITLE,
TaskName, TaskName,
) )
from reflector.hatchet.error_classification import is_non_retryable
from reflector.hatchet.workflows.models import ( from reflector.hatchet.workflows.models import (
ActionItemsResult, ActionItemsResult,
ConsentResult, ConsentResult,
@@ -216,6 +219,13 @@ def make_audio_progress_logger(
R = TypeVar("R") R = TypeVar("R")
def _successful_run_results(
results: list[dict[str, Any] | BaseException],
) -> list[dict[str, Any]]:
"""Return only successful (non-exception) results from aio_run_many(return_exceptions=True)."""
return [r for r in results if not isinstance(r, BaseException)]
def with_error_handling( def with_error_handling(
step_name: TaskName, set_error_status: bool = True step_name: TaskName, set_error_status: bool = True
) -> Callable[ ) -> Callable[
@@ -243,8 +253,12 @@ def with_error_handling(
error=str(e), error=str(e),
exc_info=True, exc_info=True,
) )
if set_error_status: if is_non_retryable(e):
await set_workflow_error_status(input.transcript_id) # Hard fail: stop retries, set error status, fail workflow
if set_error_status:
await set_workflow_error_status(input.transcript_id)
raise NonRetryableException(str(e)) from e
# Transient: do not set error status — Hatchet will retry
raise raise
return wrapper # type: ignore[return-value] return wrapper # type: ignore[return-value]
@@ -253,7 +267,10 @@ def with_error_handling(
@daily_multitrack_pipeline.task( @daily_multitrack_pipeline.task(
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
) )
@with_error_handling(TaskName.GET_RECORDING) @with_error_handling(TaskName.GET_RECORDING)
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
@@ -309,6 +326,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
parents=[get_recording], parents=[get_recording],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
) )
@with_error_handling(TaskName.GET_PARTICIPANTS) @with_error_handling(TaskName.GET_PARTICIPANTS)
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult: async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
@@ -412,6 +431,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
parents=[get_participants], parents=[get_participants],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
) )
@with_error_handling(TaskName.PROCESS_TRACKS) @with_error_handling(TaskName.PROCESS_TRACKS)
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult: async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
@@ -435,7 +456,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
for i, track in enumerate(input.tracks) for i, track in enumerate(input.tracks)
] ]
results = await track_workflow.aio_run_many(bulk_runs) results = await track_workflow.aio_run_many(bulk_runs, return_exceptions=True)
target_language = participants_result.target_language target_language = participants_result.target_language
@@ -443,7 +464,18 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
padded_tracks = [] padded_tracks = []
created_padded_files = set() created_padded_files = set()
for result in results: for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(
"[Hatchet] process_tracks: track workflow failed, failing step",
transcript_id=input.transcript_id,
track_index=i,
error=str(result),
)
ctx.log(f"process_tracks: track {i} failed ({result}), failing step")
raise ValueError(
f"Track {i} workflow failed after retries: {result!s}"
) from result
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK]) transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
track_words.append(transcribe_result.words) track_words.append(transcribe_result.words)
@@ -481,7 +513,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
@daily_multitrack_pipeline.task( @daily_multitrack_pipeline.task(
parents=[process_tracks], parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3, retries=2,
backoff_factor=2.0,
backoff_max_seconds=15,
desired_worker_labels={ desired_worker_labels={
"pool": DesiredWorkerLabel( "pool": DesiredWorkerLabel(
value="cpu-heavy", value="cpu-heavy",
@@ -593,6 +627,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
parents=[mixdown_tracks], parents=[mixdown_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
) )
@with_error_handling(TaskName.GENERATE_WAVEFORM) @with_error_handling(TaskName.GENERATE_WAVEFORM)
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult: async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
@@ -661,6 +697,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
parents=[process_tracks], parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
) )
@with_error_handling(TaskName.DETECT_TOPICS) @with_error_handling(TaskName.DETECT_TOPICS)
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
@@ -722,11 +760,22 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
for chunk in chunks for chunk in chunks
] ]
results = await topic_chunk_workflow.aio_run_many(bulk_runs) results = await topic_chunk_workflow.aio_run_many(bulk_runs, return_exceptions=True)
topic_chunks = [ topic_chunks: list[TopicChunkResult] = []
TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results for i, result in enumerate(results):
] if isinstance(result, BaseException):
logger.error(
"[Hatchet] detect_topics: chunk workflow failed, failing step",
transcript_id=input.transcript_id,
chunk_index=i,
error=str(result),
)
ctx.log(f"detect_topics: chunk {i} failed ({result}), failing step")
raise ValueError(
f"Topic chunk {i} workflow failed after retries: {result!s}"
) from result
topic_chunks.append(TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]))
async with fresh_db_connection(): async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -764,8 +813,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
@daily_multitrack_pipeline.task( @daily_multitrack_pipeline.task(
parents=[detect_topics], parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=15,
) )
@with_error_handling(TaskName.GENERATE_TITLE) @with_error_handling(TaskName.GENERATE_TITLE)
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
@@ -830,7 +881,9 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
@daily_multitrack_pipeline.task( @daily_multitrack_pipeline.task(
parents=[detect_topics], parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3, retries=5,
backoff_factor=2.0,
backoff_max_seconds=30,
) )
@with_error_handling(TaskName.EXTRACT_SUBJECTS) @with_error_handling(TaskName.EXTRACT_SUBJECTS)
async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult: async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult:
@@ -909,6 +962,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
parents=[extract_subjects], parents=[extract_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
) )
@with_error_handling(TaskName.PROCESS_SUBJECTS) @with_error_handling(TaskName.PROCESS_SUBJECTS)
async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult: async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult:
@@ -935,12 +990,24 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
for i, subject in enumerate(subjects) for i, subject in enumerate(subjects)
] ]
results = await subject_workflow.aio_run_many(bulk_runs) results = await subject_workflow.aio_run_many(bulk_runs, return_exceptions=True)
subject_summaries = [ subject_summaries: list[SubjectSummaryResult] = []
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY]) for i, result in enumerate(results):
for result in results if isinstance(result, BaseException):
] logger.error(
"[Hatchet] process_subjects: subject workflow failed, failing step",
transcript_id=input.transcript_id,
subject_index=i,
error=str(result),
)
ctx.log(f"process_subjects: subject {i} failed ({result}), failing step")
raise ValueError(
f"Subject {i} workflow failed after retries: {result!s}"
) from result
subject_summaries.append(
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY])
)
ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries") ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries")
@@ -951,6 +1018,8 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
parents=[process_subjects], parents=[process_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=15,
) )
@with_error_handling(TaskName.GENERATE_RECAP) @with_error_handling(TaskName.GENERATE_RECAP)
async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
@@ -1040,6 +1109,8 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
parents=[extract_subjects], parents=[extract_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_LONG), execution_timeout=timedelta(seconds=TIMEOUT_LONG),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=15,
) )
@with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS) @with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS)
async def identify_action_items( async def identify_action_items(
@@ -1108,6 +1179,8 @@ async def identify_action_items(
parents=[process_tracks, generate_title, generate_recap, identify_action_items], parents=[process_tracks, generate_title, generate_recap, identify_action_items],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3, retries=3,
backoff_factor=2.0,
backoff_max_seconds=5,
) )
@with_error_handling(TaskName.FINALIZE) @with_error_handling(TaskName.FINALIZE)
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
@@ -1177,7 +1250,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
@daily_multitrack_pipeline.task( @daily_multitrack_pipeline.task(
parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 parents=[finalize],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=10,
) )
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False) @with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
@@ -1283,6 +1360,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
parents=[cleanup_consent], parents=[cleanup_consent],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=5, retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
) )
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False) @with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
@@ -1310,6 +1389,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
parents=[cleanup_consent], parents=[cleanup_consent],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=5, retries=5,
backoff_factor=2.0,
backoff_max_seconds=15,
) )
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False) @with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
@@ -1378,3 +1459,32 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
except Exception as e: except Exception as e:
ctx.log(f"send_webhook unexpected error, continuing anyway: {e}") ctx.log(f"send_webhook unexpected error, continuing anyway: {e}")
return WebhookResult(webhook_sent=False) return WebhookResult(webhook_sent=False)
async def on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
"""Run when the workflow is truly dead (all retries exhausted).
Sets transcript status to 'error' only if it is not already 'ended'.
Post-finalize tasks (cleanup_consent, post_zulip, send_webhook) use
set_error_status=False; if one of them fails, we must not overwrite
the 'ended' status that finalize already set.
"""
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.status == "ended":
logger.info(
"[Hatchet] on_workflow_failure: transcript already ended, skipping error status (failure was post-finalize)",
transcript_id=input.transcript_id,
)
ctx.log(
"on_workflow_failure: transcript already ended, skipping error status"
)
return
await set_workflow_error_status(input.transcript_id)
@daily_multitrack_pipeline.on_failure_task()
async def _register_on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
await on_workflow_failure(input, ctx)

View File

@@ -34,7 +34,12 @@ padding_workflow = hatchet.workflow(
) )
@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) @padding_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult: async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
"""Pad audio track with silence based on WebM container start_time.""" """Pad audio track with silence based on WebM container start_time."""
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}") ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")

View File

@@ -13,7 +13,7 @@ from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_HEAVY
from reflector.hatchet.workflows.models import SubjectSummaryResult from reflector.hatchet.workflows.models import SubjectSummaryResult
from reflector.logger import logger from reflector.logger import logger
from reflector.processors.summary.prompts import ( from reflector.processors.summary.prompts import (
@@ -41,8 +41,10 @@ subject_workflow = hatchet.workflow(
@subject_workflow.task( @subject_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3, retries=5,
backoff_factor=2.0,
backoff_max_seconds=60,
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)], rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)],
) )
async def generate_detailed_summary( async def generate_detailed_summary(

View File

@@ -50,7 +50,9 @@ topic_chunk_workflow = hatchet.workflow(
@topic_chunk_workflow.task( @topic_chunk_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3, retries=5,
backoff_factor=2.0,
backoff_max_seconds=60,
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)], rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)],
) )
async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult: async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult:

View File

@@ -44,7 +44,12 @@ hatchet = HatchetClientManager.get_client()
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput) track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) @track_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
)
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment. """Pad single audio track with silence for alignment.
@@ -137,7 +142,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
@track_workflow.task( @track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3 parents=[pad_track],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
backoff_factor=2.0,
backoff_max_seconds=30,
) )
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult: async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
"""Transcribe audio track using GPU (Modal.com) or local Whisper.""" """Transcribe audio track using GPU (Modal.com) or local Whisper."""

View File

@@ -65,10 +65,25 @@ class LLM:
async def get_response( async def get_response(
self, prompt: str, texts: list[str], tone_name: str | None = None self, prompt: str, texts: list[str], tone_name: str | None = None
) -> str: ) -> str:
"""Get a text response using TreeSummarize for non-function-calling models""" """Get a text response using TreeSummarize for non-function-calling models.
summarizer = TreeSummarize(verbose=False)
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name) Uses the same retry() wrapper as get_structured_response for transient
return str(response).strip() network errors (connection, timeout, OSError) with exponential backoff.
"""
async def _call():
summarizer = TreeSummarize(verbose=False)
response = await summarizer.aget_response(
prompt, texts, tone_name=tone_name
)
return str(response).strip()
return await retry(_call)(
retry_attempts=3,
retry_backoff_interval=1.0,
retry_backoff_max=30.0,
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
)
async def get_structured_response( async def get_structured_response(
self, self,

View File

@@ -7,7 +7,7 @@ import os
import httpx import httpx
from reflector.hatchet.constants import TIMEOUT_AUDIO from reflector.hatchet.constants import TIMEOUT_AUDIO_HTTP
from reflector.logger import logger from reflector.logger import logger
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
@@ -60,7 +60,7 @@ class AudioPaddingModalProcessor(AudioPaddingProcessor):
headers["Authorization"] = f"Bearer {self.modal_api_key}" headers["Authorization"] = f"Bearer {self.modal_api_key}"
try: try:
async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO) as client: async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO_HTTP) as client:
response = await client.post( response = await client.post(
url, url,
headers=headers, headers=headers,

View File

@@ -55,7 +55,9 @@ class Settings(BaseSettings):
WHISPER_FILE_MODEL: str = "tiny" WHISPER_FILE_MODEL: str = "tiny"
TRANSCRIPT_URL: str | None = None TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90 TRANSCRIPT_TIMEOUT: int = 90
TRANSCRIPT_FILE_TIMEOUT: int = 600 TRANSCRIPT_FILE_TIMEOUT: int = (
540 # Below Hatchet TIMEOUT_HEAVY (600) to avoid timeout race
)
# Audio Transcription: modal backend # Audio Transcription: modal backend
TRANSCRIPT_MODAL_API_KEY: str | None = None TRANSCRIPT_MODAL_API_KEY: str | None = None

View File

@@ -30,6 +30,7 @@ def retry(fn):
"retry_httpx_status_stop", "retry_httpx_status_stop",
( (
401, # auth issue 401, # auth issue
402, # payment required / no credits — needs human action
404, # not found 404, # not found
413, # payload too large 413, # payload too large
418, # teapot 418, # teapot
@@ -58,8 +59,9 @@ def retry(fn):
result = await fn(*args, **kwargs) result = await fn(*args, **kwargs)
if isinstance(result, Response): if isinstance(result, Response):
result.raise_for_status() result.raise_for_status()
if result: # Return any result including falsy (e.g. "" from get_response);
return result # only retry on exception, not on empty string.
return result
except HTTPStatusError as e: except HTTPStatusError as e:
retry_logger.exception(e) retry_logger.exception(e)
status_code = e.response.status_code status_code = e.response.status_code

View File

@@ -50,5 +50,8 @@ async def transcript_process(
if isinstance(config, ProcessError): if isinstance(config, ProcessError):
raise HTTPException(status_code=500, detail=config.detail) raise HTTPException(status_code=500, detail=config.detail)
else: else:
await dispatch_transcript_processing(config) # When transcript is in error state, force a new workflow instead of replaying
# (replay would re-run from failure point with same conditions and likely fail again)
force = transcript.status == "error"
await dispatch_transcript_processing(config, force=force)
return ProcessStatus(status="ok") return ProcessStatus(status="ok")

View File

@@ -0,0 +1,303 @@
"""
Tests for Hatchet error handling: NonRetryable classification and error status.
These tests encode the desired behavior from the Hatchet Workflow Analysis doc:
- Transient exceptions: do NOT set error status (let Hatchet retry; user stays on "processing").
- Hard-fail exceptions: set error status and re-raise as NonRetryableException (stop retries).
- on_failure_task: sets error status when workflow is truly dead.
Run before the fix: some tests fail (reproducing the issues).
Run after the fix: all tests pass.
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
from reflector.hatchet.error_classification import is_non_retryable
from reflector.llm import LLMParseError
# --- Tests for is_non_retryable() (pass once error_classification exists) ---
def test_is_non_retryable_returns_true_for_value_error():
"""ValueError (e.g. missing config) should stop retries."""
assert is_non_retryable(ValueError("DAILY_API_KEY must be set")) is True
def test_is_non_retryable_returns_true_for_type_error():
"""TypeError (bad input) should stop retries."""
assert is_non_retryable(TypeError("expected str")) is True
def test_is_non_retryable_returns_true_for_http_401():
"""HTTP 401 auth error should stop retries."""
resp = MagicMock()
resp.status_code = 401
err = httpx.HTTPStatusError("Unauthorized", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_true_for_http_402():
"""HTTP 402 (no credits) should stop retries."""
resp = MagicMock()
resp.status_code = 402
err = httpx.HTTPStatusError("Payment Required", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_true_for_http_404():
"""HTTP 404 should stop retries."""
resp = MagicMock()
resp.status_code = 404
err = httpx.HTTPStatusError("Not Found", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_false_for_http_503():
"""HTTP 503 is transient; retries are useful."""
resp = MagicMock()
resp.status_code = 503
err = httpx.HTTPStatusError(
"Service Unavailable", request=MagicMock(), response=resp
)
assert is_non_retryable(err) is False
def test_is_non_retryable_returns_false_for_timeout():
"""Timeout is transient."""
assert is_non_retryable(httpx.TimeoutException("timed out")) is False
def test_is_non_retryable_returns_true_for_llm_parse_error():
"""LLMParseError after internal retries should stop."""
from pydantic import BaseModel
class _Dummy(BaseModel):
pass
assert is_non_retryable(LLMParseError(_Dummy, "Failed to parse", 3)) is True
def test_is_non_retryable_returns_true_for_non_retryable_exception():
"""Already-wrapped NonRetryableException should stay non-retryable."""
assert is_non_retryable(NonRetryableException("custom")) is True
# --- Tests for with_error_handling (need pipeline module with patch) ---
@pytest.fixture(scope="module")
def pipeline_module():
"""Import daily_multitrack_pipeline with Hatchet client mocked."""
with patch("reflector.hatchet.client.settings") as s:
s.HATCHET_CLIENT_TOKEN = "test-token"
s.HATCHET_DEBUG = False
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import daily_multitrack_pipeline
return daily_multitrack_pipeline
@pytest.fixture
def mock_input():
"""Minimal PipelineInput for decorator tests."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import PipelineInput
return PipelineInput(
recording_id="rec-1",
tracks=[],
bucket_name="bucket",
transcript_id="ts-123",
room_id=None,
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
@pytest.mark.asyncio
async def test_with_error_handling_transient_does_not_set_error_status(
pipeline_module, mock_input, mock_ctx
):
"""Transient exception must NOT set error status (so user stays on 'processing' during retries).
Before fix: set_workflow_error_status is called on every exception → FAIL.
After fix: not called for transient → PASS.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_input, mock_ctx)
# Desired: do NOT set error status for transient (Hatchet will retry)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_with_error_handling_hard_fail_raises_non_retryable_and_sets_status(
pipeline_module, mock_input, mock_ctx
):
"""Hard-fail (e.g. ValueError) must set error status and re-raise NonRetryableException.
Before fix: raises ValueError, set_workflow_error_status called → test would need to expect ValueError.
After fix: raises NonRetryableException, set_workflow_error_status called → PASS.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("PADDING_URL must be set")
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_input, mock_ctx)
assert "PADDING_URL" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-123")
@pytest.mark.asyncio
async def test_with_error_handling_set_error_status_false_never_sets_status(
pipeline_module, mock_input, mock_ctx
):
"""When set_error_status=False, we must never set error status (e.g. cleanup_consent)."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("something went wrong")
wrapped = with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)(
failing_task
)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises((ValueError, NonRetryableException)):
await wrapped(mock_input, mock_ctx)
mock_set_error.assert_not_called()
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB (for unit tests)."""
yield None
@pytest.mark.asyncio
async def test_on_failure_task_sets_error_status(pipeline_module, mock_input, mock_ctx):
"""When workflow fails and transcript is not yet 'ended', on_failure sets status to 'error'."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
on_workflow_failure,
)
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_input.transcript_id)
@pytest.mark.asyncio
async def test_on_failure_task_does_not_overwrite_ended(
pipeline_module, mock_input, mock_ctx
):
"""When workflow fails after finalize (e.g. post_zulip), do not overwrite 'ended' with 'error'.
cleanup_consent, post_zulip, send_webhook use set_error_status=False; if one fails,
on_workflow_failure must not set status to 'error' when transcript is already 'ended'.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
on_workflow_failure,
)
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_input, mock_ctx)
mock_set_error.assert_not_called()
# --- Tests for fan-out helper (_successful_run_results) ---
def test_successful_run_results_filters_exceptions():
"""_successful_run_results returns only non-exception items from aio_run_many(return_exceptions=True)."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
_successful_run_results,
)
results = [
{"key": "ok1"},
ValueError("child failed"),
{"key": "ok2"},
RuntimeError("another"),
]
successful = _successful_run_results(results)
assert len(successful) == 2
assert successful[0] == {"key": "ok1"}
assert successful[1] == {"key": "ok2"}

View File

@@ -1,6 +1,6 @@
"""Tests for LLM structured output with astructured_predict + reflection retry""" """Tests for LLM structured output with astructured_predict + reflection retry"""
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
@@ -252,6 +252,63 @@ class TestNetworkErrorRetries:
assert mock_settings.llm.astructured_predict.call_count == 3 assert mock_settings.llm.astructured_predict.call_count == 3
class TestGetResponseRetries:
"""Test that get_response() uses the same retry() wrapper for transient errors."""
@pytest.mark.asyncio
async def test_get_response_retries_on_connection_error(self, test_settings):
"""Test that get_response retries on ConnectionError and returns on success."""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(
side_effect=[
ConnectionError("Connection refused"),
" Summary text ",
]
)
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
result = await llm.get_response("Prompt", ["text"])
assert result == "Summary text"
assert mock_instance.aget_response.call_count == 2
@pytest.mark.asyncio
async def test_get_response_exhausts_retries(self, test_settings):
"""Test that get_response raises RetryException after retry attempts exceeded."""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(
side_effect=ConnectionError("Connection refused")
)
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
with pytest.raises(RetryException, match="Retry attempts exceeded"):
await llm.get_response("Prompt", ["text"])
assert mock_instance.aget_response.call_count == 3
@pytest.mark.asyncio
async def test_get_response_returns_empty_string_without_retry(self, test_settings):
"""Empty or whitespace-only LLM response must return '' and not raise RetryException.
retry() must return falsy results (e.g. '' from get_response) instead of
treating them as 'no result' and retrying until RetryException.
"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> ""
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
result = await llm.get_response("Prompt", ["text"])
assert result == ""
assert mock_instance.aget_response.call_count == 1
class TestTextsInclusion: class TestTextsInclusion:
"""Test that texts parameter is included in the prompt sent to astructured_predict""" """Test that texts parameter is included in the prompt sent to astructured_predict"""

View File

@@ -49,6 +49,15 @@ async def test_retry_httpx(httpx_mock):
) )
@pytest.mark.asyncio
async def test_retry_402_stops_by_default(httpx_mock):
"""402 (payment required / no credits) is in default retry_httpx_status_stop — do not retry."""
httpx_mock.add_response(status_code=402, json={"error": "insufficient_credits"})
async with httpx.AsyncClient() as client:
with pytest.raises(RetryHTTPException):
await retry(client.get)("https://test_url", retry_timeout=5)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retry_normal(): async def test_retry_normal():
left = 3 left = 3

View File

@@ -231,3 +231,81 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
{"s3_key": k} for k in track_keys {"s3_key": k} for k in track_keys
] ]
mock_file_pipeline.delay.assert_not_called() mock_file_pipeline.delay.assert_not_called()
@pytest.mark.usefixtures("setup_database")
@pytest.mark.asyncio
async def test_reprocess_error_transcript_passes_force(client):
"""When transcript status is 'error', reprocess passes force=True to start fresh workflow."""
from datetime import datetime, timezone
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
transcript = await transcripts_controller.add(
"",
source_kind="room",
source_language="en",
target_language="en",
user_id="test-user",
share_mode="public",
room_id=room.id,
)
track_keys = ["recordings/test-room/track1.webm"]
recording = await recordings_controller.create(
Recording(
bucket_name="daily-bucket",
object_key="recordings/test-room",
meeting_id="test-meeting",
track_keys=track_keys,
recorded_at=datetime.now(timezone.utc),
)
)
await transcripts_controller.update(
transcript,
{
"recording_id": recording.id,
"status": "error",
"workflow_run_id": "old-failed-run",
},
)
with (
patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery,
patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet,
patch(
"reflector.views.transcripts_process.dispatch_transcript_processing",
new_callable=AsyncMock,
) as mock_dispatch,
):
mock_celery.return_value = False
from hatchet_sdk.clients.rest.models import V1TaskStatus
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
response = await client.post(f"/transcripts/{transcript.id}/process")
assert response.status_code == 200
mock_dispatch.assert_called_once()
assert mock_dispatch.call_args.kwargs["force"] is True