From 67420d2ec4f8aa37dbcc53b39d14b19739851d11 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Tue, 16 Dec 2025 22:47:09 -0500 Subject: [PATCH] self-review (no-mistakes) --- server/reflector/hatchet/client.py | 37 ++-- .../hatchet/workflows/diarization_pipeline.py | 9 +- .../reflector/services/transcript_process.py | 209 ++++++------------ server/reflector/tools/process_transcript.py | 38 +++- server/reflector/views/transcripts_process.py | 2 +- server/reflector/worker/process.py | 12 +- server/reflector/ws_manager.py | 40 ++-- server/tests/test_hatchet_dispatch.py | 43 +++- 8 files changed, 190 insertions(+), 200 deletions(-) diff --git a/server/reflector/hatchet/client.py b/server/reflector/hatchet/client.py index 76088f17..7351eda1 100644 --- a/server/reflector/hatchet/client.py +++ b/server/reflector/hatchet/client.py @@ -8,8 +8,10 @@ Uses singleton pattern because: """ import logging +import threading from hatchet_sdk import ClientConfig, Hatchet +from hatchet_sdk.clients.rest.models import V1TaskStatus from reflector.logger import logger from reflector.settings import settings @@ -26,24 +28,23 @@ class HatchetClientManager: """ _instance: Hatchet | None = None + _lock = threading.Lock() @classmethod def get_client(cls) -> Hatchet: - """Get or create the Hatchet client. - - Configures root logger so all logger.info() calls in workflows - appear in the Hatchet dashboard logs. - """ + """Get or create the Hatchet client (thread-safe singleton).""" if cls._instance is None: - if not settings.HATCHET_CLIENT_TOKEN: - raise ValueError("HATCHET_CLIENT_TOKEN must be set") + with cls._lock: + if cls._instance is None: + if not settings.HATCHET_CLIENT_TOKEN: + raise ValueError("HATCHET_CLIENT_TOKEN must be set") - # Pass root logger to Hatchet so workflow logs appear in dashboard - root_logger = logging.getLogger() - cls._instance = Hatchet( - debug=settings.HATCHET_DEBUG, - config=ClientConfig(logger=root_logger), - ) + # Pass root logger to Hatchet so workflow logs appear in dashboard + root_logger = logging.getLogger() + cls._instance = Hatchet( + debug=settings.HATCHET_DEBUG, + config=ClientConfig(logger=root_logger), + ) return cls._instance @classmethod @@ -71,11 +72,10 @@ class HatchetClientManager: return result.run.metadata.id @classmethod - async def get_workflow_run_status(cls, workflow_run_id: str) -> str: + async def get_workflow_run_status(cls, workflow_run_id: str) -> V1TaskStatus: """Get workflow run status.""" client = cls.get_client() - status = await client.runs.aio_get_status(workflow_run_id) - return str(status) + return await client.runs.aio_get_status(workflow_run_id) @classmethod async def cancel_workflow(cls, workflow_run_id: str) -> None: @@ -96,7 +96,7 @@ class HatchetClientManager: """Check if workflow can be replayed (is FAILED).""" try: status = await cls.get_workflow_run_status(workflow_run_id) - return "FAILED" in status + return status == V1TaskStatus.FAILED except Exception as e: logger.warning( "[Hatchet] Failed to check replay status", @@ -115,4 +115,5 @@ class HatchetClientManager: @classmethod def reset(cls) -> None: """Reset the client instance (for testing).""" - cls._instance = None + with cls._lock: + cls._instance = None diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index d30e2b7d..ed0a13b8 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -501,16 +501,19 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: # Determine target sample rate from first track target_sample_rate = None for url in valid_urls: + container = None try: container = av.open(url) for frame in container.decode(audio=0): target_sample_rate = frame.sample_rate break - container.close() - if target_sample_rate: - break except Exception: continue + finally: + if container is not None: + container.close() + if target_sample_rate: + break if not target_sample_rate: raise ValueError("No decodable audio frames in any track") diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py index be88c417..ad469cab 100644 --- a/server/reflector/services/transcript_process.py +++ b/server/reflector/services/transcript_process.py @@ -11,6 +11,7 @@ from typing import Literal, Union, assert_never import celery from celery.result import AsyncResult +from hatchet_sdk.clients.rest.models import V1TaskStatus from reflector.db.recordings import recordings_controller from reflector.db.transcripts import Transcript @@ -114,14 +115,12 @@ async def validate_transcript_for_processing( # Check Hatchet workflows (if enabled) if settings.HATCHET_ENABLED and transcript.workflow_run_id: - from reflector.hatchet.client import HatchetClientManager - try: status = await HatchetClientManager.get_workflow_run_status( transcript.workflow_run_id ) # If workflow is running or queued, don't allow new processing - if "RUNNING" in status or "QUEUED" in status: + if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED): return ValidationAlreadyScheduled( detail="Hatchet workflow already running" ) @@ -173,50 +172,25 @@ async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResu ) -def dispatch_transcript_processing( +async def dispatch_transcript_processing( config: ProcessingConfig, force: bool = False ) -> AsyncResult | None: + """Dispatch transcript processing to appropriate backend (Hatchet or Celery). + + Returns AsyncResult for Celery tasks, None for Hatchet workflows. + """ + from reflector.db.rooms import rooms_controller + from reflector.db.transcripts import transcripts_controller + if isinstance(config, MultitrackProcessingConfig): # Check if room has use_hatchet=True (overrides env vars) room_forces_hatchet = False if config.room_id: - import asyncio - - from reflector.db.rooms import rooms_controller - - async def _check_room_hatchet(): - import databases - - from reflector.db import _database_context - - db = databases.Database(settings.DATABASE_URL) - _database_context.set(db) - await db.connect() - try: - room = await rooms_controller.get_by_id(config.room_id) - return room.use_hatchet if room else False - finally: - await db.disconnect() - _database_context.set(None) - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as pool: - room_forces_hatchet = pool.submit( - asyncio.run, _check_room_hatchet() - ).result() - else: - room_forces_hatchet = asyncio.run(_check_room_hatchet()) + room = await rooms_controller.get_by_id(config.room_id) + room_forces_hatchet = room.use_hatchet if room else False # Start durable workflow if enabled (Hatchet or Conductor) # or if room has use_hatchet=True - durable_started = False use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet if room_forces_hatchet: @@ -227,115 +201,76 @@ def dispatch_transcript_processing( ) if use_hatchet: - import asyncio - - import databases - - from reflector.db import _database_context - from reflector.db.transcripts import transcripts_controller - - async def _handle_hatchet(): - db = databases.Database(settings.DATABASE_URL) - _database_context.set(db) - await db.connect() - - try: - transcript = await transcripts_controller.get_by_id( - config.transcript_id + # First check if we can replay (outside transaction since it's read-only) + transcript = await transcripts_controller.get_by_id(config.transcript_id) + if transcript and transcript.workflow_run_id and not force: + can_replay = await HatchetClientManager.can_replay( + transcript.workflow_run_id + ) + if can_replay: + await HatchetClientManager.replay_workflow( + transcript.workflow_run_id ) + logger.info( + "Replaying Hatchet workflow", + workflow_id=transcript.workflow_run_id, + ) + return None - if transcript and transcript.workflow_run_id and not force: - can_replay = await HatchetClientManager.can_replay( - transcript.workflow_run_id - ) - if can_replay: - await HatchetClientManager.replay_workflow( - transcript.workflow_run_id - ) - logger.info( - "Replaying Hatchet workflow", - workflow_id=transcript.workflow_run_id, - ) - return transcript.workflow_run_id + # Force: cancel old workflow if exists + if force and transcript and transcript.workflow_run_id: + await HatchetClientManager.cancel_workflow(transcript.workflow_run_id) + logger.info( + "Cancelled old workflow (--force)", + workflow_id=transcript.workflow_run_id, + ) + await transcripts_controller.update( + transcript, {"workflow_run_id": None} + ) - # Force: cancel old workflow if exists - if force and transcript and transcript.workflow_run_id: - await HatchetClientManager.cancel_workflow( - transcript.workflow_run_id - ) + # Re-fetch and check for concurrent dispatch (optimistic approach). + # No database lock - worst case is duplicate dispatch, but Hatchet + # workflows are idempotent so this is acceptable. + transcript = await transcripts_controller.get_by_id(config.transcript_id) + if transcript and transcript.workflow_run_id: + # Another process started a workflow between validation and now + try: + status = await HatchetClientManager.get_workflow_run_status( + transcript.workflow_run_id + ) + if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED): logger.info( - "Cancelled old workflow (--force)", + "Concurrent workflow detected, skipping dispatch", workflow_id=transcript.workflow_run_id, ) - await transcripts_controller.update( - transcript, {"workflow_run_id": None} - ) + return None + except Exception: + # If we can't get status, proceed with new workflow + pass - # Re-fetch transcript to check for concurrent dispatch (TOCTOU protection) - transcript = await transcripts_controller.get_by_id( - config.transcript_id - ) - if transcript and transcript.workflow_run_id: - # Another process started a workflow between validation and now - try: - status = await HatchetClientManager.get_workflow_run_status( - transcript.workflow_run_id - ) - if "RUNNING" in status or "QUEUED" in status: - logger.info( - "Concurrent workflow detected, skipping dispatch", - workflow_id=transcript.workflow_run_id, - ) - return transcript.workflow_run_id - except Exception: - # If we can't get status, proceed with new workflow - pass + workflow_id = await HatchetClientManager.start_workflow( + workflow_name="DiarizationPipeline", + input_data={ + "recording_id": config.recording_id, + "room_name": None, + "tracks": [{"s3_key": k} for k in config.track_keys], + "bucket_name": config.bucket_name, + "transcript_id": config.transcript_id, + "room_id": config.room_id, + }, + additional_metadata={ + "transcript_id": config.transcript_id, + "recording_id": config.recording_id, + "daily_recording_id": config.recording_id, + }, + ) - workflow_id = await HatchetClientManager.start_workflow( - workflow_name="DiarizationPipeline", - input_data={ - "recording_id": config.recording_id, - "room_name": None, - "tracks": [{"s3_key": k} for k in config.track_keys], - "bucket_name": config.bucket_name, - "transcript_id": config.transcript_id, - "room_id": config.room_id, - }, - additional_metadata={ - "transcript_id": config.transcript_id, - "recording_id": config.recording_id, - "daily_recording_id": config.recording_id, - }, - ) - - if transcript: - await transcripts_controller.update( - transcript, {"workflow_run_id": workflow_id} - ) - - return workflow_id - finally: - await db.disconnect() - _database_context.set(None) - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as pool: - workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result() - else: - workflow_id = asyncio.run(_handle_hatchet()) + if transcript: + await transcripts_controller.update( + transcript, {"workflow_run_id": workflow_id} + ) logger.info("Hatchet workflow dispatched", workflow_id=workflow_id) - durable_started = True - - # If durable workflow started, skip Celery - if durable_started: return None # Celery pipeline (durable workflows disabled) diff --git a/server/reflector/tools/process_transcript.py b/server/reflector/tools/process_transcript.py index d89e5f57..b0ff7729 100644 --- a/server/reflector/tools/process_transcript.py +++ b/server/reflector/tools/process_transcript.py @@ -15,6 +15,7 @@ import time from typing import Callable from celery.result import AsyncResult +from hatchet_sdk.clients.rest.models import V1TaskStatus from reflector.db.transcripts import Transcript, transcripts_controller from reflector.services.transcript_process import ( @@ -35,12 +36,12 @@ async def process_transcript_inner( on_validation: Callable[[ValidationResult], None], on_preprocess: Callable[[PrepareResult], None], force: bool = False, -) -> AsyncResult: +) -> AsyncResult | None: validation = await validate_transcript_for_processing(transcript) on_validation(validation) config = await prepare_transcript_processing(validation) on_preprocess(config) - return dispatch_transcript_processing(config, force=force) + return await dispatch_transcript_processing(config, force=force) async def process_transcript( @@ -92,7 +93,38 @@ async def process_transcript( force=force, ) - if sync: + if result is None: + # Hatchet workflow dispatched + if sync: + from reflector.hatchet.client import HatchetClientManager + + # Re-fetch transcript to get workflow_run_id + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript or not transcript.workflow_run_id: + print("Error: workflow_run_id not found", file=sys.stderr) + sys.exit(1) + + print("Waiting for Hatchet workflow...", file=sys.stderr) + while True: + status = await HatchetClientManager.get_workflow_run_status( + transcript.workflow_run_id + ) + print(f" Status: {status}", file=sys.stderr) + + if status == V1TaskStatus.COMPLETED: + print("Workflow completed successfully", file=sys.stderr) + break + elif status in (V1TaskStatus.FAILED, V1TaskStatus.CANCELLED): + print(f"Workflow failed: {status}", file=sys.stderr) + sys.exit(1) + + await asyncio.sleep(5) + else: + print( + "Task dispatched (use --sync to wait for completion)", + file=sys.stderr, + ) + elif sync: print("Waiting for task completion...", file=sys.stderr) while not result.ready(): print(f" Status: {result.state}", file=sys.stderr) diff --git a/server/reflector/views/transcripts_process.py b/server/reflector/views/transcripts_process.py index 927cc8a9..325d82e7 100644 --- a/server/reflector/views/transcripts_process.py +++ b/server/reflector/views/transcripts_process.py @@ -50,5 +50,5 @@ async def transcript_process( if isinstance(config, ProcessError): raise HTTPException(status_code=500, detail=config.detail) else: - dispatch_transcript_processing(config) + await dispatch_transcript_processing(config) return ProcessStatus(status="ok") diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index 3f27db7b..19ef9909 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -286,10 +286,18 @@ async def _process_multitrack_recording_inner( room_id=room.id, ) - # Start durable workflow if enabled (Hatchet) + # Start durable workflow if enabled (Hatchet) or room overrides it durable_started = False + use_hatchet = settings.HATCHET_ENABLED or (room and room.use_hatchet) - if settings.HATCHET_ENABLED: + if room and room.use_hatchet and not settings.HATCHET_ENABLED: + logger.info( + "Room forces Hatchet workflow", + room_id=room.id, + transcript_id=transcript.id, + ) + + if use_hatchet: from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415 workflow_id = await HatchetClientManager.start_workflow( diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index a1f620c4..1a922264 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -109,29 +109,19 @@ class WebsocketManager: await socket.send_json(data) +_ws_manager_instance: WebsocketManager | None = None +_ws_manager_lock = threading.Lock() + + def get_ws_manager() -> WebsocketManager: - """ - Returns the WebsocketManager instance for managing websockets. - - This function initializes and returns the WebsocketManager instance, - which is responsible for managing websockets and handling websocket - connections. - - Returns: - WebsocketManager: The initialized WebsocketManager instance. - - Raises: - ImportError: If the 'reflector.settings' module cannot be imported. - RedisConnectionError: If there is an error connecting to the Redis server. - """ - local = threading.local() - if hasattr(local, "ws_manager"): - return local.ws_manager - - pubsub_client = RedisPubSubManager( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - ) - ws_manager = WebsocketManager(pubsub_client=pubsub_client) - local.ws_manager = ws_manager - return ws_manager + """Returns the WebsocketManager singleton instance.""" + global _ws_manager_instance + if _ws_manager_instance is None: + with _ws_manager_lock: + if _ws_manager_instance is None: + pubsub_client = RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) + _ws_manager_instance = WebsocketManager(pubsub_client=pubsub_client) + return _ws_manager_instance diff --git a/server/tests/test_hatchet_dispatch.py b/server/tests/test_hatchet_dispatch.py index 05cbe15f..857645ec 100644 --- a/server/tests/test_hatchet_dispatch.py +++ b/server/tests/test_hatchet_dispatch.py @@ -11,6 +11,7 @@ These tests verify: from unittest.mock import AsyncMock, patch import pytest +from hatchet_sdk.clients.rest.models import V1TaskStatus from reflector.db.transcripts import Transcript @@ -35,8 +36,12 @@ async def test_hatchet_validation_blocks_running_workflow(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: - mock_hatchet.get_workflow_run_status = AsyncMock(return_value="RUNNING") + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: + mock_hatchet.get_workflow_run_status = AsyncMock( + return_value=V1TaskStatus.RUNNING + ) with patch( "reflector.services.transcript_process.task_is_scheduled_or_active" @@ -69,8 +74,12 @@ async def test_hatchet_validation_blocks_queued_workflow(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: - mock_hatchet.get_workflow_run_status = AsyncMock(return_value="QUEUED") + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: + mock_hatchet.get_workflow_run_status = AsyncMock( + return_value=V1TaskStatus.QUEUED + ) with patch( "reflector.services.transcript_process.task_is_scheduled_or_active" @@ -103,8 +112,12 @@ async def test_hatchet_validation_allows_failed_workflow(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: - mock_hatchet.get_workflow_run_status = AsyncMock(return_value="FAILED") + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: + mock_hatchet.get_workflow_run_status = AsyncMock( + return_value=V1TaskStatus.FAILED + ) with patch( "reflector.services.transcript_process.task_is_scheduled_or_active" @@ -138,8 +151,12 @@ async def test_hatchet_validation_allows_completed_workflow(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: - mock_hatchet.get_workflow_run_status = AsyncMock(return_value="COMPLETED") + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: + mock_hatchet.get_workflow_run_status = AsyncMock( + return_value=V1TaskStatus.COMPLETED + ) with patch( "reflector.services.transcript_process.task_is_scheduled_or_active" @@ -172,7 +189,9 @@ async def test_hatchet_validation_allows_when_status_check_fails(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: # Status check fails (workflow might be deleted) mock_hatchet.get_workflow_run_status = AsyncMock( side_effect=Exception("Workflow not found") @@ -210,7 +229,9 @@ async def test_hatchet_validation_skipped_when_no_workflow_id(): with patch("reflector.services.transcript_process.settings") as mock_settings: mock_settings.HATCHET_ENABLED = True - with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet: + with patch( + "reflector.services.transcript_process.HatchetClientManager" + ) as mock_hatchet: # Should not be called mock_hatchet.get_workflow_run_status = AsyncMock() @@ -338,7 +359,7 @@ async def test_prepare_multitrack_config(): assert result.bucket_name == "test-bucket" assert result.track_keys == ["track1.webm", "track2.webm"] assert result.transcript_id == "test-transcript-id" - assert result.room_id == "test-room" + assert result.room_id is None # ValidationOk didn't specify room_id @pytest.mark.usefixtures("setup_database")