mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 01:46:47 +00:00
pipeline type fixes (#812)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Sequence
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
@@ -180,7 +180,7 @@ class TranscriptDuration(BaseModel):
|
||||
|
||||
|
||||
class TranscriptWaveform(BaseModel):
|
||||
waveform: list[float]
|
||||
waveform: Sequence[float]
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
|
||||
@@ -16,7 +16,7 @@ import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
import httpx
|
||||
from hatchet_sdk import Context
|
||||
@@ -162,7 +162,15 @@ def _spawn_storage():
|
||||
)
|
||||
|
||||
|
||||
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def with_error_handling(
|
||||
step_name: str, set_error_status: bool = True
|
||||
) -> Callable[
|
||||
[Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]],
|
||||
Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
|
||||
]:
|
||||
"""Decorator that handles task failures uniformly.
|
||||
|
||||
Args:
|
||||
@@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
||||
set_error_status: Whether to set transcript status to 'error' on failure.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def decorator(
|
||||
func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
|
||||
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(input: PipelineInput, ctx: Context):
|
||||
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
||||
try:
|
||||
return await func(input, ctx)
|
||||
except Exception as e:
|
||||
@@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
|
||||
recording = ctx.task_output(get_recording)
|
||||
mtg_session_id = recording.mtg_session_id
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptParticipant,
|
||||
@@ -264,16 +273,17 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
# Note: title NOT cleared - preserves existing titles
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
"participants": [],
|
||||
},
|
||||
)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
# Note: title NOT cleared - preserves existing titles
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
"participants": [],
|
||||
},
|
||||
)
|
||||
|
||||
mtg_session_id = assert_non_none_and_non_empty(
|
||||
mtg_session_id, "mtg_session_id is required"
|
||||
@@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
|
||||
for chunk in topic_chunks:
|
||||
topic = TranscriptTopic(
|
||||
@@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
summary=chunk.summary,
|
||||
timestamp=chunk.timestamp,
|
||||
transcript=" ".join(w.text for w in chunk.words),
|
||||
words=[w.model_dump() for w in chunk.words],
|
||||
words=chunk.words,
|
||||
)
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
await append_event_and_broadcast(
|
||||
@@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
async with fresh_db_connection():
|
||||
ctx.log("generate_title: DB connection established")
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
|
||||
|
||||
async def on_title_callback(data):
|
||||
|
||||
@@ -42,7 +42,7 @@ class RecordingResult(BaseModel):
|
||||
|
||||
id: NonEmptyString | None
|
||||
mtg_session_id: NonEmptyString | None
|
||||
duration: float
|
||||
duration: int | None
|
||||
|
||||
|
||||
class ParticipantsResult(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user