pipeline type fixes (#812)

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-12-26 11:28:43 -05:00
committed by GitHub
parent bab1e2d537
commit 5baa6dd92e
3 changed files with 34 additions and 20 deletions

View File

@@ -5,7 +5,7 @@ import shutil
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal, Sequence
import sqlalchemy import sqlalchemy
from fastapi import HTTPException from fastapi import HTTPException
@@ -180,7 +180,7 @@ class TranscriptDuration(BaseModel):
class TranscriptWaveform(BaseModel): class TranscriptWaveform(BaseModel):
waveform: list[float] waveform: Sequence[float]
class TranscriptEvent(BaseModel): class TranscriptEvent(BaseModel):

View File

@@ -16,7 +16,7 @@ import tempfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Any, Callable, Coroutine, TypeVar
import httpx import httpx
from hatchet_sdk import Context from hatchet_sdk import Context
@@ -162,7 +162,15 @@ def _spawn_storage():
) )
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable: R = TypeVar("R")
def with_error_handling(
step_name: str, set_error_status: bool = True
) -> Callable[
[Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]],
Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
]:
"""Decorator that handles task failures uniformly. """Decorator that handles task failures uniformly.
Args: Args:
@@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
set_error_status: Whether to set transcript status to 'error' on failure. set_error_status: Whether to set transcript status to 'error' on failure.
""" """
def decorator(func: Callable) -> Callable: def decorator(
func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context): async def wrapper(input: PipelineInput, ctx: Context) -> R:
try: try:
return await func(input, ctx) return await func(input, ctx)
except Exception as e: except Exception as e:
@@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
await set_workflow_error_status(input.transcript_id) await set_workflow_error_status(input.transcript_id)
raise raise
return wrapper return wrapper # type: ignore[return-value]
return decorator return decorator
@@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
recording = ctx.task_output(get_recording) recording = ctx.task_output(get_recording)
mtg_session_id = recording.mtg_session_id mtg_session_id = recording.mtg_session_id
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptParticipant, TranscriptParticipant,
@@ -264,16 +273,17 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
) )
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript: if not transcript:
# Note: title NOT cleared - preserves existing titles raise ValueError(f"Transcript {input.transcript_id} not found")
await transcripts_controller.update( # Note: title NOT cleared - preserves existing titles
transcript, await transcripts_controller.update(
{ transcript,
"events": [], {
"topics": [], "events": [],
"participants": [], "topics": [],
}, "participants": [],
) },
)
mtg_session_id = assert_non_none_and_non_empty( mtg_session_id = assert_non_none_and_non_empty(
mtg_session_id, "mtg_session_id is required" mtg_session_id, "mtg_session_id is required"
@@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
async with fresh_db_connection(): async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
for chunk in topic_chunks: for chunk in topic_chunks:
topic = TranscriptTopic( topic = TranscriptTopic(
@@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
summary=chunk.summary, summary=chunk.summary,
timestamp=chunk.timestamp, timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words), transcript=" ".join(w.text for w in chunk.words),
words=[w.model_dump() for w in chunk.words], words=chunk.words,
) )
await transcripts_controller.upsert_topic(transcript, topic) await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast( await append_event_and_broadcast(
@@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
async with fresh_db_connection(): async with fresh_db_connection():
ctx.log("generate_title: DB connection established") ctx.log("generate_title: DB connection established")
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}") ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
async def on_title_callback(data): async def on_title_callback(data):

View File

@@ -42,7 +42,7 @@ class RecordingResult(BaseModel):
id: NonEmptyString | None id: NonEmptyString | None
mtg_session_id: NonEmptyString | None mtg_session_id: NonEmptyString | None
duration: float duration: int | None
class ParticipantsResult(BaseModel): class ParticipantsResult(BaseModel):