self-review round

This commit is contained in:
Igor Loskutov
2025-12-17 13:05:08 -05:00
parent 7a29c742c5
commit 6ae8f1d870
10 changed files with 79 additions and 86 deletions

View File

@@ -104,7 +104,7 @@ services:
SERVER_GRPC_PORT: "7077" SERVER_GRPC_PORT: "7077"
SERVER_URL: http://localhost:8889 SERVER_URL: http://localhost:8889
SERVER_AUTH_SET_EMAIL_VERIFIED: "t" SERVER_AUTH_SET_EMAIL_VERIFIED: "t"
SERVER_DEFAULT_ENGINE_VERSION: "V1" # SERVER_DEFAULT_ENGINE_VERSION: "V1" # default
SERVER_INTERNAL_CLIENT_INTERNAL_GRPC_BROADCAST_ADDRESS: hatchet:7077 SERVER_INTERNAL_CLIENT_INTERNAL_GRPC_BROADCAST_ADDRESS: hatchet:7077
volumes: volumes:
- ./data/hatchet-config:/config - ./data/hatchet-config:/config

View File

@@ -57,9 +57,29 @@ uv run /app/requeue_uploaded_file.py TRANSCRIPT_ID
After resetting the Hatchet database: After resetting the Hatchet database:
### Option A: Automatic (CLI)
```bash
# Get default tenant ID and create token in one command
TENANT_ID=$(docker compose exec -T postgres psql -U reflector -d hatchet -t -c \
"SELECT id FROM \"Tenant\" WHERE slug = 'default';" | tr -d ' \n') && \
TOKEN=$(docker compose exec -T hatchet /hatchet-admin token create \
--config /config --tenant-id "$TENANT_ID" 2>/dev/null | tr -d '\n') && \
echo "HATCHET_CLIENT_TOKEN=$TOKEN"
```
Copy the output to `server/.env`.
### Option B: Manual (UI)
1. Create API token at http://localhost:8889 → Settings → API Tokens 1. Create API token at http://localhost:8889 → Settings → API Tokens
2. Update `server/.env`: `HATCHET_CLIENT_TOKEN=<new-token>` 2. Update `server/.env`: `HATCHET_CLIENT_TOKEN=<new-token>`
3. Restart: `docker compose restart server hatchet-worker`
### Then restart workers
```bash
docker compose restart server hatchet-worker
```
Workflows register automatically when hatchet-worker starts. Workflows register automatically when hatchet-worker starts.

View File

@@ -135,8 +135,5 @@ select = [
"reflector/processors/summary/summary_builder.py" = ["E501"] "reflector/processors/summary/summary_builder.py" = ["E501"]
"gpu/modal_deployments/**.py" = ["PLC0415"] "gpu/modal_deployments/**.py" = ["PLC0415"]
"reflector/tools/**.py" = ["PLC0415"] "reflector/tools/**.py" = ["PLC0415"]
"reflector/hatchet/run_workers.py" = ["PLC0415"]
"reflector/hatchet/workflows/**.py" = ["PLC0415"]
"reflector/views/hatchet.py" = ["PLC0415"]
"migrations/versions/**.py" = ["PLC0415"] "migrations/versions/**.py" = ["PLC0415"]
"tests/**.py" = ["PLC0415"] "tests/**.py" = ["PLC0415"]

View File

@@ -1,6 +1,6 @@
"""Hatchet workflow orchestration for Reflector.""" """Hatchet workflow orchestration for Reflector."""
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.progress import emit_progress, emit_progress_async from reflector.hatchet.progress import emit_progress_async
__all__ = ["HatchetClientManager", "emit_progress", "emit_progress_async"] __all__ = ["HatchetClientManager", "emit_progress_async"]

View File

@@ -20,11 +20,7 @@ from reflector.settings import settings
class HatchetClientManager: class HatchetClientManager:
"""Singleton manager for Hatchet client connections. """Singleton manager for Hatchet client connections.
Singleton pattern is used because Hatchet SDK maintains persistent gRPC See module docstring for rationale. For test isolation, use `reset()`.
connections for workflow registration, and multiple clients would conflict.
For testing, use the `reset()` method or the `reset_hatchet_client` fixture
to ensure test isolation.
""" """
_instance: Hatchet | None = None _instance: Hatchet | None = None
@@ -68,25 +64,21 @@ class HatchetClientManager:
input_data, input_data,
additional_metadata=additional_metadata, additional_metadata=additional_metadata,
) )
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id return result.run.metadata.id
@classmethod @classmethod
async def get_workflow_run_status(cls, workflow_run_id: str) -> V1TaskStatus: async def get_workflow_run_status(cls, workflow_run_id: str) -> V1TaskStatus:
"""Get workflow run status."""
client = cls.get_client() client = cls.get_client()
return await client.runs.aio_get_status(workflow_run_id) return await client.runs.aio_get_status(workflow_run_id)
@classmethod @classmethod
async def cancel_workflow(cls, workflow_run_id: str) -> None: async def cancel_workflow(cls, workflow_run_id: str) -> None:
"""Cancel a workflow."""
client = cls.get_client() client = cls.get_client()
await client.runs.aio_cancel(workflow_run_id) await client.runs.aio_cancel(workflow_run_id)
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id) logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
@classmethod @classmethod
async def replay_workflow(cls, workflow_run_id: str) -> None: async def replay_workflow(cls, workflow_run_id: str) -> None:
"""Replay a failed workflow."""
client = cls.get_client() client = cls.get_client()
await client.runs.aio_replay(workflow_run_id) await client.runs.aio_replay(workflow_run_id)
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id) logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)

