self-review (no-mistakes)

This commit is contained in:
Igor Loskutov
2025-12-16 16:04:52 -05:00
parent e81e0cb5c3
commit fce0945564
9 changed files with 1034 additions and 1041 deletions

View File

@@ -1,37 +1,71 @@
"""Hatchet Python client wrapper."""
"""Hatchet Python client wrapper.
from hatchet_sdk import Hatchet
Uses singleton pattern because:
1. Hatchet client maintains persistent gRPC connections for workflow registration
2. Creating multiple clients would cause registration conflicts and resource leaks
3. The SDK is designed for a single client instance per process
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
"""
import logging
from hatchet_sdk import ClientConfig, Hatchet
from reflector.logger import logger
from reflector.settings import settings
class HatchetClientManager:
"""Singleton manager for Hatchet client connections."""
"""Singleton manager for Hatchet client connections.
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
connections for workflow registration, and multiple clients would conflict.
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
to ensure test isolation.
"""
_instance: Hatchet | None = None
@classmethod
def get_client(cls) -> Hatchet:
"""Get or create the Hatchet client."""
"""Get or create the Hatchet client.
Configures root logger so all logger.info() calls in workflows
appear in the Hatchet dashboard logs.
"""
if cls._instance is None:
if not settings.HATCHET_CLIENT_TOKEN:
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
# Pass root logger to Hatchet so workflow logs appear in dashboard
root_logger = logging.getLogger()
cls._instance = Hatchet(
debug=settings.HATCHET_DEBUG,
config=ClientConfig(logger=root_logger),
)
return cls._instance
@classmethod
async def start_workflow(
cls, workflow_name: str, input_data: dict, key: str | None = None
cls,
workflow_name: str,
input_data: dict,
additional_metadata: dict | None = None,
) -> str:
"""Start a workflow and return the workflow run ID."""
"""Start a workflow and return the workflow run ID.
Args:
workflow_name: Name of the workflow to trigger.
input_data: Input data for the workflow run.
additional_metadata: Optional metadata for filtering in dashboard
(e.g., transcript_id, recording_id).
"""
client = cls.get_client()
result = await client.runs.aio_create(
workflow_name,
input_data,
additional_metadata=additional_metadata,
)
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,123 @@
"""
Pydantic models for Hatchet workflow task return types.
Provides static typing for all task outputs, enabling type checking
and better IDE support.
"""
from typing import Any
from pydantic import BaseModel
# ============================================================================
# Track Processing Results (track_processing.py)
# ============================================================================
class PadTrackResult(BaseModel):
"""Result from pad_track task."""
padded_url: str
size: int
track_index: int
class TranscribeTrackResult(BaseModel):
"""Result from transcribe_track task."""
words: list[dict[str, Any]]
track_index: int
# ============================================================================
# Diarization Pipeline Results (diarization_pipeline.py)
# ============================================================================
class RecordingResult(BaseModel):
"""Result from get_recording task."""
id: str | None
mtg_session_id: str | None
room_name: str | None
duration: float
class ParticipantsResult(BaseModel):
"""Result from get_participants task."""
participants: list[dict[str, Any]]
num_tracks: int
source_language: str
target_language: str
class ProcessTracksResult(BaseModel):
"""Result from process_tracks task."""
all_words: list[dict[str, Any]]
padded_urls: list[str | None]
word_count: int
num_tracks: int
target_language: str
created_padded_files: list[str]
class MixdownResult(BaseModel):
"""Result from mixdown_tracks task."""
audio_key: str
duration: float
tracks_mixed: int
class WaveformResult(BaseModel):
"""Result from generate_waveform task."""
waveform_generated: bool
class TopicsResult(BaseModel):
"""Result from detect_topics task."""
topics: list[dict[str, Any]]
class TitleResult(BaseModel):
"""Result from generate_title task."""
title: str | None
class SummaryResult(BaseModel):
"""Result from generate_summary task."""
summary: str | None
short_summary: str | None
class FinalizeResult(BaseModel):
"""Result from finalize task."""
status: str
class ConsentResult(BaseModel):
"""Result from cleanup_consent task."""
consent_checked: bool
class ZulipResult(BaseModel):
"""Result from post_zulip task."""
zulip_message_id: int | None = None
skipped: bool = False
class WebhookResult(BaseModel):
"""Result from send_webhook task."""
webhook_sent: bool
skipped: bool = False
response_code: int | None = None

