pipeline type fixes (#812)

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-12-26 11:28:43 -05:00
committed by GitHub
parent bab1e2d537
commit 5baa6dd92e
3 changed files with 34 additions and 20 deletions

View File

@@ -16,7 +16,7 @@ import tempfile
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from typing import Callable
from typing import Any, Callable, Coroutine, TypeVar
import httpx
from hatchet_sdk import Context
@@ -162,7 +162,15 @@ def _spawn_storage():
)
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
R = TypeVar("R")
def with_error_handling(
step_name: str, set_error_status: bool = True
) -> Callable[
[Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]],
Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
]:
"""Decorator that handles task failures uniformly.
Args:
@@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
set_error_status: Whether to set transcript status to 'error' on failure.
"""
def decorator(func: Callable) -> Callable:
def decorator(
func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context):
async def wrapper(input: PipelineInput, ctx: Context) -> R:
try:
return await func(input, ctx)
except Exception as e:
@@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
await set_workflow_error_status(input.transcript_id)
raise
return wrapper
return wrapper # type: ignore[return-value]
return decorator
@@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
recording = ctx.task_output(get_recording)
mtg_session_id = recording.mtg_session_id
async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptParticipant,
@@ -264,16 +273,17 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
# Note: title NOT cleared - preserves existing titles
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
"participants": [],
},
)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
# Note: title NOT cleared - preserves existing titles
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
"participants": [],
},
)
mtg_session_id = assert_non_none_and_non_empty(
mtg_session_id, "mtg_session_id is required"
@@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
for chunk in topic_chunks:
topic = TranscriptTopic(
@@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
summary=chunk.summary,
timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words),
words=[w.model_dump() for w in chunk.words],
words=chunk.words,
)
await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast(
@@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
async with fresh_db_connection():
ctx.log("generate_title: DB connection established")
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
async def on_title_callback(data):

View File

@@ -42,7 +42,7 @@ class RecordingResult(BaseModel):
id: NonEmptyString | None
mtg_session_id: NonEmptyString | None
duration: float
duration: int | None
class ParticipantsResult(BaseModel):