View File

@@ -1,13 +1,12 @@
"""Progress event emission for Hatchet workers.""" """Progress event emission for Hatchet workers."""
import asyncio
from typing import Literal from typing import Literal
from reflector.db.transcripts import PipelineProgressData from reflector.db.transcripts import PipelineProgressData
from reflector.logger import logger from reflector.logger import logger
from reflector.ws_manager import get_ws_manager from reflector.ws_manager import get_ws_manager
# Step mapping for progress tracking (matches Conductor pipeline) # Step mapping for progress tracking
PIPELINE_STEPS = { PIPELINE_STEPS = {
"get_recording": 1, "get_recording": 1,
"get_participants": 2, "get_participants": 2,
@@ -63,45 +62,6 @@ async def _emit_progress_async(
) )
def emit_progress(
transcript_id: str,
step: str,
status: Literal["pending", "in_progress", "completed", "failed"],
workflow_id: str | None = None,
) -> None:
"""Emit a pipeline progress event (sync wrapper for Hatchet workers).
Args:
transcript_id: The transcript ID to emit progress for
step: The current step name (e.g., "transcribe_track")
status: The step status
workflow_id: Optional workflow run ID
"""
try:
# Get or create event loop for sync context
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
# Already in async context, schedule the coroutine
asyncio.create_task(
_emit_progress_async(transcript_id, step, status, workflow_id)
)
else:
# Not in async context, run synchronously
asyncio.run(_emit_progress_async(transcript_id, step, status, workflow_id))
except Exception as e:
# Progress emission should never break the pipeline
logger.warning(
"[Hatchet Progress] Failed to emit progress event",
error=str(e),
transcript_id=transcript_id,
step=step,
)
async def emit_progress_async( async def emit_progress_async(
transcript_id: str, transcript_id: str,
step: str, step: str,

View File

@@ -1,5 +1,6 @@
""" """
Run Hatchet workers for the diarization pipeline. Run Hatchet workers for the diarization pipeline.
Runs as a separate process, just like Celery workers.
Usage: Usage:
uv run -m reflector.hatchet.run_workers uv run -m reflector.hatchet.run_workers
@@ -30,8 +31,9 @@ def main() -> None:
debug=settings.HATCHET_DEBUG, debug=settings.HATCHET_DEBUG,
) )
# Import here (not top-level) - workflow imports trigger HatchetClientManager.get_client() # Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
# which requires HATCHET_CLIENT_TOKEN; must validate settings first # at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
# Can't use lazy init: decorators need the client object when function is defined.
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415 from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows import ( # noqa: PLC0415 from reflector.hatchet.workflows import ( # noqa: PLC0415
diarization_pipeline, diarization_pipeline,
@@ -40,13 +42,11 @@ def main() -> None:
hatchet = HatchetClientManager.get_client() hatchet = HatchetClientManager.get_client()
# Create worker with both workflows
worker = hatchet.worker( worker = hatchet.worker(
"reflector-diarization-worker", "reflector-diarization-worker",
workflows=[diarization_pipeline, track_workflow], workflows=[diarization_pipeline, track_workflow],
) )
# Handle graceful shutdown
def shutdown_handler(signum: int, frame) -> None: def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...") logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit # Worker cleanup happens automatically on exit

View File

@@ -3,6 +3,10 @@ Hatchet main workflow: DiarizationPipeline
Multitrack diarization pipeline for Daily.co recordings. Multitrack diarization pipeline for Daily.co recordings.
Orchestrates the full processing flow from recording metadata to final transcript. Orchestrates the full processing flow from recording metadata to final transcript.
Note: This file uses deferred imports (inside functions/tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
are not shared across forks, avoiding connection pooling issues.
""" """
import asyncio import asyncio
@@ -92,9 +96,9 @@ diarization_pipeline = hatchet.workflow(
@asynccontextmanager @asynccontextmanager
async def fresh_db_connection(): async def fresh_db_connection():
"""Context manager for database connections in Hatchet workers.""" """Context manager for database connections in Hatchet workers."""
import databases import databases # noqa: PLC0415
from reflector.db import _database_context from reflector.db import _database_context # noqa: PLC0415
_database_context.set(None) _database_context.set(None)
db = databases.Database(settings.DATABASE_URL) db = databases.Database(settings.DATABASE_URL)
@@ -116,7 +120,7 @@ async def set_workflow_error_status(transcript_id: str) -> bool:
""" """
try: try:
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
await transcripts_controller.set_status(transcript_id, "error") await transcripts_controller.set_status(transcript_id, "error")
logger.info( logger.info(
@@ -205,7 +209,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
# Set transcript status to "processing" at workflow start # Set transcript status to "processing" at workflow start
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
@@ -273,7 +277,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
# Get transcript and reset events/topics/participants # Get transcript and reset events/topics/participants
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import ( from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptParticipant, TranscriptParticipant,
transcripts_controller, transcripts_controller,
) )
@@ -651,7 +655,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
# Update transcript with audio_location # Update transcript with audio_location
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
@@ -689,7 +693,10 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
) )
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptWaveform,
transcripts_controller,
)
# Cleanup temporary padded S3 files (deferred until after mixdown) # Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = _to_dict(ctx.task_output(process_tracks)) track_data = _to_dict(ctx.task_output(process_tracks))
@@ -776,8 +783,11 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
words = track_data.get("all_words", []) words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en") target_language = track_data.get("target_language", "en")
from reflector.db.transcripts import TranscriptTopic, transcripts_controller from reflector.db.transcripts import ( # noqa: PLC0415
from reflector.processors.types import ( TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import ( # noqa: PLC0415
TitleSummaryWithId as TitleSummaryWithIdProcessorType, TitleSummaryWithId as TitleSummaryWithIdProcessorType,
) )
@@ -841,7 +851,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
topics_data = _to_dict(ctx.task_output(detect_topics)) topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalTitle, TranscriptFinalTitle,
transcripts_controller, transcripts_controller,
) )
@@ -901,7 +911,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
topics_data = _to_dict(ctx.task_output(detect_topics)) topics_data = _to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
TranscriptFinalShortSummary, TranscriptFinalShortSummary,
transcripts_controller, transcripts_controller,
@@ -995,9 +1005,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
all_words = track_data.get("all_words", []) all_words = track_data.get("all_words", [])
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import TranscriptText, transcripts_controller from reflector.db.transcripts import ( # noqa: PLC0415
from reflector.processors.types import Transcript as TranscriptType TranscriptText,
from reflector.processors.types import Word transcripts_controller,
)
from reflector.processors.types import ( # noqa: PLC0415
Transcript as TranscriptType,
)
from reflector.processors.types import Word # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript is None: if transcript is None:
@@ -1057,8 +1072,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
) )
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript and transcript.meeting_id: if transcript and transcript.meeting_id:
@@ -1099,7 +1114,7 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
return ZulipResult(zulip_message_id=None, skipped=True) return ZulipResult(zulip_message_id=None, skipped=True)
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if transcript:
@@ -1135,8 +1150,8 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
return WebhookResult(webhook_sent=False, skipped=True) return WebhookResult(webhook_sent=False, skipped=True)
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller # noqa: PLC0415
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
room = await rooms_controller.get_by_id(input.room_id) room = await rooms_controller.get_by_id(input.room_id)
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)