View File

@@ -18,8 +18,17 @@ from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress_async
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns."""
if isinstance(output, dict):
return output
return output.model_dump()
# Audio constants matching existing pipeline
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 64000
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> dict:
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment.
Extracts stream.start_time from WebM container metadata and applies
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
await emit_progress_async(
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": source_url,
"size": 0,
"track_index": input.track_index,
}
return PadTrackResult(
padded_url=source_url,
size=0,
track_index=input.track_index,
)
# Create temp file for padded output
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
)
return {
"padded_url": padded_url,
"size": file_size,
"track_index": input.track_index,
}
return PadTrackResult(
padded_url=padded_url,
size=file_size,
track_index=input.track_index,
)
except Exception as e:
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
@track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
)
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
logger.info(
"[Hatchet] transcribe_track",
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
)
try:
pad_result = ctx.task_output(pad_track)
pad_result = _to_dict(ctx.task_output(pad_track))
audio_url = pad_result.get("padded_url")
if not audio_url:
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
)
return {
"words": words,
"track_index": input.track_index,
}
return TranscribeTrackResult(
words=words,
track_index=input.track_index,
)
except Exception as e:
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)

View File

@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
transcript, {"workflow_run_id": None}
)
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
transcript = await transcripts_controller.get_by_id(
config.transcript_id
)
if transcript and transcript.workflow_run_id:
# Another process started a workflow between validation and now
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
)
if "RUNNING" in status or "QUEUED" in status:
logger.info(
"Concurrent workflow detected, skipping dispatch",
workflow_id=transcript.workflow_run_id,
)
return transcript.workflow_run_id
except Exception:
# If we can't get status, proceed with new workflow
pass
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline",
input_data={
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
"transcript_id": config.transcript_id,
"room_id": config.room_id,
},
additional_metadata={
"transcript_id": config.transcript_id,
"recording_id": config.recording_id,
"daily_recording_id": config.recording_id,
},
)
if transcript:

View File

@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
"transcript_id": transcript.id,
"room_id": room.id,
},
additional_metadata={
"transcript_id": transcript.id,
"recording_id": recording_id,
"daily_recording_id": recording_id,
},
)
logger.info(
"Started Hatchet workflow",

View File

@@ -527,6 +527,22 @@ def fake_mp3_upload():
yield
@pytest.fixture(autouse=True)
def reset_hatchet_client():
"""Reset HatchetClientManager singleton before and after each test.
This ensures test isolation - each test starts with a fresh client state.
The fixture is autouse=True so it applies to all tests automatically.
"""
from reflector.hatchet.client import HatchetClientManager
# Reset before test
HatchetClientManager.reset()
yield
# Reset after test to clean up
HatchetClientManager.reset()
@pytest.fixture
async def fake_transcript_with_topics(tmpdir, client):
import shutil

View File

@@ -2,6 +2,9 @@
Tests for HatchetClientManager error handling and validation.
Only tests that catch real bugs - not mock verification tests.
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
automatically resets the singleton before and after each test.
"""
from unittest.mock import AsyncMock, MagicMock, patch
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
"""
from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
mock_settings.HATCHET_DEBUG = False
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
# Should return False on error (workflow might be gone)
assert can_replay is False
HatchetClientManager._instance = None
def test_hatchet_client_raises_without_token():
"""Test that get_client raises ValueError without token.
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
"""
from reflector.hatchet.client import HatchetClientManager
HatchetClientManager._instance = None
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = None
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
HatchetClientManager.get_client()
HatchetClientManager._instance = None