self-review round

This commit is contained in:
Igor Loskutov
2025-12-18 12:46:05 -05:00
parent 61e2b3211e
commit acad80df50
4 changed files with 48 additions and 66 deletions

View File

@@ -1,17 +0,0 @@
"""
Hatchet workflow utilities.
Shared helpers for Hatchet task implementations.
"""
def to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns.
Hatchet SDK can return task outputs as either raw dicts or Pydantic models
depending on serialization context. This normalizes the output for consistent
downstream processing.
"""
if isinstance(output, dict):
return output
return output.model_dump()

View File

@@ -27,7 +27,6 @@ from reflector.hatchet.broadcast import (
set_status_and_broadcast,
)
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import (
ConsentResult,
FinalizeResult,
@@ -69,7 +68,7 @@ from reflector.utils.daily import (
filter_cam_audio_tracks,
parse_daily_recording_filename,
)
from reflector.utils.string import NonEmptyString
from reflector.utils.string import NonEmptyString, assert_non_none_and_non_empty
from reflector.zulip import post_transcript_notification
@@ -212,8 +211,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
"""Fetch participant list from Daily.co API and update transcript in database."""
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
recording_data = to_dict(ctx.task_output(get_recording))
mtg_session_id = recording_data.get("mtg_session_id")
recording = ctx.task_output(get_recording)
mtg_session_id = recording.mtg_session_id
async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415
@@ -233,15 +232,14 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
},
)
if not mtg_session_id or not settings.DAILY_API_KEY:
return ParticipantsResult(
participants=[],
num_tracks=len(input.tracks),
source_language=transcript.source_language if transcript else "en",
target_language=transcript.target_language if transcript else "en",
mtg_session_id = assert_non_none_and_non_empty(
mtg_session_id, "mtg_session_id is required"
)
daily_api_key = assert_non_none_and_non_empty(
settings.DAILY_API_KEY, "DAILY_API_KEY is required"
)
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
async with DailyApiClient(api_key=daily_api_key) as client:
participants = await client.get_meeting_participants(mtg_session_id)
id_to_name = {}
@@ -302,8 +300,8 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
"""Spawn child workflows for each track (dynamic fan-out)."""
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
participants_data = to_dict(ctx.task_output(get_participants))
source_language = participants_data.get("source_language", "en")
participants_result = ctx.task_output(get_participants)
source_language = participants_result.source_language
child_coroutines = [
track_workflow.aio_run(
@@ -320,7 +318,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
results = await asyncio.gather(*child_coroutines)
target_language = participants_data.get("target_language", "en")
target_language = participants_result.target_language
# Collect results from each track (don't mutate lists while iterating)
track_words = []
@@ -371,31 +369,23 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
track_data = to_dict(ctx.task_output(process_tracks))
padded_tracks_data = track_data.get("padded_tracks", [])
track_result = ctx.task_output(process_tracks)
padded_tracks = track_result.padded_tracks
if not padded_tracks_data:
if not padded_tracks:
raise ValueError("No padded tracks to mixdown")
storage = _spawn_storage()
# Presign URLs on demand (avoids stale URLs on workflow replay)
padded_urls = []
for track_info in padded_tracks_data:
# Handle both dict (from to_dict) and PaddedTrackInfo
if isinstance(track_info, dict):
key = track_info.get("key")
bucket = track_info.get("bucket_name")
else:
key = track_info.key
bucket = track_info.bucket_name
if key:
for track_info in padded_tracks:
if track_info.key:
url = await storage.get_file_url(
key,
track_info.key,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
bucket=bucket,
bucket=track_info.bucket_name,
)
padded_urls.append(url)
@@ -465,8 +455,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
)
# Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = to_dict(ctx.task_output(process_tracks))
created_padded_files = track_data.get("created_padded_files", [])
track_result = ctx.task_output(process_tracks)
created_padded_files = track_result.created_padded_files
if created_padded_files:
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
storage = _spawn_storage()
@@ -483,8 +473,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
error=str(result),
)
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
audio_key = mixdown_data.get("audio_key")
mixdown_result = ctx.task_output(mixdown_tracks)
audio_key = mixdown_result.audio_key
storage = _spawn_storage()
audio_url = await storage.get_file_url(
@@ -532,9 +522,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
ctx.log("detect_topics: analyzing transcript for topics")
track_data = to_dict(ctx.task_output(process_tracks))
words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en")
track_result = ctx.task_output(process_tracks)
words = track_result.all_words
target_language = track_result.target_language
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptTopic,
@@ -589,8 +579,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
ctx.log("generate_title: generating title from topics")
topics_data = to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalTitle,
@@ -638,8 +628,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
ctx.log("generate_summary: generating long and short summaries")
topics_data = to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalLongSummary,
@@ -718,11 +708,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
"""
ctx.log("finalize: saving transcript and setting status to 'ended'")
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
track_data = to_dict(ctx.task_output(process_tracks))
mixdown_result = ctx.task_output(mixdown_tracks)
track_result = ctx.task_output(process_tracks)
duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", [])
duration = mixdown_result.duration
all_words = track_result.all_words
async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415

View File

@@ -23,7 +23,6 @@ from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
@@ -166,9 +165,9 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
)
try:
pad_result = to_dict(ctx.task_output(pad_track))
padded_key = pad_result.get("padded_key")
bucket_name = pad_result.get("bucket_name")
pad_result = ctx.task_output(pad_track)
padded_key = pad_result.padded_key
bucket_name = pad_result.bucket_name
if not padded_key:
raise ValueError("Missing padded_key from pad_track")

View File

@@ -2,6 +2,8 @@ from typing import Annotated, TypeVar
from pydantic import Field, TypeAdapter, constr
from reflector.utils.common import assert_not_none
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
NonEmptyString = Annotated[
NonEmptyStringBase,
@@ -30,3 +32,11 @@ def assert_equal[T](s1: T, s2: T) -> T:
if s1 != s2:
raise ValueError(f"assert_equal: {s1} != {s2}")
return s1
def assert_non_none_and_non_empty(
value: str | None, error: str | None = None
) -> NonEmptyString:
return parse_non_empty_string(
assert_not_none(value, error or "Value is None"), error
)