mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 18:06:48 +00:00
pipeline type fixes (#812)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import shutil
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Sequence
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -180,7 +180,7 @@ class TranscriptDuration(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TranscriptWaveform(BaseModel):
|
class TranscriptWaveform(BaseModel):
|
||||||
waveform: list[float]
|
waveform: Sequence[float]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptEvent(BaseModel):
|
class TranscriptEvent(BaseModel):
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import tempfile
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from hatchet_sdk import Context
|
from hatchet_sdk import Context
|
||||||
@@ -162,7 +162,15 @@ def _spawn_storage():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
def with_error_handling(
|
||||||
|
step_name: str, set_error_status: bool = True
|
||||||
|
) -> Callable[
|
||||||
|
[Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]],
|
||||||
|
Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
|
||||||
|
]:
|
||||||
"""Decorator that handles task failures uniformly.
|
"""Decorator that handles task failures uniformly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -170,9 +178,11 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
|||||||
set_error_status: Whether to set transcript status to 'error' on failure.
|
set_error_status: Whether to set transcript status to 'error' on failure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(
|
||||||
|
func: Callable[[PipelineInput, Context], Coroutine[Any, Any, R]],
|
||||||
|
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(input: PipelineInput, ctx: Context):
|
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
||||||
try:
|
try:
|
||||||
return await func(input, ctx)
|
return await func(input, ctx)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -186,7 +196,7 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
|||||||
await set_workflow_error_status(input.transcript_id)
|
await set_workflow_error_status(input.transcript_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper # type: ignore[return-value]
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -256,7 +266,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
|
|
||||||
recording = ctx.task_output(get_recording)
|
recording = ctx.task_output(get_recording)
|
||||||
mtg_session_id = recording.mtg_session_id
|
mtg_session_id = recording.mtg_session_id
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptParticipant,
|
TranscriptParticipant,
|
||||||
@@ -264,16 +273,17 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if not transcript:
|
||||||
# Note: title NOT cleared - preserves existing titles
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
await transcripts_controller.update(
|
# Note: title NOT cleared - preserves existing titles
|
||||||
transcript,
|
await transcripts_controller.update(
|
||||||
{
|
transcript,
|
||||||
"events": [],
|
{
|
||||||
"topics": [],
|
"events": [],
|
||||||
"participants": [],
|
"topics": [],
|
||||||
},
|
"participants": [],
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
mtg_session_id = assert_non_none_and_non_empty(
|
mtg_session_id = assert_non_none_and_non_empty(
|
||||||
mtg_session_id, "mtg_session_id is required"
|
mtg_session_id, "mtg_session_id is required"
|
||||||
@@ -640,6 +650,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
|
|
||||||
for chunk in topic_chunks:
|
for chunk in topic_chunks:
|
||||||
topic = TranscriptTopic(
|
topic = TranscriptTopic(
|
||||||
@@ -647,7 +659,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
summary=chunk.summary,
|
summary=chunk.summary,
|
||||||
timestamp=chunk.timestamp,
|
timestamp=chunk.timestamp,
|
||||||
transcript=" ".join(w.text for w in chunk.words),
|
transcript=" ".join(w.text for w in chunk.words),
|
||||||
words=[w.model_dump() for w in chunk.words],
|
words=chunk.words,
|
||||||
)
|
)
|
||||||
await transcripts_controller.upsert_topic(transcript, topic)
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
await append_event_and_broadcast(
|
await append_event_and_broadcast(
|
||||||
@@ -697,6 +709,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
ctx.log("generate_title: DB connection established")
|
ctx.log("generate_title: DB connection established")
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||||
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
|
ctx.log(f"generate_title: fetched transcript, exists={transcript is not None}")
|
||||||
|
|
||||||
async def on_title_callback(data):
|
async def on_title_callback(data):
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class RecordingResult(BaseModel):
|
|||||||
|
|
||||||
id: NonEmptyString | None
|
id: NonEmptyString | None
|
||||||
mtg_session_id: NonEmptyString | None
|
mtg_session_id: NonEmptyString | None
|
||||||
duration: float
|
duration: int | None
|
||||||
|
|
||||||
|
|
||||||
class ParticipantsResult(BaseModel):
|
class ParticipantsResult(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user