diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index dbbda268..d3a92906 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -623,6 +623,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: topics = topics_result.topics from reflector.db.transcripts import ( # noqa: PLC0415 + TranscriptActionItems, TranscriptFinalLongSummary, TranscriptFinalShortSummary, transcripts_controller, @@ -633,6 +634,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: empty_pipeline = topic_processing.EmptyPipeline(logger=logger) summary_result = None short_summary_result = None + action_items_result = None async with fresh_db_connection(): transcript = await transcripts_controller.get_by_id(input.transcript_id) @@ -673,18 +675,39 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: logger=logger, ) + async def on_action_items_callback(data): + nonlocal action_items_result + action_items_result = data.action_items + action_items = TranscriptActionItems(action_items=data.action_items) + await transcripts_controller.update( + transcript, + {"action_items": action_items.action_items}, + ) + await append_event_and_broadcast( + input.transcript_id, + transcript, + "ACTION_ITEMS", + action_items, + logger=logger, + ) + await topic_processing.generate_summaries( topic_objects, - transcript, # DB transcript for context + transcript, on_long_summary_callback=on_long_summary_callback, on_short_summary_callback=on_short_summary_callback, + on_action_items_callback=on_action_items_callback, empty_pipeline=empty_pipeline, logger=logger, ) ctx.log("generate_summary complete") - return SummaryResult(summary=summary_result, short_summary=short_summary_result) + return SummaryResult( + summary=summary_result, + short_summary=short_summary_result, + action_items=action_items_result, + ) @diarization_pipeline.task( diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index bc3577cf..adc3407e 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -96,6 +96,7 @@ class SummaryResult(BaseModel): summary: str | None short_summary: str | None + action_items: dict | None = None class FinalizeResult(BaseModel): diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index ae98c8f1..414092e8 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -812,6 +812,11 @@ async def reprocess_failed_daily_recordings(): ) continue + # Fetch room to check use_hatchet flag + room = None + if meeting.room_id: + room = await rooms_controller.get_by_id(meeting.room_id) + transcript = None try: transcript = await transcripts_controller.get_by_recording_id( @@ -831,20 +836,62 @@ async def reprocess_failed_daily_recordings(): ) continue - logger.info( - "Queueing Daily recording for reprocessing", - recording_id=recording.id, - room_name=meeting.room_name, - track_count=len(recording.track_keys), - transcript_status=transcript.status if transcript else None, - ) + use_hatchet = settings.HATCHET_ENABLED or (room and room.use_hatchet) + + if use_hatchet: + # Hatchet requires a transcript for workflow_run_id tracking + if not transcript: + logger.warning( + "No transcript for Hatchet reprocessing, skipping", + recording_id=recording.id, + ) + continue + + workflow_id = await HatchetClientManager.start_workflow( + workflow_name="DiarizationPipeline", + input_data={ + "recording_id": recording.id, + "tracks": [ + {"s3_key": k} + for k in filter_cam_audio_tracks(recording.track_keys) + ], + "bucket_name": bucket_name, + "transcript_id": transcript.id, + "room_id": room.id if room else None, + }, + additional_metadata={ + "transcript_id": transcript.id, + "recording_id": recording.id, + "reprocess": True, + }, + ) + await transcripts_controller.update( + transcript, {"workflow_run_id": workflow_id} + ) + + logger.info( + "Queued Daily recording for Hatchet reprocessing", + recording_id=recording.id, + workflow_id=workflow_id, + room_name=meeting.room_name, + track_count=len(recording.track_keys), + ) + else: + logger.info( + "Queueing Daily recording for Celery reprocessing", + recording_id=recording.id, + room_name=meeting.room_name, + track_count=len(recording.track_keys), + transcript_status=transcript.status if transcript else None, + ) + + process_multitrack_recording.delay( + bucket_name=bucket_name, + daily_room_name=meeting.room_name, + recording_id=recording.id, + track_keys=recording.track_keys, + ) - process_multitrack_recording.delay( - bucket_name=bucket_name, - daily_room_name=meeting.room_name, - recording_id=recording.id, - track_keys=recording.track_keys, - ) reprocessed_count += 1 except Exception as e: