self-review round

This commit is contained in:
Igor Loskutov
2025-12-17 13:05:08 -05:00
parent 7a29c742c5
commit 6ae8f1d870
10 changed files with 79 additions and 86 deletions

View File

@@ -104,7 +104,7 @@ services:
SERVER_GRPC_PORT: "7077"
SERVER_URL: http://localhost:8889
SERVER_AUTH_SET_EMAIL_VERIFIED: "t"
SERVER_DEFAULT_ENGINE_VERSION: "V1"
# SERVER_DEFAULT_ENGINE_VERSION: "V1" # default
SERVER_INTERNAL_CLIENT_INTERNAL_GRPC_BROADCAST_ADDRESS: hatchet:7077
volumes:
- ./data/hatchet-config:/config

View File

@@ -57,9 +57,29 @@ uv run /app/requeue_uploaded_file.py TRANSCRIPT_ID
After resetting the Hatchet database:
### Option A: Automatic (CLI)
```bash
# Get default tenant ID and create token in one command
TENANT_ID=$(docker compose exec -T postgres psql -U reflector -d hatchet -t -c \
"SELECT id FROM \"Tenant\" WHERE slug = 'default';" | tr -d ' \n') && \
TOKEN=$(docker compose exec -T hatchet /hatchet-admin token create \
--config /config --tenant-id "$TENANT_ID" 2>/dev/null | tr -d '\n') && \
echo "HATCHET_CLIENT_TOKEN=$TOKEN"
```
Copy the output to `server/.env`.
### Option B: Manual (UI)
1. Create API token at http://localhost:8889 → Settings → API Tokens
2. Update `server/.env`: `HATCHET_CLIENT_TOKEN=<new-token>`
3. Restart: `docker compose restart server hatchet-worker`
### Then restart workers
```bash
docker compose restart server hatchet-worker
```
Workflows register automatically when hatchet-worker starts.

View File

@@ -135,8 +135,5 @@ select = [
"reflector/processors/summary/summary_builder.py" = ["E501"]
"gpu/modal_deployments/**.py" = ["PLC0415"]
"reflector/tools/**.py" = ["PLC0415"]
"reflector/hatchet/run_workers.py" = ["PLC0415"]
"reflector/hatchet/workflows/**.py" = ["PLC0415"]
"reflector/views/hatchet.py" = ["PLC0415"]
"migrations/versions/**.py" = ["PLC0415"]
"tests/**.py" = ["PLC0415"]

View File

@@ -1,6 +1,6 @@
"""Hatchet workflow orchestration for Reflector."""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress, emit_progress_async
from reflector.hatchet.progress import emit_progress_async
__all__ = ["HatchetClientManager", "emit_progress", "emit_progress_async"]
__all__ = ["HatchetClientManager", "emit_progress_async"]

View File

@@ -20,11 +20,7 @@ from reflector.settings import settings
class HatchetClientManager:
"""Singleton manager for Hatchet client connections.
Singleton pattern is used because Hatchet SDK maintains persistent gRPC
connections for workflow registration, and multiple clients would conflict.
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
to ensure test isolation.
See module docstring for rationale. For test isolation, use `reset()`.
"""
_instance: Hatchet | None = None
@@ -68,25 +64,21 @@ class HatchetClientManager:
input_data,
additional_metadata=additional_metadata,
)
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id
@classmethod
async def get_workflow_run_status(cls, workflow_run_id: str) -> V1TaskStatus:
"""Get workflow run status."""
client = cls.get_client()
return await client.runs.aio_get_status(workflow_run_id)
@classmethod
async def cancel_workflow(cls, workflow_run_id: str) -> None:
"""Cancel a workflow."""
client = cls.get_client()
await client.runs.aio_cancel(workflow_run_id)
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
@classmethod
async def replay_workflow(cls, workflow_run_id: str) -> None:
"""Replay a failed workflow."""
client = cls.get_client()
await client.runs.aio_replay(workflow_run_id)
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)

View File