View File

@@ -3,6 +3,15 @@ Hatchet child workflow: TrackProcessing
Handles individual audio track processing: padding and transcription. Handles individual audio track processing: padding and transcription.
Spawned dynamically by the main diarization pipeline for each track. Spawned dynamically by the main diarization pipeline for each track.
Architecture note: This is a separate workflow (not inline tasks in DiarizationPipeline)
because Hatchet workflow DAGs are defined statically, but the number of tracks varies
at runtime. Child workflow spawning via `aio_run()` + `asyncio.gather()` is the
standard pattern for dynamic fan-out. See `process_tracks` in diarization_pipeline.py.
Note: This file uses deferred imports (inside tasks) intentionally.
Hatchet workers run in forked processes; fresh imports per task ensure
storage/DB connections are not shared across forks.
""" """
import math import math
@@ -190,8 +199,8 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
try: try:
# Create fresh storage instance to avoid aioboto3 fork issues # Create fresh storage instance to avoid aioboto3 fork issues
from reflector.settings import settings from reflector.settings import settings # noqa: PLC0415
from reflector.storage.storage_aws import AwsStorage from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
storage = AwsStorage( storage = AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
@@ -312,8 +321,8 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
raise ValueError("Missing padded_key from pad_track") raise ValueError("Missing padded_key from pad_track")
# Presign URL on demand (avoids stale URLs on workflow replay) # Presign URL on demand (avoids stale URLs on workflow replay)
from reflector.settings import settings from reflector.settings import settings # noqa: PLC0415
from reflector.storage.storage_aws import AwsStorage from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
storage = AwsStorage( storage = AwsStorage(
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME, aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
@@ -329,7 +338,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
bucket=bucket_name, bucket=bucket_name,
) )
from reflector.pipelines.transcription_helpers import ( from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
transcribe_file_with_processor, transcribe_file_with_processor,
) )

View File

@@ -188,7 +188,7 @@ async def dispatch_transcript_processing(
room = await rooms_controller.get_by_id(config.room_id) room = await rooms_controller.get_by_id(config.room_id)
room_forces_hatchet = room.use_hatchet if room else False room_forces_hatchet = room.use_hatchet if room else False
# Start durable workflow if enabled (Hatchet or Conductor) # Start durable workflow if enabled (Hatchet)
# or if room has use_hatchet=True # or if room has use_hatchet=True
use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet