From 2cbc373cc3cb767c3242b557c4c14c1a20f29c47 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Mon, 22 Dec 2025 18:06:59 -0500 Subject: [PATCH] self-review --- .../hatchet/workflows/diarization_pipeline.py | 111 ++++++++---------- server/reflector/hatchet/workflows/models.py | 23 ++-- .../workflows/topic_chunk_processing.py | 4 +- .../hatchet/workflows/track_processing.py | 11 +- 4 files changed, 73 insertions(+), 76 deletions(-) diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 912aa220..e26c4e0a 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -34,14 +34,19 @@ from reflector.hatchet.workflows.models import ( FinalizeResult, MixdownResult, PaddedTrackInfo, + PadTrackResult, + ParticipantInfo, ParticipantsResult, ProcessSubjectsResult, ProcessTracksResult, RecapResult, RecordingResult, SubjectsResult, + SubjectSummaryResult, TitleResult, + TopicChunkResult, TopicsResult, + TranscribeTrackResult, WaveformResult, WebhookResult, ZulipResult, @@ -58,13 +63,8 @@ from reflector.hatchet.workflows.track_processing import TrackInput, track_workf from reflector.logger import logger from reflector.pipelines import topic_processing from reflector.processors import AudioFileWriterProcessor -from reflector.processors.types import ( - TitleSummary, - Word, -) -from reflector.processors.types import ( - Transcript as TranscriptType, -) +from reflector.processors.types import TitleSummary, Word +from reflector.processors.types import Transcript as TranscriptType from reflector.settings import settings from reflector.storage.storage_aws import AwsStorage from reflector.utils.audio_constants import ( @@ -285,7 +285,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe track_keys = [t["s3_key"] for t in input.tracks] cam_audio_keys = filter_cam_audio_tracks(track_keys) - participants_list = [] + participants_list: list[ParticipantInfo] = [] for idx, key in enumerate(cam_audio_keys): try: parsed = parse_daily_recording_filename(key) @@ -307,11 +307,11 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe ) await transcripts_controller.upsert_participant(transcript, participant) participants_list.append( - { - "participant_id": participant_id, - "user_name": name, - "speaker": idx, - } + ParticipantInfo( + participant_id=participant_id, + user_name=name, + speaker=idx, + ) ) ctx.log(f"get_participants complete: {len(participants_list)} participants") @@ -352,31 +352,30 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes target_language = participants_result.target_language - track_words = [] + track_words: list[list[Word]] = [] padded_tracks = [] created_padded_files = set() for result in results: - transcribe_result = result.get("transcribe_track", {}) - track_words.append(transcribe_result.get("words", [])) + transcribe_result = TranscribeTrackResult(**result["transcribe_track"]) + track_words.append(transcribe_result.words) - pad_result = result.get("pad_track", {}) - padded_key = pad_result.get("padded_key") - bucket_name = pad_result.get("bucket_name") + pad_result = PadTrackResult(**result["pad_track"]) # Store S3 key info (not presigned URL) - consumer tasks presign on demand - if padded_key: + if pad_result.padded_key: padded_tracks.append( - PaddedTrackInfo(key=padded_key, bucket_name=bucket_name) + PaddedTrackInfo( + key=pad_result.padded_key, bucket_name=pad_result.bucket_name + ) ) - track_index = pad_result.get("track_index") - if pad_result.get("size", 0) > 0 and track_index is not None: - storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" + if pad_result.size > 0: + storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm" created_padded_files.add(storage_path) all_words = [word for words in track_words for word in words] - all_words.sort(key=lambda w: w.get("start", 0)) + all_words.sort(key=lambda w: w.start) ctx.log( f"process_tracks complete: {len(all_words)} words from {len(input.tracks)} tracks" @@ -569,9 +568,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: first_word = chunk_words[0] last_word = chunk_words[-1] - timestamp = first_word.get("start", 0) - duration = last_word.get("end", 0) - timestamp - chunk_text = " ".join(w.get("word", "") for w in chunk_words) + timestamp = first_word.start + duration = last_word.end - timestamp + chunk_text = " ".join(w.text for w in chunk_words) chunks.append( { @@ -604,40 +603,37 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: results = await topic_chunk_workflow.aio_run_many(bulk_runs) - topic_results = [ - result.get("detect_chunk_topic", {}) + topic_chunks = [ + TopicChunkResult(**result["detect_chunk_topic"]) for result in results - if result.get("detect_chunk_topic") + if "detect_chunk_topic" in result ] async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) - for topic_data in topic_results: + for chunk in topic_chunks: topic = TranscriptTopic( - title=topic_data.get("title", ""), - summary=topic_data.get("summary", ""), - timestamp=topic_data.get("timestamp", 0), - transcript=" ".join( - w.get("word", "") for w in topic_data.get("words", []) - ), - words=topic_data.get("words", []), + title=chunk.title, + summary=chunk.summary, + timestamp=chunk.timestamp, + transcript=" ".join(w.text for w in chunk.words), + words=[w.model_dump() for w in chunk.words], ) await transcripts_controller.upsert_topic(transcript, topic) await append_event_and_broadcast( input.transcript_id, transcript, "TOPIC", topic, logger=logger ) - # Convert to TitleSummary format for downstream steps topics_list = [ - { - "title": t.get("title", ""), - "summary": t.get("summary", ""), - "timestamp": t.get("timestamp", 0), - "duration": t.get("duration", 0), - "transcript": {"words": t.get("words", [])}, - } - for t in topic_results + TitleSummary( + title=chunk.title, + summary=chunk.summary, + timestamp=chunk.timestamp, + duration=chunk.duration, + transcript=TranscriptType(words=chunk.words), + ) + for chunk in topic_chunks ] ctx.log(f"detect_topics complete: found {len(topics_list)} topics") @@ -662,8 +658,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: transcripts_controller, ) - topic_objects = [TitleSummary(**t) for t in topics] - ctx.log(f"generate_title: created {len(topic_objects)} TitleSummary objects") + ctx.log(f"generate_title: received {len(topics)} TitleSummary objects") empty_pipeline = topic_processing.EmptyPipeline(logger=logger) title_result = None @@ -695,7 +690,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: ctx.log("generate_title: calling topic_processing.generate_title (LLM call)...") await topic_processing.generate_title( - topic_objects, + topics, on_title_callback=on_title_callback, empty_pipeline=empty_pipeline, logger=logger, @@ -735,8 +730,6 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult SummaryBuilder, ) - topic_objects = [TitleSummary(**t) for t in topics] - async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -750,7 +743,7 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult } text_lines = [] - for topic in topic_objects: + for topic in topics: for segment in topic.transcript.as_segments(): name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}") text_lines.append(f"{name}: {segment.text}") @@ -818,7 +811,9 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject results = await subject_workflow.aio_run_many(bulk_runs) subject_summaries = [ - result.get("generate_detailed_summary", {}) for result in results + SubjectSummaryResult(**result["generate_detailed_summary"]) + for result in results + if "generate_detailed_summary" in result ] ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries") @@ -858,7 +853,7 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: return RecapResult(short_summary="", long_summary="") summaries = [ - {"subject": s.get("subject", ""), "summary": s.get("paragraph_summary", "")} + {"subject": s.subject, "summary": s.paragraph_summary} for s in subject_summaries ] @@ -963,7 +958,6 @@ async def identify_action_items( action_items_dict = action_items_response.model_dump() - # Save to database and broadcast async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript: @@ -1035,8 +1029,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: if transcript is None: raise ValueError(f"Transcript {input.transcript_id} not found in database") - word_objects = [Word(**w) for w in all_words] - merged_transcript = TranscriptType(words=word_objects, translation=None) + merged_transcript = TranscriptType(words=all_words, translation=None) await append_event_and_broadcast( input.transcript_id, diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index 4b9c5cec..748361d9 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -5,13 +5,20 @@ Provides static typing for all task outputs, enabling type checking and better IDE support. """ -from typing import Any - from pydantic import BaseModel +from reflector.processors.types import TitleSummary, Word from reflector.utils.string import NonEmptyString +class ParticipantInfo(BaseModel): + """Participant info with speaker index for workflow result.""" + + participant_id: NonEmptyString + user_name: NonEmptyString + speaker: int + + class PadTrackResult(BaseModel): """Result from pad_track task.""" @@ -26,7 +33,7 @@ class PadTrackResult(BaseModel): class TranscribeTrackResult(BaseModel): """Result from transcribe_track task.""" - words: list[dict[str, Any]] + words: list[Word] track_index: int @@ -41,7 +48,7 @@ class RecordingResult(BaseModel): class ParticipantsResult(BaseModel): """Result from get_participants task.""" - participants: list[dict[str, Any]] + participants: list[ParticipantInfo] num_tracks: int source_language: NonEmptyString target_language: NonEmptyString @@ -57,7 +64,7 @@ class PaddedTrackInfo(BaseModel): class ProcessTracksResult(BaseModel): """Result from process_tracks task.""" - all_words: list[dict[str, Any]] + all_words: list[Word] padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs word_count: int num_tracks: int @@ -87,13 +94,13 @@ class TopicChunkResult(BaseModel): summary: str timestamp: float duration: float - words: list[dict[str, Any]] + words: list[Word] class TopicsResult(BaseModel): """Result from detect_topics task.""" - topics: list[dict[str, Any]] + topics: list[TitleSummary] class TitleResult(BaseModel): @@ -123,7 +130,7 @@ class SubjectSummaryResult(BaseModel): class ProcessSubjectsResult(BaseModel): """Result from process_subjects fan-out task.""" - subject_summaries: list[dict[str, Any]] # List of SubjectSummaryResult dicts + subject_summaries: list[SubjectSummaryResult] class RecapResult(BaseModel): diff --git a/server/reflector/hatchet/workflows/topic_chunk_processing.py b/server/reflector/hatchet/workflows/topic_chunk_processing.py index 3b5af0fc..d2a74575 100644 --- a/server/reflector/hatchet/workflows/topic_chunk_processing.py +++ b/server/reflector/hatchet/workflows/topic_chunk_processing.py @@ -6,7 +6,6 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing. """ from datetime import timedelta -from typing import Any from hatchet_sdk import Context from pydantic import BaseModel @@ -15,6 +14,7 @@ from reflector.hatchet.client import HatchetClientManager from reflector.hatchet.workflows.models import TopicChunkResult from reflector.logger import logger from reflector.processors.prompts import TOPIC_PROMPT +from reflector.processors.types import Word class TopicChunkInput(BaseModel): @@ -24,7 +24,7 @@ class TopicChunkInput(BaseModel): chunk_text: str timestamp: float duration: float - words: list[dict[str, Any]] + words: list[Word] hatchet = HatchetClientManager.get_client() diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index 6aae6eb2..b2cc452e 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -197,23 +197,20 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe transcript = await transcribe_file_with_processor(audio_url, input.language) # Tag all words with speaker index - words = [] for word in transcript.words: - word_dict = word.model_dump() - word_dict["speaker"] = input.track_index - words.append(word_dict) + word.speaker = input.track_index ctx.log( - f"transcribe_track complete: track {input.track_index}, {len(words)} words" + f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words" ) logger.info( "[Hatchet] transcribe_track complete", track_index=input.track_index, - word_count=len(words), + word_count=len(transcript.words), ) return TranscribeTrackResult( - words=words, + words=transcript.words, track_index=input.track_index, )