diff --git a/docker-compose.selfhosted.yml b/docker-compose.selfhosted.yml index e4308883..dc364b49 100644 --- a/docker-compose.selfhosted.yml +++ b/docker-compose.selfhosted.yml @@ -137,6 +137,7 @@ services: postgres: image: postgres:17-alpine restart: unless-stopped + command: ["postgres", "-c", "max_connections=200"] environment: POSTGRES_USER: reflector POSTGRES_PASSWORD: reflector diff --git a/server/reflector/hatchet/constants.py b/server/reflector/hatchet/constants.py index b3810ad6..8f9c5465 100644 --- a/server/reflector/hatchet/constants.py +++ b/server/reflector/hatchet/constants.py @@ -39,5 +39,12 @@ TIMEOUT_MEDIUM = ( 300 # Single LLM calls, waveform generation (5m for slow LLM responses) ) TIMEOUT_LONG = 180 # Action items (larger context LLM) -TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown -TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks +TIMEOUT_TITLE = 300 # generate_title (single LLM call; doc: reduce from 600s) +TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown (Hatchet execution_timeout) +TIMEOUT_AUDIO_HTTP = ( + 660 # httpx timeout for pad_track — below 720 so Hatchet doesn't race +) +TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks (Hatchet execution_timeout) +TIMEOUT_HEAVY_HTTP = ( + 540 # httpx timeout for transcribe_track — below 600 so Hatchet doesn't race +) diff --git a/server/reflector/hatchet/error_classification.py b/server/reflector/hatchet/error_classification.py new file mode 100644 index 00000000..5d26a8fd --- /dev/null +++ b/server/reflector/hatchet/error_classification.py @@ -0,0 +1,74 @@ +"""Classify exceptions as non-retryable for Hatchet workflows. + +When a task raises NonRetryableException (or an exception classified as +non-retryable and re-raised as such), Hatchet stops immediately — no further +retries. Used by with_error_handling to avoid wasting retries on config errors, +auth failures, corrupt data, etc. +""" + +# Optional dependencies: only classify if the exception type is available. +# This avoids hard dependency on openai/av/botocore for code paths that don't use them. +try: + import openai +except ImportError: + openai = None # type: ignore[assignment] + +try: + import av +except ImportError: + av = None # type: ignore[assignment] + +try: + from botocore.exceptions import ClientError as BotoClientError +except ImportError: + BotoClientError = None # type: ignore[misc, assignment] + +from hatchet_sdk import NonRetryableException +from httpx import HTTPStatusError + +from reflector.llm import LLMParseError + +# HTTP status codes that won't change on retry (auth, not found, payment, payload) +NON_RETRYABLE_HTTP_STATUSES = {401, 402, 403, 404, 413} +NON_RETRYABLE_S3_CODES = {"AccessDenied", "NoSuchBucket", "NoSuchKey"} + + +def is_non_retryable(e: BaseException) -> bool: + """Return True if the exception should stop Hatchet retries immediately. + + Hard failures (config, auth, missing resource, corrupt data) return True. + Transient errors (timeouts, 5xx, 429, connection) return False. + """ + if isinstance(e, NonRetryableException): + return True + + # Config/input errors + if isinstance(e, (ValueError, TypeError)): + return True + + # HTTP status codes that won't change on retry + if isinstance(e, HTTPStatusError): + return e.response.status_code in NON_RETRYABLE_HTTP_STATUSES + + # OpenAI auth errors + if openai is not None and isinstance(e, openai.AuthenticationError): + return True + + # LLM parse failures (already retried internally) + if isinstance(e, LLMParseError): + return True + + # S3 permission/existence errors + if BotoClientError is not None and isinstance(e, BotoClientError): + code = e.response.get("Error", {}).get("Code", "") + return code in NON_RETRYABLE_S3_CODES + + # Corrupt audio (PyAV) — AVError in some versions; fallback to InvalidDataError + if av is not None: + av_error = getattr(av, "AVError", None) or getattr( + getattr(av, "error", None), "InvalidDataError", None + ) + if av_error is not None and isinstance(e, av_error): + return True + + return False diff --git a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py index 3fa725b6..9be49ca4 100644 --- a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py +++ b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py @@ -27,6 +27,7 @@ from hatchet_sdk import ( ConcurrencyExpression, ConcurrencyLimitStrategy, Context, + NonRetryableException, ) from hatchet_sdk.labels import DesiredWorkerLabel from pydantic import BaseModel @@ -43,8 +44,10 @@ from reflector.hatchet.constants import ( TIMEOUT_LONG, TIMEOUT_MEDIUM, TIMEOUT_SHORT, + TIMEOUT_TITLE, TaskName, ) +from reflector.hatchet.error_classification import is_non_retryable from reflector.hatchet.workflows.models import ( ActionItemsResult, ConsentResult, @@ -216,6 +219,13 @@ def make_audio_progress_logger( R = TypeVar("R") +def _successful_run_results( + results: list[dict[str, Any] | BaseException], +) -> list[dict[str, Any]]: + """Return only successful (non-exception) results from aio_run_many(return_exceptions=True).""" + return [r for r in results if not isinstance(r, BaseException)] + + def with_error_handling( step_name: TaskName, set_error_status: bool = True ) -> Callable[ @@ -243,8 +253,12 @@ def with_error_handling( error=str(e), exc_info=True, ) - if set_error_status: - await set_workflow_error_status(input.transcript_id) + if is_non_retryable(e): + # Hard fail: stop retries, set error status, fail workflow + if set_error_status: + await set_workflow_error_status(input.transcript_id) + raise NonRetryableException(str(e)) from e + # Transient: do not set error status — Hatchet will retry raise return wrapper # type: ignore[return-value] @@ -253,7 +267,10 @@ def with_error_handling( @daily_multitrack_pipeline.task( - execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 + execution_timeout=timedelta(seconds=TIMEOUT_SHORT), + retries=3, + backoff_factor=2.0, + backoff_max_seconds=10, ) @with_error_handling(TaskName.GET_RECORDING) async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: @@ -309,6 +326,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: parents=[get_recording], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3, + backoff_factor=2.0, + backoff_max_seconds=10, ) @with_error_handling(TaskName.GET_PARTICIPANTS) async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult: @@ -412,6 +431,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe parents=[get_participants], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, ) @with_error_handling(TaskName.PROCESS_TRACKS) async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult: @@ -435,7 +456,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes for i, track in enumerate(input.tracks) ] - results = await track_workflow.aio_run_many(bulk_runs) + results = await track_workflow.aio_run_many(bulk_runs, return_exceptions=True) target_language = participants_result.target_language @@ -443,7 +464,18 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes padded_tracks = [] created_padded_files = set() - for result in results: + for i, result in enumerate(results): + if isinstance(result, BaseException): + logger.error( + "[Hatchet] process_tracks: track workflow failed, failing step", + transcript_id=input.transcript_id, + track_index=i, + error=str(result), + ) + ctx.log(f"process_tracks: track {i} failed ({result}), failing step") + raise ValueError( + f"Track {i} workflow failed after retries: {result!s}" + ) from result transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK]) track_words.append(transcribe_result.words) @@ -481,7 +513,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes @daily_multitrack_pipeline.task( parents=[process_tracks], execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), - retries=3, + retries=2, + backoff_factor=2.0, + backoff_max_seconds=15, desired_worker_labels={ "pool": DesiredWorkerLabel( value="cpu-heavy", @@ -593,6 +627,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: parents=[mixdown_tracks], execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=3, + backoff_factor=2.0, + backoff_max_seconds=10, ) @with_error_handling(TaskName.GENERATE_WAVEFORM) async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult: @@ -661,6 +697,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul parents=[process_tracks], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, ) @with_error_handling(TaskName.DETECT_TOPICS) async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: @@ -722,11 +760,22 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: for chunk in chunks ] - results = await topic_chunk_workflow.aio_run_many(bulk_runs) + results = await topic_chunk_workflow.aio_run_many(bulk_runs, return_exceptions=True) - topic_chunks = [ - TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results - ] + topic_chunks: list[TopicChunkResult] = [] + for i, result in enumerate(results): + if isinstance(result, BaseException): + logger.error( + "[Hatchet] detect_topics: chunk workflow failed, failing step", + transcript_id=input.transcript_id, + chunk_index=i, + error=str(result), + ) + ctx.log(f"detect_topics: chunk {i} failed ({result}), failing step") + raise ValueError( + f"Topic chunk {i} workflow failed after retries: {result!s}" + ) from result + topic_chunks.append(TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC])) async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -764,8 +813,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: @daily_multitrack_pipeline.task( parents=[detect_topics], - execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), + execution_timeout=timedelta(seconds=TIMEOUT_TITLE), retries=3, + backoff_factor=2.0, + backoff_max_seconds=15, ) @with_error_handling(TaskName.GENERATE_TITLE) async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: @@ -830,7 +881,9 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: @daily_multitrack_pipeline.task( parents=[detect_topics], execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), - retries=3, + retries=5, + backoff_factor=2.0, + backoff_max_seconds=30, ) @with_error_handling(TaskName.EXTRACT_SUBJECTS) async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult: @@ -909,6 +962,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult parents=[extract_subjects], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, ) @with_error_handling(TaskName.PROCESS_SUBJECTS) async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult: @@ -935,12 +990,24 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject for i, subject in enumerate(subjects) ] - results = await subject_workflow.aio_run_many(bulk_runs) + results = await subject_workflow.aio_run_many(bulk_runs, return_exceptions=True) - subject_summaries = [ - SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY]) - for result in results - ] + subject_summaries: list[SubjectSummaryResult] = [] + for i, result in enumerate(results): + if isinstance(result, BaseException): + logger.error( + "[Hatchet] process_subjects: subject workflow failed, failing step", + transcript_id=input.transcript_id, + subject_index=i, + error=str(result), + ) + ctx.log(f"process_subjects: subject {i} failed ({result}), failing step") + raise ValueError( + f"Subject {i} workflow failed after retries: {result!s}" + ) from result + subject_summaries.append( + SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY]) + ) ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries") @@ -951,6 +1018,8 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject parents=[process_subjects], execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=3, + backoff_factor=2.0, + backoff_max_seconds=15, ) @with_error_handling(TaskName.GENERATE_RECAP) async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: @@ -1040,6 +1109,8 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: parents=[extract_subjects], execution_timeout=timedelta(seconds=TIMEOUT_LONG), retries=3, + backoff_factor=2.0, + backoff_max_seconds=15, ) @with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS) async def identify_action_items( @@ -1108,6 +1179,8 @@ async def identify_action_items( parents=[process_tracks, generate_title, generate_recap, identify_action_items], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3, + backoff_factor=2.0, + backoff_max_seconds=5, ) @with_error_handling(TaskName.FINALIZE) async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: @@ -1177,7 +1250,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: @daily_multitrack_pipeline.task( - parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 + parents=[finalize], + execution_timeout=timedelta(seconds=TIMEOUT_SHORT), + retries=3, + backoff_factor=2.0, + backoff_max_seconds=10, ) @with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False) async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: @@ -1283,6 +1360,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: parents=[cleanup_consent], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=5, + backoff_factor=2.0, + backoff_max_seconds=15, ) @with_error_handling(TaskName.POST_ZULIP, set_error_status=False) async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: @@ -1310,6 +1389,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: parents=[cleanup_consent], execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=5, + backoff_factor=2.0, + backoff_max_seconds=15, ) @with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False) async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: @@ -1378,3 +1459,32 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: except Exception as e: ctx.log(f"send_webhook unexpected error, continuing anyway: {e}") return WebhookResult(webhook_sent=False) + + +async def on_workflow_failure(input: PipelineInput, ctx: Context) -> None: + """Run when the workflow is truly dead (all retries exhausted). + + Sets transcript status to 'error' only if it is not already 'ended'. + Post-finalize tasks (cleanup_consent, post_zulip, send_webhook) use + set_error_status=False; if one of them fails, we must not overwrite + the 'ended' status that finalize already set. + """ + async with fresh_db_connection(): + from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript and transcript.status == "ended": + logger.info( + "[Hatchet] on_workflow_failure: transcript already ended, skipping error status (failure was post-finalize)", + transcript_id=input.transcript_id, + ) + ctx.log( + "on_workflow_failure: transcript already ended, skipping error status" + ) + return + await set_workflow_error_status(input.transcript_id) + + +@daily_multitrack_pipeline.on_failure_task() +async def _register_on_workflow_failure(input: PipelineInput, ctx: Context) -> None: + await on_workflow_failure(input, ctx) diff --git a/server/reflector/hatchet/workflows/padding_workflow.py b/server/reflector/hatchet/workflows/padding_workflow.py index 0e0056ed..d63125c4 100644 --- a/server/reflector/hatchet/workflows/padding_workflow.py +++ b/server/reflector/hatchet/workflows/padding_workflow.py @@ -34,7 +34,12 @@ padding_workflow = hatchet.workflow( ) -@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) +@padding_workflow.task( + execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), + retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, +) async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult: """Pad audio track with silence based on WebM container start_time.""" ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}") diff --git a/server/reflector/hatchet/workflows/subject_processing.py b/server/reflector/hatchet/workflows/subject_processing.py index 1985a15c..df7d8f2f 100644 --- a/server/reflector/hatchet/workflows/subject_processing.py +++ b/server/reflector/hatchet/workflows/subject_processing.py @@ -13,7 +13,7 @@ from hatchet_sdk.rate_limit import RateLimit from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM +from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_HEAVY from reflector.hatchet.workflows.models import SubjectSummaryResult from reflector.logger import logger from reflector.processors.summary.prompts import ( @@ -41,8 +41,10 @@ subject_workflow = hatchet.workflow( @subject_workflow.task( - execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), - retries=3, + execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), + retries=5, + backoff_factor=2.0, + backoff_max_seconds=60, rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)], ) async def generate_detailed_summary( diff --git a/server/reflector/hatchet/workflows/topic_chunk_processing.py b/server/reflector/hatchet/workflows/topic_chunk_processing.py index 82b68569..e7c90252 100644 --- a/server/reflector/hatchet/workflows/topic_chunk_processing.py +++ b/server/reflector/hatchet/workflows/topic_chunk_processing.py @@ -50,7 +50,9 @@ topic_chunk_workflow = hatchet.workflow( @topic_chunk_workflow.task( execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), - retries=3, + retries=5, + backoff_factor=2.0, + backoff_max_seconds=60, rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)], ) async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult: diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index f2ca2d6b..2458ee0c 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -44,7 +44,12 @@ hatchet = HatchetClientManager.get_client() track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput) -@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3) +@track_workflow.task( + execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), + retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, +) async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: """Pad single audio track with silence for alignment. @@ -137,7 +142,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: @track_workflow.task( - parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3 + parents=[pad_track], + execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), + retries=3, + backoff_factor=2.0, + backoff_max_seconds=30, ) async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult: """Transcribe audio track using GPU (Modal.com) or local Whisper.""" diff --git a/server/reflector/llm.py b/server/reflector/llm.py index 1ea7dd44..add13f4d 100644 --- a/server/reflector/llm.py +++ b/server/reflector/llm.py @@ -65,10 +65,25 @@ class LLM: async def get_response( self, prompt: str, texts: list[str], tone_name: str | None = None ) -> str: - """Get a text response using TreeSummarize for non-function-calling models""" - summarizer = TreeSummarize(verbose=False) - response = await summarizer.aget_response(prompt, texts, tone_name=tone_name) - return str(response).strip() + """Get a text response using TreeSummarize for non-function-calling models. + + Uses the same retry() wrapper as get_structured_response for transient + network errors (connection, timeout, OSError) with exponential backoff. + """ + + async def _call(): + summarizer = TreeSummarize(verbose=False) + response = await summarizer.aget_response( + prompt, texts, tone_name=tone_name + ) + return str(response).strip() + + return await retry(_call)( + retry_attempts=3, + retry_backoff_interval=1.0, + retry_backoff_max=30.0, + retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError), + ) async def get_structured_response( self, diff --git a/server/reflector/processors/audio_padding_modal.py b/server/reflector/processors/audio_padding_modal.py index 825dc95f..7fb33537 100644 --- a/server/reflector/processors/audio_padding_modal.py +++ b/server/reflector/processors/audio_padding_modal.py @@ -7,7 +7,7 @@ import os import httpx -from reflector.hatchet.constants import TIMEOUT_AUDIO +from reflector.hatchet.constants import TIMEOUT_AUDIO_HTTP from reflector.logger import logger from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor @@ -60,7 +60,7 @@ class AudioPaddingModalProcessor(AudioPaddingProcessor): headers["Authorization"] = f"Bearer {self.modal_api_key}" try: - async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO) as client: + async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO_HTTP) as client: response = await client.post( url, headers=headers, diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 3a608aef..5509d9bd 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -55,7 +55,9 @@ class Settings(BaseSettings): WHISPER_FILE_MODEL: str = "tiny" TRANSCRIPT_URL: str | None = None TRANSCRIPT_TIMEOUT: int = 90 - TRANSCRIPT_FILE_TIMEOUT: int = 600 + TRANSCRIPT_FILE_TIMEOUT: int = ( + 540 # Below Hatchet TIMEOUT_HEAVY (600) to avoid timeout race + ) # Audio Transcription: modal backend TRANSCRIPT_MODAL_API_KEY: str | None = None diff --git a/server/reflector/utils/retry.py b/server/reflector/utils/retry.py index 52a2f87b..54a3b233 100644 --- a/server/reflector/utils/retry.py +++ b/server/reflector/utils/retry.py @@ -30,6 +30,7 @@ def retry(fn): "retry_httpx_status_stop", ( 401, # auth issue + 402, # payment required / no credits — needs human action 404, # not found 413, # payload too large 418, # teapot @@ -58,8 +59,9 @@ def retry(fn): result = await fn(*args, **kwargs) if isinstance(result, Response): result.raise_for_status() - if result: - return result + # Return any result including falsy (e.g. "" from get_response); + # only retry on exception, not on empty string. + return result except HTTPStatusError as e: retry_logger.exception(e) status_code = e.response.status_code diff --git a/server/reflector/views/transcripts_process.py b/server/reflector/views/transcripts_process.py index 325d82e7..1e1f7201 100644 --- a/server/reflector/views/transcripts_process.py +++ b/server/reflector/views/transcripts_process.py @@ -50,5 +50,8 @@ async def transcript_process( if isinstance(config, ProcessError): raise HTTPException(status_code=500, detail=config.detail) else: - await dispatch_transcript_processing(config) + # When transcript is in error state, force a new workflow instead of replaying + # (replay would re-run from failure point with same conditions and likely fail again) + force = transcript.status == "error" + await dispatch_transcript_processing(config, force=force) return ProcessStatus(status="ok") diff --git a/server/tests/test_hatchet_error_handling.py b/server/tests/test_hatchet_error_handling.py new file mode 100644 index 00000000..c22ffe1e --- /dev/null +++ b/server/tests/test_hatchet_error_handling.py @@ -0,0 +1,303 @@ +""" +Tests for Hatchet error handling: NonRetryable classification and error status. + +These tests encode the desired behavior from the Hatchet Workflow Analysis doc: +- Transient exceptions: do NOT set error status (let Hatchet retry; user stays on "processing"). +- Hard-fail exceptions: set error status and re-raise as NonRetryableException (stop retries). +- on_failure_task: sets error status when workflow is truly dead. + +Run before the fix: some tests fail (reproducing the issues). +Run after the fix: all tests pass. +""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from hatchet_sdk import NonRetryableException + +from reflector.hatchet.error_classification import is_non_retryable +from reflector.llm import LLMParseError + +# --- Tests for is_non_retryable() (pass once error_classification exists) --- + + +def test_is_non_retryable_returns_true_for_value_error(): + """ValueError (e.g. missing config) should stop retries.""" + assert is_non_retryable(ValueError("DAILY_API_KEY must be set")) is True + + +def test_is_non_retryable_returns_true_for_type_error(): + """TypeError (bad input) should stop retries.""" + assert is_non_retryable(TypeError("expected str")) is True + + +def test_is_non_retryable_returns_true_for_http_401(): + """HTTP 401 auth error should stop retries.""" + resp = MagicMock() + resp.status_code = 401 + err = httpx.HTTPStatusError("Unauthorized", request=MagicMock(), response=resp) + assert is_non_retryable(err) is True + + +def test_is_non_retryable_returns_true_for_http_402(): + """HTTP 402 (no credits) should stop retries.""" + resp = MagicMock() + resp.status_code = 402 + err = httpx.HTTPStatusError("Payment Required", request=MagicMock(), response=resp) + assert is_non_retryable(err) is True + + +def test_is_non_retryable_returns_true_for_http_404(): + """HTTP 404 should stop retries.""" + resp = MagicMock() + resp.status_code = 404 + err = httpx.HTTPStatusError("Not Found", request=MagicMock(), response=resp) + assert is_non_retryable(err) is True + + +def test_is_non_retryable_returns_false_for_http_503(): + """HTTP 503 is transient; retries are useful.""" + resp = MagicMock() + resp.status_code = 503 + err = httpx.HTTPStatusError( + "Service Unavailable", request=MagicMock(), response=resp + ) + assert is_non_retryable(err) is False + + +def test_is_non_retryable_returns_false_for_timeout(): + """Timeout is transient.""" + assert is_non_retryable(httpx.TimeoutException("timed out")) is False + + +def test_is_non_retryable_returns_true_for_llm_parse_error(): + """LLMParseError after internal retries should stop.""" + from pydantic import BaseModel + + class _Dummy(BaseModel): + pass + + assert is_non_retryable(LLMParseError(_Dummy, "Failed to parse", 3)) is True + + +def test_is_non_retryable_returns_true_for_non_retryable_exception(): + """Already-wrapped NonRetryableException should stay non-retryable.""" + assert is_non_retryable(NonRetryableException("custom")) is True + + +# --- Tests for with_error_handling (need pipeline module with patch) --- + + +@pytest.fixture(scope="module") +def pipeline_module(): + """Import daily_multitrack_pipeline with Hatchet client mocked.""" + with patch("reflector.hatchet.client.settings") as s: + s.HATCHET_CLIENT_TOKEN = "test-token" + s.HATCHET_DEBUG = False + mock_client = MagicMock() + mock_client.workflow.return_value = MagicMock() + with patch( + "reflector.hatchet.client.HatchetClientManager.get_client", + return_value=mock_client, + ): + from reflector.hatchet.workflows import daily_multitrack_pipeline + + return daily_multitrack_pipeline + + +@pytest.fixture +def mock_input(): + """Minimal PipelineInput for decorator tests.""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import PipelineInput + + return PipelineInput( + recording_id="rec-1", + tracks=[], + bucket_name="bucket", + transcript_id="ts-123", + room_id=None, + ) + + +@pytest.fixture +def mock_ctx(): + """Minimal Context-like object.""" + ctx = MagicMock() + ctx.log = MagicMock() + return ctx + + +@pytest.mark.asyncio +async def test_with_error_handling_transient_does_not_set_error_status( + pipeline_module, mock_input, mock_ctx +): + """Transient exception must NOT set error status (so user stays on 'processing' during retries). + + Before fix: set_workflow_error_status is called on every exception → FAIL. + After fix: not called for transient → PASS. + """ + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + TaskName, + with_error_handling, + ) + + async def failing_task(input, ctx): + raise httpx.TimeoutException("timed out") + + wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task) + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status", + new_callable=AsyncMock, + ) as mock_set_error: + with pytest.raises(httpx.TimeoutException): + await wrapped(mock_input, mock_ctx) + + # Desired: do NOT set error status for transient (Hatchet will retry) + mock_set_error.assert_not_called() + + +@pytest.mark.asyncio +async def test_with_error_handling_hard_fail_raises_non_retryable_and_sets_status( + pipeline_module, mock_input, mock_ctx +): + """Hard-fail (e.g. ValueError) must set error status and re-raise NonRetryableException. + + Before fix: raises ValueError, set_workflow_error_status called → test would need to expect ValueError. + After fix: raises NonRetryableException, set_workflow_error_status called → PASS. + """ + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + TaskName, + with_error_handling, + ) + + async def failing_task(input, ctx): + raise ValueError("PADDING_URL must be set") + + wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task) + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status", + new_callable=AsyncMock, + ) as mock_set_error: + with pytest.raises(NonRetryableException) as exc_info: + await wrapped(mock_input, mock_ctx) + + assert "PADDING_URL" in str(exc_info.value) + mock_set_error.assert_called_once_with("ts-123") + + +@pytest.mark.asyncio +async def test_with_error_handling_set_error_status_false_never_sets_status( + pipeline_module, mock_input, mock_ctx +): + """When set_error_status=False, we must never set error status (e.g. cleanup_consent).""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + TaskName, + with_error_handling, + ) + + async def failing_task(input, ctx): + raise ValueError("something went wrong") + + wrapped = with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)( + failing_task + ) + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status", + new_callable=AsyncMock, + ) as mock_set_error: + with pytest.raises((ValueError, NonRetryableException)): + await wrapped(mock_input, mock_ctx) + + mock_set_error.assert_not_called() + + +@asynccontextmanager +async def _noop_db_context(): + """Async context manager that yields without touching the DB (for unit tests).""" + yield None + + +@pytest.mark.asyncio +async def test_on_failure_task_sets_error_status(pipeline_module, mock_input, mock_ctx): + """When workflow fails and transcript is not yet 'ended', on_failure sets status to 'error'.""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + on_workflow_failure, + ) + + transcript_processing = MagicMock() + transcript_processing.status = "processing" + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection", + _noop_db_context, + ): + with patch( + "reflector.db.transcripts.transcripts_controller.get_by_id", + new_callable=AsyncMock, + return_value=transcript_processing, + ): + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status", + new_callable=AsyncMock, + ) as mock_set_error: + await on_workflow_failure(mock_input, mock_ctx) + mock_set_error.assert_called_once_with(mock_input.transcript_id) + + +@pytest.mark.asyncio +async def test_on_failure_task_does_not_overwrite_ended( + pipeline_module, mock_input, mock_ctx +): + """When workflow fails after finalize (e.g. post_zulip), do not overwrite 'ended' with 'error'. + + cleanup_consent, post_zulip, send_webhook use set_error_status=False; if one fails, + on_workflow_failure must not set status to 'error' when transcript is already 'ended'. + """ + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + on_workflow_failure, + ) + + transcript_ended = MagicMock() + transcript_ended.status = "ended" + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection", + _noop_db_context, + ): + with patch( + "reflector.db.transcripts.transcripts_controller.get_by_id", + new_callable=AsyncMock, + return_value=transcript_ended, + ): + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status", + new_callable=AsyncMock, + ) as mock_set_error: + await on_workflow_failure(mock_input, mock_ctx) + mock_set_error.assert_not_called() + + +# --- Tests for fan-out helper (_successful_run_results) --- + + +def test_successful_run_results_filters_exceptions(): + """_successful_run_results returns only non-exception items from aio_run_many(return_exceptions=True).""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + _successful_run_results, + ) + + results = [ + {"key": "ok1"}, + ValueError("child failed"), + {"key": "ok2"}, + RuntimeError("another"), + ] + successful = _successful_run_results(results) + assert len(successful) == 2 + assert successful[0] == {"key": "ok1"} + assert successful[1] == {"key": "ok2"} diff --git a/server/tests/test_llm_retry.py b/server/tests/test_llm_retry.py index 5c28ff5f..a5ce995e 100644 --- a/server/tests/test_llm_retry.py +++ b/server/tests/test_llm_retry.py @@ -1,6 +1,6 @@ """Tests for LLM structured output with astructured_predict + reflection retry""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel, Field, ValidationError @@ -252,6 +252,63 @@ class TestNetworkErrorRetries: assert mock_settings.llm.astructured_predict.call_count == 3 +class TestGetResponseRetries: + """Test that get_response() uses the same retry() wrapper for transient errors.""" + + @pytest.mark.asyncio + async def test_get_response_retries_on_connection_error(self, test_settings): + """Test that get_response retries on ConnectionError and returns on success.""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + mock_instance = MagicMock() + mock_instance.aget_response = AsyncMock( + side_effect=[ + ConnectionError("Connection refused"), + " Summary text ", + ] + ) + + with patch("reflector.llm.TreeSummarize", return_value=mock_instance): + result = await llm.get_response("Prompt", ["text"]) + + assert result == "Summary text" + assert mock_instance.aget_response.call_count == 2 + + @pytest.mark.asyncio + async def test_get_response_exhausts_retries(self, test_settings): + """Test that get_response raises RetryException after retry attempts exceeded.""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + mock_instance = MagicMock() + mock_instance.aget_response = AsyncMock( + side_effect=ConnectionError("Connection refused") + ) + + with patch("reflector.llm.TreeSummarize", return_value=mock_instance): + with pytest.raises(RetryException, match="Retry attempts exceeded"): + await llm.get_response("Prompt", ["text"]) + + assert mock_instance.aget_response.call_count == 3 + + @pytest.mark.asyncio + async def test_get_response_returns_empty_string_without_retry(self, test_settings): + """Empty or whitespace-only LLM response must return '' and not raise RetryException. + + retry() must return falsy results (e.g. '' from get_response) instead of + treating them as 'no result' and retrying until RetryException. + """ + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + mock_instance = MagicMock() + mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> "" + + with patch("reflector.llm.TreeSummarize", return_value=mock_instance): + result = await llm.get_response("Prompt", ["text"]) + + assert result == "" + assert mock_instance.aget_response.call_count == 1 + + class TestTextsInclusion: """Test that texts parameter is included in the prompt sent to astructured_predict""" diff --git a/server/tests/test_retry_decorator.py b/server/tests/test_retry_decorator.py index 9aacf727..1e863b7d 100644 --- a/server/tests/test_retry_decorator.py +++ b/server/tests/test_retry_decorator.py @@ -49,6 +49,15 @@ async def test_retry_httpx(httpx_mock): ) +@pytest.mark.asyncio +async def test_retry_402_stops_by_default(httpx_mock): + """402 (payment required / no credits) is in default retry_httpx_status_stop — do not retry.""" + httpx_mock.add_response(status_code=402, json={"error": "insufficient_credits"}) + async with httpx.AsyncClient() as client: + with pytest.raises(RetryHTTPException): + await retry(client.get)("https://test_url", retry_timeout=5) + + @pytest.mark.asyncio async def test_retry_normal(): left = 3 diff --git a/server/tests/test_transcripts_process.py b/server/tests/test_transcripts_process.py index 1adf996e..623015d3 100644 --- a/server/tests/test_transcripts_process.py +++ b/server/tests/test_transcripts_process.py @@ -231,3 +231,81 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client): {"s3_key": k} for k in track_keys ] mock_file_pipeline.delay.assert_not_called() + + +@pytest.mark.usefixtures("setup_database") +@pytest.mark.asyncio +async def test_reprocess_error_transcript_passes_force(client): + """When transcript status is 'error', reprocess passes force=True to start fresh workflow.""" + from datetime import datetime, timezone + + from reflector.db.recordings import Recording, recordings_controller + from reflector.db.rooms import rooms_controller + from reflector.db.transcripts import transcripts_controller + + room = await rooms_controller.add( + name="test-room", + user_id="test-user", + zulip_auto_post=False, + zulip_stream="", + zulip_topic="", + is_locked=False, + room_mode="normal", + recording_type="cloud", + recording_trigger="automatic-2nd-participant", + is_shared=False, + ) + + transcript = await transcripts_controller.add( + "", + source_kind="room", + source_language="en", + target_language="en", + user_id="test-user", + share_mode="public", + room_id=room.id, + ) + + track_keys = ["recordings/test-room/track1.webm"] + recording = await recordings_controller.create( + Recording( + bucket_name="daily-bucket", + object_key="recordings/test-room", + meeting_id="test-meeting", + track_keys=track_keys, + recorded_at=datetime.now(timezone.utc), + ) + ) + + await transcripts_controller.update( + transcript, + { + "recording_id": recording.id, + "status": "error", + "workflow_run_id": "old-failed-run", + }, + ) + + with ( + patch( + "reflector.services.transcript_process.task_is_scheduled_or_active" + ) as mock_celery, + patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet, + patch( + "reflector.views.transcripts_process.dispatch_transcript_processing", + new_callable=AsyncMock, + ) as mock_dispatch, + ): + mock_celery.return_value = False + from hatchet_sdk.clients.rest.models import V1TaskStatus + + mock_hatchet.get_workflow_run_status = AsyncMock( + return_value=V1TaskStatus.FAILED + ) + response = await client.post(f"/transcripts/{transcript.id}/process") + + assert response.status_code == 200 + mock_dispatch.assert_called_once() + assert mock_dispatch.call_args.kwargs["force"] is True