diff --git a/server/migrations/versions/0f943fede0e0_add_workflow_run_id_to_transcript.py b/server/migrations/versions/0f943fede0e0_add_workflow_run_id_to_transcript.py new file mode 100644 index 00000000..cd1857c1 --- /dev/null +++ b/server/migrations/versions/0f943fede0e0_add_workflow_run_id_to_transcript.py @@ -0,0 +1,28 @@ +"""add workflow_run_id to transcript + +Revision ID: 0f943fede0e0 +Revises: a326252ac554 +Create Date: 2025-12-16 01:54:13.855106 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0f943fede0e0" +down_revision: Union[str, None] = "a326252ac554" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("transcript", schema=None) as batch_op: + batch_op.add_column(sa.Column("workflow_run_id", sa.String(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("transcript", schema=None) as batch_op: + batch_op.drop_column("workflow_run_id") diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index fd1a7a5e..0f73fb1f 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -83,6 +83,8 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Column("webvtt", sqlalchemy.Text), + # Hatchet workflow run ID for resumption of failed workflows + sqlalchemy.Column("workflow_run_id", sqlalchemy.String), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_created_at", "created_at"), @@ -227,6 +229,7 @@ class Transcript(BaseModel): zulip_message_id: int | None = None audio_deleted: bool | None = None webvtt: str | None = None + workflow_run_id: str | None = None # Hatchet workflow run ID for resumption @field_serializer("created_at", when_used="json") def serialize_datetime(self, dt: datetime) -> str: diff --git a/server/reflector/hatchet/client.py b/server/reflector/hatchet/client.py index bc3a63f0..2d48bb12 100644 --- a/server/reflector/hatchet/client.py +++ b/server/reflector/hatchet/client.py @@ -2,6 +2,7 @@ from hatchet_sdk import Hatchet +from reflector.logger import logger from reflector.settings import settings @@ -35,9 +36,44 @@ class HatchetClientManager: # SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id return result.run.metadata.id + @classmethod + async def get_workflow_run_status(cls, workflow_run_id: str) -> str: + """Get workflow run status.""" + client = cls.get_client() + status = await client.runs.aio_get_status(workflow_run_id) + return str(status) + + @classmethod + async def cancel_workflow(cls, workflow_run_id: str) -> None: + """Cancel a workflow.""" + client = cls.get_client() + await client.runs.aio_cancel(workflow_run_id) + logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id) + + @classmethod + async def replay_workflow(cls, workflow_run_id: str) -> None: + """Replay a failed workflow.""" + client = cls.get_client() + await client.runs.aio_replay(workflow_run_id) + logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id) + + @classmethod + async def can_replay(cls, workflow_run_id: str) -> bool: + """Check if workflow can be replayed (is FAILED).""" + try: + status = await cls.get_workflow_run_status(workflow_run_id) + return "FAILED" in status + except Exception as e: + logger.warning( + "[Hatchet] Failed to check replay status", + workflow_run_id=workflow_run_id, + error=str(e), + ) + return False + @classmethod async def get_workflow_status(cls, workflow_run_id: str) -> dict: - """Get the current status of a workflow run.""" + """Get the full workflow run details as dict.""" client = cls.get_client() run = await client.runs.aio_get(workflow_run_id) return run.to_dict() diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 4bbae444..94c31242 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -71,6 +71,28 @@ async def _close_db_connection(db): _database_context.set(None) +async def _set_error_status(transcript_id: str): + """Set transcript status to 'error' on workflow failure (matches Celery line 790).""" + try: + db = await _get_fresh_db_connection() + try: + from reflector.db.transcripts import transcripts_controller + + await transcripts_controller.set_status(transcript_id, "error") + logger.info( + "[Hatchet] Set transcript status to error", + transcript_id=transcript_id, + ) + finally: + await _close_db_connection(db) + except Exception as e: + logger.error( + "[Hatchet] Failed to set error status", + transcript_id=transcript_id, + error=str(e), + ) + + def _get_storage(): """Create fresh storage instance.""" from reflector.settings import settings @@ -98,6 +120,21 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id ) + # Set transcript status to "processing" at workflow start (matches Celery behavior) + db = await _get_fresh_db_connection() + try: + from reflector.db.transcripts import transcripts_controller + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + await transcripts_controller.set_status(input.transcript_id, "processing") + logger.info( + "[Hatchet] Set transcript status to processing", + transcript_id=input.transcript_id, + ) + finally: + await _close_db_connection(db) + try: from reflector.dailyco_api.client import DailyApiClient from reflector.settings import settings @@ -140,6 +177,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict: except Exception as e: logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "get_recording", "failed", ctx.workflow_run_id ) @@ -150,7 +188,10 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict: parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3 ) async def get_participants(input: PipelineInput, ctx: Context) -> dict: - """Fetch participant list from Daily.co API.""" + """Fetch participant list from Daily.co API and update transcript in database. + + Matches Celery's update_participants_from_daily() behavior. + """ logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) await emit_progress_async( @@ -163,38 +204,118 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict: from reflector.dailyco_api.client import DailyApiClient from reflector.settings import settings - - if not mtg_session_id or not settings.DAILY_API_KEY: - # Return empty participants if no session ID - await emit_progress_async( - input.transcript_id, - "get_participants", - "completed", - ctx.workflow_run_id, - ) - return {"participants": [], "num_tracks": len(input.tracks)} - - async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: - participants = await client.get_meeting_participants(mtg_session_id) - - participants_list = [ - {"participant_id": p.participant_id, "user_name": p.user_name} - for p in participants.data - ] - - logger.info( - "[Hatchet] get_participants complete", - participant_count=len(participants_list), + from reflector.utils.daily import ( + filter_cam_audio_tracks, + parse_daily_recording_filename, ) + # Get transcript and reset events/topics/participants (matches Celery line 599-607) + db = await _get_fresh_db_connection() + try: + from reflector.db.transcripts import ( + TranscriptParticipant, + transcripts_controller, + ) + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + # Reset events/topics/participants (matches Celery line 599-607) + # Note: title NOT cleared - Celery preserves existing titles + await transcripts_controller.update( + transcript, + { + "events": [], + "topics": [], + "participants": [], + }, + ) + + if not mtg_session_id or not settings.DAILY_API_KEY: + await emit_progress_async( + input.transcript_id, + "get_participants", + "completed", + ctx.workflow_run_id, + ) + return { + "participants": [], + "num_tracks": len(input.tracks), + "source_language": transcript.source_language + if transcript + else "en", + "target_language": transcript.target_language + if transcript + else "en", + } + + # Fetch participants from Daily API + async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: + participants = await client.get_meeting_participants(mtg_session_id) + + id_to_name = {} + id_to_user_id = {} + for p in participants.data: + if p.user_name: + id_to_name[p.participant_id] = p.user_name + if p.user_id: + id_to_user_id[p.participant_id] = p.user_id + + # Get track keys and filter for cam-audio tracks + track_keys = [t["s3_key"] for t in input.tracks] + cam_audio_keys = filter_cam_audio_tracks(track_keys) + + # Update participants in database (matches Celery lines 568-590) + participants_list = [] + for idx, key in enumerate(cam_audio_keys): + try: + parsed = parse_daily_recording_filename(key) + participant_id = parsed.participant_id + except ValueError as e: + logger.error( + "Failed to parse Daily recording filename", + error=str(e), + key=key, + ) + continue + + default_name = f"Speaker {idx}" + name = id_to_name.get(participant_id, default_name) + user_id = id_to_user_id.get(participant_id) + + participant = TranscriptParticipant( + id=participant_id, speaker=idx, name=name, user_id=user_id + ) + await transcripts_controller.upsert_participant(transcript, participant) + participants_list.append( + { + "participant_id": participant_id, + "user_name": name, + "speaker": idx, + } + ) + + logger.info( + "[Hatchet] get_participants complete", + participant_count=len(participants_list), + ) + + finally: + await _close_db_connection(db) + await emit_progress_async( input.transcript_id, "get_participants", "completed", ctx.workflow_run_id ) - return {"participants": participants_list, "num_tracks": len(input.tracks)} + return { + "participants": participants_list, + "num_tracks": len(input.tracks), + "source_language": transcript.source_language if transcript else "en", + "target_language": transcript.target_language if transcript else "en", + } except Exception as e: logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "get_participants", "failed", ctx.workflow_run_id ) @@ -215,55 +336,87 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict: transcript_id=input.transcript_id, ) - # Spawn child workflows for each track - child_coroutines = [ - track_workflow.aio_run( - TrackInput( - track_index=i, - s3_key=track["s3_key"], - bucket_name=input.bucket_name, - transcript_id=input.transcript_id, + try: + # Get source_language from get_participants (matches Celery: uses transcript.source_language) + participants_data = ctx.task_output(get_participants) + source_language = participants_data.get("source_language", "en") + + # Spawn child workflows for each track with correct language + child_coroutines = [ + track_workflow.aio_run( + TrackInput( + track_index=i, + s3_key=track["s3_key"], + bucket_name=input.bucket_name, + transcript_id=input.transcript_id, + language=source_language, + ) ) + for i, track in enumerate(input.tracks) + ] + + # Wait for all child workflows to complete + results = await asyncio.gather(*child_coroutines) + + # Get target_language for later use in detect_topics + target_language = participants_data.get("target_language", "en") + + # Collect all track results + all_words = [] + padded_urls = [] + created_padded_files = set() + + for result in results: + transcribe_result = result.get("transcribe_track", {}) + all_words.extend(transcribe_result.get("words", [])) + + pad_result = result.get("pad_track", {}) + padded_urls.append(pad_result.get("padded_url")) + + # Track padded files for cleanup (matches Celery line 636-637) + track_index = pad_result.get("track_index") + if pad_result.get("size", 0) > 0 and track_index is not None: + # File was created (size > 0 means padding was applied) + storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm" + created_padded_files.add(storage_path) + + # Sort words by start time + all_words.sort(key=lambda w: w.get("start", 0)) + + # NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes) + # Mixdown needs the padded files, so we can't delete them here + + logger.info( + "[Hatchet] process_tracks complete", + num_tracks=len(input.tracks), + total_words=len(all_words), ) - for i, track in enumerate(input.tracks) - ] - # Wait for all child workflows to complete - results = await asyncio.gather(*child_coroutines) + return { + "all_words": all_words, + "padded_urls": padded_urls, + "word_count": len(all_words), + "num_tracks": len(input.tracks), + "target_language": target_language, + "created_padded_files": list( + created_padded_files + ), # For cleanup after mixdown + } - # Collect all track results - all_words = [] - padded_urls = [] - - for result in results: - transcribe_result = result.get("transcribe_track", {}) - all_words.extend(transcribe_result.get("words", [])) - - pad_result = result.get("pad_track", {}) - padded_urls.append(pad_result.get("padded_url")) - - # Sort words by start time - all_words.sort(key=lambda w: w.get("start", 0)) - - logger.info( - "[Hatchet] process_tracks complete", - num_tracks=len(input.tracks), - total_words=len(all_words), - ) - - return { - "all_words": all_words, - "padded_urls": padded_urls, - "word_count": len(all_words), - "num_tracks": len(input.tracks), - } + except Exception as e: + logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) + await emit_progress_async( + input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id + ) + raise @diarization_pipeline.task( parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3 ) async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: - """Mix all padded tracks into single audio file.""" + """Mix all padded tracks into single audio file using PyAV (same as Celery).""" logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) await emit_progress_async( @@ -279,80 +432,182 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: storage = _get_storage() - # Download all tracks and mix - temp_inputs = [] - try: - for i, url in enumerate(padded_urls): - if not url: - continue - temp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False) - temp_inputs.append(temp_input.name) + # Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph + from fractions import Fraction - # Download track - import httpx + from av.audio.resampler import AudioResampler - async with httpx.AsyncClient() as client: - response = await client.get(url) - response.raise_for_status() - with open(temp_input.name, "wb") as f: - f.write(response.content) + from reflector.processors import AudioFileWriterProcessor - # Mix using PyAV amix filter - if len(temp_inputs) == 0: - raise ValueError("No valid tracks to mixdown") - - output_path = tempfile.mktemp(suffix=".mp3") + valid_urls = [url for url in padded_urls if url] + if not valid_urls: + raise ValueError("No valid padded tracks to mixdown") + # Determine target sample rate from first track + target_sample_rate = None + for url in valid_urls: try: - # Use ffmpeg-style mixing via PyAV - containers = [av.open(path) for path in temp_inputs] + container = av.open(url) + for frame in container.decode(audio=0): + target_sample_rate = frame.sample_rate + break + container.close() + if target_sample_rate: + break + except Exception: + continue - # Get the longest duration - max_duration = 0.0 - for container in containers: - if container.duration: - duration = float(container.duration * av.time_base) - max_duration = max(max_duration, duration) + if not target_sample_rate: + raise ValueError("No decodable audio frames in any track") - # Close containers for now - for container in containers: - container.close() + # Build PyAV filter graph: N abuffer -> amix -> aformat -> sink + graph = av.filter.Graph() + inputs = [] - # Use subprocess for mixing (simpler than complex PyAV graph) - import subprocess + for idx, url in enumerate(valid_urls): + args = ( + f"time_base=1/{target_sample_rate}:" + f"sample_rate={target_sample_rate}:" + f"sample_fmt=s32:" + f"channel_layout=stereo" + ) + in_ctx = graph.add("abuffer", args=args, name=f"in{idx}") + inputs.append(in_ctx) - # Build ffmpeg command - cmd = ["ffmpeg", "-y"] - for path in temp_inputs: - cmd.extend(["-i", path]) + mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix") + fmt = graph.add( + "aformat", + args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}", + name="fmt", + ) + sink = graph.add("abuffersink", name="out") - # Build filter for N inputs - n = len(temp_inputs) - filter_str = f"amix=inputs={n}:duration=longest:normalize=0" - cmd.extend(["-filter_complex", filter_str]) - cmd.extend(["-ac", "2", "-ar", "48000", "-b:a", "128k", output_path]) + for idx, in_ctx in enumerate(inputs): + in_ctx.link_to(mixer, 0, idx) + mixer.link_to(fmt) + fmt.link_to(sink) + graph.configure() - subprocess.run(cmd, check=True, capture_output=True) + # Create temp output file + output_path = tempfile.mktemp(suffix=".mp3") + containers = [] - # Upload mixed file - file_size = Path(output_path).stat().st_size - storage_path = f"file_pipeline_hatchet/{input.transcript_id}/mixed.mp3" + try: + # Open all containers + for url in valid_urls: + try: + c = av.open( + url, + options={ + "reconnect": "1", + "reconnect_streamed": "1", + "reconnect_delay_max": "5", + }, + ) + containers.append(c) + except Exception as e: + logger.warning( + "[Hatchet] mixdown: failed to open container", + url=url, + error=str(e), + ) - with open(output_path, "rb") as mixed_file: - await storage.put_file(storage_path, mixed_file) + if not containers: + raise ValueError("Could not open any track containers") - logger.info( - "[Hatchet] mixdown_tracks uploaded", - key=storage_path, - size=file_size, - ) + # Create AudioFileWriterProcessor for MP3 output with duration capture + duration_ms = [0.0] # Mutable container for callback capture - finally: - Path(output_path).unlink(missing_ok=True) + async def capture_duration(d): + duration_ms[0] = d + + writer = AudioFileWriterProcessor( + path=output_path, on_duration=capture_duration + ) + + decoders = [c.decode(audio=0) for c in containers] + active = [True] * len(decoders) + resamplers = [ + AudioResampler(format="s32", layout="stereo", rate=target_sample_rate) + for _ in decoders + ] + + while any(active): + for i, (dec, is_active) in enumerate(zip(decoders, active)): + if not is_active: + continue + try: + frame = next(dec) + except StopIteration: + active[i] = False + inputs[i].push(None) + continue + + if frame.sample_rate != target_sample_rate: + continue + out_frames = resamplers[i].resample(frame) or [] + for rf in out_frames: + rf.sample_rate = target_sample_rate + rf.time_base = Fraction(1, target_sample_rate) + inputs[i].push(rf) + + while True: + try: + mixed = sink.pull() + except Exception: + break + mixed.sample_rate = target_sample_rate + mixed.time_base = Fraction(1, target_sample_rate) + await writer.push(mixed) + + # Flush remaining frames + while True: + try: + mixed = sink.pull() + except Exception: + break + mixed.sample_rate = target_sample_rate + mixed.time_base = Fraction(1, target_sample_rate) + await writer.push(mixed) + + await writer.flush() + + # Duration is captured via callback in milliseconds (from AudioFileWriterProcessor) finally: - for path in temp_inputs: - Path(path).unlink(missing_ok=True) + for c in containers: + try: + c.close() + except Exception: + pass + + # Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3) + file_size = Path(output_path).stat().st_size + storage_path = f"{input.transcript_id}/audio.mp3" + + with open(output_path, "rb") as mixed_file: + await storage.put_file(storage_path, mixed_file) + + Path(output_path).unlink(missing_ok=True) + + # Update transcript with audio_location (matches Celery line 661) + db = await _get_fresh_db_connection() + try: + from reflector.db.transcripts import transcripts_controller + + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + await transcripts_controller.update( + transcript, {"audio_location": "storage"} + ) + finally: + await _close_db_connection(db) + + logger.info( + "[Hatchet] mixdown_tracks uploaded", + key=storage_path, + size=file_size, + ) await emit_progress_async( input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id @@ -360,12 +615,15 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: return { "audio_key": storage_path, - "duration": max_duration, - "tracks_mixed": len(temp_inputs), + "duration": duration_ms[ + 0 + ], # Duration in milliseconds from AudioFileWriterProcessor + "tracks_mixed": len(valid_urls), } except Exception as e: logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id ) @@ -376,7 +634,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3 ) async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: - """Generate audio waveform visualization.""" + """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery).""" logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id) await emit_progress_async( @@ -384,6 +642,35 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: ) try: + import httpx + + from reflector.db.transcripts import TranscriptWaveform, transcripts_controller + from reflector.utils.audio_waveform import get_audio_waveform + + # Cleanup temporary padded S3 files (matches Celery lines 710-725) + # Moved here from process_tracks because mixdown_tracks needs the padded files + track_data = ctx.task_output(process_tracks) + created_padded_files = track_data.get("created_padded_files", []) + if created_padded_files: + logger.info( + f"[Hatchet] Cleaning up {len(created_padded_files)} temporary S3 files" + ) + storage = _get_storage() + cleanup_tasks = [] + for storage_path in created_padded_files: + cleanup_tasks.append(storage.delete_file(storage_path)) + + cleanup_results = await asyncio.gather( + *cleanup_tasks, return_exceptions=True + ) + for storage_path, result in zip(created_padded_files, cleanup_results): + if isinstance(result, Exception): + logger.warning( + "[Hatchet] Failed to cleanup temporary padded track", + storage_path=storage_path, + error=str(result), + ) + mixdown_data = ctx.task_output(mixdown_tracks) audio_key = mixdown_data.get("audio_key") @@ -394,18 +681,34 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, ) - from reflector.pipelines.waveform_helpers import generate_waveform_data + # Download MP3 to temp file (AudioWaveformProcessor needs local file) + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: + temp_path = temp_file.name - waveform = await generate_waveform_data(audio_url) + try: + async with httpx.AsyncClient() as client: + response = await client.get(audio_url, timeout=120) + response.raise_for_status() + with open(temp_path, "wb") as f: + f.write(response.content) - # Store waveform - waveform_key = f"file_pipeline_hatchet/{input.transcript_id}/waveform.json" - import json + # Generate waveform (matches Celery: get_audio_waveform with 255 segments) + waveform = get_audio_waveform(path=Path(temp_path), segments_count=255) - waveform_bytes = json.dumps(waveform).encode() - import io + # Save waveform to database via event (matches Celery on_waveform callback) + db = await _get_fresh_db_connection() + try: + transcript = await transcripts_controller.get_by_id(input.transcript_id) + if transcript: + waveform_data = TranscriptWaveform(waveform=waveform) + await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform_data + ) + finally: + await _close_db_connection(db) - await storage.put_file(waveform_key, io.BytesIO(waveform_bytes)) + finally: + Path(temp_path).unlink(missing_ok=True) logger.info("[Hatchet] generate_waveform complete") @@ -413,10 +716,11 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id ) - return {"waveform_key": waveform_key} + return {"waveform_generated": True} except Exception as e: logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id ) @@ -427,7 +731,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3 ) async def detect_topics(input: PipelineInput, ctx: Context) -> dict: - """Detect topics using LLM.""" + """Detect topics using LLM and save to database (matches Celery on_topic callback).""" logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) await emit_progress_async( @@ -437,26 +741,52 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict: try: track_data = ctx.task_output(process_tracks) words = track_data.get("all_words", []) + target_language = track_data.get("target_language", "en") + from reflector.db.transcripts import TranscriptTopic, transcripts_controller from reflector.pipelines import topic_processing + from reflector.processors.types import ( + TitleSummaryWithId as TitleSummaryWithIdProcessorType, + ) from reflector.processors.types import Transcript as TranscriptType from reflector.processors.types import Word # Convert word dicts to Word objects word_objects = [Word(**w) for w in words] - transcript = TranscriptType(words=word_objects) + transcript_type = TranscriptType(words=word_objects) empty_pipeline = topic_processing.EmptyPipeline(logger=logger) - async def noop_callback(t): - pass + # Get DB connection for callbacks + db = await _get_fresh_db_connection() - topics = await topic_processing.detect_topics( - transcript, - "en", # target_language - on_topic_callback=noop_callback, - empty_pipeline=empty_pipeline, - ) + try: + transcript = await transcripts_controller.get_by_id(input.transcript_id) + + # Callback that upserts topics to DB (matches Celery on_topic) + async def on_topic_callback(data): + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + transcript=data.transcript.text, + words=data.transcript.words, + ) + if isinstance(data, TitleSummaryWithIdProcessorType): + topic.id = data.id + await transcripts_controller.upsert_topic(transcript, topic) + await transcripts_controller.append_event( + transcript=transcript, event="TOPIC", data=topic + ) + + topics = await topic_processing.detect_topics( + transcript_type, + target_language, + on_topic_callback=on_topic_callback, + empty_pipeline=empty_pipeline, + ) + finally: + await _close_db_connection(db) topics_list = [t.model_dump() for t in topics] @@ -470,6 +800,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict: except Exception as e: logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id ) @@ -480,7 +811,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict: parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3 ) async def generate_title(input: PipelineInput, ctx: Context) -> dict: - """Generate meeting title using LLM.""" + """Generate meeting title using LLM and save to database (matches Celery on_title callback).""" logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) await emit_progress_async( @@ -491,23 +822,56 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict: topics_data = ctx.task_output(detect_topics) topics = topics_data.get("topics", []) + from reflector.db.transcripts import ( + TranscriptFinalTitle, + transcripts_controller, + ) from reflector.pipelines import topic_processing - from reflector.processors.types import Topic + from reflector.processors.types import TitleSummary - topic_objects = [Topic(**t) for t in topics] + topic_objects = [TitleSummary(**t) for t in topics] - title = await topic_processing.generate_title(topic_objects) + empty_pipeline = topic_processing.EmptyPipeline(logger=logger) + title_result = None - logger.info("[Hatchet] generate_title complete", title=title) + db = await _get_fresh_db_connection() + try: + transcript = await transcripts_controller.get_by_id(input.transcript_id) + + # Callback that updates title in DB (matches Celery on_title) + async def on_title_callback(data): + nonlocal title_result + title_result = data.title + final_title = TranscriptFinalTitle(title=data.title) + if not transcript.title: + await transcripts_controller.update( + transcript, + {"title": final_title.title}, + ) + await transcripts_controller.append_event( + transcript=transcript, event="FINAL_TITLE", data=final_title + ) + + await topic_processing.generate_title( + topic_objects, + on_title_callback=on_title_callback, + empty_pipeline=empty_pipeline, + logger=logger, + ) + finally: + await _close_db_connection(db) + + logger.info("[Hatchet] generate_title complete", title=title_result) await emit_progress_async( input.transcript_id, "generate_title", "completed", ctx.workflow_run_id ) - return {"title": title} + return {"title": title_result} except Exception as e: logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "generate_title", "failed", ctx.workflow_run_id ) @@ -518,7 +882,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict: parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3 ) async def generate_summary(input: PipelineInput, ctx: Context) -> dict: - """Generate meeting summary using LLM.""" + """Generate meeting summary using LLM and save to database (matches Celery callbacks).""" logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) await emit_progress_async( @@ -526,23 +890,71 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict: ) try: - track_data = ctx.task_output(process_tracks) topics_data = ctx.task_output(detect_topics) - - words = track_data.get("all_words", []) topics = topics_data.get("topics", []) - from reflector.pipelines import topic_processing - from reflector.processors.types import Topic, Word - from reflector.processors.types import Transcript as TranscriptType - - word_objects = [Word(**w) for w in words] - transcript = TranscriptType(words=word_objects) - topic_objects = [Topic(**t) for t in topics] - - summary, short_summary = await topic_processing.generate_summary( - transcript, topic_objects + from reflector.db.transcripts import ( + TranscriptFinalLongSummary, + TranscriptFinalShortSummary, + transcripts_controller, ) + from reflector.pipelines import topic_processing + from reflector.processors.types import TitleSummary + + topic_objects = [TitleSummary(**t) for t in topics] + + empty_pipeline = topic_processing.EmptyPipeline(logger=logger) + summary_result = None + short_summary_result = None + + db = await _get_fresh_db_connection() + try: + transcript = await transcripts_controller.get_by_id(input.transcript_id) + + # Callback that updates long_summary in DB (matches Celery on_long_summary) + async def on_long_summary_callback(data): + nonlocal summary_result + summary_result = data.long_summary + final_long_summary = TranscriptFinalLongSummary( + long_summary=data.long_summary + ) + await transcripts_controller.update( + transcript, + {"long_summary": final_long_summary.long_summary}, + ) + await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + # Callback that updates short_summary in DB (matches Celery on_short_summary) + async def on_short_summary_callback(data): + nonlocal short_summary_result + short_summary_result = data.short_summary + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + await transcripts_controller.update( + transcript, + {"short_summary": final_short_summary.short_summary}, + ) + await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + await topic_processing.generate_summaries( + topic_objects, + transcript, # DB transcript for context + on_long_summary_callback=on_long_summary_callback, + on_short_summary_callback=on_short_summary_callback, + empty_pipeline=empty_pipeline, + logger=logger, + ) + finally: + await _close_db_connection(db) logger.info("[Hatchet] generate_summary complete") @@ -550,10 +962,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict: input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id ) - return {"summary": summary, "short_summary": short_summary} + return {"summary": summary_result, "short_summary": short_summary_result} except Exception as e: logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id ) @@ -566,7 +979,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict: retries=3, ) async def finalize(input: PipelineInput, ctx: Context) -> dict: - """Finalize transcript status and update database.""" + """Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'. + + Matches Celery's on_transcript + set_status behavior. + Note: Title and summaries are already saved by their respective task callbacks. + """ logger.info("[Hatchet] finalize", transcript_id=input.transcript_id) await emit_progress_async( @@ -574,21 +991,17 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict: ) try: - title_data = ctx.task_output(generate_title) - summary_data = ctx.task_output(generate_summary) mixdown_data = ctx.task_output(mixdown_tracks) track_data = ctx.task_output(process_tracks) - title = title_data.get("title", "") - summary = summary_data.get("summary", "") - short_summary = summary_data.get("short_summary", "") duration = mixdown_data.get("duration", 0) all_words = track_data.get("all_words", []) db = await _get_fresh_db_connection() try: - from reflector.db.transcripts import transcripts_controller + from reflector.db.transcripts import TranscriptText, transcripts_controller + from reflector.processors.types import Transcript as TranscriptType from reflector.processors.types import Word transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -600,18 +1013,32 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict: # Convert words back to Word objects for storage word_objects = [Word(**w) for w in all_words] + # Create merged transcript for TRANSCRIPT event (matches Celery line 734-736) + merged_transcript = TranscriptType(words=word_objects, translation=None) + + # Emit TRANSCRIPT event (matches Celery on_transcript callback) + await transcripts_controller.append_event( + transcript=transcript, + event="TRANSCRIPT", + data=TranscriptText( + text=merged_transcript.text, + translation=merged_transcript.translation, + ), + ) + + # Save duration and clear workflow_run_id (workflow completed successfully) + # Note: title/long_summary/short_summary already saved by their callbacks await transcripts_controller.update( transcript, { - "status": "ended", - "title": title, - "long_summary": summary, - "short_summary": short_summary, "duration": duration, - "words": word_objects, + "workflow_run_id": None, # Clear on success - no need to resume }, ) + # Set status to "ended" (matches Celery line 745) + await transcripts_controller.set_status(input.transcript_id, "ended") + logger.info( "[Hatchet] finalize complete", transcript_id=input.transcript_id ) @@ -627,6 +1054,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict: except Exception as e: logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True) + await _set_error_status(input.transcript_id) await emit_progress_async( input.transcript_id, "finalize", "failed", ctx.workflow_run_id ) diff --git a/server/reflector/processors/summary/summary_builder.py b/server/reflector/processors/summary/summary_builder.py index df348093..eb76d94f 100644 --- a/server/reflector/processors/summary/summary_builder.py +++ b/server/reflector/processors/summary/summary_builder.py @@ -166,6 +166,7 @@ class SummaryBuilder: self.model_name: str = llm.model_name self.logger = logger or structlog.get_logger() self.participant_instructions: str | None = None + self._logged_participant_instructions: bool = False if filename: self.read_transcript_from_file(filename) @@ -208,7 +209,9 @@ class SummaryBuilder: def _enhance_prompt_with_participants(self, prompt: str) -> str: """Add participant instructions to any prompt if participants are known.""" if self.participant_instructions: - self.logger.debug("Adding participant instructions to prompt") + if not self._logged_participant_instructions: + self.logger.debug("Adding participant instructions to prompts") + self._logged_participant_instructions = True return f"{prompt}\n\n{self.participant_instructions}" return prompt diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py index 379d5aae..1c386a86 100644 --- a/server/reflector/services/transcript_process.py +++ b/server/reflector/services/transcript_process.py @@ -102,6 +102,7 @@ async def validate_transcript_for_processing( if transcript.status == "idle": return ValidationNotReady(detail="Recording is not ready for processing") + # Check Celery tasks if task_is_scheduled_or_active( "reflector.pipelines.main_file_pipeline.task_pipeline_file_process", transcript_id=transcript.id, @@ -111,6 +112,23 @@ async def validate_transcript_for_processing( ): return ValidationAlreadyScheduled(detail="already running") + # Check Hatchet workflows (if enabled) + if settings.HATCHET_ENABLED and transcript.workflow_run_id: + from reflector.hatchet.client import HatchetClientManager + + try: + status = await HatchetClientManager.get_workflow_run_status( + transcript.workflow_run_id + ) + # If workflow is running or queued, don't allow new processing + if "RUNNING" in status or "QUEUED" in status: + return ValidationAlreadyScheduled( + detail="Hatchet workflow already running" + ) + except Exception: + # If we can't get status, allow processing (workflow might be gone) + pass + return ValidationOk( recording_id=transcript.recording_id, transcript_id=transcript.id ) @@ -155,7 +173,9 @@ async def prepare_transcript_processing( ) -def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None: +def dispatch_transcript_processing( + config: ProcessingConfig, force: bool = False +) -> AsyncResult | None: if isinstance(config, MultitrackProcessingConfig): # Start durable workflow if enabled (Hatchet or Conductor) durable_started = False @@ -163,18 +183,69 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No if settings.HATCHET_ENABLED: import asyncio - async def _start_hatchet(): - return await HatchetClientManager.start_workflow( - workflow_name="DiarizationPipeline", - input_data={ - "recording_id": config.recording_id, - "room_name": None, # Not available in reprocess path - "tracks": [{"s3_key": k} for k in config.track_keys], - "bucket_name": config.bucket_name, - "transcript_id": config.transcript_id, - "room_id": config.room_id, - }, - ) + import databases + + from reflector.db import _database_context + from reflector.db.transcripts import transcripts_controller + + async def _handle_hatchet(): + db = databases.Database(settings.DATABASE_URL) + _database_context.set(db) + await db.connect() + + try: + transcript = await transcripts_controller.get_by_id( + config.transcript_id + ) + + if transcript and transcript.workflow_run_id and not force: + can_replay = await HatchetClientManager.can_replay( + transcript.workflow_run_id + ) + if can_replay: + await HatchetClientManager.replay_workflow( + transcript.workflow_run_id + ) + logger.info( + "Replaying Hatchet workflow", + workflow_id=transcript.workflow_run_id, + ) + return transcript.workflow_run_id + + # Force: cancel old workflow if exists + if force and transcript and transcript.workflow_run_id: + await HatchetClientManager.cancel_workflow( + transcript.workflow_run_id + ) + logger.info( + "Cancelled old workflow (--force)", + workflow_id=transcript.workflow_run_id, + ) + await transcripts_controller.update( + transcript, {"workflow_run_id": None} + ) + + workflow_id = await HatchetClientManager.start_workflow( + workflow_name="DiarizationPipeline", + input_data={ + "recording_id": config.recording_id, + "room_name": None, + "tracks": [{"s3_key": k} for k in config.track_keys], + "bucket_name": config.bucket_name, + "transcript_id": config.transcript_id, + "room_id": config.room_id, + }, + ) + + if transcript: + await transcripts_controller.update( + transcript, {"workflow_run_id": workflow_id} + ) + + return workflow_id + finally: + await db.disconnect() + _database_context.set(None) try: loop = asyncio.get_running_loop() @@ -182,19 +253,14 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No loop = None if loop and loop.is_running(): - # Already in async context import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as pool: - workflow_id = pool.submit(asyncio.run, _start_hatchet()).result() + workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result() else: - workflow_id = asyncio.run(_start_hatchet()) + workflow_id = asyncio.run(_handle_hatchet()) - logger.info( - "Started Hatchet workflow (reprocess)", - workflow_id=workflow_id, - transcript_id=config.transcript_id, - ) + logger.info("Hatchet workflow dispatched", workflow_id=workflow_id) durable_started = True elif settings.CONDUCTOR_ENABLED: diff --git a/server/reflector/tools/process_transcript.py b/server/reflector/tools/process_transcript.py index cb8ade76..98db0307 100644 --- a/server/reflector/tools/process_transcript.py +++ b/server/reflector/tools/process_transcript.py @@ -34,21 +34,25 @@ async def process_transcript_inner( transcript: Transcript, on_validation: Callable[[ValidationResult], None], on_preprocess: Callable[[PrepareResult], None], + force: bool = False, ) -> AsyncResult: validation = await validate_transcript_for_processing(transcript) on_validation(validation) config = await prepare_transcript_processing(validation, room_id=transcript.room_id) on_preprocess(config) - return dispatch_transcript_processing(config) + return dispatch_transcript_processing(config, force=force) -async def process_transcript(transcript_id: str, sync: bool = False) -> None: +async def process_transcript( + transcript_id: str, sync: bool = False, force: bool = False +) -> None: """ Process a transcript by ID, auto-detecting multitrack vs file pipeline. Args: transcript_id: The transcript UUID sync: If True, wait for task completion. If False, dispatch and exit. + force: If True, cancel old workflow and start new (latest code). If False, replay failed workflow. """ from reflector.db import get_database @@ -82,7 +86,10 @@ async def process_transcript(transcript_id: str, sync: bool = False) -> None: print(f"Dispatching file pipeline", file=sys.stderr) result = await process_transcript_inner( - transcript, on_validation=on_validation, on_preprocess=on_preprocess + transcript, + on_validation=on_validation, + on_preprocess=on_preprocess, + force=force, ) if sync: @@ -118,9 +125,16 @@ def main(): action="store_true", help="Wait for task completion instead of just dispatching", ) + parser.add_argument( + "--force", + action="store_true", + help="Cancel old workflow and start new (uses latest code instead of replaying)", + ) args = parser.parse_args() - asyncio.run(process_transcript(args.transcript_id, sync=args.sync)) + asyncio.run( + process_transcript(args.transcript_id, sync=args.sync, force=args.force) + ) if __name__ == "__main__":