diff --git a/server/reflector/hatchet/broadcast.py b/server/reflector/hatchet/broadcast.py index 58b05f11..6b42ddbd 100644 --- a/server/reflector/hatchet/broadcast.py +++ b/server/reflector/hatchet/broadcast.py @@ -81,7 +81,8 @@ async def set_status_and_broadcast( async def append_event_and_broadcast( transcript_id: NonEmptyString, transcript: Transcript, - event_name: str, + event_name: NonEmptyString, + # TODO proper dictionary event => type data: Any, logger: structlog.BoundLogger, ) -> TranscriptEvent: diff --git a/server/reflector/hatchet/constants.py b/server/reflector/hatchet/constants.py index 0f16d6f5..fbe6d25b 100644 --- a/server/reflector/hatchet/constants.py +++ b/server/reflector/hatchet/constants.py @@ -2,6 +2,31 @@ Hatchet workflow constants. """ +from enum import StrEnum + + +class TaskName(StrEnum): + GET_RECORDING = "get_recording" + GET_PARTICIPANTS = "get_participants" + PROCESS_TRACKS = "process_tracks" + MIXDOWN_TRACKS = "mixdown_tracks" + GENERATE_WAVEFORM = "generate_waveform" + DETECT_TOPICS = "detect_topics" + GENERATE_TITLE = "generate_title" + EXTRACT_SUBJECTS = "extract_subjects" + PROCESS_SUBJECTS = "process_subjects" + GENERATE_RECAP = "generate_recap" + IDENTIFY_ACTION_ITEMS = "identify_action_items" + FINALIZE = "finalize" + CLEANUP_CONSENT = "cleanup_consent" + POST_ZULIP = "post_zulip" + SEND_WEBHOOK = "send_webhook" + PAD_TRACK = "pad_track" + TRANSCRIBE_TRACK = "transcribe_track" + DETECT_CHUNK_TOPIC = "detect_chunk_topic" + GENERATE_DETAILED_SUMMARY = "generate_detailed_summary" + + # Rate limit key for LLM API calls (shared across all LLM-calling tasks) LLM_RATE_LIMIT_KEY = "llm" diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index b7b0559e..8819e21b 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -13,10 +13,11 @@ import asyncio import functools import json import tempfile +import time from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Coroutine, TypeVar +from typing import Any, Callable, Coroutine, Protocol, TypeVar import httpx from hatchet_sdk import Context @@ -34,6 +35,7 @@ from reflector.hatchet.constants import ( TIMEOUT_LONG, TIMEOUT_MEDIUM, TIMEOUT_SHORT, + TaskName, ) from reflector.hatchet.workflows.models import ( ActionItemsResult, @@ -70,6 +72,13 @@ from reflector.hatchet.workflows.track_processing import TrackInput, track_workf from reflector.logger import logger from reflector.pipelines import topic_processing from reflector.processors import AudioFileWriterProcessor +from reflector.processors.summary.models import ActionItemsResponse +from reflector.processors.summary.prompts import ( + RECAP_PROMPT, + build_participant_instructions, + build_summary_markdown, +) +from reflector.processors.summary.summary_builder import SummaryBuilder from reflector.processors.types import TitleSummary, Word from reflector.processors.types import Transcript as TranscriptType from reflector.settings import settings @@ -162,11 +171,50 @@ def _spawn_storage(): ) +class Loggable(Protocol): + """Protocol for objects with a log method.""" + + def log(self, message: str) -> None: ... + + +def make_audio_progress_logger( + ctx: Loggable, task_name: TaskName, interval: float = 5.0 +) -> Callable[[float | None, float], None]: + """Create a throttled progress logger callback for audio processing. + + Args: + ctx: Object with .log() method (e.g., Hatchet Context). + task_name: Name to prefix in log messages. + interval: Minimum seconds between log messages. + + Returns: + Callback(progress_pct, audio_position) that logs at most every `interval` seconds. + """ + start_time = time.monotonic() + last_log_time = [start_time] + + def callback(progress_pct: float | None, audio_position: float) -> None: + now = time.monotonic() + if now - last_log_time[0] >= interval: + elapsed = now - start_time + if progress_pct is not None: + ctx.log( + f"{task_name} progress: {progress_pct:.1f}% @ {audio_position:.1f}s (elapsed: {elapsed:.1f}s)" + ) + else: + ctx.log( + f"{task_name} progress: @ {audio_position:.1f}s (elapsed: {elapsed:.1f}s)" + ) + last_log_time[0] = now + + return callback + + R = TypeVar("R") def with_error_handling( - step_name: str, set_error_status: bool = True + step_name: TaskName, set_error_status: bool = True ) -> Callable[ [Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]], Callable[[PipelineInput, Context], Coroutine[Any, Any, R]], @@ -204,7 +252,7 @@ def with_error_handling( @diarization_pipeline.task( execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 ) -@with_error_handling("get_recording") +@with_error_handling(TaskName.GET_RECORDING) async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: """Fetch recording metadata from Daily.co API.""" ctx.log(f"get_recording: starting for recording_id={input.recording_id}") @@ -259,7 +307,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3, ) -@with_error_handling("get_participants") +@with_error_handling(TaskName.GET_PARTICIPANTS) async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult: """Fetch participant list from Daily.co API and update transcript in database.""" ctx.log(f"get_participants: transcript_id={input.transcript_id}") @@ -350,7 +398,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) -@with_error_handling("process_tracks") +@with_error_handling(TaskName.PROCESS_TRACKS) async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult: """Spawn child workflows for each track (dynamic fan-out).""" ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows") @@ -380,10 +428,10 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes created_padded_files = set() for result in results: - transcribe_result = TranscribeTrackResult(**result["transcribe_track"]) + transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK]) track_words.append(transcribe_result.words) - pad_result = PadTrackResult(**result["pad_track"]) + pad_result = PadTrackResult(**result[TaskName.PAD_TRACK]) # Store S3 key info (not presigned URL) - consumer tasks presign on demand if pad_result.padded_key: @@ -419,7 +467,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3, ) -@with_error_handling("mixdown_tracks") +@with_error_handling(TaskName.MIXDOWN_TRACKS) async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: """Mix all padded tracks into single audio file using PyAV (same as Celery).""" ctx.log("mixdown_tracks: mixing padded tracks into single audio file") @@ -480,6 +528,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: target_sample_rate, offsets_seconds=None, logger=logger, + progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS), + expected_duration_sec=recording_duration if recording_duration > 0 else None, ) await writer.flush() @@ -514,7 +564,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=3, ) -@with_error_handling("generate_waveform") +@with_error_handling(TaskName.GENERATE_WAVEFORM) async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult: """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery).""" ctx.log(f"generate_waveform: transcript_id={input.transcript_id}") @@ -582,7 +632,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) -@with_error_handling("detect_topics") +@with_error_handling(TaskName.DETECT_TOPICS) async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: """Detect topics using parallel child workflows (one per chunk).""" ctx.log("detect_topics: analyzing transcript for topics") @@ -645,7 +695,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: results = await topic_chunk_workflow.aio_run_many(bulk_runs) topic_chunks = [ - TopicChunkResult(**result["detect_chunk_topic"]) for result in results + TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results ] async with fresh_db_connection(): @@ -687,7 +737,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) -@with_error_handling("generate_title") +@with_error_handling(TaskName.GENERATE_TITLE) async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: """Generate meeting title using LLM and save to database (matches Celery on_title callback).""" ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}") @@ -752,7 +802,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=3, ) -@with_error_handling("extract_subjects") +@with_error_handling(TaskName.EXTRACT_SUBJECTS) async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult: """Extract main subjects/topics from transcript for parallel processing.""" ctx.log(f"extract_subjects: starting for transcript_id={input.transcript_id}") @@ -773,9 +823,6 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult # sharing DB connections and LLM HTTP pools across forks from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 from reflector.llm import LLM # noqa: PLC0415 - from reflector.processors.summary.summary_builder import ( # noqa: PLC0415 - SummaryBuilder, - ) async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -833,7 +880,7 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3, ) -@with_error_handling("process_subjects") +@with_error_handling(TaskName.PROCESS_SUBJECTS) async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult: """Spawn child workflows for each subject (dynamic fan-out, parallel LLM calls).""" subjects_result = ctx.task_output(extract_subjects) @@ -861,7 +908,7 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject results = await subject_workflow.aio_run_many(bulk_runs) subject_summaries = [ - SubjectSummaryResult(**result["generate_detailed_summary"]) + SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY]) for result in results ] @@ -875,7 +922,7 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=3, ) -@with_error_handling("generate_recap") +@with_error_handling(TaskName.GENERATE_RECAP) async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: """Generate recap and long summary from subject summaries, save to database.""" ctx.log(f"generate_recap: starting for transcript_id={input.transcript_id}") @@ -891,11 +938,6 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: transcripts_controller, ) from reflector.llm import LLM # noqa: PLC0415 - from reflector.processors.summary.prompts import ( # noqa: PLC0415 - RECAP_PROMPT, - build_participant_instructions, - build_summary_markdown, - ) subject_summaries = process_result.subject_summaries @@ -969,7 +1011,7 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult: execution_timeout=timedelta(seconds=TIMEOUT_LONG), retries=3, ) -@with_error_handling("identify_action_items") +@with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS) async def identify_action_items( input: PipelineInput, ctx: Context ) -> ActionItemsResult: @@ -980,7 +1022,7 @@ async def identify_action_items( if not subjects_result.transcript_text: ctx.log("identify_action_items: no transcript text, returning empty") - return ActionItemsResult(action_items={"decisions": [], "next_steps": []}) + return ActionItemsResult(action_items=ActionItemsResponse()) # Deferred imports: Hatchet workers fork processes, fresh imports avoid # sharing DB connections and LLM HTTP pools across forks @@ -989,9 +1031,6 @@ async def identify_action_items( transcripts_controller, ) from reflector.llm import LLM # noqa: PLC0415 - from reflector.processors.summary.summary_builder import ( # noqa: PLC0415 - SummaryBuilder, - ) # TODO: refactor SummaryBuilder methods into standalone functions llm = LLM(settings=settings) @@ -1010,11 +1049,11 @@ async def identify_action_items( if action_items_response is None: raise RuntimeError("Failed to identify action items - LLM call failed") - action_items_dict = action_items_response.model_dump() - async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript: + # Serialize to dict for DB storage and WebSocket broadcast + action_items_dict = action_items_response.model_dump() action_items = TranscriptActionItems(action_items=action_items_dict) await transcripts_controller.update( transcript, {"action_items": action_items.action_items} @@ -1028,11 +1067,11 @@ async def identify_action_items( ) ctx.log( - f"identify_action_items complete: {len(action_items_dict.get('decisions', []))} decisions, " - f"{len(action_items_dict.get('next_steps', []))} next steps" + f"identify_action_items complete: {len(action_items_response.decisions)} decisions, " + f"{len(action_items_response.next_steps)} next steps" ) - return ActionItemsResult(action_items=action_items_dict) + return ActionItemsResult(action_items=action_items_response) @diarization_pipeline.task( @@ -1040,7 +1079,7 @@ async def identify_action_items( execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3, ) -@with_error_handling("finalize") +@with_error_handling(TaskName.FINALIZE) async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: """Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'. @@ -1123,7 +1162,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: @diarization_pipeline.task( parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3 ) -@with_error_handling("cleanup_consent", set_error_status=False) +@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False) async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: """Check consent and delete audio files if any participant denied.""" ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}") @@ -1225,7 +1264,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=5, ) -@with_error_handling("post_zulip", set_error_status=False) +@with_error_handling(TaskName.POST_ZULIP, set_error_status=False) async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: """Post notification to Zulip.""" ctx.log(f"post_zulip: transcript_id={input.transcript_id}") @@ -1252,7 +1291,7 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM), retries=30, ) -@with_error_handling("send_webhook", set_error_status=False) +@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False) async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: """Send completion webhook to external service.""" ctx.log(f"send_webhook: transcript_id={input.transcript_id}") diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index c2a98901..1bad1f4a 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -7,6 +7,7 @@ and better IDE support. from pydantic import BaseModel +from reflector.processors.summary.models import ActionItemsResponse from reflector.processors.types import TitleSummary, Word from reflector.utils.string import NonEmptyString @@ -143,7 +144,7 @@ class RecapResult(BaseModel): class ActionItemsResult(BaseModel): """Result from identify_action_items task.""" - action_items: dict # ActionItemsResponse as dict (may have empty lists) + action_items: ActionItemsResponse class FinalizeResult(BaseModel): diff --git a/server/reflector/processors/summary/models.py b/server/reflector/processors/summary/models.py new file mode 100644 index 00000000..6f65d7f0 --- /dev/null +++ b/server/reflector/processors/summary/models.py @@ -0,0 +1,50 @@ +"""Pydantic models for summary processing.""" + +from pydantic import BaseModel, Field + + +class ActionItem(BaseModel): + """A single action item from the meeting""" + + task: str = Field(description="The task or action item to be completed") + assigned_to: str | None = Field( + default=None, description="Person or team assigned to this task (name)" + ) + assigned_to_participant_id: str | None = Field( + default=None, description="Participant ID if assigned_to matches a participant" + ) + deadline: str | None = Field( + default=None, description="Deadline or timeframe mentioned for this task" + ) + context: str | None = Field( + default=None, description="Additional context or notes about this task" + ) + + +class Decision(BaseModel): + """A decision made during the meeting""" + + decision: str = Field(description="What was decided") + rationale: str | None = Field( + default=None, + description="Reasoning or key factors that influenced this decision", + ) + decided_by: str | None = Field( + default=None, description="Person or group who made the decision (name)" + ) + decided_by_participant_id: str | None = Field( + default=None, description="Participant ID if decided_by matches a participant" + ) + + +class ActionItemsResponse(BaseModel): + """Pydantic model for identified action items""" + + decisions: list[Decision] = Field( + default_factory=list, + description="List of decisions made during the meeting", + ) + next_steps: list[ActionItem] = Field( + default_factory=list, + description="List of action items and next steps to be taken", + ) diff --git a/server/reflector/processors/summary/summary_builder.py b/server/reflector/processors/summary/summary_builder.py index f89e3730..fadcfa23 100644 --- a/server/reflector/processors/summary/summary_builder.py +++ b/server/reflector/processors/summary/summary_builder.py @@ -15,6 +15,7 @@ import structlog from pydantic import BaseModel, Field from reflector.llm import LLM +from reflector.processors.summary.models import ActionItemsResponse from reflector.processors.summary.prompts import ( DETAILED_SUBJECT_PROMPT_TEMPLATE, PARAGRAPH_SUMMARY_PROMPT, @@ -148,53 +149,6 @@ class SubjectsResponse(BaseModel): ) -class ActionItem(BaseModel): - """A single action item from the meeting""" - - task: str = Field(description="The task or action item to be completed") - assigned_to: str | None = Field( - default=None, description="Person or team assigned to this task (name)" - ) - assigned_to_participant_id: str | None = Field( - default=None, description="Participant ID if assigned_to matches a participant" - ) - deadline: str | None = Field( - default=None, description="Deadline or timeframe mentioned for this task" - ) - context: str | None = Field( - default=None, description="Additional context or notes about this task" - ) - - -class Decision(BaseModel): - """A decision made during the meeting""" - - decision: str = Field(description="What was decided") - rationale: str | None = Field( - default=None, - description="Reasoning or key factors that influenced this decision", - ) - decided_by: str | None = Field( - default=None, description="Person or group who made the decision (name)" - ) - decided_by_participant_id: str | None = Field( - default=None, description="Participant ID if decided_by matches a participant" - ) - - -class ActionItemsResponse(BaseModel): - """Pydantic model for identified action items""" - - decisions: list[Decision] = Field( - default_factory=list, - description="List of decisions made during the meeting", - ) - next_steps: list[ActionItem] = Field( - default_factory=list, - description="List of action items and next steps to be taken", - ) - - class SummaryBuilder: def __init__(self, llm: LLM, filename: str | None = None, logger=None) -> None: self.transcript: str | None = None diff --git a/server/reflector/utils/audio_mixdown.py b/server/reflector/utils/audio_mixdown.py index 61654bb3..8e72cf1c 100644 --- a/server/reflector/utils/audio_mixdown.py +++ b/server/reflector/utils/audio_mixdown.py @@ -43,6 +43,8 @@ async def mixdown_tracks_pyav( target_sample_rate: int, offsets_seconds: list[float] | None = None, logger=None, + progress_callback=None, + expected_duration_sec: float | None = None, ) -> None: """Multi-track mixdown using PyAV filter graph (amix). @@ -57,6 +59,10 @@ async def mixdown_tracks_pyav( If provided, must have same length as track_urls. Delays are relative to the minimum offset (earliest track has delay=0). logger: Optional logger instance + progress_callback: Optional callback(progress_pct: float | None, audio_position: float) + called on progress updates. progress_pct is 0-100 if duration known, None otherwise. + audio_position is current position in seconds. + expected_duration_sec: Optional fallback duration if container metadata unavailable. Raises: ValueError: If offsets_seconds length doesn't match track_urls, @@ -171,6 +177,17 @@ async def mixdown_tracks_pyav( logger.error("Mixdown failed - no valid containers opened") raise ValueError("Mixdown failed: Could not open any track containers") + # Calculate total duration for progress reporting. + # Try container metadata first, fall back to expected_duration_sec if provided. + max_duration_sec = 0.0 + for c in containers: + if c.duration is not None: + dur_sec = c.duration / av.time_base + max_duration_sec = max(max_duration_sec, dur_sec) + if max_duration_sec == 0.0 and expected_duration_sec: + max_duration_sec = expected_duration_sec + current_max_time = 0.0 + decoders = [c.decode(audio=0) for c in containers] active = [True] * len(decoders) resamplers = [ @@ -192,6 +209,18 @@ async def mixdown_tracks_pyav( if frame.sample_rate != target_sample_rate: continue + + # Update progress based on frame timestamp + if progress_callback and frame.time is not None: + current_max_time = max(current_max_time, frame.time) + if max_duration_sec > 0: + progress_pct = min( + 100.0, (current_max_time / max_duration_sec) * 100 + ) + else: + progress_pct = None # Duration unavailable + progress_callback(progress_pct, current_max_time) + out_frames = resamplers[i].resample(frame) or [] for rf in out_frames: rf.sample_rate = target_sample_rate