diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 2beef845..b3f6d49c 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -5,7 +5,7 @@ import shutil from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Sequence import sqlalchemy from fastapi import HTTPException @@ -180,7 +180,7 @@ class TranscriptDuration(BaseModel): class TranscriptWaveform(BaseModel): - waveform: list[float] + waveform: Sequence[float] class TranscriptEvent(BaseModel): diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 224ebcef..b7b0559e 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -16,7 +16,7 @@ import tempfile from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path -from typing import Callable +from typing import Any, Callable, Coroutine, TypeVar import httpx from hatchet_sdk import Context @@ -162,7 +162,15 @@ def _spawn_storage(): ) -def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable: +R = TypeVar("R") + + +def with_error_handling( + step_name: str, set_error_status: bool = True +) -> Callable[ + [Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]], + Callable[[PipelineInput, Context], Coroutine[Any, Any, R]], +]: """Decorator that handles task failures uniformly. Args: @@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab set_error_status: Whether to set transcript status to 'error' on failure. """ - def decorator(func: Callable) -> Callable: + def decorator( + func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]], + ) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]: @functools.wraps(func) - async def wrapper(input: PipelineInput, ctx: Context): + async def wrapper(input: PipelineInput, ctx: Context) -> R: try: return await func(input, ctx) except Exception as e: @@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab await set_workflow_error_status(input.transcript_id) raise - return wrapper + return wrapper # type: ignore[return-value] return decorator @@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe recording = ctx.task_output(get_recording) mtg_session_id = recording.mtg_session_id - async with fresh_db_connection(): from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptParticipant, @@ -264,16 +273,17 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe ) transcript = await transcripts_controller.get_by_id(input.transcript_id) - if transcript: - # Note: title NOT cleared - preserves existing titles - await transcripts_controller.update( - transcript, - { - "events": [], - "topics": [], - "participants": [], - }, - ) + if not transcript: + raise ValueError(f"Transcript {input.transcript_id} not found") + # Note: title NOT cleared - preserves existing titles + await transcripts_controller.update( + transcript, + { + "events": [], + "topics": [], + "participants": [], + }, + ) mtg_session_id = assert_non_none_and_non_empty( mtg_session_id, "mtg_session_id is required" @@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) + if not transcript: + raise ValueError(f"Transcript {input.transcript_id} not found") for chunk in topic_chunks: topic = TranscriptTopic( @@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: summary=chunk.summary, timestamp=chunk.timestamp, transcript=" ".join(w.text for w in chunk.words), - words=[w.model_dump() for w in chunk.words], + words=chunk.words, ) await transcripts_controller.upsert_topic(transcript, topic) await append_event_and_broadcast( @@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: async with fresh_db_connection(): ctx.log("generate_title: DB connection established") transcript = await transcripts_controller.get_by_id(input.transcript_id) + if not transcript: + raise ValueError(f"Transcript {input.transcript_id} not found") ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}") async def on_title_callback(data): diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index 748361d9..c2a98901 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -42,7 +42,7 @@ class RecordingResult(BaseModel): id: NonEmptyString | None mtg_session_id: NonEmptyString | None - duration: float + duration: int | None class ParticipantsResult(BaseModel):