diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 0f73fb1f..41204df6 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -183,16 +183,6 @@ class TranscriptEvent(BaseModel): data: dict -class PipelineProgressData(BaseModel): - """Data payload for PIPELINE_PROGRESS WebSocket events.""" - - workflow_id: str | None = None - current_step: str - step_index: int - total_steps: int - step_status: Literal["pending", "in_progress", "completed", "failed"] - - class TranscriptParticipant(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=generate_uuid4) diff --git a/server/reflector/hatchet/__init__.py b/server/reflector/hatchet/__init__.py index 74ff6cc2..d56d559e 100644 --- a/server/reflector/hatchet/__init__.py +++ b/server/reflector/hatchet/__init__.py @@ -1,6 +1,5 @@ """Hatchet workflow orchestration for Reflector.""" from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.progress import emit_progress_async -__all__ = ["HatchetClientManager", "emit_progress_async"] +__all__ = ["HatchetClientManager"] diff --git a/server/reflector/hatchet/broadcast.py b/server/reflector/hatchet/broadcast.py new file mode 100644 index 00000000..684b8a02 --- /dev/null +++ b/server/reflector/hatchet/broadcast.py @@ -0,0 +1,82 @@ +"""WebSocket broadcasting helpers for Hatchet workflows. + +Provides WebSocket broadcasting for Hatchet that matches Celery's @broadcast_to_sockets +decorator behavior. Events are broadcast to transcript rooms and user rooms. +""" + +from reflector.db.transcripts import TranscriptEvent +from reflector.logger import logger +from reflector.ws_manager import get_ws_manager + +# Events that should also be sent to user room (matches Celery behavior) +USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"} + + +async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None: + """Broadcast a TranscriptEvent to WebSocket subscribers. + + Fire-and-forget: errors are logged but don't interrupt workflow execution. + """ + try: + ws_manager = get_ws_manager() + + # Broadcast to transcript room + await ws_manager.send_json( + room_id=f"ts:{transcript_id}", + message=event.model_dump(mode="json"), + ) + + # Also broadcast to user room for certain events + if event.event in USER_ROOM_EVENTS: + # Deferred import to avoid circular dependency + from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 + + transcript = await transcripts_controller.get_by_id(transcript_id) + if transcript and transcript.user_id: + await ws_manager.send_json( + room_id=f"user:{transcript.user_id}", + message={ + "event": f"TRANSCRIPT_{event.event}", + "data": {"id": transcript_id, **event.data}, + }, + ) + except Exception as e: + logger.warning( + "[Hatchet Broadcast] Failed to broadcast event", + error=str(e), + transcript_id=transcript_id, + event=event.event, + ) + + +async def set_status_and_broadcast(transcript_id: str, status: str) -> None: + """Set transcript status and broadcast to WebSocket. + + Wrapper around transcripts_controller.set_status that adds WebSocket broadcasting. + """ + from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 + + event = await transcripts_controller.set_status(transcript_id, status) + if event: + await broadcast_event(transcript_id, event) + + +async def append_event_and_broadcast( + transcript_id: str, + transcript, # Transcript model + event_name: str, + data, # Pydantic model +) -> TranscriptEvent: + """Append event to transcript and broadcast to WebSocket. + + Wrapper around transcripts_controller.append_event that adds WebSocket broadcasting. + """ + from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 + + event = await transcripts_controller.append_event( + transcript=transcript, + event=event_name, + data=data, + ) + await broadcast_event(transcript_id, event) + return event diff --git a/server/reflector/hatchet/progress.py b/server/reflector/hatchet/progress.py deleted file mode 100644 index 411af9e6..00000000 --- a/server/reflector/hatchet/progress.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Progress event emission for Hatchet workers.""" - -from typing import Literal - -from reflector.db.transcripts import PipelineProgressData -from reflector.logger import logger -from reflector.ws_manager import get_ws_manager - -# Step mapping for progress tracking -PIPELINE_STEPS = { - "get_recording": 1, - "get_participants": 2, - "pad_track": 3, # Fork tasks share same step - "mixdown_tracks": 4, - "generate_waveform": 5, - "transcribe_track": 6, # Fork tasks share same step - "merge_transcripts": 7, - "detect_topics": 8, - "generate_title": 9, # Fork tasks share same step - "generate_summary": 9, # Fork tasks share same step - "finalize": 10, - "cleanup_consent": 11, - "post_zulip": 12, - "send_webhook": 13, -} - -TOTAL_STEPS = 13 - - -async def _emit_progress_async( - transcript_id: str, - step: str, - status: Literal["pending", "in_progress", "completed", "failed"], - workflow_id: str | None = None, -) -> None: - """Async implementation of progress emission.""" - ws_manager = get_ws_manager() - step_index = PIPELINE_STEPS.get(step, 0) - - data = PipelineProgressData( - workflow_id=workflow_id, - current_step=step, - step_index=step_index, - total_steps=TOTAL_STEPS, - step_status=status, - ) - - await ws_manager.send_json( - room_id=f"ts:{transcript_id}", - message={ - "event": "PIPELINE_PROGRESS", - "data": data.model_dump(), - }, - ) - - logger.debug( - "[Hatchet Progress] Emitted", - transcript_id=transcript_id, - step=step, - status=status, - step_index=step_index, - ) - - -async def emit_progress_async( - transcript_id: str, - step: str, - status: Literal["pending", "in_progress", "completed", "failed"], - workflow_id: str | None = None, -) -> None: - """Async version of emit_progress for use in async Hatchet tasks.""" - try: - await _emit_progress_async(transcript_id, step, status, workflow_id) - except Exception as e: - logger.warning( - "[Hatchet Progress] Failed to emit progress event", - error=str(e), - transcript_id=transcript_id, - step=step, - ) diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 8e9fda6f..ffc89d2d 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -25,8 +25,11 @@ from hatchet_sdk import Context from pydantic import BaseModel from reflector.dailyco_api.client import DailyApiClient +from reflector.hatchet.broadcast import ( + append_event_and_broadcast, + set_status_and_broadcast, +) from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.progress import emit_progress_async from reflector.hatchet.workflows.models import ( ConsentResult, FinalizeResult, @@ -55,32 +58,29 @@ from reflector.processors.types import ( ) from reflector.settings import settings from reflector.storage.storage_aws import AwsStorage +from reflector.utils.audio_constants import ( + PRESIGNED_URL_EXPIRATION_SECONDS, + WAVEFORM_SEGMENTS, +) from reflector.utils.audio_waveform import get_audio_waveform from reflector.utils.daily import ( filter_cam_audio_tracks, parse_daily_recording_filename, ) +from reflector.utils.string import NonEmptyString from reflector.zulip import post_transcript_notification -# Audio constants -OPUS_STANDARD_SAMPLE_RATE = 48000 -OPUS_DEFAULT_BIT_RATE = 64000 -PRESIGNED_URL_EXPIRATION_SECONDS = 7200 -WAVEFORM_SEGMENTS = 255 - class PipelineInput(BaseModel): """Input to trigger the diarization pipeline.""" - recording_id: str | None - room_name: str | None + recording_id: NonEmptyString tracks: list[dict] # List of {"s3_key": str} - bucket_name: str - transcript_id: str - room_id: str | None = None + bucket_name: NonEmptyString + transcript_id: NonEmptyString + room_id: NonEmptyString | None = None -# Get hatchet client and define workflow hatchet = HatchetClientManager.get_client() diarization_pipeline = hatchet.workflow( @@ -120,9 +120,7 @@ async def set_workflow_error_status(transcript_id: str) -> bool: """ try: async with fresh_db_connection(): - from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 - - await transcripts_controller.set_status(transcript_id, "error") + await set_status_and_broadcast(transcript_id, "error") logger.info( "[Hatchet] Set transcript status to error", transcript_id=transcript_id, @@ -181,9 +179,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab ) if set_error_status: await set_workflow_error_status(input.transcript_id) - await emit_progress_async( - input.transcript_id, step_name, "failed", ctx.workflow_run_id - ) raise return wrapper @@ -203,34 +198,18 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: ctx.log(f"get_recording: recording_id={input.recording_id}") logger.info("[Hatchet] get_recording", recording_id=input.recording_id) - await emit_progress_async( - input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id - ) - - # Set transcript status to "processing" at workflow start + # Set transcript status to "processing" at workflow start (broadcasts to WebSocket) async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript: - await transcripts_controller.set_status(input.transcript_id, "processing") + await set_status_and_broadcast(input.transcript_id, "processing") logger.info( "[Hatchet] Set transcript status to processing", transcript_id=input.transcript_id, ) - if not input.recording_id: - # No recording_id in reprocess path - return minimal data - await emit_progress_async( - input.transcript_id, "get_recording", "completed", ctx.workflow_run_id - ) - return RecordingResult( - id=None, - mtg_session_id=None, - room_name=input.room_name, - duration=0, - ) - if not settings.DAILY_API_KEY: raise ValueError("DAILY_API_KEY not configured") @@ -247,14 +226,9 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: duration=recording.duration, ) - await emit_progress_async( - input.transcript_id, "get_recording", "completed", ctx.workflow_run_id - ) - return RecordingResult( id=recording.id, mtg_session_id=recording.mtgSessionId, - room_name=recording.room_name, duration=recording.duration, ) @@ -268,14 +242,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe ctx.log(f"get_participants: transcript_id={input.transcript_id}") logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id - ) - recording_data = _to_dict(ctx.task_output(get_recording)) mtg_session_id = recording_data.get("mtg_session_id") - # Get transcript and reset events/topics/participants async with fresh_db_connection(): from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptParticipant, @@ -284,7 +253,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript: - # Reset events/topics/participants # Note: title NOT cleared - preserves existing titles await transcripts_controller.update( transcript, @@ -296,12 +264,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe ) if not mtg_session_id or not settings.DAILY_API_KEY: - await emit_progress_async( - input.transcript_id, - "get_participants", - "completed", - ctx.workflow_run_id, - ) return ParticipantsResult( participants=[], num_tracks=len(input.tracks), @@ -309,7 +271,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe target_language=transcript.target_language if transcript else "en", ) - # Fetch participants from Daily API async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: participants = await client.get_meeting_participants(mtg_session_id) @@ -321,11 +282,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe if p.user_id: id_to_user_id[p.participant_id] = p.user_id - # Get track keys and filter for cam-audio tracks track_keys = [t["s3_key"] for t in input.tracks] cam_audio_keys = filter_cam_audio_tracks(track_keys) - # Update participants in database participants_list = [] for idx, key in enumerate(cam_audio_keys): try: @@ -361,10 +320,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe participant_count=len(participants_list), ) - await emit_progress_async( - input.transcript_id, "get_participants", "completed", ctx.workflow_run_id - ) - return ParticipantsResult( participants=participants_list, num_tracks=len(input.tracks), @@ -389,7 +344,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes participants_data = _to_dict(ctx.task_output(get_participants)) source_language = participants_data.get("source_language", "en") - # Spawn child workflows for each track with correct language child_coroutines = [ track_workflow.aio_run( TrackInput( @@ -403,7 +357,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes for i, track in enumerate(input.tracks) ] - # Wait for all child workflows to complete results = await asyncio.gather(*child_coroutines) # Get target_language for later use in detect_topics @@ -428,13 +381,11 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes PaddedTrackInfo(key=padded_key, bucket_name=bucket_name) ) - # Track padded files for cleanup track_index = pad_result.get("track_index") if pad_result.get("size", 0) > 0 and track_index is not None: storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" created_padded_files.add(storage_path) - # Merge all words and sort by start time all_words = [word for words in track_words for word in words] all_words.sort(key=lambda w: w.get("start", 0)) @@ -466,10 +417,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: ctx.log("mixdown_tracks: mixing padded tracks into single audio file") logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id - ) - track_data = _to_dict(ctx.task_output(process_tracks)) padded_tracks_data = track_data.get("padded_tracks", []) @@ -503,7 +450,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: if not valid_urls: raise ValueError("No valid padded tracks to mixdown") - # Determine target sample rate from first track target_sample_rate = None for url in valid_urls: container = None @@ -551,12 +497,10 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: fmt.link_to(sink) graph.configure() - # Create temp output file output_path = tempfile.mktemp(suffix=".mp3") containers = [] try: - # Open all containers for url in valid_urls: try: c = av.open( @@ -644,7 +588,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: except Exception: pass - # Upload mixed file to storage file_size = Path(output_path).stat().st_size storage_path = f"{input.transcript_id}/audio.mp3" @@ -653,7 +596,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: Path(output_path).unlink(missing_ok=True) - # Update transcript with audio_location async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 @@ -670,10 +612,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: size=file_size, ) - await emit_progress_async( - input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id - ) - return MixdownResult( audio_key=storage_path, duration=duration_ms[0], @@ -689,10 +627,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery).""" logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id - ) - from reflector.db.transcripts import ( # noqa: PLC0415 TranscriptWaveform, transcripts_controller, @@ -740,18 +674,16 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul with open(temp_path, "wb") as f: f.write(response.content) - # Generate waveform waveform = get_audio_waveform( path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS ) - # Save waveform to database via event async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript: waveform_data = TranscriptWaveform(waveform=waveform) - await transcripts_controller.append_event( - transcript=transcript, event="WAVEFORM", data=waveform_data + await append_event_and_broadcast( + input.transcript_id, transcript, "WAVEFORM", waveform_data ) finally: @@ -759,10 +691,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul logger.info("[Hatchet] generate_waveform complete") - await emit_progress_async( - input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id - ) - return WaveformResult(waveform_generated=True) @@ -775,10 +703,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: ctx.log("detect_topics: analyzing transcript for topics") logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id - ) - track_data = _to_dict(ctx.task_output(process_tracks)) words = track_data.get("all_words", []) target_language = track_data.get("target_language", "en") @@ -791,7 +715,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: TitleSummaryWithId as TitleSummaryWithIdProcessorType, ) - # Convert word dicts to Word objects word_objects = [Word(**w) for w in words] transcript_type = TranscriptType(words=word_objects) @@ -800,7 +723,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) - # Callback that upserts topics to DB async def on_topic_callback(data): topic = TranscriptTopic( title=data.title, @@ -812,8 +734,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: if isinstance(data, TitleSummaryWithIdProcessorType): topic.id = data.id await transcripts_controller.upsert_topic(transcript, topic) - await transcripts_controller.append_event( - transcript=transcript, event="TOPIC", data=topic + await append_event_and_broadcast( + input.transcript_id, transcript, "TOPIC", topic ) topics = await topic_processing.detect_topics( @@ -828,10 +750,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: ctx.log(f"detect_topics complete: found {len(topics_list)} topics") logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list)) - await emit_progress_async( - input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id - ) - return TopicsResult(topics=topics_list) @@ -844,10 +762,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: ctx.log("generate_title: generating title from topics") logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id - ) - topics_data = _to_dict(ctx.task_output(detect_topics)) topics = topics_data.get("topics", []) @@ -864,7 +778,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) - # Callback that updates title in DB async def on_title_callback(data): nonlocal title_result title_result = data.title @@ -874,8 +787,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: transcript, {"title": final_title.title}, ) - await transcripts_controller.append_event( - transcript=transcript, event="FINAL_TITLE", data=final_title + await append_event_and_broadcast( + input.transcript_id, transcript, "FINAL_TITLE", final_title ) await topic_processing.generate_title( @@ -888,10 +801,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: ctx.log(f"generate_title complete: '{title_result}'") logger.info("[Hatchet] generate_title complete", title=title_result) - await emit_progress_async( - input.transcript_id, "generate_title", "completed", ctx.workflow_run_id - ) - return TitleResult(title=title_result) @@ -904,10 +813,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: ctx.log("generate_summary: generating long and short summaries") logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id - ) - topics_data = _to_dict(ctx.task_output(detect_topics)) topics = topics_data.get("topics", []) @@ -926,7 +831,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) - # Callback that updates long_summary in DB async def on_long_summary_callback(data): nonlocal summary_result summary_result = data.long_summary @@ -937,13 +841,13 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: transcript, {"long_summary": final_long_summary.long_summary}, ) - await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_LONG_SUMMARY", - data=final_long_summary, + await append_event_and_broadcast( + input.transcript_id, + transcript, + "FINAL_LONG_SUMMARY", + final_long_summary, ) - # Callback that updates short_summary in DB async def on_short_summary_callback(data): nonlocal short_summary_result short_summary_result = data.short_summary @@ -954,10 +858,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: transcript, {"short_summary": final_short_summary.short_summary}, ) - await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_SHORT_SUMMARY", - data=final_short_summary, + await append_event_and_broadcast( + input.transcript_id, + transcript, + "FINAL_SHORT_SUMMARY", + final_short_summary, ) await topic_processing.generate_summaries( @@ -972,10 +877,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: ctx.log("generate_summary complete") logger.info("[Hatchet] generate_summary complete") - await emit_progress_async( - input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id - ) - return SummaryResult(summary=summary_result, short_summary=short_summary_result) @@ -994,10 +895,6 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: ctx.log("finalize: saving transcript and setting status to 'ended'") logger.info("[Hatchet] finalize", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id - ) - mixdown_data = _to_dict(ctx.task_output(mixdown_tracks)) track_data = _to_dict(ctx.task_output(process_tracks)) @@ -1006,6 +903,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: async with fresh_db_connection(): from reflector.db.transcripts import ( # noqa: PLC0415 + TranscriptDuration, TranscriptText, transcripts_controller, ) @@ -1018,17 +916,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: if transcript is None: raise ValueError(f"Transcript {input.transcript_id} not found in database") - # Convert words back to Word objects for storage word_objects = [Word(**w) for w in all_words] - - # Create merged transcript for TRANSCRIPT event merged_transcript = TranscriptType(words=word_objects, translation=None) - # Emit TRANSCRIPT event - await transcripts_controller.append_event( - transcript=transcript, - event="TRANSCRIPT", - data=TranscriptText( + await append_event_and_broadcast( + input.transcript_id, + transcript, + "TRANSCRIPT", + TranscriptText( text=merged_transcript.text, translation=merged_transcript.translation, ), @@ -1044,18 +939,18 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: }, ) - # Set status to "ended" - await transcripts_controller.set_status(input.transcript_id, "ended") + duration_data = TranscriptDuration(duration=duration) + await append_event_and_broadcast( + input.transcript_id, transcript, "DURATION", duration_data + ) + + await set_status_and_broadcast(input.transcript_id, "ended") ctx.log( f"finalize complete: transcript {input.transcript_id} status set to 'ended'" ) logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "finalize", "completed", ctx.workflow_run_id - ) - return FinalizeResult(status="COMPLETED") @@ -1067,10 +962,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: """Check and handle consent requirements.""" logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id - ) - async with fresh_db_connection(): from reflector.db.meetings import meetings_controller # noqa: PLC0415 from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 @@ -1087,10 +978,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult: "[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id ) - await emit_progress_async( - input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id - ) - return ConsentResult(consent_checked=True) @@ -1102,15 +989,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: """Post notification to Zulip.""" logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id - ) - if not settings.ZULIP_REALM: logger.info("[Hatchet] post_zulip skipped (Zulip not configured)") - await emit_progress_async( - input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id - ) return ZulipResult(zulip_message_id=None, skipped=True) async with fresh_db_connection(): @@ -1123,10 +1003,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult: else: message_id = None - await emit_progress_async( - input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id - ) - return ZulipResult(zulip_message_id=message_id) @@ -1138,15 +1014,8 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: """Send completion webhook to external service.""" logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id) - await emit_progress_async( - input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id - ) - if not input.room_id: logger.info("[Hatchet] send_webhook skipped (no room_id)") - await emit_progress_async( - input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id - ) return WebhookResult(webhook_sent=False, skipped=True) async with fresh_db_connection(): @@ -1174,17 +1043,6 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult: "[Hatchet] send_webhook complete", status_code=response.status_code ) - await emit_progress_async( - input.transcript_id, - "send_webhook", - "completed", - ctx.workflow_run_id, - ) - return WebhookResult(webhook_sent=True, response_code=response.status_code) - await emit_progress_async( - input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id - ) - return WebhookResult(webhook_sent=False, skipped=True) diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index 7373e205..ad28edb4 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -40,7 +40,6 @@ class RecordingResult(BaseModel): id: str | None mtg_session_id: str | None - room_name: str | None duration: float diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index 434ce9a1..b709578c 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -26,9 +26,13 @@ from hatchet_sdk import Context from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager -from reflector.hatchet.progress import emit_progress_async from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult from reflector.logger import logger +from reflector.utils.audio_constants import ( + OPUS_DEFAULT_BIT_RATE, + OPUS_STANDARD_SAMPLE_RATE, + PRESIGNED_URL_EXPIRATION_SECONDS, +) def _to_dict(output) -> dict: @@ -38,12 +42,6 @@ def _to_dict(output) -> dict: return output.model_dump() -# Audio constants matching existing pipeline -OPUS_STANDARD_SAMPLE_RATE = 48000 -OPUS_DEFAULT_BIT_RATE = 64000 -PRESIGNED_URL_EXPIRATION_SECONDS = 7200 - - class TrackInput(BaseModel): """Input for individual track processing.""" @@ -193,10 +191,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: transcript_id=input.transcript_id, ) - await emit_progress_async( - input.transcript_id, "pad_track", "in_progress", ctx.workflow_run_id - ) - try: # Create fresh storage instance to avoid aioboto3 fork issues from reflector.settings import settings # noqa: PLC0415 @@ -229,9 +223,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: f"Track {input.track_index} requires no padding", track_index=input.track_index, ) - await emit_progress_async( - input.transcript_id, "pad_track", "completed", ctx.workflow_run_id - ) return PadTrackResult( padded_key=input.s3_key, bucket_name=input.bucket_name, @@ -275,10 +266,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: padded_key=storage_path, ) - await emit_progress_async( - input.transcript_id, "pad_track", "completed", ctx.workflow_run_id - ) - # Return S3 key (not presigned URL) - consumer tasks presign on demand # This avoids stale URLs when workflow is replayed return PadTrackResult( @@ -290,9 +277,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: except Exception as e: logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True) - await emit_progress_async( - input.transcript_id, "pad_track", "failed", ctx.workflow_run_id - ) raise @@ -308,10 +292,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe language=input.language, ) - await emit_progress_async( - input.transcript_id, "transcribe_track", "in_progress", ctx.workflow_run_id - ) - try: pad_result = _to_dict(ctx.task_output(pad_track)) padded_key = pad_result.get("padded_key") @@ -360,10 +340,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe word_count=len(words), ) - await emit_progress_async( - input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id - ) - return TranscribeTrackResult( words=words, track_index=input.track_index, @@ -371,7 +347,4 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe except Exception as e: logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True) - await emit_progress_async( - input.transcript_id, "transcribe_track", "failed", ctx.workflow_run_id - ) raise diff --git a/server/reflector/pipelines/main_multitrack_pipeline.py b/server/reflector/pipelines/main_multitrack_pipeline.py index 579bfbd3..26f42c4f 100644 --- a/server/reflector/pipelines/main_multitrack_pipeline.py +++ b/server/reflector/pipelines/main_multitrack_pipeline.py @@ -32,6 +32,11 @@ from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import TitleSummary from reflector.processors.types import Transcript as TranscriptType from reflector.storage import Storage, get_transcripts_storage +from reflector.utils.audio_constants import ( + OPUS_DEFAULT_BIT_RATE, + OPUS_STANDARD_SAMPLE_RATE, + PRESIGNED_URL_EXPIRATION_SECONDS, +) from reflector.utils.daily import ( filter_cam_audio_tracks, parse_daily_recording_filename, @@ -39,13 +44,6 @@ from reflector.utils.daily import ( from reflector.utils.string import NonEmptyString from reflector.video_platforms.factory import create_platform_client -# Audio encoding constants -OPUS_STANDARD_SAMPLE_RATE = 48000 -OPUS_DEFAULT_BIT_RATE = 128000 - -# Storage operation constants -PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours - class PipelineMainMultitrack(PipelineMainBase): def __init__(self, transcript_id: str): diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py index f1fdec85..4c97faf0 100644 --- a/server/reflector/services/transcript_process.py +++ b/server/reflector/services/transcript_process.py @@ -251,7 +251,6 @@ async def dispatch_transcript_processing( workflow_name="DiarizationPipeline", input_data={ "recording_id": config.recording_id, - "room_name": None, "tracks": [{"s3_key": k} for k in config.track_keys], "bucket_name": config.bucket_name, "transcript_id": config.transcript_id, diff --git a/server/reflector/utils/audio_constants.py b/server/reflector/utils/audio_constants.py new file mode 100644 index 00000000..0ad964a9 --- /dev/null +++ b/server/reflector/utils/audio_constants.py @@ -0,0 +1,15 @@ +""" +Shared audio processing constants. + +Used by both Hatchet workflows and Celery pipelines for consistent audio encoding. +""" + +# Opus codec settings +OPUS_STANDARD_SAMPLE_RATE = 48000 +OPUS_DEFAULT_BIT_RATE = 128000 # 128kbps for good speech quality + +# S3 presigned URL expiration +PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours + +# Waveform visualization +WAVEFORM_SEGMENTS = 255 diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index 9a1ad9f6..801d5bd5 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -303,7 +303,6 @@ async def _process_multitrack_recording_inner( workflow_name="DiarizationPipeline", input_data={ "recording_id": recording_id, - "room_name": daily_room_name, "tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)], "bucket_name": bucket_name, "transcript_id": transcript.id, diff --git a/server/tests/test_hatchet_progress.py b/server/tests/test_hatchet_progress.py deleted file mode 100644 index 059f68e0..00000000 --- a/server/tests/test_hatchet_progress.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Tests for Hatchet progress emission. - -Only tests that catch real bugs - error handling and step completeness. -""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - - -@pytest.mark.asyncio -async def test_emit_progress_async_handles_exception(): - """Test that emit_progress_async catches exceptions gracefully. - - Critical: Progress emission must NEVER crash the pipeline. - WebSocket errors should be silently caught. - """ - from reflector.hatchet.progress import emit_progress_async - - with patch("reflector.hatchet.progress.get_ws_manager") as mock_get_ws: - mock_ws = MagicMock() - mock_ws.send_json = AsyncMock(side_effect=Exception("WebSocket error")) - mock_get_ws.return_value = mock_ws - - # Should not raise - exceptions are caught - await emit_progress_async( - transcript_id="test-transcript-123", - step="finalize", - status="completed", - ) - - -@pytest.mark.asyncio -async def test_pipeline_steps_mapping_complete(): - """Test the PIPELINE_STEPS mapping includes all expected steps. - - Useful: Catches when someone adds a new pipeline step but forgets - to add it to the progress mapping, resulting in missing UI updates. - """ - from reflector.hatchet.progress import PIPELINE_STEPS, TOTAL_STEPS - - expected_steps = [ - "get_recording", - "get_participants", - "pad_track", - "mixdown_tracks", - "generate_waveform", - "transcribe_track", - "merge_transcripts", - "detect_topics", - "generate_title", - "generate_summary", - "finalize", - "cleanup_consent", - "post_zulip", - "send_webhook", - ] - - for step in expected_steps: - assert step in PIPELINE_STEPS, f"Missing step in PIPELINE_STEPS: {step}" - assert 1 <= PIPELINE_STEPS[step] <= TOTAL_STEPS