mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
self-review (no-mistakes)
This commit is contained in:
@@ -8,8 +8,10 @@ Uses singleton pattern because:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
from hatchet_sdk import ClientConfig, Hatchet
|
from hatchet_sdk import ClientConfig, Hatchet
|
||||||
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -26,24 +28,23 @@ class HatchetClientManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Hatchet | None = None
|
_instance: Hatchet | None = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_client(cls) -> Hatchet:
|
def get_client(cls) -> Hatchet:
|
||||||
"""Get or create the Hatchet client.
|
"""Get or create the Hatchet client (thread-safe singleton)."""
|
||||||
|
|
||||||
Configures root logger so all logger.info() calls in workflows
|
|
||||||
appear in the Hatchet dashboard logs.
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
if not settings.HATCHET_CLIENT_TOKEN:
|
with cls._lock:
|
||||||
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
if cls._instance is None:
|
||||||
|
if not settings.HATCHET_CLIENT_TOKEN:
|
||||||
|
raise ValueError("HATCHET_CLIENT_TOKEN must be set")
|
||||||
|
|
||||||
# Pass root logger to Hatchet so workflow logs appear in dashboard
|
# Pass root logger to Hatchet so workflow logs appear in dashboard
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
cls._instance = Hatchet(
|
cls._instance = Hatchet(
|
||||||
debug=settings.HATCHET_DEBUG,
|
debug=settings.HATCHET_DEBUG,
|
||||||
config=ClientConfig(logger=root_logger),
|
config=ClientConfig(logger=root_logger),
|
||||||
)
|
)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -71,11 +72,10 @@ class HatchetClientManager:
|
|||||||
return result.run.metadata.id
|
return result.run.metadata.id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_workflow_run_status(cls, workflow_run_id: str) -> str:
|
async def get_workflow_run_status(cls, workflow_run_id: str) -> V1TaskStatus:
|
||||||
"""Get workflow run status."""
|
"""Get workflow run status."""
|
||||||
client = cls.get_client()
|
client = cls.get_client()
|
||||||
status = await client.runs.aio_get_status(workflow_run_id)
|
return await client.runs.aio_get_status(workflow_run_id)
|
||||||
return str(status)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def cancel_workflow(cls, workflow_run_id: str) -> None:
|
async def cancel_workflow(cls, workflow_run_id: str) -> None:
|
||||||
@@ -96,7 +96,7 @@ class HatchetClientManager:
|
|||||||
"""Check if workflow can be replayed (is FAILED)."""
|
"""Check if workflow can be replayed (is FAILED)."""
|
||||||
try:
|
try:
|
||||||
status = await cls.get_workflow_run_status(workflow_run_id)
|
status = await cls.get_workflow_run_status(workflow_run_id)
|
||||||
return "FAILED" in status
|
return status == V1TaskStatus.FAILED
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[Hatchet] Failed to check replay status",
|
"[Hatchet] Failed to check replay status",
|
||||||
@@ -115,4 +115,5 @@ class HatchetClientManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def reset(cls) -> None:
|
def reset(cls) -> None:
|
||||||
"""Reset the client instance (for testing)."""
|
"""Reset the client instance (for testing)."""
|
||||||
cls._instance = None
|
with cls._lock:
|
||||||
|
cls._instance = None
|
||||||
|
|||||||
@@ -501,16 +501,19 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
# Determine target sample rate from first track
|
# Determine target sample rate from first track
|
||||||
target_sample_rate = None
|
target_sample_rate = None
|
||||||
for url in valid_urls:
|
for url in valid_urls:
|
||||||
|
container = None
|
||||||
try:
|
try:
|
||||||
container = av.open(url)
|
container = av.open(url)
|
||||||
for frame in container.decode(audio=0):
|
for frame in container.decode(audio=0):
|
||||||
target_sample_rate = frame.sample_rate
|
target_sample_rate = frame.sample_rate
|
||||||
break
|
break
|
||||||
container.close()
|
|
||||||
if target_sample_rate:
|
|
||||||
break
|
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
finally:
|
||||||
|
if container is not None:
|
||||||
|
container.close()
|
||||||
|
if target_sample_rate:
|
||||||
|
break
|
||||||
|
|
||||||
if not target_sample_rate:
|
if not target_sample_rate:
|
||||||
raise ValueError("No decodable audio frames in any track")
|
raise ValueError("No decodable audio frames in any track")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Literal, Union, assert_never
|
|||||||
|
|
||||||
import celery
|
import celery
|
||||||
from celery.result import AsyncResult
|
from celery.result import AsyncResult
|
||||||
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
from reflector.db.transcripts import Transcript
|
from reflector.db.transcripts import Transcript
|
||||||
@@ -114,14 +115,12 @@ async def validate_transcript_for_processing(
|
|||||||
|
|
||||||
# Check Hatchet workflows (if enabled)
|
# Check Hatchet workflows (if enabled)
|
||||||
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
|
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = await HatchetClientManager.get_workflow_run_status(
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
transcript.workflow_run_id
|
transcript.workflow_run_id
|
||||||
)
|
)
|
||||||
# If workflow is running or queued, don't allow new processing
|
# If workflow is running or queued, don't allow new processing
|
||||||
if "RUNNING" in status or "QUEUED" in status:
|
if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED):
|
||||||
return ValidationAlreadyScheduled(
|
return ValidationAlreadyScheduled(
|
||||||
detail="Hatchet workflow already running"
|
detail="Hatchet workflow already running"
|
||||||
)
|
)
|
||||||
@@ -173,50 +172,25 @@ async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResu
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_transcript_processing(
|
async def dispatch_transcript_processing(
|
||||||
config: ProcessingConfig, force: bool = False
|
config: ProcessingConfig, force: bool = False
|
||||||
) -> AsyncResult | None:
|
) -> AsyncResult | None:
|
||||||
|
"""Dispatch transcript processing to appropriate backend (Hatchet or Celery).
|
||||||
|
|
||||||
|
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
|
||||||
|
"""
|
||||||
|
from reflector.db.rooms import rooms_controller
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
if isinstance(config, MultitrackProcessingConfig):
|
if isinstance(config, MultitrackProcessingConfig):
|
||||||
# Check if room has use_hatchet=True (overrides env vars)
|
# Check if room has use_hatchet=True (overrides env vars)
|
||||||
room_forces_hatchet = False
|
room_forces_hatchet = False
|
||||||
if config.room_id:
|
if config.room_id:
|
||||||
import asyncio
|
room = await rooms_controller.get_by_id(config.room_id)
|
||||||
|
room_forces_hatchet = room.use_hatchet if room else False
|
||||||
from reflector.db.rooms import rooms_controller
|
|
||||||
|
|
||||||
async def _check_room_hatchet():
|
|
||||||
import databases
|
|
||||||
|
|
||||||
from reflector.db import _database_context
|
|
||||||
|
|
||||||
db = databases.Database(settings.DATABASE_URL)
|
|
||||||
_database_context.set(db)
|
|
||||||
await db.connect()
|
|
||||||
try:
|
|
||||||
room = await rooms_controller.get_by_id(config.room_id)
|
|
||||||
return room.use_hatchet if room else False
|
|
||||||
finally:
|
|
||||||
await db.disconnect()
|
|
||||||
_database_context.set(None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = None
|
|
||||||
|
|
||||||
if loop and loop.is_running():
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
||||||
room_forces_hatchet = pool.submit(
|
|
||||||
asyncio.run, _check_room_hatchet()
|
|
||||||
).result()
|
|
||||||
else:
|
|
||||||
room_forces_hatchet = asyncio.run(_check_room_hatchet())
|
|
||||||
|
|
||||||
# Start durable workflow if enabled (Hatchet or Conductor)
|
# Start durable workflow if enabled (Hatchet or Conductor)
|
||||||
# or if room has use_hatchet=True
|
# or if room has use_hatchet=True
|
||||||
durable_started = False
|
|
||||||
use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet
|
use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet
|
||||||
|
|
||||||
if room_forces_hatchet:
|
if room_forces_hatchet:
|
||||||
@@ -227,115 +201,76 @@ def dispatch_transcript_processing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_hatchet:
|
if use_hatchet:
|
||||||
import asyncio
|
# First check if we can replay (outside transaction since it's read-only)
|
||||||
|
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
||||||
import databases
|
if transcript and transcript.workflow_run_id and not force:
|
||||||
|
can_replay = await HatchetClientManager.can_replay(
|
||||||
from reflector.db import _database_context
|
transcript.workflow_run_id
|
||||||
from reflector.db.transcripts import transcripts_controller
|
)
|
||||||
|
if can_replay:
|
||||||
async def _handle_hatchet():
|
await HatchetClientManager.replay_workflow(
|
||||||
db = databases.Database(settings.DATABASE_URL)
|
transcript.workflow_run_id
|
||||||
_database_context.set(db)
|
|
||||||
await db.connect()
|
|
||||||
|
|
||||||
try:
|
|
||||||
transcript = await transcripts_controller.get_by_id(
|
|
||||||
config.transcript_id
|
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"Replaying Hatchet workflow",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
if transcript and transcript.workflow_run_id and not force:
|
# Force: cancel old workflow if exists
|
||||||
can_replay = await HatchetClientManager.can_replay(
|
if force and transcript and transcript.workflow_run_id:
|
||||||
transcript.workflow_run_id
|
await HatchetClientManager.cancel_workflow(transcript.workflow_run_id)
|
||||||
)
|
logger.info(
|
||||||
if can_replay:
|
"Cancelled old workflow (--force)",
|
||||||
await HatchetClientManager.replay_workflow(
|
workflow_id=transcript.workflow_run_id,
|
||||||
transcript.workflow_run_id
|
)
|
||||||
)
|
await transcripts_controller.update(
|
||||||
logger.info(
|
transcript, {"workflow_run_id": None}
|
||||||
"Replaying Hatchet workflow",
|
)
|
||||||
workflow_id=transcript.workflow_run_id,
|
|
||||||
)
|
|
||||||
return transcript.workflow_run_id
|
|
||||||
|
|
||||||
# Force: cancel old workflow if exists
|
# Re-fetch and check for concurrent dispatch (optimistic approach).
|
||||||
if force and transcript and transcript.workflow_run_id:
|
# No database lock - worst case is duplicate dispatch, but Hatchet
|
||||||
await HatchetClientManager.cancel_workflow(
|
# workflows are idempotent so this is acceptable.
|
||||||
transcript.workflow_run_id
|
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
||||||
)
|
if transcript and transcript.workflow_run_id:
|
||||||
|
# Another process started a workflow between validation and now
|
||||||
|
try:
|
||||||
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cancelled old workflow (--force)",
|
"Concurrent workflow detected, skipping dispatch",
|
||||||
workflow_id=transcript.workflow_run_id,
|
workflow_id=transcript.workflow_run_id,
|
||||||
)
|
)
|
||||||
await transcripts_controller.update(
|
return None
|
||||||
transcript, {"workflow_run_id": None}
|
except Exception:
|
||||||
)
|
# If we can't get status, proceed with new workflow
|
||||||
|
pass
|
||||||
|
|
||||||
# Re-fetch transcript to check for concurrent dispatch (TOCTOU protection)
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
transcript = await transcripts_controller.get_by_id(
|
workflow_name="DiarizationPipeline",
|
||||||
config.transcript_id
|
input_data={
|
||||||
)
|
"recording_id": config.recording_id,
|
||||||
if transcript and transcript.workflow_run_id:
|
"room_name": None,
|
||||||
# Another process started a workflow between validation and now
|
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||||
try:
|
"bucket_name": config.bucket_name,
|
||||||
status = await HatchetClientManager.get_workflow_run_status(
|
"transcript_id": config.transcript_id,
|
||||||
transcript.workflow_run_id
|
"room_id": config.room_id,
|
||||||
)
|
},
|
||||||
if "RUNNING" in status or "QUEUED" in status:
|
additional_metadata={
|
||||||
logger.info(
|
"transcript_id": config.transcript_id,
|
||||||
"Concurrent workflow detected, skipping dispatch",
|
"recording_id": config.recording_id,
|
||||||
workflow_id=transcript.workflow_run_id,
|
"daily_recording_id": config.recording_id,
|
||||||
)
|
},
|
||||||
return transcript.workflow_run_id
|
)
|
||||||
except Exception:
|
|
||||||
# If we can't get status, proceed with new workflow
|
|
||||||
pass
|
|
||||||
|
|
||||||
workflow_id = await HatchetClientManager.start_workflow(
|
if transcript:
|
||||||
workflow_name="DiarizationPipeline",
|
await transcripts_controller.update(
|
||||||
input_data={
|
transcript, {"workflow_run_id": workflow_id}
|
||||||
"recording_id": config.recording_id,
|
)
|
||||||
"room_name": None,
|
|
||||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
|
||||||
"bucket_name": config.bucket_name,
|
|
||||||
"transcript_id": config.transcript_id,
|
|
||||||
"room_id": config.room_id,
|
|
||||||
},
|
|
||||||
additional_metadata={
|
|
||||||
"transcript_id": config.transcript_id,
|
|
||||||
"recording_id": config.recording_id,
|
|
||||||
"daily_recording_id": config.recording_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if transcript:
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript, {"workflow_run_id": workflow_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
return workflow_id
|
|
||||||
finally:
|
|
||||||
await db.disconnect()
|
|
||||||
_database_context.set(None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = None
|
|
||||||
|
|
||||||
if loop and loop.is_running():
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
||||||
workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result()
|
|
||||||
else:
|
|
||||||
workflow_id = asyncio.run(_handle_hatchet())
|
|
||||||
|
|
||||||
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||||
durable_started = True
|
|
||||||
|
|
||||||
# If durable workflow started, skip Celery
|
|
||||||
if durable_started:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Celery pipeline (durable workflows disabled)
|
# Celery pipeline (durable workflows disabled)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import time
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from celery.result import AsyncResult
|
from celery.result import AsyncResult
|
||||||
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
from reflector.db.transcripts import Transcript, transcripts_controller
|
from reflector.db.transcripts import Transcript, transcripts_controller
|
||||||
from reflector.services.transcript_process import (
|
from reflector.services.transcript_process import (
|
||||||
@@ -35,12 +36,12 @@ async def process_transcript_inner(
|
|||||||
on_validation: Callable[[ValidationResult], None],
|
on_validation: Callable[[ValidationResult], None],
|
||||||
on_preprocess: Callable[[PrepareResult], None],
|
on_preprocess: Callable[[PrepareResult], None],
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> AsyncResult:
|
) -> AsyncResult | None:
|
||||||
validation = await validate_transcript_for_processing(transcript)
|
validation = await validate_transcript_for_processing(transcript)
|
||||||
on_validation(validation)
|
on_validation(validation)
|
||||||
config = await prepare_transcript_processing(validation)
|
config = await prepare_transcript_processing(validation)
|
||||||
on_preprocess(config)
|
on_preprocess(config)
|
||||||
return dispatch_transcript_processing(config, force=force)
|
return await dispatch_transcript_processing(config, force=force)
|
||||||
|
|
||||||
|
|
||||||
async def process_transcript(
|
async def process_transcript(
|
||||||
@@ -92,7 +93,38 @@ async def process_transcript(
|
|||||||
force=force,
|
force=force,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sync:
|
if result is None:
|
||||||
|
# Hatchet workflow dispatched
|
||||||
|
if sync:
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
# Re-fetch transcript to get workflow_run_id
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if not transcript or not transcript.workflow_run_id:
|
||||||
|
print("Error: workflow_run_id not found", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("Waiting for Hatchet workflow...", file=sys.stderr)
|
||||||
|
while True:
|
||||||
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
print(f" Status: {status}", file=sys.stderr)
|
||||||
|
|
||||||
|
if status == V1TaskStatus.COMPLETED:
|
||||||
|
print("Workflow completed successfully", file=sys.stderr)
|
||||||
|
break
|
||||||
|
elif status in (V1TaskStatus.FAILED, V1TaskStatus.CANCELLED):
|
||||||
|
print(f"Workflow failed: {status}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Task dispatched (use --sync to wait for completion)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
elif sync:
|
||||||
print("Waiting for task completion...", file=sys.stderr)
|
print("Waiting for task completion...", file=sys.stderr)
|
||||||
while not result.ready():
|
while not result.ready():
|
||||||
print(f" Status: {result.state}", file=sys.stderr)
|
print(f" Status: {result.state}", file=sys.stderr)
|
||||||
|
|||||||
@@ -50,5 +50,5 @@ async def transcript_process(
|
|||||||
if isinstance(config, ProcessError):
|
if isinstance(config, ProcessError):
|
||||||
raise HTTPException(status_code=500, detail=config.detail)
|
raise HTTPException(status_code=500, detail=config.detail)
|
||||||
else:
|
else:
|
||||||
dispatch_transcript_processing(config)
|
await dispatch_transcript_processing(config)
|
||||||
return ProcessStatus(status="ok")
|
return ProcessStatus(status="ok")
|
||||||
|
|||||||
@@ -286,10 +286,18 @@ async def _process_multitrack_recording_inner(
|
|||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start durable workflow if enabled (Hatchet)
|
# Start durable workflow if enabled (Hatchet) or room overrides it
|
||||||
durable_started = False
|
durable_started = False
|
||||||
|
use_hatchet = settings.HATCHET_ENABLED or (room and room.use_hatchet)
|
||||||
|
|
||||||
if settings.HATCHET_ENABLED:
|
if room and room.use_hatchet and not settings.HATCHET_ENABLED:
|
||||||
|
logger.info(
|
||||||
|
"Room forces Hatchet workflow",
|
||||||
|
room_id=room.id,
|
||||||
|
transcript_id=transcript.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_hatchet:
|
||||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||||
|
|
||||||
workflow_id = await HatchetClientManager.start_workflow(
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
|
|||||||
@@ -109,29 +109,19 @@ class WebsocketManager:
|
|||||||
await socket.send_json(data)
|
await socket.send_json(data)
|
||||||
|
|
||||||
|
|
||||||
|
_ws_manager_instance: WebsocketManager | None = None
|
||||||
|
_ws_manager_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_ws_manager() -> WebsocketManager:
|
def get_ws_manager() -> WebsocketManager:
|
||||||
"""
|
"""Returns the WebsocketManager singleton instance."""
|
||||||
Returns the WebsocketManager instance for managing websockets.
|
global _ws_manager_instance
|
||||||
|
if _ws_manager_instance is None:
|
||||||
This function initializes and returns the WebsocketManager instance,
|
with _ws_manager_lock:
|
||||||
which is responsible for managing websockets and handling websocket
|
if _ws_manager_instance is None:
|
||||||
connections.
|
pubsub_client = RedisPubSubManager(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
Returns:
|
port=settings.REDIS_PORT,
|
||||||
WebsocketManager: The initialized WebsocketManager instance.
|
)
|
||||||
|
_ws_manager_instance = WebsocketManager(pubsub_client=pubsub_client)
|
||||||
Raises:
|
return _ws_manager_instance
|
||||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
|
||||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
|
||||||
"""
|
|
||||||
local = threading.local()
|
|
||||||
if hasattr(local, "ws_manager"):
|
|
||||||
return local.ws_manager
|
|
||||||
|
|
||||||
pubsub_client = RedisPubSubManager(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
)
|
|
||||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
|
||||||
local.ws_manager = ws_manager
|
|
||||||
return ws_manager
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ These tests verify:
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||||
|
|
||||||
from reflector.db.transcripts import Transcript
|
from reflector.db.transcripts import Transcript
|
||||||
|
|
||||||
@@ -35,8 +36,12 @@ async def test_hatchet_validation_blocks_running_workflow():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock(return_value="RUNNING")
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
|
return_value=V1TaskStatus.RUNNING
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
||||||
@@ -69,8 +74,12 @@ async def test_hatchet_validation_blocks_queued_workflow():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock(return_value="QUEUED")
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
|
return_value=V1TaskStatus.QUEUED
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
||||||
@@ -103,8 +112,12 @@ async def test_hatchet_validation_allows_failed_workflow():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock(return_value="FAILED")
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
|
return_value=V1TaskStatus.FAILED
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
||||||
@@ -138,8 +151,12 @@ async def test_hatchet_validation_allows_completed_workflow():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock(return_value="COMPLETED")
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
|
return_value=V1TaskStatus.COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
"reflector.services.transcript_process.task_is_scheduled_or_active"
|
||||||
@@ -172,7 +189,9 @@ async def test_hatchet_validation_allows_when_status_check_fails():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
# Status check fails (workflow might be deleted)
|
# Status check fails (workflow might be deleted)
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock(
|
mock_hatchet.get_workflow_run_status = AsyncMock(
|
||||||
side_effect=Exception("Workflow not found")
|
side_effect=Exception("Workflow not found")
|
||||||
@@ -210,7 +229,9 @@ async def test_hatchet_validation_skipped_when_no_workflow_id():
|
|||||||
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
with patch("reflector.services.transcript_process.settings") as mock_settings:
|
||||||
mock_settings.HATCHET_ENABLED = True
|
mock_settings.HATCHET_ENABLED = True
|
||||||
|
|
||||||
with patch("reflector.hatchet.client.HatchetClientManager") as mock_hatchet:
|
with patch(
|
||||||
|
"reflector.services.transcript_process.HatchetClientManager"
|
||||||
|
) as mock_hatchet:
|
||||||
# Should not be called
|
# Should not be called
|
||||||
mock_hatchet.get_workflow_run_status = AsyncMock()
|
mock_hatchet.get_workflow_run_status = AsyncMock()
|
||||||
|
|
||||||
@@ -338,7 +359,7 @@ async def test_prepare_multitrack_config():
|
|||||||
assert result.bucket_name == "test-bucket"
|
assert result.bucket_name == "test-bucket"
|
||||||
assert result.track_keys == ["track1.webm", "track2.webm"]
|
assert result.track_keys == ["track1.webm", "track2.webm"]
|
||||||
assert result.transcript_id == "test-transcript-id"
|
assert result.transcript_id == "test-transcript-id"
|
||||||
assert result.room_id == "test-room"
|
assert result.room_id is None # ValidationOk didn't specify room_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
|
|||||||
Reference in New Issue
Block a user