diff --git a/server/reflector/dailyco_api/webhook_utils.py b/server/reflector/dailyco_api/webhook_utils.py index b10d4fa2..27d5fb4e 100644 --- a/server/reflector/dailyco_api/webhook_utils.py +++ b/server/reflector/dailyco_api/webhook_utils.py @@ -195,7 +195,6 @@ def parse_recording_error(event: DailyWebhookEvent) -> RecordingErrorPayload: return RecordingErrorPayload(**event.payload) -# Webhook event type to parser mapping WEBHOOK_PARSERS = { "participant.joined": parse_participant_joined, "participant.left": parse_participant_left, diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py new file mode 100644 index 00000000..bc48a4eb --- /dev/null +++ b/server/reflector/services/transcript_process.py @@ -0,0 +1,169 @@ +""" +Transcript processing service - shared logic for HTTP endpoints and Celery tasks. + +This module provides result-based error handling that works in both contexts: +- HTTP endpoint: converts errors to HTTPException +- Celery task: converts errors to Exception +""" + +from dataclasses import dataclass +from typing import Literal, Union + +import celery +from celery.result import AsyncResult + +from reflector.db.recordings import recordings_controller +from reflector.db.transcripts import Transcript +from reflector.pipelines.main_file_pipeline import task_pipeline_file_process +from reflector.pipelines.main_multitrack_pipeline import ( + task_pipeline_multitrack_process, +) +from reflector.utils.match import absurd +from reflector.utils.string import NonEmptyString + + +@dataclass +class ProcessError: + detail: NonEmptyString + + +@dataclass +class FileProcessingConfig: + transcript_id: NonEmptyString + mode: Literal["file"] = "file" + + +@dataclass +class MultitrackProcessingConfig: + transcript_id: NonEmptyString + bucket_name: NonEmptyString + track_keys: list[str] + mode: Literal["multitrack"] = "multitrack" + + +ProcessingConfig = Union[FileProcessingConfig, MultitrackProcessingConfig] +PrepareResult = Union[ProcessingConfig, ProcessError] + + +@dataclass +class ValidationOk: + # transcript currently doesnt always have recording_id + recording_id: NonEmptyString | None + transcript_id: NonEmptyString + + +@dataclass +class ValidationLocked: + detail: NonEmptyString + + +@dataclass +class ValidationNotReady: + detail: NonEmptyString + + +@dataclass +class ValidationAlreadyScheduled: + detail: NonEmptyString + + +ValidationError = Union[ + ValidationNotReady, ValidationLocked, ValidationAlreadyScheduled +] +ValidationResult = Union[ValidationOk, ValidationError] + + +@dataclass +class DispatchOk: + status: Literal["ok"] = "ok" + + +@dataclass +class DispatchAlreadyRunning: + status: Literal["already_running"] = "already_running" + + +DispatchResult = Union[ + DispatchOk, DispatchAlreadyRunning, ProcessError, ValidationError +] + + +async def validate_transcript_for_processing( + transcript: Transcript, +) -> ValidationResult: + if transcript.locked: + return ValidationLocked(detail="Recording is locked") + + if transcript.status == "idle": + return ValidationNotReady(detail="Recording is not ready for processing") + + if task_is_scheduled_or_active( + "reflector.pipelines.main_file_pipeline.task_pipeline_file_process", + transcript_id=transcript.id, + ) or task_is_scheduled_or_active( + "reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process", + transcript_id=transcript.id, + ): + return ValidationAlreadyScheduled(detail="already running") + + return ValidationOk( + recording_id=transcript.recording_id, transcript_id=transcript.id + ) + + +async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResult: + """ + Determine processing mode from transcript/recording data. + """ + bucket_name: str | None = None + track_keys: list[str] | None = None + + if validation.recording_id: + recording = await recordings_controller.get_by_id(validation.recording_id) + if recording: + bucket_name = recording.bucket_name + track_keys = recording.track_keys + + if track_keys is not None and len(track_keys) == 0: + return ProcessError( + detail="No track keys found, must be either > 0 or None", + ) + if track_keys is not None and not bucket_name: + return ProcessError( + detail="Bucket name must be specified", + ) + + if track_keys: + return MultitrackProcessingConfig( + bucket_name=bucket_name, # type: ignore (validated above) + track_keys=track_keys, + transcript_id=validation.transcript_id, + ) + + return FileProcessingConfig( + transcript_id=validation.transcript_id, + ) + + +def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult: + if isinstance(config, MultitrackProcessingConfig): + return task_pipeline_multitrack_process.delay( + transcript_id=config.transcript_id, + bucket_name=config.bucket_name, + track_keys=config.track_keys, + ) + elif isinstance(config, FileProcessingConfig): + return task_pipeline_file_process.delay(transcript_id=config.transcript_id) + else: + absurd(config) + + +def task_is_scheduled_or_active(task_name: str, **kwargs): + inspect = celery.current_app.control.inspect() + + for worker, tasks in (inspect.scheduled() | inspect.active()).items(): + for task in tasks: + if task["name"] == task_name and task["kwargs"] == kwargs: + return True + + return False diff --git a/server/reflector/tools/process_transcript.py b/server/reflector/tools/process_transcript.py new file mode 100644 index 00000000..ce9efd71 --- /dev/null +++ b/server/reflector/tools/process_transcript.py @@ -0,0 +1,127 @@ +""" +Process transcript by ID - auto-detects multitrack vs file pipeline. + +Usage: + uv run -m reflector.tools.process_transcript + + # Or via docker: + docker compose exec server uv run -m reflector.tools.process_transcript +""" + +import argparse +import asyncio +import sys +import time +from typing import Callable + +from celery.result import AsyncResult + +from reflector.db.transcripts import Transcript, transcripts_controller +from reflector.services.transcript_process import ( + FileProcessingConfig, + MultitrackProcessingConfig, + PrepareResult, + ProcessError, + ValidationError, + ValidationResult, + dispatch_transcript_processing, + prepare_transcript_processing, + validate_transcript_for_processing, +) + + +async def process_transcript_inner( + transcript: Transcript, + on_validation: Callable[[ValidationResult], None], + on_preprocess: Callable[[PrepareResult], None], +) -> AsyncResult: + validation = await validate_transcript_for_processing(transcript) + on_validation(validation) + config = await prepare_transcript_processing(validation) + on_preprocess(config) + return dispatch_transcript_processing(config) + + +async def process_transcript(transcript_id: str, sync: bool = False) -> None: + """ + Process a transcript by ID, auto-detecting multitrack vs file pipeline. + + Args: + transcript_id: The transcript UUID + sync: If True, wait for task completion. If False, dispatch and exit. + """ + from reflector.db import get_database + + database = get_database() + await database.connect() + + try: + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + print(f"Error: Transcript {transcript_id} not found", file=sys.stderr) + sys.exit(1) + + print(f"Found transcript: {transcript.title or transcript_id}", file=sys.stderr) + print(f" Status: {transcript.status}", file=sys.stderr) + print(f" Recording ID: {transcript.recording_id or 'None'}", file=sys.stderr) + + def on_validation(validation: ValidationResult) -> None: + if isinstance(validation, ValidationError): + print(f"Error: {validation.detail}", file=sys.stderr) + sys.exit(1) + + def on_preprocess(config: PrepareResult) -> None: + if isinstance(config, ProcessError): + print(f"Error: {config.detail}", file=sys.stderr) + sys.exit(1) + elif isinstance(config, MultitrackProcessingConfig): + print(f"Dispatching multitrack pipeline", file=sys.stderr) + print(f" Bucket: {config.bucket_name}", file=sys.stderr) + print(f" Tracks: {len(config.track_keys)}", file=sys.stderr) + elif isinstance(config, FileProcessingConfig): + print(f"Dispatching file pipeline", file=sys.stderr) + + result = await process_transcript_inner( + transcript, on_validation=on_validation, on_preprocess=on_preprocess + ) + + if sync: + print("Waiting for task completion...", file=sys.stderr) + while not result.ready(): + print(f" Status: {result.state}", file=sys.stderr) + time.sleep(5) + + if result.successful(): + print("Task completed successfully", file=sys.stderr) + else: + print(f"Task failed: {result.result}", file=sys.stderr) + sys.exit(1) + else: + print( + "Task dispatched (use --sync to wait for completion)", file=sys.stderr + ) + + finally: + await database.disconnect() + + +def main(): + parser = argparse.ArgumentParser( + description="Process transcript by ID - auto-detects multitrack vs file pipeline" + ) + parser.add_argument( + "transcript_id", + help="Transcript UUID to process", + ) + parser.add_argument( + "--sync", + action="store_true", + help="Wait for task completion instead of just dispatching", + ) + + args = parser.parse_args() + asyncio.run(process_transcript(args.transcript_id, sync=args.sync)) + + +if __name__ == "__main__": + main() diff --git a/server/reflector/utils/match.py b/server/reflector/utils/match.py new file mode 100644 index 00000000..e0f6bc53 --- /dev/null +++ b/server/reflector/utils/match.py @@ -0,0 +1,10 @@ +from typing import NoReturn + + +def assert_exhaustiveness(x: NoReturn) -> NoReturn: + """Provide an assertion at type-check time that this function is never called.""" + raise AssertionError(f"Invalid value: {x!r}") + + +def absurd(x: NoReturn) -> NoReturn: + return assert_exhaustiveness(x) diff --git a/server/reflector/views/transcripts_process.py b/server/reflector/views/transcripts_process.py index cee1e10d..88f11e71 100644 --- a/server/reflector/views/transcripts_process.py +++ b/server/reflector/views/transcripts_process.py @@ -1,16 +1,21 @@ from typing import Annotated, Optional -import celery from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel import reflector.auth as auth -from reflector.db.recordings import recordings_controller from reflector.db.transcripts import transcripts_controller -from reflector.pipelines.main_file_pipeline import task_pipeline_file_process -from reflector.pipelines.main_multitrack_pipeline import ( - task_pipeline_multitrack_process, +from reflector.services.transcript_process import ( + ProcessError, + ValidationAlreadyScheduled, + ValidationError, + ValidationLocked, + ValidationOk, + dispatch_transcript_processing, + prepare_transcript_processing, + validate_transcript_for_processing, ) +from reflector.utils.match import absurd router = APIRouter() @@ -23,68 +28,28 @@ class ProcessStatus(BaseModel): async def transcript_process( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): +) -> ProcessStatus: user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( transcript_id, user_id=user_id ) - if transcript.locked: - raise HTTPException(status_code=400, detail="Transcript is locked") - - if transcript.status == "idle": - raise HTTPException( - status_code=400, detail="Recording is not ready for processing" - ) - - # avoid duplicate scheduling for either pipeline - if task_is_scheduled_or_active( - "reflector.pipelines.main_file_pipeline.task_pipeline_file_process", - transcript_id=transcript_id, - ) or task_is_scheduled_or_active( - "reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process", - transcript_id=transcript_id, - ): - return ProcessStatus(status="already running") - - # Determine processing mode strictly from DB to avoid S3 scans - bucket_name = None - track_keys: list[str] = [] - - if transcript.recording_id: - recording = await recordings_controller.get_by_id(transcript.recording_id) - if recording: - bucket_name = recording.bucket_name - track_keys = recording.track_keys - if track_keys is not None and len(track_keys) == 0: - raise HTTPException( - status_code=500, - detail="No track keys found, must be either > 0 or None", - ) - if track_keys is not None and not bucket_name: - raise HTTPException( - status_code=500, detail="Bucket name must be specified" - ) - - if track_keys: - task_pipeline_multitrack_process.delay( - transcript_id=transcript_id, - bucket_name=bucket_name, - track_keys=track_keys, - ) + validation = await validate_transcript_for_processing(transcript) + if isinstance(validation, ValidationLocked): + raise HTTPException(status_code=400, detail=validation.detail) + elif isinstance(validation, ValidationError): + raise HTTPException(status_code=400, detail=validation.detail) + elif isinstance(validation, ValidationAlreadyScheduled): + return ProcessStatus(status=validation.detail) + elif isinstance(validation, ValidationOk): + pass else: - # Default single-file pipeline - task_pipeline_file_process.delay(transcript_id=transcript_id) + absurd(validation) - return ProcessStatus(status="ok") + config = await prepare_transcript_processing(validation) - -def task_is_scheduled_or_active(task_name: str, **kwargs): - inspect = celery.current_app.control.inspect() - - for worker, tasks in (inspect.scheduled() | inspect.active()).items(): - for task in tasks: - if task["name"] == task_name and task["kwargs"] == kwargs: - return True - - return False + if isinstance(config, ProcessError): + raise HTTPException(status_code=500, detail=config.detail) + else: + dispatch_transcript_processing(config) + return ProcessStatus(status="ok") diff --git a/server/tests/test_transcripts_process.py b/server/tests/test_transcripts_process.py index 3a0614c1..e3d749df 100644 --- a/server/tests/test_transcripts_process.py +++ b/server/tests/test_transcripts_process.py @@ -139,10 +139,10 @@ async def test_whereby_recording_uses_file_pipeline(client): with ( patch( - "reflector.views.transcripts_process.task_pipeline_file_process" + "reflector.services.transcript_process.task_pipeline_file_process" ) as mock_file_pipeline, patch( - "reflector.views.transcripts_process.task_pipeline_multitrack_process" + "reflector.services.transcript_process.task_pipeline_multitrack_process" ) as mock_multitrack_pipeline, ): response = await client.post(f"/transcripts/{transcript.id}/process") @@ -194,10 +194,10 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client): with ( patch( - "reflector.views.transcripts_process.task_pipeline_file_process" + "reflector.services.transcript_process.task_pipeline_file_process" ) as mock_file_pipeline, patch( - "reflector.views.transcripts_process.task_pipeline_multitrack_process" + "reflector.services.transcript_process.task_pipeline_multitrack_process" ) as mock_multitrack_pipeline, ): response = await client.post(f"/transcripts/{transcript.id}/process")