feat: reduce Hatchet payload size by removing words from topic chunk workflows

Remove ~6.5MB of redundant Word data from Hatchet task boundaries:
- Remove words from TopicChunkInput/TopicChunkResult (child workflow I/O)
- detect_topics maps words from local chunks by chunk_index instead
- TopicsResult carries empty transcript words (persisted to DB already)
- extract_subjects refetches topics from DB instead of task output
- Clear topics at detect_topics start for retry idempotency
This commit is contained in:
Igor Loskutov
2026-02-10 18:59:41 -05:00
parent b468427f1b
commit 26f1e5f6dd
4 changed files with 204 additions and 13 deletions

View File

@@ -720,7 +720,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
chunk_text=chunk["text"], chunk_text=chunk["text"],
timestamp=chunk["timestamp"], timestamp=chunk["timestamp"],
duration=chunk["duration"], duration=chunk["duration"],
words=chunk["words"],
) )
) )
for chunk in chunks for chunk in chunks
@@ -732,31 +731,41 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results
] ]
# Build index-to-words map from local chunks (words not in child workflow results)
chunks_by_index = {chunk["index"]: chunk["words"] for chunk in chunks}
async with fresh_db_connection(): async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript: if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found") raise ValueError(f"Transcript {input.transcript_id} not found")
# Clear topics for idempotency on retry (each topic gets a fresh UUID,
# so upsert_topic would append duplicates without this)
await transcripts_controller.update(transcript, {"topics": []})
for chunk in topic_chunks: for chunk in topic_chunks:
chunk_words = chunks_by_index[chunk.chunk_index]
topic = TranscriptTopic( topic = TranscriptTopic(
title=chunk.title, title=chunk.title,
summary=chunk.summary, summary=chunk.summary,
timestamp=chunk.timestamp, timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words), transcript=" ".join(w.text for w in chunk_words),
words=chunk.words, words=chunk_words,
) )
await transcripts_controller.upsert_topic(transcript, topic) await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast( await append_event_and_broadcast(
input.transcript_id, transcript, "TOPIC", topic, logger=logger input.transcript_id, transcript, "TOPIC", topic, logger=logger
) )
# Words omitted from TopicsResult — already persisted to DB above.
# Downstream tasks that need words refetch from DB.
topics_list = [ topics_list = [
TitleSummary( TitleSummary(
title=chunk.title, title=chunk.title,
summary=chunk.summary, summary=chunk.summary,
timestamp=chunk.timestamp, timestamp=chunk.timestamp,
duration=chunk.duration, duration=chunk.duration,
transcript=TranscriptType(words=chunk.words), transcript=TranscriptType(words=[]),
) )
for chunk in topic_chunks for chunk in topic_chunks
] ]
@@ -842,9 +851,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
ctx.log(f"extract_subjects: starting for transcript_id={input.transcript_id}") ctx.log(f"extract_subjects: starting for transcript_id={input.transcript_id}")
topics_result = ctx.task_output(detect_topics) topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics
if not topics: if not topics_result.topics:
ctx.log("extract_subjects: no topics, returning empty subjects") ctx.log("extract_subjects: no topics, returning empty subjects")
return SubjectsResult( return SubjectsResult(
subjects=[], subjects=[],
@@ -857,11 +865,13 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
# sharing DB connections and LLM HTTP pools across forks # sharing DB connections and LLM HTTP pools across forks
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.llm import LLM # noqa: PLC0415 from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.types import words_to_segments # noqa: PLC0415
async with fresh_db_connection(): async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Build transcript text from topics (same logic as TranscriptFinalSummaryProcessor) # Build transcript text from DB topics (words omitted from task output
# to reduce Hatchet payload size — refetch from DB where they were persisted)
speakermap = {} speakermap = {}
if transcript and transcript.participants: if transcript and transcript.participants:
speakermap = { speakermap = {
@@ -871,8 +881,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
} }
text_lines = [] text_lines = []
for topic in topics: for db_topic in transcript.topics:
for segment in topic.transcript.as_segments(): for segment in words_to_segments(db_topic.words):
name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}") name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}")
text_lines.append(f"{name}: {segment.text}") text_lines.append(f"{name}: {segment.text}")

View File

@@ -95,7 +95,6 @@ class TopicChunkResult(BaseModel):
summary: str summary: str
timestamp: float timestamp: float
duration: float duration: float
words: list[Word]
class TopicsResult(BaseModel): class TopicsResult(BaseModel):

View File

@@ -20,7 +20,6 @@ from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM
from reflector.hatchet.workflows.models import TopicChunkResult from reflector.hatchet.workflows.models import TopicChunkResult
from reflector.logger import logger from reflector.logger import logger
from reflector.processors.prompts import TOPIC_PROMPT from reflector.processors.prompts import TOPIC_PROMPT
from reflector.processors.types import Word
class TopicChunkInput(BaseModel): class TopicChunkInput(BaseModel):
@@ -30,7 +29,6 @@ class TopicChunkInput(BaseModel):
chunk_text: str chunk_text: str
timestamp: float timestamp: float
duration: float duration: float
words: list[Word]
hatchet = HatchetClientManager.get_client() hatchet = HatchetClientManager.get_client()
@@ -99,5 +97,4 @@ async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunk
summary=response.summary, summary=response.summary,
timestamp=input.timestamp, timestamp=input.timestamp,
duration=input.duration, duration=input.duration,
words=input.words,
) )

