pipeline type fixes (#812)

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-12-26 11:28:43 -05:00
committed by GitHub
parent bab1e2d537
commit 5baa6dd92e
3 changed files with 34 additions and 20 deletions

View File

@@ -5,7 +5,7 @@ import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, Sequence
import sqlalchemy
from fastapi import HTTPException
@@ -180,7 +180,7 @@ class TranscriptDuration(BaseModel):
class TranscriptWaveform(BaseModel):
waveform: list[float]
waveform: Sequence[float]
class TranscriptEvent(BaseModel):

View File

@@ -16,7 +16,7 @@ import tempfile
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from typing import Callable
from typing import Any, Callable, Coroutine, TypeVar
import httpx
from hatchet_sdk import Context
@@ -162,7 +162,15 @@ def _spawn_storage():
)
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
R = TypeVar("R")
def with_error_handling(
step_name: str, set_error_status: bool = True
) -> Callable[
[Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]],
Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
]:
"""Decorator that handles task failures uniformly.
Args:
@@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
set_error_status: Whether to set transcript status to 'error' on failure.
"""
def decorator(func: Callable) -> Callable:
def decorator(
func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context):
async def wrapper(input: PipelineInput, ctx: Context) -> R:
try:
return await func(input, ctx)
except Exception as e:
@@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
await set_workflow_error_status(input.transcript_id)
raise
return wrapper
return wrapper # type: ignore[return-value]
return decorator
@@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
recording = ctx.task_output(get_recording)
mtg_session_id = recording.mtg_session_id
async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptParticipant,
@@ -264,7 +273,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
# Note: title NOT cleared - preserves existing titles
await transcripts_controller.update(
transcript,
@@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
for chunk in topic_chunks:
topic = TranscriptTopic(
@@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
summary=chunk.summary,
timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words),
words=[w.model_dump() for w in chunk.words],
words=chunk.words,
)
await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast(
@@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
async with fresh_db_connection():
ctx.log("generate_title: DB connection established")
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
async def on_title_callback(data):

View File

@@ -42,7 +42,7 @@ class RecordingResult(BaseModel):
id: NonEmptyString | None
mtg_session_id: NonEmptyString | None
duration: float
duration: int | None
class ParticipantsResult(BaseModel):