self-review round

This commit is contained in:
Igor Loskutov
2025-12-18 12:46:05 -05:00
parent 61e2b3211e
commit acad80df50
4 changed files with 48 additions and 66 deletions

View File

@@ -1,17 +0,0 @@
"""
Hatchet workflow utilities.
Shared helpers for Hatchet task implementations.
"""
def to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns.
Hatchet SDK can return task outputs as either raw dicts or Pydantic models
depending on serialization context. This normalizes the output for consistent
downstream processing.
"""
if isinstance(output, dict):
return output
return output.model_dump()

View File

@@ -27,7 +27,6 @@ from reflector.hatchet.broadcast import (
set_status_and_broadcast, set_status_and_broadcast,
) )
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import ( from reflector.hatchet.workflows.models import (
ConsentResult, ConsentResult,
FinalizeResult, FinalizeResult,
@@ -69,7 +68,7 @@ from reflector.utils.daily import (
filter_cam_audio_tracks, filter_cam_audio_tracks,
parse_daily_recording_filename, parse_daily_recording_filename,
) )
from reflector.utils.string import NonEmptyString from reflector.utils.string import NonEmptyString, assert_non_none_and_non_empty
from reflector.zulip import post_transcript_notification from reflector.zulip import post_transcript_notification
@@ -212,8 +211,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
"""Fetch participant list from Daily.co API and update transcript in database.""" """Fetch participant list from Daily.co API and update transcript in database."""
ctx.log(f"get_participants: transcript_id={input.transcript_id}") ctx.log(f"get_participants: transcript_id={input.transcript_id}")
recording_data = to_dict(ctx.task_output(get_recording)) recording = ctx.task_output(get_recording)
mtg_session_id = recording_data.get("mtg_session_id") mtg_session_id = recording.mtg_session_id
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415
@@ -233,15 +232,14 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
}, },
) )
if not mtg_session_id or not settings.DAILY_API_KEY: mtg_session_id = assert_non_none_and_non_empty(
return ParticipantsResult( mtg_session_id, "mtg_session_id is required"
participants=[], )
num_tracks=len(input.tracks), daily_api_key = assert_non_none_and_non_empty(
source_language=transcript.source_language if transcript else "en", settings.DAILY_API_KEY, "DAILY_API_KEY is required"
target_language=transcript.target_language if transcript else "en",
) )
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: async with DailyApiClient(api_key=daily_api_key) as client:
participants = await client.get_meeting_participants(mtg_session_id) participants = await client.get_meeting_participants(mtg_session_id)
id_to_name = {} id_to_name = {}
@@ -302,8 +300,8 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
"""Spawn child workflows for each track (dynamic fan-out).""" """Spawn child workflows for each track (dynamic fan-out)."""
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows") ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
participants_data = to_dict(ctx.task_output(get_participants)) participants_result = ctx.task_output(get_participants)
source_language = participants_data.get("source_language", "en") source_language = participants_result.source_language
child_coroutines = [ child_coroutines = [
track_workflow.aio_run( track_workflow.aio_run(
@@ -320,7 +318,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
results = await asyncio.gather(*child_coroutines) results = await asyncio.gather(*child_coroutines)
target_language = participants_data.get("target_language", "en") target_language = participants_result.target_language
# Collect results from each track (don't mutate lists while iterating) # Collect results from each track (don't mutate lists while iterating)
track_words = [] track_words = []
@@ -371,31 +369,23 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery).""" """Mix all padded tracks into single audio file using PyAV (same as Celery)."""
ctx.log("mixdown_tracks: mixing padded tracks into single audio file") ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
track_data = to_dict(ctx.task_output(process_tracks)) track_result = ctx.task_output(process_tracks)
padded_tracks_data = track_data.get("padded_tracks", []) padded_tracks = track_result.padded_tracks
if not padded_tracks_data: if not padded_tracks:
raise ValueError("No padded tracks to mixdown") raise ValueError("No padded tracks to mixdown")
storage = _spawn_storage() storage = _spawn_storage()
# Presign URLs on demand (avoids stale URLs on workflow replay) # Presign URLs on demand (avoids stale URLs on workflow replay)
padded_urls = [] padded_urls = []
for track_info in padded_tracks_data: for track_info in padded_tracks:
# Handle both dict (from to_dict) and PaddedTrackInfo if track_info.key:
if isinstance(track_info, dict):
key = track_info.get("key")
bucket = track_info.get("bucket_name")
else:
key = track_info.key
bucket = track_info.bucket_name
if key:
url = await storage.get_file_url( url = await storage.get_file_url(
key, track_info.key,
operation="get_object", operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
bucket=bucket, bucket=track_info.bucket_name,
) )
padded_urls.append(url) padded_urls.append(url)
@@ -465,8 +455,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
) )
# Cleanup temporary padded S3 files (deferred until after mixdown) # Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = to_dict(ctx.task_output(process_tracks)) track_result = ctx.task_output(process_tracks)
created_padded_files = track_data.get("created_padded_files", []) created_padded_files = track_result.created_padded_files
if created_padded_files: if created_padded_files:
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files") ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
storage = _spawn_storage() storage = _spawn_storage()
@@ -483,8 +473,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
error=str(result), error=str(result),
) )
mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) mixdown_result = ctx.task_output(mixdown_tracks)
audio_key = mixdown_data.get("audio_key") audio_key = mixdown_result.audio_key
storage = _spawn_storage() storage = _spawn_storage()
audio_url = await storage.get_file_url( audio_url = await storage.get_file_url(
@@ -532,9 +522,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics using LLM and save to database (matches Celery on_topic callback).""" """Detect topics using LLM and save to database (matches Celery on_topic callback)."""
ctx.log("detect_topics: analyzing transcript for topics") ctx.log("detect_topics: analyzing transcript for topics")
track_data = to_dict(ctx.task_output(process_tracks)) track_result = ctx.task_output(process_tracks)
words = track_data.get("all_words", []) words = track_result.all_words
target_language = track_data.get("target_language", "en") target_language = track_result.target_language
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptTopic, TranscriptTopic,
@@ -589,8 +579,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
"""Generate meeting title using LLM and save to database (matches Celery on_title callback).""" """Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
ctx.log("generate_title: generating title from topics") ctx.log("generate_title: generating title from topics")
topics_data = to_dict(ctx.task_output(detect_topics)) topics_result = ctx.task_output(detect_topics)
topics = topics_data.get("topics", []) topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalTitle, TranscriptFinalTitle,
@@ -638,8 +628,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
"""Generate meeting summary using LLM and save to database (matches Celery callbacks).""" """Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
ctx.log("generate_summary: generating long and short summaries") ctx.log("generate_summary: generating long and short summaries")
topics_data = to_dict(ctx.task_output(detect_topics)) topics_result = ctx.task_output(detect_topics)
topics = topics_data.get("topics", []) topics = topics_result.topics
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
@@ -718,11 +708,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
""" """
ctx.log("finalize: saving transcript and setting status to 'ended'") ctx.log("finalize: saving transcript and setting status to 'ended'")
mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) mixdown_result = ctx.task_output(mixdown_tracks)
track_data = to_dict(ctx.task_output(process_tracks)) track_result = ctx.task_output(process_tracks)
duration = mixdown_data.get("duration", 0) duration = mixdown_result.duration
all_words = track_data.get("all_words", []) all_words = track_result.all_words
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import ( # noqa: PLC0415 from reflector.db.transcripts import ( # noqa: PLC0415

View File

@@ -23,7 +23,6 @@ from hatchet_sdk import Context
from pydantic import BaseModel from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger from reflector.logger import logger
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
@@ -166,9 +165,9 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
) )
try: try:
pad_result = to_dict(ctx.task_output(pad_track)) pad_result = ctx.task_output(pad_track)
padded_key = pad_result.get("padded_key") padded_key = pad_result.padded_key
bucket_name = pad_result.get("bucket_name") bucket_name = pad_result.bucket_name
if not padded_key: if not padded_key:
raise ValueError("Missing padded_key from pad_track") raise ValueError("Missing padded_key from pad_track")

View File

@@ -2,6 +2,8 @@ from typing import Annotated, TypeVar
from pydantic import Field, TypeAdapter, constr from pydantic import Field, TypeAdapter, constr
from reflector.utils.common import assert_not_none
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False) NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
NonEmptyString = Annotated[ NonEmptyString = Annotated[
NonEmptyStringBase, NonEmptyStringBase,
@@ -30,3 +32,11 @@ def assert_equal[T](s1: T, s2: T) -> T:
if s1 != s2: if s1 != s2:
raise ValueError(f"assert_equal: {s1} != {s2}") raise ValueError(f"assert_equal: {s1} != {s2}")
return s1 return s1
def assert_non_none_and_non_empty(
value: str | None, error: str | None = None
) -> NonEmptyString:
return parse_non_empty_string(
assert_not_none(value, error or "Value is None"), error
)