mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 20:59:05 +00:00
self-review (no-mistakes)
This commit is contained in:
@@ -1,37 +1,71 @@
|
||||
"""Hatchet Python client wrapper."""
|
||||
"""Hatchet Python client wrapper.
|
||||
|
||||
from hatchet_sdk import Hatchet
|
||||
Uses singleton pattern because:
|
||||
1. Hatchet client maintains persistent gRPC connections for workflow registration
|
||||
2. Creating multiple clients would cause registration conflicts and resource leaks
|
||||
3. The SDK is designed for a single client instance per process
|
||||
4. Tests use `HatchetClientManager.reset()` to isolate state between tests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from hatchet_sdk import ClientConfig, Hatchet
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class HatchetClientManager:
|
||||
"""Singleton manager for Hatchet client connections."""
|
||||
"""Singleton manager for Hatchet client connections.
|
||||
|
||||
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
|
||||
connections for workflow registration, and multiple clients would conflict.
|
||||
|
||||
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
|
||||
to ensure test isolation.
|
||||
"""
|
||||
|
||||
_instance: Hatchet | None = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls) -> Hatchet:
|
||||
"""Get or create the Hatchet client."""
|
||||
"""Get or create the Hatchet client.
|
||||
|
||||
Configures root logger so all logger.info() calls in workflows
|
||||
appear in the Hatchet dashboard logs.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
||||
|
||||
# Pass root logger to Hatchet so workflow logs appear in dashboard
|
||||
root_logger = logging.getLogger()
|
||||
cls._instance = Hatchet(
|
||||
debug=settings.HATCHET_DEBUG,
|
||||
config=ClientConfig(logger=root_logger),
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
async def start_workflow(
|
||||
cls, workflow_name: str, input_data: dict, key: str | None = None
|
||||
cls,
|
||||
workflow_name: str,
|
||||
input_data: dict,
|
||||
additional_metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""Start a workflow and return the workflow run ID."""
|
||||
"""Start a workflow and return the workflow run ID.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to trigger.
|
||||
input_data: Input data for the workflow run.
|
||||
additional_metadata: Optional metadata for filtering in dashboard
|
||||
(e.g., transcript_id, recording_id).
|
||||
"""
|
||||
client = cls.get_client()
|
||||
result = await client.runs.aio_create(
|
||||
workflow_name,
|
||||
input_data,
|
||||
additional_metadata=additional_metadata,
|
||||
)
|
||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||
return result.run.metadata.id
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
123
server/reflector/hatchet/workflows/models.py
Normal file
123
server/reflector/hatchet/workflows/models.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Pydantic models for Hatchet workflow task return types.
|
||||
|
||||
Provides static typing for all task outputs, enabling type checking
|
||||
and better IDE support.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# ============================================================================
|
||||
# Track Processing Results (track_processing.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PadTrackResult(BaseModel):
|
||||
"""Result from pad_track task."""
|
||||
|
||||
padded_url: str
|
||||
size: int
|
||||
track_index: int
|
||||
|
||||
|
||||
class TranscribeTrackResult(BaseModel):
|
||||
"""Result from transcribe_track task."""
|
||||
|
||||
words: list[dict[str, Any]]
|
||||
track_index: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Diarization Pipeline Results (diarization_pipeline.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RecordingResult(BaseModel):
|
||||
"""Result from get_recording task."""
|
||||
|
||||
id: str | None
|
||||
mtg_session_id: str | None
|
||||
room_name: str | None
|
||||
duration: float
|
||||
|
||||
|
||||
class ParticipantsResult(BaseModel):
|
||||
"""Result from get_participants task."""
|
||||
|
||||
participants: list[dict[str, Any]]
|
||||
num_tracks: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
|
||||
class ProcessTracksResult(BaseModel):
|
||||
"""Result from process_tracks task."""
|
||||
|
||||
all_words: list[dict[str, Any]]
|
||||
padded_urls: list[str | None]
|
||||
word_count: int
|
||||
num_tracks: int
|
||||
target_language: str
|
||||
created_padded_files: list[str]
|
||||
|
||||
|
||||
class MixdownResult(BaseModel):
|
||||
"""Result from mixdown_tracks task."""
|
||||
|
||||
audio_key: str
|
||||
duration: float
|
||||
tracks_mixed: int
|
||||
|
||||
|
||||
class WaveformResult(BaseModel):
|
||||
"""Result from generate_waveform task."""
|
||||
|
||||
waveform_generated: bool
|
||||
|
||||
|
||||
class TopicsResult(BaseModel):
|
||||
"""Result from detect_topics task."""
|
||||
|
||||
topics: list[dict[str, Any]]
|
||||
|
||||
|
||||
class TitleResult(BaseModel):
|
||||
"""Result from generate_title task."""
|
||||
|
||||
title: str | None
|
||||
|
||||
|
||||
class SummaryResult(BaseModel):
|
||||
"""Result from generate_summary task."""
|
||||
|
||||
summary: str | None
|
||||
short_summary: str | None
|
||||
|
||||
|
||||
class FinalizeResult(BaseModel):
|
||||
"""Result from finalize task."""
|
||||
|
||||
status: str
|
||||
|
||||
|
||||
class ConsentResult(BaseModel):
|
||||
"""Result from cleanup_consent task."""
|
||||
|
||||
consent_checked: bool
|
||||
|
||||
|
||||
class ZulipResult(BaseModel):
|
||||
"""Result from post_zulip task."""
|
||||
|
||||
zulip_message_id: int | None = None
|
||||
skipped: bool = False
|
||||
|
||||
|
||||
class WebhookResult(BaseModel):
|
||||
"""Result from send_webhook task."""
|
||||
|
||||
webhook_sent: bool
|
||||
skipped: bool = False
|
||||
response_code: int | None = None
|
||||
@@ -18,8 +18,17 @@ from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
def _to_dict(output) -> dict:
|
||||
"""Convert task output to dict, handling both dict and Pydantic model returns."""
|
||||
if isinstance(output, dict):
|
||||
return output
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
# Audio constants matching existing pipeline
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
@@ -161,7 +170,7 @@ def _apply_audio_padding_to_file(
|
||||
|
||||
|
||||
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad single audio track with silence for alignment.
|
||||
|
||||
Extracts stream.start_time from WebM container metadata and applies
|
||||
@@ -213,11 +222,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return {
|
||||
"padded_url": source_url,
|
||||
"size": 0,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return PadTrackResult(
|
||||
padded_url=source_url,
|
||||
size=0,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
# Create temp file for padded output
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
@@ -265,11 +274,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"padded_url": padded_url,
|
||||
"size": file_size,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return PadTrackResult(
|
||||
padded_url=padded_url,
|
||||
size=file_size,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||
@@ -282,7 +291,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> dict:
|
||||
@track_workflow.task(
|
||||
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
|
||||
)
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
logger.info(
|
||||
"[Hatchet] transcribe_track",
|
||||
@@ -295,7 +304,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
)
|
||||
|
||||
try:
|
||||
pad_result = ctx.task_output(pad_track)
|
||||
pad_result = _to_dict(ctx.task_output(pad_track))
|
||||
audio_url = pad_result.get("padded_url")
|
||||
|
||||
if not audio_url:
|
||||
@@ -324,10 +333,10 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> dict:
|
||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"track_index": input.track_index,
|
||||
}
|
||||
return TranscribeTrackResult(
|
||||
words=words,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||
|
||||
@@ -224,6 +224,26 @@ def dispatch_transcript_processing(
|
||||
transcript, {"workflow_run_id": None}
|
||||
)
|
||||
|
||||
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
config.transcript_id
|
||||
)
|
||||
if transcript and transcript.workflow_run_id:
|
||||
# Another process started a workflow between validation and now
|
||||
try:
|
||||
status = await HatchetClientManager.get_workflow_run_status(
|
||||
transcript.workflow_run_id
|
||||
)
|
||||
if "RUNNING" in status or "QUEUED" in status:
|
||||
logger.info(
|
||||
"Concurrent workflow detected, skipping dispatch",
|
||||
workflow_id=transcript.workflow_run_id,
|
||||
)
|
||||
return transcript.workflow_run_id
|
||||
except Exception:
|
||||
# If we can't get status, proceed with new workflow
|
||||
pass
|
||||
|
||||
workflow_id = await HatchetClientManager.start_workflow(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
@@ -234,6 +254,11 @@ def dispatch_transcript_processing(
|
||||
"transcript_id": config.transcript_id,
|
||||
"room_id": config.room_id,
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": config.transcript_id,
|
||||
"recording_id": config.recording_id,
|
||||
"daily_recording_id": config.recording_id,
|
||||
},
|
||||
)
|
||||
|
||||
if transcript:
|
||||
|
||||
@@ -302,6 +302,11 @@ async def _process_multitrack_recording_inner(
|
||||
"transcript_id": transcript.id,
|
||||
"room_id": room.id,
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": transcript.id,
|
||||
"recording_id": recording_id,
|
||||
"daily_recording_id": recording_id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Started Hatchet workflow",
|
||||
|
||||
@@ -527,6 +527,22 @@ def fake_mp3_upload():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_hatchet_client():
|
||||
"""Reset HatchetClientManager singleton before and after each test.
|
||||
|
||||
This ensures test isolation - each test starts with a fresh client state.
|
||||
The fixture is autouse=True so it applies to all tests automatically.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
# Reset before test
|
||||
HatchetClientManager.reset()
|
||||
yield
|
||||
# Reset after test to clean up
|
||||
HatchetClientManager.reset()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir, client):
|
||||
import shutil
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
Tests for HatchetClientManager error handling and validation.
|
||||
|
||||
Only tests that catch real bugs - not mock verification tests.
|
||||
|
||||
Note: The `reset_hatchet_client` fixture (autouse=True in conftest.py)
|
||||
automatically resets the singleton before and after each test.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -18,8 +21,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||
mock_settings.HATCHET_CLIENT_TOKEN = "test-token"
|
||||
mock_settings.HATCHET_DEBUG = False
|
||||
@@ -37,8 +38,6 @@ async def test_hatchet_client_can_replay_handles_exception():
|
||||
# Should return False on error (workflow might be gone)
|
||||
assert can_replay is False
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
|
||||
def test_hatchet_client_raises_without_token():
|
||||
"""Test that get_client raises ValueError without token.
|
||||
@@ -48,12 +47,8 @@ def test_hatchet_client_raises_without_token():
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||
mock_settings.HATCHET_CLIENT_TOKEN = None
|
||||
|
||||
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
||||
HatchetClientManager.get_client()
|
||||
|
||||
HatchetClientManager._instance = None
|
||||
|
||||
Reference in New Issue
Block a user