feat: parallelize hatchet (#804)

* parallelize hatchet (no-mistakes)

* dry (no-mistakes) (minimal)

* comments

* self-review

* self-review

* self-review

* self-review

* pr comments

* pr comments

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-12-23 11:03:36 -05:00
committed by GitHub
parent 7c2d0698ed
commit 594bcc09e0
15 changed files with 849 additions and 287 deletions

View File

@@ -4,11 +4,23 @@ from reflector.hatchet.workflows.diarization_pipeline import (
PipelineInput,
diarization_pipeline,
)
from reflector.hatchet.workflows.subject_processing import (
SubjectInput,
subject_workflow,
)
from reflector.hatchet.workflows.topic_chunk_processing import (
TopicChunkInput,
topic_chunk_workflow,
)
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
__all__ = [
"diarization_pipeline",
"subject_workflow",
"topic_chunk_workflow",
"track_workflow",
"PipelineInput",
"SubjectInput",
"TopicChunkInput",
"TrackInput",
]

View File

@@ -28,33 +28,50 @@ from reflector.hatchet.broadcast import (
set_status_and_broadcast,
)
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import (
TIMEOUT_AUDIO,
TIMEOUT_HEAVY,
TIMEOUT_LONG,
TIMEOUT_MEDIUM,
TIMEOUT_SHORT,
)
from reflector.hatchet.workflows.models import (
ActionItemsResult,
ConsentResult,
FinalizeResult,
MixdownResult,
PaddedTrackInfo,
PadTrackResult,
ParticipantInfo,
ParticipantsResult,
ProcessSubjectsResult,
ProcessTracksResult,
RecapResult,
RecordingResult,
SummaryResult,
SubjectsResult,
SubjectSummaryResult,
TitleResult,
TopicChunkResult,
TopicsResult,
TranscribeTrackResult,
WaveformResult,
WebhookResult,
ZulipResult,
)
from reflector.hatchet.workflows.subject_processing import (
SubjectInput,
subject_workflow,
)
from reflector.hatchet.workflows.topic_chunk_processing import (
TopicChunkInput,
topic_chunk_workflow,
)
from reflector.hatchet.workflows.track_processing import TrackInput, track_workflow
from reflector.logger import logger
from reflector.pipelines import topic_processing
from reflector.processors import AudioFileWriterProcessor
from reflector.processors.types import (
TitleSummary,
TitleSummaryWithId,
Word,
)
from reflector.processors.types import (
Transcript as TranscriptType,
)
from reflector.processors.types import TitleSummary, Word
from reflector.processors.types import Transcript as TranscriptType
from reflector.settings import settings
from reflector.storage.storage_aws import AwsStorage
from reflector.utils.audio_constants import (
@@ -71,6 +88,7 @@ from reflector.utils.daily import (
parse_daily_recording_filename,
)
from reflector.utils.string import NonEmptyString, assert_non_none_and_non_empty
from reflector.utils.transcript_constants import TOPIC_CHUNK_WORD_COUNT
from reflector.zulip import post_transcript_notification
@@ -173,7 +191,9 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
return decorator
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
@diarization_pipeline.task(
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
)
@with_error_handling("get_recording")
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
"""Fetch recording metadata from Daily.co API."""
@@ -225,7 +245,9 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
@diarization_pipeline.task(
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
parents=[get_recording],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
)
@with_error_handling("get_participants")
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
@@ -274,7 +296,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
track_keys = [t["s3_key"] for t in input.tracks]
cam_audio_keys = filter_cam_audio_tracks(track_keys)
participants_list = []
participants_list: list[ParticipantInfo] = []
for idx, key in enumerate(cam_audio_keys):
try:
parsed = parse_daily_recording_filename(key)
@@ -296,11 +318,11 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
)
await transcripts_controller.upsert_participant(transcript, participant)
participants_list.append(
{
"participant_id": participant_id,
"user_name": name,
"speaker": idx,
}
ParticipantInfo(
participant_id=participant_id,
user_name=name,
speaker=idx,
)
)
ctx.log(f"get_participants complete: {len(participants_list)} participants")
@@ -314,7 +336,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
@diarization_pipeline.task(
parents=[get_participants], execution_timeout=timedelta(seconds=600), retries=3
parents=[get_participants],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
)
@with_error_handling("process_tracks")
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
@@ -324,9 +348,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
participants_result = ctx.task_output(get_participants)
source_language = participants_result.source_language
child_coroutines = [
track_workflow.aio_run(
TrackInput(
bulk_runs = [
track_workflow.create_bulk_run_item(
input=TrackInput(
track_index=i,
s3_key=track["s3_key"],
bucket_name=input.bucket_name,
@@ -337,35 +361,34 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
for i, track in enumerate(input.tracks)
]
results = await asyncio.gather(*child_coroutines)
results = await track_workflow.aio_run_many(bulk_runs)
target_language = participants_result.target_language
track_words = []
track_words: list[list[Word]] = []
padded_tracks = []
created_padded_files = set()
for result in results:
transcribe_result = result.get("transcribe_track", {})
track_words.append(transcribe_result.get("words", []))
transcribe_result = TranscribeTrackResult(**result["transcribe_track"])
track_words.append(transcribe_result.words)
pad_result = result.get("pad_track", {})
padded_key = pad_result.get("padded_key")
bucket_name = pad_result.get("bucket_name")
pad_result = PadTrackResult(**result["pad_track"])
# Store S3 key info (not presigned URL) - consumer tasks presign on demand
if padded_key:
if pad_result.padded_key:
padded_tracks.append(
PaddedTrackInfo(key=padded_key, bucket_name=bucket_name)
PaddedTrackInfo(
key=pad_result.padded_key, bucket_name=pad_result.bucket_name
)
)
track_index = pad_result.get("track_index")
if pad_result.get("size", 0) > 0 and track_index is not None:
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
if pad_result.size > 0:
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{pad_result.track_index}.webm"
created_padded_files.add(storage_path)
all_words = [word for words in track_words for word in words]
all_words.sort(key=lambda w: w.get("start", 0))
all_words.sort(key=lambda w: w.start)
ctx.log(
f"process_tracks complete: {len(all_words)} words from {len(input.tracks)} tracks"
@@ -382,7 +405,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
@diarization_pipeline.task(
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3,
)
@with_error_handling("mixdown_tracks")
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
@@ -463,7 +488,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
parents=[mixdown_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
)
@with_error_handling("generate_waveform")
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
@@ -529,55 +556,102 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
@diarization_pipeline.task(
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
parents=[mixdown_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
)
@with_error_handling("detect_topics")
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
"""Detect topics using parallel child workflows (one per chunk)."""
ctx.log("detect_topics: analyzing transcript for topics")
track_result = ctx.task_output(process_tracks)
words = track_result.all_words
target_language = track_result.target_language
if not words:
ctx.log("detect_topics: no words, returning empty topics")
return TopicsResult(topics=[])
# Deferred imports: Hatchet workers fork processes
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptTopic,
transcripts_controller,
)
word_objects = [Word(**w) for w in words]
transcript_type = TranscriptType(words=word_objects)
chunk_size = TOPIC_CHUNK_WORD_COUNT
chunks = []
for i in range(0, len(words), chunk_size):
chunk_words = words[i : i + chunk_size]
if not chunk_words:
continue
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
first_word = chunk_words[0]
last_word = chunk_words[-1]
timestamp = first_word.start
duration = last_word.end - timestamp
chunk_text = " ".join(w.text for w in chunk_words)
chunks.append(
{
"index": len(chunks),
"text": chunk_text,
"timestamp": timestamp,
"duration": duration,
"words": chunk_words,
}
)
if not chunks:
ctx.log("detect_topics: no chunks generated, returning empty topics")
return TopicsResult(topics=[])
ctx.log(f"detect_topics: spawning {len(chunks)} topic chunk workflows in parallel")
bulk_runs = [
topic_chunk_workflow.create_bulk_run_item(
input=TopicChunkInput(
chunk_index=chunk["index"],
chunk_text=chunk["text"],
timestamp=chunk["timestamp"],
duration=chunk["duration"],
words=chunk["words"],
)
)
for chunk in chunks
]
results = await topic_chunk_workflow.aio_run_many(bulk_runs)
topic_chunks = [
TopicChunkResult(**result["detect_chunk_topic"]) for result in results
]
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
async def on_topic_callback(data):
for chunk in topic_chunks:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
transcript=data.transcript.text,
words=data.transcript.words,
title=chunk.title,
summary=chunk.summary,
timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words),
words=[w.model_dump() for w in chunk.words],
)
if isinstance(
data, TitleSummaryWithId
): # Celery parity: main_live_pipeline.py
topic.id = data.id
await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast(
input.transcript_id, transcript, "TOPIC", topic, logger=logger
)
topics = await topic_processing.detect_topics(
transcript_type,
target_language,
on_topic_callback=on_topic_callback,
empty_pipeline=empty_pipeline,
topics_list = [
TitleSummary(
title=chunk.title,
summary=chunk.summary,
timestamp=chunk.timestamp,
duration=chunk.duration,
transcript=TranscriptType(words=chunk.words),
)
topics_list = [t.model_dump() for t in topics]
for chunk in topic_chunks
]
ctx.log(f"detect_topics complete: found {len(topics_list)} topics")
@@ -585,7 +659,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=600), retries=3
parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
)
@with_error_handling("generate_title")
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
@@ -601,8 +677,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
transcripts_controller,
)
topic_objects = [TitleSummary(**t) for t in topics]
ctx.log(f"generate_title: created {len(topic_objects)} TitleSummary objects")
ctx.log(f"generate_title: received {len(topics)} TitleSummary objects")
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
title_result = None
@@ -634,7 +709,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
ctx.log("generate_title: calling topic_processing.generate_title (LLM call)...")
await topic_processing.generate_title(
topic_objects,
topics,
on_title_callback=on_title_callback,
empty_pipeline=empty_pipeline,
logger=logger,
@@ -647,97 +722,277 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
@diarization_pipeline.task(
parents=[detect_topics], execution_timeout=timedelta(seconds=600), retries=3
parents=[detect_topics],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
)
@with_error_handling("generate_summary")
async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
ctx.log(f"generate_summary: starting for transcript_id={input.transcript_id}")
@with_error_handling("extract_subjects")
async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult:
"""Extract main subjects/topics from transcript for parallel processing."""
ctx.log(f"extract_subjects: starting for transcript_id={input.transcript_id}")
topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
ctx.log(f"generate_summary: received {len(topics)} topics from detect_topics")
if not topics:
ctx.log("extract_subjects: no topics, returning empty subjects")
return SubjectsResult(
subjects=[],
transcript_text="",
participant_names=[],
participant_name_to_id={},
)
# Deferred imports: Hatchet workers fork processes, fresh imports avoid
# sharing DB connections and LLM HTTP pools across forks
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.summary.summary_builder import ( # noqa: PLC0415
SummaryBuilder,
)
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Build transcript text from topics (same logic as TranscriptFinalSummaryProcessor)
speakermap = {}
if transcript and transcript.participants:
speakermap = {
p.speaker: p.name
for p in transcript.participants
if p.speaker is not None and p.name
}
text_lines = []
for topic in topics:
for segment in topic.transcript.as_segments():
name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}")
text_lines.append(f"{name}: {segment.text}")
transcript_text = "\n".join(text_lines)
participant_names = []
participant_name_to_id = {}
if transcript and transcript.participants:
participant_names = [p.name for p in transcript.participants if p.name]
participant_name_to_id = {
p.name: p.id for p in transcript.participants if p.name and p.id
}
# TODO: refactor SummaryBuilder methods into standalone functions
llm = LLM(settings=settings)
builder = SummaryBuilder(llm, logger=logger)
builder.set_transcript(transcript_text)
if participant_names:
builder.set_known_participants(
participant_names, participant_name_to_id=participant_name_to_id
)
ctx.log("extract_subjects: calling LLM to extract subjects")
await builder.extract_subjects()
ctx.log(f"extract_subjects complete: {len(builder.subjects)} subjects")
return SubjectsResult(
subjects=builder.subjects,
transcript_text=transcript_text,
participant_names=participant_names,
participant_name_to_id=participant_name_to_id,
)
@diarization_pipeline.task(
parents=[extract_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
retries=3,
)
@with_error_handling("process_subjects")
async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult:
"""Spawn child workflows for each subject (dynamic fan-out, parallel LLM calls)."""
subjects_result = ctx.task_output(extract_subjects)
subjects = subjects_result.subjects
if not subjects:
ctx.log("process_subjects: no subjects to process")
return ProcessSubjectsResult(subject_summaries=[])
ctx.log(f"process_subjects: spawning {len(subjects)} subject workflows in parallel")
bulk_runs = [
subject_workflow.create_bulk_run_item(
input=SubjectInput(
subject=subject,
subject_index=i,
transcript_text=subjects_result.transcript_text,
participant_names=subjects_result.participant_names,
participant_name_to_id=subjects_result.participant_name_to_id,
)
)
for i, subject in enumerate(subjects)
]
results = await subject_workflow.aio_run_many(bulk_runs)
subject_summaries = [
SubjectSummaryResult(**result["generate_detailed_summary"])
for result in results
]
ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries")
return ProcessSubjectsResult(subject_summaries=subject_summaries)
@diarization_pipeline.task(
parents=[process_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
)
@with_error_handling("generate_recap")
async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
"""Generate recap and long summary from subject summaries, save to database."""
ctx.log(f"generate_recap: starting for transcript_id={input.transcript_id}")
subjects_result = ctx.task_output(extract_subjects)
process_result = ctx.task_output(process_subjects)
# Deferred imports: Hatchet workers fork processes, fresh imports avoid
# sharing DB connections and LLM HTTP pools across forks
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptActionItems,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
transcripts_controller,
)
from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.summary.prompts import ( # noqa: PLC0415
RECAP_PROMPT,
build_participant_instructions,
build_summary_markdown,
)
topic_objects = [TitleSummary(**t) for t in topics]
ctx.log(f"generate_summary: created {len(topic_objects)} TitleSummary objects")
subject_summaries = process_result.subject_summaries
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
summary_result = None
short_summary_result = None
action_items_result = None
if not subject_summaries:
ctx.log("generate_recap: no subject summaries, returning empty")
return RecapResult(short_summary="", long_summary="")
summaries = [
{"subject": s.subject, "summary": s.paragraph_summary}
for s in subject_summaries
]
summaries_text = "\n\n".join([f"{s['subject']}: {s['summary']}" for s in summaries])
llm = LLM(settings=settings)
participant_instructions = build_participant_instructions(
subjects_result.participant_names
)
recap_prompt = RECAP_PROMPT
if participant_instructions:
recap_prompt = f"{recap_prompt}\n\n{participant_instructions}"
ctx.log("generate_recap: calling LLM for recap")
recap_response = await llm.get_response(
recap_prompt,
[summaries_text],
tone_name="Recap summarizer",
)
short_summary = str(recap_response)
long_summary = build_summary_markdown(short_summary, summaries)
async with fresh_db_connection():
ctx.log("generate_summary: DB connection established")
transcript = await transcripts_controller.get_by_id(input.transcript_id)
ctx.log(
f"generate_summary: fetched transcript, exists={transcript is not None}"
)
async def on_long_summary_callback(data):
nonlocal summary_result
ctx.log(
f"generate_summary: on_long_summary_callback received ({len(data.long_summary)} chars)"
)
summary_result = data.long_summary
final_long_summary = TranscriptFinalLongSummary(
long_summary=data.long_summary
)
if transcript:
await transcripts_controller.update(
transcript,
{"long_summary": final_long_summary.long_summary},
{
"short_summary": short_summary,
"long_summary": long_summary,
},
)
ctx.log("generate_summary: saved long_summary to DB")
await append_event_and_broadcast(
input.transcript_id,
transcript,
"FINAL_LONG_SUMMARY",
final_long_summary,
logger=logger,
)
ctx.log("generate_summary: broadcasted FINAL_LONG_SUMMARY event")
async def on_short_summary_callback(data):
nonlocal short_summary_result
ctx.log(
f"generate_summary: on_short_summary_callback received ({len(data.short_summary)} chars)"
)
short_summary_result = data.short_summary
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
await transcripts_controller.update(
transcript,
{"short_summary": final_short_summary.short_summary},
)
ctx.log("generate_summary: saved short_summary to DB")
final_short = TranscriptFinalShortSummary(short_summary=short_summary)
await append_event_and_broadcast(
input.transcript_id,
transcript,
"FINAL_SHORT_SUMMARY",
final_short_summary,
final_short,
logger=logger,
)
ctx.log("generate_summary: broadcasted FINAL_SHORT_SUMMARY event")
async def on_action_items_callback(data):
nonlocal action_items_result
ctx.log(
f"generate_summary: on_action_items_callback received ({len(data.action_items)} items)"
)
action_items_result = data.action_items
action_items = TranscriptActionItems(action_items=data.action_items)
await transcripts_controller.update(
final_long = TranscriptFinalLongSummary(long_summary=long_summary)
await append_event_and_broadcast(
input.transcript_id,
transcript,
{"action_items": action_items.action_items},
"FINAL_LONG_SUMMARY",
final_long,
logger=logger,
)
ctx.log("generate_recap complete")
return RecapResult(short_summary=short_summary, long_summary=long_summary)
@diarization_pipeline.task(
parents=[extract_subjects],
execution_timeout=timedelta(seconds=TIMEOUT_LONG),
retries=3,
)
@with_error_handling("identify_action_items")
async def identify_action_items(
input: PipelineInput, ctx: Context
) -> ActionItemsResult:
"""Identify action items from transcript (parallel with subject processing)."""
ctx.log(f"identify_action_items: starting for transcript_id={input.transcript_id}")
subjects_result = ctx.task_output(extract_subjects)
if not subjects_result.transcript_text:
ctx.log("identify_action_items: no transcript text, returning empty")
return ActionItemsResult(action_items={"decisions": [], "next_steps": []})
# Deferred imports: Hatchet workers fork processes, fresh imports avoid
# sharing DB connections and LLM HTTP pools across forks
from reflector.db.transcripts import ( # noqa: PLC0415
TranscriptActionItems,
transcripts_controller,
)
from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.summary.summary_builder import ( # noqa: PLC0415
SummaryBuilder,
)
# TODO: refactor SummaryBuilder methods into standalone functions
llm = LLM(settings=settings)
builder = SummaryBuilder(llm, logger=logger)
builder.set_transcript(subjects_result.transcript_text)
if subjects_result.participant_names:
builder.set_known_participants(
subjects_result.participant_names,
participant_name_to_id=subjects_result.participant_name_to_id,
)
ctx.log("identify_action_items: calling LLM")
action_items_response = await builder.identify_action_items()
if action_items_response is None:
raise RuntimeError("Failed to identify action items - LLM call failed")
action_items_dict = action_items_response.model_dump()
async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
action_items = TranscriptActionItems(action_items=action_items_dict)
await transcripts_controller.update(
transcript, {"action_items": action_items.action_items}
)
ctx.log("generate_summary: saved action_items to DB")
await append_event_and_broadcast(
input.transcript_id,
transcript,
@@ -745,34 +1000,18 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
action_items,
logger=logger,
)
ctx.log("generate_summary: broadcasted ACTION_ITEMS event")
ctx.log(
"generate_summary: calling topic_processing.generate_summaries (LLM calls)..."
)
await topic_processing.generate_summaries(
topic_objects,
transcript,
on_long_summary_callback=on_long_summary_callback,
on_short_summary_callback=on_short_summary_callback,
on_action_items_callback=on_action_items_callback,
empty_pipeline=empty_pipeline,
logger=logger,
)
ctx.log("generate_summary: topic_processing.generate_summaries returned")
ctx.log("generate_summary complete")
return SummaryResult(
summary=summary_result,
short_summary=short_summary_result,
action_items=action_items_result,
ctx.log(
f"identify_action_items complete: {len(action_items_dict.get('decisions', []))} decisions, "
f"{len(action_items_dict.get('next_steps', []))} next steps"
)
return ActionItemsResult(action_items=action_items_dict)
@diarization_pipeline.task(
parents=[generate_waveform, generate_title, generate_summary],
execution_timeout=timedelta(seconds=60),
parents=[generate_waveform, generate_title, generate_recap, identify_action_items],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
)
@with_error_handling("finalize")
@@ -818,8 +1057,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
if transcript is None:
raise ValueError(f"Transcript {input.transcript_id} not found in database")
word_objects = [Word(**w) for w in all_words]
merged_transcript = TranscriptType(words=word_objects, translation=None)
merged_transcript = TranscriptType(words=all_words, translation=None)
await append_event_and_broadcast(
input.transcript_id,
@@ -857,7 +1095,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
@diarization_pipeline.task(
parents=[finalize], execution_timeout=timedelta(seconds=60), retries=3
parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
)
@with_error_handling("cleanup_consent", set_error_status=False)
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
@@ -957,7 +1195,9 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
@diarization_pipeline.task(
parents=[cleanup_consent], execution_timeout=timedelta(seconds=60), retries=5
parents=[cleanup_consent],
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=5,
)
@with_error_handling("post_zulip", set_error_status=False)
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
@@ -982,7 +1222,9 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
@diarization_pipeline.task(
parents=[post_zulip], execution_timeout=timedelta(seconds=120), retries=30
parents=[post_zulip],
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=30,
)
@with_error_handling("send_webhook", set_error_status=False)
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:

View File

@@ -5,13 +5,20 @@ Provides static typing for all task outputs, enabling type checking
and better IDE support.
"""
from typing import Any
from pydantic import BaseModel
from reflector.processors.types import TitleSummary, Word
from reflector.utils.string import NonEmptyString
class ParticipantInfo(BaseModel):
"""Participant info with speaker index for workflow result."""
participant_id: NonEmptyString
user_name: NonEmptyString
speaker: int
class PadTrackResult(BaseModel):
"""Result from pad_track task."""
@@ -26,7 +33,7 @@ class PadTrackResult(BaseModel):
class TranscribeTrackResult(BaseModel):
"""Result from transcribe_track task."""
words: list[dict[str, Any]]
words: list[Word]
track_index: int
@@ -41,7 +48,7 @@ class RecordingResult(BaseModel):
class ParticipantsResult(BaseModel):
"""Result from get_participants task."""
participants: list[dict[str, Any]]
participants: list[ParticipantInfo]
num_tracks: int
source_language: NonEmptyString
target_language: NonEmptyString
@@ -57,7 +64,7 @@ class PaddedTrackInfo(BaseModel):
class ProcessTracksResult(BaseModel):
"""Result from process_tracks task."""
all_words: list[dict[str, Any]]
all_words: list[Word]
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
word_count: int
num_tracks: int
@@ -79,10 +86,21 @@ class WaveformResult(BaseModel):
waveform_generated: bool
class TopicChunkResult(BaseModel):
"""Result from topic chunk child workflow."""
chunk_index: int
title: str
summary: str
timestamp: float
duration: float
words: list[Word]
class TopicsResult(BaseModel):
"""Result from detect_topics task."""
topics: list[dict[str, Any]]
topics: list[TitleSummary]
class TitleResult(BaseModel):
@@ -91,12 +109,41 @@ class TitleResult(BaseModel):
title: str | None
class SummaryResult(BaseModel):
"""Result from generate_summary task."""
class SubjectsResult(BaseModel):
"""Result from extract_subjects task."""
summary: str | None
short_summary: str | None
action_items: dict | None = None
subjects: list[str]
transcript_text: str # Formatted transcript for LLM consumption
participant_names: list[str]
participant_name_to_id: dict[str, str]
class SubjectSummaryResult(BaseModel):
"""Result from subject summary child workflow."""
subject: str
subject_index: int
detailed_summary: str
paragraph_summary: str
class ProcessSubjectsResult(BaseModel):
"""Result from process_subjects fan-out task."""
subject_summaries: list[SubjectSummaryResult]
class RecapResult(BaseModel):
"""Result from generate_recap task."""
short_summary: str # Recap paragraph
long_summary: str # Full markdown summary
class ActionItemsResult(BaseModel):
"""Result from identify_action_items task."""
action_items: dict # ActionItemsResponse as dict (may have empty lists)
class FinalizeResult(BaseModel):

View File

@@ -0,0 +1,107 @@
"""
Hatchet child workflow: SubjectProcessing
Handles individual subject/topic summary generation.
Spawned dynamically by the main diarization pipeline for each extracted subject
via aio_run_many() for parallel processing.
"""
from datetime import timedelta
from hatchet_sdk import Context
from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM
from reflector.hatchet.workflows.models import SubjectSummaryResult
from reflector.logger import logger
from reflector.processors.summary.prompts import (
DETAILED_SUBJECT_PROMPT_TEMPLATE,
PARAGRAPH_SUMMARY_PROMPT,
build_participant_instructions,
)
class SubjectInput(BaseModel):
"""Input for individual subject processing."""
subject: str
subject_index: int
transcript_text: str
participant_names: list[str]
participant_name_to_id: dict[str, str]
hatchet = HatchetClientManager.get_client()
subject_workflow = hatchet.workflow(
name="SubjectProcessing", input_validator=SubjectInput
)
@subject_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
retries=3,
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)],
)
async def generate_detailed_summary(
input: SubjectInput, ctx: Context
) -> SubjectSummaryResult:
"""Generate detailed analysis for a single subject, then condense to paragraph."""
ctx.log(
f"generate_detailed_summary: subject '{input.subject}' (index {input.subject_index})"
)
logger.info(
"[Hatchet] generate_detailed_summary",
subject=input.subject,
subject_index=input.subject_index,
)
# Deferred imports: Hatchet workers fork processes, fresh imports ensure
# LLM HTTP connection pools aren't shared across forks
from reflector.llm import LLM # noqa: PLC0415
from reflector.settings import settings # noqa: PLC0415
llm = LLM(settings=settings)
participant_instructions = build_participant_instructions(input.participant_names)
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=input.subject)
if participant_instructions:
detailed_prompt = f"{detailed_prompt}\n\n{participant_instructions}"
ctx.log("generate_detailed_summary: calling LLM for detailed analysis")
detailed_response = await llm.get_response(
detailed_prompt,
[input.transcript_text],
tone_name="Topic assistant",
)
detailed_summary = str(detailed_response)
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
if participant_instructions:
paragraph_prompt = f"{paragraph_prompt}\n\n{participant_instructions}"
ctx.log("generate_detailed_summary: calling LLM for paragraph summary")
paragraph_response = await llm.get_response(
paragraph_prompt,
[detailed_summary],
tone_name="Topic summarizer",
)
paragraph_summary = str(paragraph_response)
ctx.log(f"generate_detailed_summary complete: subject '{input.subject}'")
logger.info(
"[Hatchet] generate_detailed_summary complete",
subject=input.subject,
subject_index=input.subject_index,
detailed_len=len(detailed_summary),
paragraph_len=len(paragraph_summary),
)
return SubjectSummaryResult(
subject=input.subject,
subject_index=input.subject_index,
detailed_summary=detailed_summary,
paragraph_summary=paragraph_summary,
)

View File

@@ -0,0 +1,91 @@
"""
Hatchet child workflow: TopicChunkProcessing
Handles topic detection for individual transcript chunks.
Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
"""
from datetime import timedelta
from hatchet_sdk import Context
from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_SHORT
from reflector.hatchet.workflows.models import TopicChunkResult
from reflector.logger import logger
from reflector.processors.prompts import TOPIC_PROMPT
from reflector.processors.types import Word
class TopicChunkInput(BaseModel):
"""Input for individual topic chunk processing."""
chunk_index: int
chunk_text: str
timestamp: float
duration: float
words: list[Word]
hatchet = HatchetClientManager.get_client()
topic_chunk_workflow = hatchet.workflow(
name="TopicChunkProcessing", input_validator=TopicChunkInput
)
@topic_chunk_workflow.task(
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
retries=3,
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)],
)
async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult:
"""Detect topic for a single transcript chunk."""
ctx.log(f"detect_chunk_topic: chunk {input.chunk_index}")
logger.info(
"[Hatchet] detect_chunk_topic",
chunk_index=input.chunk_index,
text_length=len(input.chunk_text),
)
# Deferred imports: Hatchet workers fork processes, fresh imports avoid
# sharing LLM HTTP connection pools across forks
from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.transcript_topic_detector import ( # noqa: PLC0415
TopicResponse,
)
from reflector.settings import settings # noqa: PLC0415
from reflector.utils.text import clean_title # noqa: PLC0415
llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
prompt = TOPIC_PROMPT.format(text=input.chunk_text)
response = await llm.get_structured_response(
prompt,
[input.chunk_text],
TopicResponse,
tone_name="Topic analyzer",
timeout=settings.LLM_STRUCTURED_RESPONSE_TIMEOUT,
)
title = clean_title(response.title)
ctx.log(
f"detect_chunk_topic complete: chunk {input.chunk_index}, title='{title[:50]}'"
)
logger.info(
"[Hatchet] detect_chunk_topic complete",
chunk_index=input.chunk_index,
title=title[:50],
)
return TopicChunkResult(
chunk_index=input.chunk_index,
title=title,
summary=response.summary,
timestamp=input.timestamp,
duration=input.duration,
words=input.words,
)

View File

@@ -23,6 +23,7 @@ from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.constants import TIMEOUT_AUDIO, TIMEOUT_HEAVY
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
@@ -47,7 +48,7 @@ hatchet = HatchetClientManager.get_client()
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment.
@@ -153,7 +154,7 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
@track_workflow.task(
parents=[pad_track], execution_timeout=timedelta(seconds=600), retries=3
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
)
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
@@ -197,23 +198,20 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
transcript = await transcribe_file_with_processor(audio_url, input.language)
# Tag all words with speaker index
words = []
for word in transcript.words:
word_dict = word.model_dump()
word_dict["speaker"] = input.track_index
words.append(word_dict)
word.speaker = input.track_index
ctx.log(
f"transcribe_track complete: track {input.track_index}, {len(words)} words"
f"transcribe_track complete: track {input.track_index}, {len(transcript.words)} words"
)
logger.info(
"[Hatchet] transcribe_track complete",
track_index=input.track_index,
word_count=len(words),
word_count=len(transcript.words),
)
return TranscribeTrackResult(
words=words,
words=transcript.words,
track_index=input.track_index,
)