mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
299 lines
10 KiB
Python
299 lines
10 KiB
Python
"""
|
|
Transcript processing service - shared logic for HTTP endpoints and Celery tasks.
|
|
|
|
This module provides result-based error handling that works in both contexts:
|
|
- HTTP endpoint: converts errors to HTTPException
|
|
- Celery task: converts errors to Exception
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Literal, Union, assert_never
|
|
|
|
import celery
|
|
from celery.result import AsyncResult
|
|
from hatchet_sdk.clients.rest.exceptions import ApiException
|
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
|
|
|
from reflector.db.recordings import recordings_controller
|
|
from reflector.db.rooms import rooms_controller
|
|
from reflector.db.transcripts import Transcript, transcripts_controller
|
|
from reflector.hatchet.client import HatchetClientManager
|
|
from reflector.logger import logger
|
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
|
from reflector.pipelines.main_multitrack_pipeline import (
|
|
task_pipeline_multitrack_process,
|
|
)
|
|
from reflector.settings import settings
|
|
from reflector.utils.string import NonEmptyString
|
|
|
|
|
|
@dataclass
|
|
class ProcessError:
|
|
detail: NonEmptyString
|
|
|
|
|
|
@dataclass
|
|
class FileProcessingConfig:
|
|
transcript_id: NonEmptyString
|
|
mode: Literal["file"] = "file"
|
|
|
|
|
|
@dataclass
|
|
class MultitrackProcessingConfig:
|
|
transcript_id: NonEmptyString
|
|
bucket_name: NonEmptyString
|
|
track_keys: list[str]
|
|
recording_id: NonEmptyString | None = None
|
|
room_id: NonEmptyString | None = None
|
|
mode: Literal["multitrack"] = "multitrack"
|
|
|
|
|
|
ProcessingConfig = Union[FileProcessingConfig, MultitrackProcessingConfig]
|
|
PrepareResult = Union[ProcessingConfig, ProcessError]
|
|
|
|
|
|
@dataclass
|
|
class ValidationOk:
|
|
# transcript currently doesnt always have recording_id
|
|
recording_id: NonEmptyString | None
|
|
transcript_id: NonEmptyString
|
|
room_id: NonEmptyString | None = None
|
|
|
|
|
|
@dataclass
|
|
class ValidationLocked:
|
|
detail: NonEmptyString
|
|
|
|
|
|
@dataclass
|
|
class ValidationNotReady:
|
|
detail: NonEmptyString
|
|
|
|
|
|
@dataclass
|
|
class ValidationAlreadyScheduled:
|
|
detail: NonEmptyString
|
|
|
|
|
|
ValidationError = Union[
|
|
ValidationNotReady, ValidationLocked, ValidationAlreadyScheduled
|
|
]
|
|
ValidationResult = Union[ValidationOk, ValidationError]
|
|
|
|
|
|
@dataclass
|
|
class DispatchOk:
|
|
status: Literal["ok"] = "ok"
|
|
|
|
|
|
@dataclass
|
|
class DispatchAlreadyRunning:
|
|
status: Literal["already_running"] = "already_running"
|
|
|
|
|
|
DispatchResult = Union[
|
|
DispatchOk, DispatchAlreadyRunning, ProcessError, ValidationError
|
|
]
|
|
|
|
|
|
async def validate_transcript_for_processing(
|
|
transcript: Transcript,
|
|
) -> ValidationResult:
|
|
if transcript.locked:
|
|
return ValidationLocked(detail="Recording is locked")
|
|
|
|
if transcript.status == "idle":
|
|
return ValidationNotReady(detail="Recording is not ready for processing")
|
|
|
|
# Check Celery tasks
|
|
if task_is_scheduled_or_active(
|
|
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
|
transcript_id=transcript.id,
|
|
) or task_is_scheduled_or_active(
|
|
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
|
|
transcript_id=transcript.id,
|
|
):
|
|
return ValidationAlreadyScheduled(detail="already running")
|
|
|
|
# Check Hatchet workflows (if enabled)
|
|
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
|
|
try:
|
|
status = await HatchetClientManager.get_workflow_run_status(
|
|
transcript.workflow_run_id
|
|
)
|
|
# If workflow is running or queued, don't allow new processing
|
|
if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED):
|
|
return ValidationAlreadyScheduled(
|
|
detail="Hatchet workflow already running"
|
|
)
|
|
except ApiException:
|
|
# Workflow might be gone (404) or API issue - allow processing
|
|
pass
|
|
|
|
return ValidationOk(
|
|
recording_id=transcript.recording_id,
|
|
transcript_id=transcript.id,
|
|
room_id=transcript.room_id,
|
|
)
|
|
|
|
|
|
async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResult:
|
|
"""
|
|
Determine processing mode from transcript/recording data.
|
|
"""
|
|
bucket_name: str | None = None
|
|
track_keys: list[str] | None = None
|
|
recording_id: str | None = validation.recording_id
|
|
|
|
if validation.recording_id:
|
|
recording = await recordings_controller.get_by_id(validation.recording_id)
|
|
if recording:
|
|
bucket_name = recording.bucket_name
|
|
track_keys = recording.track_keys
|
|
|
|
if track_keys is not None and len(track_keys) == 0:
|
|
return ProcessError(
|
|
detail="No track keys found, must be either > 0 or None",
|
|
)
|
|
if track_keys is not None and not bucket_name:
|
|
return ProcessError(
|
|
detail="Bucket name must be specified",
|
|
)
|
|
|
|
if track_keys:
|
|
return MultitrackProcessingConfig(
|
|
bucket_name=bucket_name, # type: ignore (validated above)
|
|
track_keys=track_keys,
|
|
transcript_id=validation.transcript_id,
|
|
recording_id=recording_id,
|
|
room_id=validation.room_id,
|
|
)
|
|
|
|
return FileProcessingConfig(
|
|
transcript_id=validation.transcript_id,
|
|
)
|
|
|
|
|
|
async def dispatch_transcript_processing(
|
|
config: ProcessingConfig, force: bool = False
|
|
) -> AsyncResult | None:
|
|
"""Dispatch transcript processing to appropriate backend (Hatchet or Celery).
|
|
|
|
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
|
|
"""
|
|
if isinstance(config, MultitrackProcessingConfig):
|
|
# Check if room has use_hatchet=True (overrides env vars)
|
|
room_forces_hatchet = False
|
|
if config.room_id:
|
|
room = await rooms_controller.get_by_id(config.room_id)
|
|
room_forces_hatchet = room.use_hatchet if room else False
|
|
|
|
# Start durable workflow if enabled (Hatchet)
|
|
# or if room has use_hatchet=True
|
|
use_hatchet = settings.HATCHET_ENABLED or room_forces_hatchet
|
|
|
|
if room_forces_hatchet:
|
|
logger.info(
|
|
"Room forces Hatchet workflow",
|
|
room_id=config.room_id,
|
|
transcript_id=config.transcript_id,
|
|
)
|
|
|
|
if use_hatchet:
|
|
# First check if we can replay (outside transaction since it's read-only)
|
|
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
|
if transcript and transcript.workflow_run_id and not force:
|
|
can_replay = await HatchetClientManager.can_replay(
|
|
transcript.workflow_run_id
|
|
)
|
|
if can_replay:
|
|
await HatchetClientManager.replay_workflow(
|
|
transcript.workflow_run_id
|
|
)
|
|
logger.info(
|
|
"Replaying Hatchet workflow",
|
|
workflow_id=transcript.workflow_run_id,
|
|
)
|
|
return None
|
|
|
|
# Force: cancel old workflow if exists
|
|
if force and transcript and transcript.workflow_run_id:
|
|
await HatchetClientManager.cancel_workflow(transcript.workflow_run_id)
|
|
logger.info(
|
|
"Cancelled old workflow (--force)",
|
|
workflow_id=transcript.workflow_run_id,
|
|
)
|
|
await transcripts_controller.update(
|
|
transcript, {"workflow_run_id": None}
|
|
)
|
|
|
|
# Re-fetch and check for concurrent dispatch (optimistic approach).
|
|
# No database lock - worst case is duplicate dispatch, but Hatchet
|
|
# workflows are idempotent so this is acceptable.
|
|
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
|
if transcript and transcript.workflow_run_id:
|
|
# Another process started a workflow between validation and now
|
|
try:
|
|
status = await HatchetClientManager.get_workflow_run_status(
|
|
transcript.workflow_run_id
|
|
)
|
|
if status in (V1TaskStatus.RUNNING, V1TaskStatus.QUEUED):
|
|
logger.info(
|
|
"Concurrent workflow detected, skipping dispatch",
|
|
workflow_id=transcript.workflow_run_id,
|
|
)
|
|
return None
|
|
except ApiException:
|
|
# Workflow might be gone (404) or API issue - proceed with new workflow
|
|
pass
|
|
|
|
workflow_id = await HatchetClientManager.start_workflow(
|
|
workflow_name="DiarizationPipeline",
|
|
input_data={
|
|
"recording_id": config.recording_id,
|
|
"room_name": None,
|
|
"tracks": [{"s3_key": k} for k in config.track_keys],
|
|
"bucket_name": config.bucket_name,
|
|
"transcript_id": config.transcript_id,
|
|
"room_id": config.room_id,
|
|
},
|
|
additional_metadata={
|
|
"transcript_id": config.transcript_id,
|
|
"recording_id": config.recording_id,
|
|
"daily_recording_id": config.recording_id,
|
|
},
|
|
)
|
|
|
|
if transcript:
|
|
await transcripts_controller.update(
|
|
transcript, {"workflow_run_id": workflow_id}
|
|
)
|
|
|
|
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
|
return None
|
|
|
|
# Celery pipeline (durable workflows disabled)
|
|
return task_pipeline_multitrack_process.delay(
|
|
transcript_id=config.transcript_id,
|
|
bucket_name=config.bucket_name,
|
|
track_keys=config.track_keys,
|
|
)
|
|
elif isinstance(config, FileProcessingConfig):
|
|
return task_pipeline_file_process.delay(transcript_id=config.transcript_id)
|
|
else:
|
|
assert_never(config)
|
|
|
|
|
|
def task_is_scheduled_or_active(task_name: str, **kwargs):
|
|
inspect = celery.current_app.control.inspect()
|
|
|
|
scheduled = inspect.scheduled() or {}
|
|
active = inspect.active() or {}
|
|
all = scheduled | active
|
|
for worker, tasks in all.items():
|
|
for task in tasks:
|
|
if task["name"] == task_name and task["kwargs"] == kwargs:
|
|
return True
|
|
|
|
return False
|