View File

@@ -0,0 +1,185 @@
"""
Tests for Hatchet payload thinning optimizations.
Verifies that:
1. TopicChunkInput no longer carries words
2. TopicChunkResult no longer carries words
3. words_to_segments() matches Transcript.as_segments(is_multitrack=False) — behavioral equivalence
for the extract_subjects refactoring
4. TopicsResult can be constructed with empty transcript words
"""
from reflector.hatchet.workflows.models import TopicChunkResult
from reflector.hatchet.workflows.topic_chunk_processing import TopicChunkInput
from reflector.processors.types import Word
def _make_words(speaker: int = 0, start: float = 0.0) -> list[Word]:
return [
Word(text="Hello", start=start, end=start + 0.5, speaker=speaker),
Word(text=" world.", start=start + 0.5, end=start + 1.0, speaker=speaker),
]
class TestTopicChunkInputNoWords:
"""TopicChunkInput must not have a words field."""
def test_no_words_field(self):
assert "words" not in TopicChunkInput.model_fields
def test_construction_without_words(self):
inp = TopicChunkInput(
chunk_index=0, chunk_text="Hello world.", timestamp=0.0, duration=1.0
)
assert inp.chunk_index == 0
assert inp.chunk_text == "Hello world."
def test_rejects_words_kwarg(self):
"""Passing words= should raise a validation error (field doesn't exist)."""
import pydantic
try:
TopicChunkInput(
chunk_index=0,
chunk_text="text",
timestamp=0.0,
duration=1.0,
words=_make_words(),
)
# If pydantic is configured to ignore extra, this won't raise.
# Verify the field is still absent from the model.
assert "words" not in TopicChunkInput.model_fields
except pydantic.ValidationError:
pass # Expected
class TestTopicChunkResultNoWords:
"""TopicChunkResult must not have a words field."""
def test_no_words_field(self):
assert "words" not in TopicChunkResult.model_fields
def test_construction_without_words(self):
result = TopicChunkResult(
chunk_index=0,
title="Test",
summary="Summary",
timestamp=0.0,
duration=1.0,
)
assert result.title == "Test"
assert result.chunk_index == 0
def test_serialization_roundtrip(self):
"""Serialized TopicChunkResult has no words key."""
result = TopicChunkResult(
chunk_index=0,
title="Test",
summary="Summary",
timestamp=0.0,
duration=1.0,
)
data = result.model_dump()
assert "words" not in data
reconstructed = TopicChunkResult(**data)
assert reconstructed == result
class TestWordsToSegmentsBehavioralEquivalence:
"""words_to_segments() must produce same output as Transcript.as_segments(is_multitrack=False).
This ensures the extract_subjects refactoring (from task output topic.transcript.as_segments()
to words_to_segments(db_topic.words)) preserves identical behavior.
"""
def test_single_speaker(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments
words = _make_words(speaker=0)
direct = words_to_segments(words)
via_transcript = TranscriptType(words=words).as_segments(is_multitrack=False)
assert len(direct) == len(via_transcript)
for d, v in zip(direct, via_transcript):
assert d.text == v.text
assert d.speaker == v.speaker
assert d.start == v.start
assert d.end == v.end
def test_multiple_speakers(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments
words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world.", start=0.5, end=1.0, speaker=0),
Word(text=" How", start=1.0, end=1.5, speaker=1),
Word(text=" are", start=1.5, end=2.0, speaker=1),
Word(text=" you?", start=2.0, end=2.5, speaker=1),
]
direct = words_to_segments(words)
via_transcript = TranscriptType(words=words).as_segments(is_multitrack=False)
assert len(direct) == len(via_transcript)
for d, v in zip(direct, via_transcript):
assert d.text == v.text
assert d.speaker == v.speaker
def test_empty_words(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments
assert words_to_segments([]) == []
assert TranscriptType(words=[]).as_segments(is_multitrack=False) == []
class TestTopicsResultEmptyWords:
"""TopicsResult can carry topics with empty transcript words."""
def test_construction_with_empty_words(self):
from reflector.hatchet.workflows.models import TopicsResult
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType
topics = [
TitleSummary(
title="Topic A",
summary="Summary A",
timestamp=0.0,
duration=5.0,
transcript=TranscriptType(words=[]),
),
TitleSummary(
title="Topic B",
summary="Summary B",
timestamp=5.0,
duration=5.0,
transcript=TranscriptType(words=[]),
),
]
result = TopicsResult(topics=topics)
assert len(result.topics) == 2
for t in result.topics:
assert t.transcript.words == []
def test_serialization_roundtrip(self):
from reflector.hatchet.workflows.models import TopicsResult
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType
topics = [
TitleSummary(
title="Topic",
summary="Summary",
timestamp=0.0,
duration=1.0,
transcript=TranscriptType(words=[]),
)
]
result = TopicsResult(topics=topics)
data = result.model_dump()
reconstructed = TopicsResult(**data)
assert len(reconstructed.topics) == 1
assert reconstructed.topics[0].transcript.words == []