From acad80df50dc64dadfb4a0e5292f8f5cbf47a309 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Thu, 18 Dec 2025 12:46:05 -0500 Subject: [PATCH] self-review round --- server/reflector/hatchet/utils.py | 17 ---- .../hatchet/workflows/diarization_pipeline.py | 80 ++++++++----------- .../hatchet/workflows/track_processing.py | 7 +- server/reflector/utils/string.py | 10 +++ 4 files changed, 48 insertions(+), 66 deletions(-) delete mode 100644 server/reflector/hatchet/utils.py diff --git a/server/reflector/hatchet/utils.py b/server/reflector/hatchet/utils.py deleted file mode 100644 index d98bdf59..00000000 --- a/server/reflector/hatchet/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Hatchet workflow utilities. - -Shared helpers for Hatchet task implementations. -""" - - -def to_dict(output) -> dict: - """Convert task output to dict, handling both dict and Pydantic model returns. - - Hatchet SDK can return task outputs as either raw dicts or Pydantic models - depending on serialization context. This normalizes the output for consistent - downstream processing. - """ - if isinstance(output, dict): - return output - return output.model_dump() diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index c9cf797f..02668497 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -27,7 +27,6 @@ from reflector.hatchet.broadcast import ( set_status_and_broadcast, ) from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.utils import to_dict from reflector.hatchet.workflows.models import ( ConsentResult, FinalizeResult, @@ -69,7 +68,7 @@ from reflector.utils.daily import ( filter_cam_audio_tracks, parse_daily_recording_filename, ) -from reflector.utils.string import NonEmptyString +from reflector.utils.string import NonEmptyString, assert_non_none_and_non_empty from reflector.zulip import post_transcript_notification @@ -212,8 +211,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe """Fetch participant list from Daily.co API and update transcript in database.""" ctx.log(f"get_participants: transcript_id={input.transcript_id}") - recording_data = to_dict(ctx.task_output(get_recording)) - mtg_session_id = recording_data.get("mtg_session_id") + recording = ctx.task_output(get_recording) + mtg_session_id = recording.mtg_session_id async with fresh_db_connection(): from reflector.db.transcripts import ( # noqa: PLC0415 @@ -233,15 +232,14 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe }, ) - if not mtg_session_id or not settings.DAILY_API_KEY: - return ParticipantsResult( - participants=[], - num_tracks=len(input.tracks), - source_language=transcript.source_language if transcript else "en", - target_language=transcript.target_language if transcript else "en", - ) + mtg_session_id = assert_non_none_and_non_empty( + mtg_session_id, "mtg_session_id is required" + ) + daily_api_key = assert_non_none_and_non_empty( + settings.DAILY_API_KEY, "DAILY_API_KEY is required" + ) - async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: + async with DailyApiClient(api_key=daily_api_key) as client: participants = await client.get_meeting_participants(mtg_session_id) id_to_name = {} @@ -302,8 +300,8 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes """Spawn child workflows for each track (dynamic fan-out).""" ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows") - participants_data = to_dict(ctx.task_output(get_participants)) - source_language = participants_data.get("source_language", "en") + participants_result = ctx.task_output(get_participants) + source_language = participants_result.source_language child_coroutines = [ track_workflow.aio_run( @@ -320,7 +318,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes results = await asyncio.gather(*child_coroutines) - target_language = participants_data.get("target_language", "en") + target_language = participants_result.target_language # Collect results from each track (don't mutate lists while iterating) track_words = [] @@ -371,31 +369,23 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: """Mix all padded tracks into single audio file using PyAV (same as Celery).""" ctx.log("mixdown_tracks: mixing padded tracks into single audio file") - track_data = to_dict(ctx.task_output(process_tracks)) - padded_tracks_data = track_data.get("padded_tracks", []) + track_result = ctx.task_output(process_tracks) + padded_tracks = track_result.padded_tracks - if not padded_tracks_data: + if not padded_tracks: raise ValueError("No padded tracks to mixdown") storage = _spawn_storage() # Presign URLs on demand (avoids stale URLs on workflow replay) padded_urls = [] - for track_info in padded_tracks_data: - # Handle both dict (from to_dict) and PaddedTrackInfo - if isinstance(track_info, dict): - key = track_info.get("key") - bucket = track_info.get("bucket_name") - else: - key = track_info.key - bucket = track_info.bucket_name - - if key: + for track_info in padded_tracks: + if track_info.key: url = await storage.get_file_url( - key, + track_info.key, operation="get_object", expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, - bucket=bucket, + bucket=track_info.bucket_name, ) padded_urls.append(url) @@ -465,8 +455,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul ) # Cleanup temporary padded S3 files (deferred until after mixdown) - track_data = to_dict(ctx.task_output(process_tracks)) - created_padded_files = track_data.get("created_padded_files", []) + track_result = ctx.task_output(process_tracks) + created_padded_files = track_result.created_padded_files if created_padded_files: ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files") storage = _spawn_storage() @@ -483,8 +473,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul error=str(result), ) - mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) - audio_key = mixdown_data.get("audio_key") + mixdown_result = ctx.task_output(mixdown_tracks) + audio_key = mixdown_result.audio_key storage = _spawn_storage() audio_url = await storage.get_file_url( @@ -532,9 +522,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: """Detect topics using LLM and save to database (matches Celery on_topic callback).""" ctx.log("detect_topics: analyzing transcript for topics") - track_data = to_dict(ctx.task_output(process_tracks)) - words = track_data.get("all_words", []) - target_language = track_data.get("target_language", "en") + track_result = ctx.task_output(process_tracks) + words = track_result.all_words + target_language = track_result.target_language from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptTopic, @@ -589,8 +579,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: """Generate meeting title using LLM and save to database (matches Celery on_title callback).""" ctx.log("generate_title: generating title from topics") - topics_data = to_dict(ctx.task_output(detect_topics)) - topics = topics_data.get("topics", []) + topics_result = ctx.task_output(detect_topics) + topics = topics_result.topics from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptFinalTitle, @@ -638,8 +628,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: """Generate meeting summary using LLM and save to database (matches Celery callbacks).""" ctx.log("generate_summary: generating long and short summaries") - topics_data = to_dict(ctx.task_output(detect_topics)) - topics = topics_data.get("topics", []) + topics_result = ctx.task_output(detect_topics) + topics = topics_result.topics from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptFinalLongSummary, @@ -718,11 +708,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: """ ctx.log("finalize: saving transcript and setting status to 'ended'") - mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) - track_data = to_dict(ctx.task_output(process_tracks)) + mixdown_result = ctx.task_output(mixdown_tracks) + track_result = ctx.task_output(process_tracks) - duration = mixdown_data.get("duration", 0) - all_words = track_data.get("all_words", []) + duration = mixdown_result.duration + all_words = track_result.all_words async with fresh_db_connection(): from reflector.db.transcripts import ( # noqa: PLC0415 diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index 45873daf..6aae6eb2 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -23,7 +23,6 @@ from hatchet_sdk import Context from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.utils import to_dict from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult from reflector.logger import logger from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS @@ -166,9 +165,9 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe ) try: - pad_result = to_dict(ctx.task_output(pad_track)) - padded_key = pad_result.get("padded_key") - bucket_name = pad_result.get("bucket_name") + pad_result = ctx.task_output(pad_track) + padded_key = pad_result.padded_key + bucket_name = pad_result.bucket_name if not padded_key: raise ValueError("Missing padded_key from pad_track") diff --git a/server/reflector/utils/string.py b/server/reflector/utils/string.py index ae4277c5..bfbc8c44 100644 --- a/server/reflector/utils/string.py +++ b/server/reflector/utils/string.py @@ -2,6 +2,8 @@ from typing import Annotated, TypeVar from pydantic import Field, TypeAdapter, constr +from reflector.utils.common import assert_not_none + NonEmptyStringBase = constr(min_length=1, strip_whitespace=False) NonEmptyString = Annotated[ NonEmptyStringBase, @@ -30,3 +32,11 @@ def assert_equal[T](s1: T, s2: T) -> T: if s1 != s2: raise ValueError(f"assert_equal: {s1} != {s2}") return s1 + + +def assert_non_none_and_non_empty( + value: str | None, error: str | None = None +) -> NonEmptyString: + return parse_non_empty_string( + assert_not_none(value, error or "Value is None"), error + )