@@ -1,13 +1,12 @@
"""Progress event emission for Hatchet workers."""
import asyncio
from typing import Literal
from reflector.db.transcripts import PipelineProgressData
from reflector.logger import logger
from reflector.ws_manager import get_ws_manager
# Step mapping for progress tracking (matches Conductor pipeline)
# Step mapping for progress tracking
PIPELINE_STEPS = {
"get_recording": 1,
"get_participants": 2,
@@ -63,45 +62,6 @@ async def _emit_progress_async(
)
def emit_progress(
transcript_id: str,
step: str,
status: Literal["pending", "in_progress", "completed", "failed"],
workflow_id: str | None = None,
) -> None:
"""Emit a pipeline progress event (sync wrapper for Hatchet workers).
Args:
transcript_id: The transcript ID to emit progress for
step: The current step name (e.g., "transcribe_track")
status: The step status
workflow_id: Optional workflow run ID
"""
try:
# Get or create event loop for sync context
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
# Already in async context, schedule the coroutine
asyncio.create_task(
_emit_progress_async(transcript_id, step, status, workflow_id)
)
else:
# Not in async context, run synchronously
asyncio.run(_emit_progress_async(transcript_id, step, status, workflow_id))
except Exception as e:
# Progress emission should never break the pipeline
logger.warning(
"[Hatchet Progress] Failed to emit progress event",
error=str(e),
transcript_id=transcript_id,
step=step,
)
async def emit_progress_async(
transcript_id: str,
step: str,

View File

@@ -1,5 +1,6 @@
"""
Run Hatchet workers for the diarization pipeline.
Runs as a separate process, just like Celery workers.
Usage:
uv run -m reflector.hatchet.run_workers
@@ -30,8 +31,9 @@ def main() -> None:
debug=settings.HATCHET_DEBUG,
)
# Import here (not top-level) - workflow imports trigger HatchetClientManager.get_client()
# which requires HATCHET_CLIENT_TOKEN; must validate settings first
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
# Can't use lazy init: decorators need the client object when function is defined.
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows import ( # noqa: PLC0415
diarization_pipeline,
@@ -40,13 +42,11 @@ def main() -> None:
hatchet = HatchetClientManager.get_client()
# Create worker with both workflows
worker = hatchet.worker(
"reflector-diarization-worker",
workflows=[diarization_pipeline, track_workflow],
)
# Handle graceful shutdown
def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit

View File

@@ -3,6 +3,10 @@ Hatchet main workflow: DiarizationPipeline
Multitrack diarization pipeline for Daily.co recordings.
Orchestrates the full processing flow from recording metadata to final transcript.
Note: This file uses deferred imports (inside functions/tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
are not shared across forks, avoiding connection pooling issues.
"""
import asyncio
@@ -92,9 +96,9 @@ diarization_pipeline = hatchet.workflow(
@asynccontextmanager
async def fresh_db_connection():
"""Context manager for database connections in Hatchet workers."""
import databases
import databases # noqa: PLC0415
from reflector.db import _database_context
from reflector.db import _database_context # noqa: PLC0415
_database_context.set(None)
db = databases.Database(settings.DATABASE_URL)
@@ -116,7 +120,7 @@ async def set_workflow_error_status(transcript_id: str) -> bool:
"""
try:
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
await transcripts_controller.set_status(transcript_id, "error")
logger.info(
@@ -205,7 +209,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
# Set transcript status to "processing" at workflow start
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
@@ -273,7 +277,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
# Get transcript and reset events/topics/participants
async with fresh_db_connection():
from reflector.db.transcripts import (
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptParticipant,
transcripts_controller,
)
@@ -651,7 +655,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
# Update transcript with audio_location
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
@@ -689,7 +693,10 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
)
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptWaveform,
transcripts_controller,
)
# Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = _to_dict(ctx.task_output(process_tracks))
@@ -776,8 +783,11 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en")
from reflector.db.transcripts import TranscriptTopic, transcripts_controller
from reflector.processors.types import (
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import ( # noqa: PLC0415
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
)
@@ -841,7 +851,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import (
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalTitle,
transcripts_controller,
)
@@ -901,7 +911,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import (
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
transcripts_controller,
@@ -995,9 +1005,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
all_words = track_data.get("all_words", [])
async with fresh_db_connection():
from reflector.db.transcripts import TranscriptText, transcripts_controller
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptText,
transcripts_controller,
)
from reflector.processors.types import ( # noqa: PLC0415
Transcript as TranscriptType,
)
from reflector.processors.types import Word # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript is None:
@@ -1057,8 +1072,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
)
async with fresh_db_connection():
from reflector.db.meetings import meetings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.meetings import meetings_controller # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.meeting_id:
@@ -1099,7 +1114,7 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
return ZulipResult(zulip_message_id=None, skipped=True)
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
@@ -1135,8 +1150,8 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
return WebhookResult(webhook_sent=False, skipped=True)
async with fresh_db_connection():
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.rooms import rooms_controller # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
room = await rooms_controller.get_by_id(input.room_id)
transcript = await transcripts_controller.get_by_id(input.transcript_id)

View File

@@ -3,6 +3,15 @@ Hatchet child workflow: TrackProcessing
Handles individual audio track processing: padding and transcription.
Spawned dynamically by the main diarization pipeline for each track.
Architecture note: This is a separate workflow (not inline tasks in DiarizationPipeline)
because Hatchet workflow DAGs are defined statically, but the number of tracks varies
at runtime. Child workflow spawning via `aio_run()` + `asyncio.gather()` is the
standard pattern for dynamic fan-out. See `process_tracks` in diarization_pipeline.py.
Note: This file uses deferred imports (inside tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure
storage/DB connections are not shared across forks.
"""
import math
@@ -190,8 +199,8 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
try:
# Create fresh storage instance to avoid aioboto3 fork issues
from reflector.settings import settings
from reflector.storage.storage_aws import AwsStorage
from reflector.settings import settings # noqa: PLC0415
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
storage = AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
@@ -312,8 +321,8 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
raise ValueError("Missing padded_key from pad_track")
# Presign URL on demand (avoids stale URLs on workflow replay)
from reflector.settings import settings
from reflector.storage.storage_aws import AwsStorage
from reflector.settings import settings # noqa: PLC0415
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
storage = AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
@@ -329,7 +338,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
bucket=bucket_name,
)
from reflector.pipelines.transcription_helpers import (
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
transcribe_file_with_processor,
)

View File

@@ -188,7 +188,7 @@ async def dispatch_transcript_processing(
room = await rooms_controller.get_by_id(config.room_id)
room_forces_hatchet = room.use_hatchet if room else False
# Start durable workflow if enabled (Hatchet or Conductor)
# Start durable workflow if enabled (Hatchet)
# or if room has use_hatchet=True
use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet