mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
self-review round
This commit is contained in:
@@ -183,16 +183,6 @@ class TranscriptEvent(BaseModel):
|
|||||||
data: dict
|
data: dict
|
||||||
|
|
||||||
|
|
||||||
class PipelineProgressData(BaseModel):
|
|
||||||
"""Data payload for PIPELINE_PROGRESS WebSocket events."""
|
|
||||||
|
|
||||||
workflow_id: str | None = None
|
|
||||||
current_step: str
|
|
||||||
step_index: int
|
|
||||||
total_steps: int
|
|
||||||
step_status: Literal["pending", "in_progress", "completed", "failed"]
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptParticipant(BaseModel):
|
class TranscriptParticipant(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Hatchet workflow orchestration for Reflector."""
|
"""Hatchet workflow orchestration for Reflector."""
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
|
||||||
|
|
||||||
__all__ = ["HatchetClientManager", "emit_progress_async"]
|
__all__ = ["HatchetClientManager"]
|
||||||
|
|||||||
82
server/reflector/hatchet/broadcast.py
Normal file
82
server/reflector/hatchet/broadcast.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""WebSocket broadcasting helpers for Hatchet workflows.
|
||||||
|
|
||||||
|
Provides WebSocket broadcasting for Hatchet that matches Celery's @broadcast_to_sockets
|
||||||
|
decorator behavior. Events are broadcast to transcript rooms and user rooms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.db.transcripts import TranscriptEvent
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.ws_manager import get_ws_manager
|
||||||
|
|
||||||
|
# Events that should also be sent to user room (matches Celery behavior)
|
||||||
|
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
|
||||||
|
|
||||||
|
|
||||||
|
async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
||||||
|
"""Broadcast a TranscriptEvent to WebSocket subscribers.
|
||||||
|
|
||||||
|
Fire-and-forget: errors are logged but don't interrupt workflow execution.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ws_manager = get_ws_manager()
|
||||||
|
|
||||||
|
# Broadcast to transcript room
|
||||||
|
await ws_manager.send_json(
|
||||||
|
room_id=f"ts:{transcript_id}",
|
||||||
|
message=event.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also broadcast to user room for certain events
|
||||||
|
if event.event in USER_ROOM_EVENTS:
|
||||||
|
# Deferred import to avoid circular dependency
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if transcript and transcript.user_id:
|
||||||
|
await ws_manager.send_json(
|
||||||
|
room_id=f"user:{transcript.user_id}",
|
||||||
|
message={
|
||||||
|
"event": f"TRANSCRIPT_{event.event}",
|
||||||
|
"data": {"id": transcript_id, **event.data},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet Broadcast] Failed to broadcast event",
|
||||||
|
error=str(e),
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
event=event.event,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def set_status_and_broadcast(transcript_id: str, status: str) -> None:
|
||||||
|
"""Set transcript status and broadcast to WebSocket.
|
||||||
|
|
||||||
|
Wrapper around transcripts_controller.set_status that adds WebSocket broadcasting.
|
||||||
|
"""
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
event = await transcripts_controller.set_status(transcript_id, status)
|
||||||
|
if event:
|
||||||
|
await broadcast_event(transcript_id, event)
|
||||||
|
|
||||||
|
|
||||||
|
async def append_event_and_broadcast(
|
||||||
|
transcript_id: str,
|
||||||
|
transcript, # Transcript model
|
||||||
|
event_name: str,
|
||||||
|
data, # Pydantic model
|
||||||
|
) -> TranscriptEvent:
|
||||||
|
"""Append event to transcript and broadcast to WebSocket.
|
||||||
|
|
||||||
|
Wrapper around transcripts_controller.append_event that adds WebSocket broadcasting.
|
||||||
|
"""
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
event = await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event=event_name,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
await broadcast_event(transcript_id, event)
|
||||||
|
return event
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""Progress event emission for Hatchet workers."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from reflector.db.transcripts import PipelineProgressData
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.ws_manager import get_ws_manager
|
|
||||||
|
|
||||||
# Step mapping for progress tracking
|
|
||||||
PIPELINE_STEPS = {
|
|
||||||
"get_recording": 1,
|
|
||||||
"get_participants": 2,
|
|
||||||
"pad_track": 3, # Fork tasks share same step
|
|
||||||
"mixdown_tracks": 4,
|
|
||||||
"generate_waveform": 5,
|
|
||||||
"transcribe_track": 6, # Fork tasks share same step
|
|
||||||
"merge_transcripts": 7,
|
|
||||||
"detect_topics": 8,
|
|
||||||
"generate_title": 9, # Fork tasks share same step
|
|
||||||
"generate_summary": 9, # Fork tasks share same step
|
|
||||||
"finalize": 10,
|
|
||||||
"cleanup_consent": 11,
|
|
||||||
"post_zulip": 12,
|
|
||||||
"send_webhook": 13,
|
|
||||||
}
|
|
||||||
|
|
||||||
TOTAL_STEPS = 13
|
|
||||||
|
|
||||||
|
|
||||||
async def _emit_progress_async(
|
|
||||||
transcript_id: str,
|
|
||||||
step: str,
|
|
||||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
|
||||||
workflow_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Async implementation of progress emission."""
|
|
||||||
ws_manager = get_ws_manager()
|
|
||||||
step_index = PIPELINE_STEPS.get(step, 0)
|
|
||||||
|
|
||||||
data = PipelineProgressData(
|
|
||||||
workflow_id=workflow_id,
|
|
||||||
current_step=step,
|
|
||||||
step_index=step_index,
|
|
||||||
total_steps=TOTAL_STEPS,
|
|
||||||
step_status=status,
|
|
||||||
)
|
|
||||||
|
|
||||||
await ws_manager.send_json(
|
|
||||||
room_id=f"ts:{transcript_id}",
|
|
||||||
message={
|
|
||||||
"event": "PIPELINE_PROGRESS",
|
|
||||||
"data": data.model_dump(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"[Hatchet Progress] Emitted",
|
|
||||||
transcript_id=transcript_id,
|
|
||||||
step=step,
|
|
||||||
status=status,
|
|
||||||
step_index=step_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def emit_progress_async(
|
|
||||||
transcript_id: str,
|
|
||||||
step: str,
|
|
||||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
|
||||||
workflow_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Async version of emit_progress for use in async Hatchet tasks."""
|
|
||||||
try:
|
|
||||||
await _emit_progress_async(transcript_id, step, status, workflow_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
"[Hatchet Progress] Failed to emit progress event",
|
|
||||||
error=str(e),
|
|
||||||
transcript_id=transcript_id,
|
|
||||||
step=step,
|
|
||||||
)
|
|
||||||
@@ -25,8 +25,11 @@ from hatchet_sdk import Context
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
|
from reflector.hatchet.broadcast import (
|
||||||
|
append_event_and_broadcast,
|
||||||
|
set_status_and_broadcast,
|
||||||
|
)
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
|
||||||
from reflector.hatchet.workflows.models import (
|
from reflector.hatchet.workflows.models import (
|
||||||
ConsentResult,
|
ConsentResult,
|
||||||
FinalizeResult,
|
FinalizeResult,
|
||||||
@@ -55,32 +58,29 @@ from reflector.processors.types import (
|
|||||||
)
|
)
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage.storage_aws import AwsStorage
|
from reflector.storage.storage_aws import AwsStorage
|
||||||
|
from reflector.utils.audio_constants import (
|
||||||
|
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
WAVEFORM_SEGMENTS,
|
||||||
|
)
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
from reflector.utils.daily import (
|
from reflector.utils.daily import (
|
||||||
filter_cam_audio_tracks,
|
filter_cam_audio_tracks,
|
||||||
parse_daily_recording_filename,
|
parse_daily_recording_filename,
|
||||||
)
|
)
|
||||||
|
from reflector.utils.string import NonEmptyString
|
||||||
from reflector.zulip import post_transcript_notification
|
from reflector.zulip import post_transcript_notification
|
||||||
|
|
||||||
# Audio constants
|
|
||||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
|
||||||
OPUS_DEFAULT_BIT_RATE = 64000
|
|
||||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
|
||||||
WAVEFORM_SEGMENTS = 255
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineInput(BaseModel):
|
class PipelineInput(BaseModel):
|
||||||
"""Input to trigger the diarization pipeline."""
|
"""Input to trigger the diarization pipeline."""
|
||||||
|
|
||||||
recording_id: str | None
|
recording_id: NonEmptyString
|
||||||
room_name: str | None
|
|
||||||
tracks: list[dict] # List of {"s3_key": str}
|
tracks: list[dict] # List of {"s3_key": str}
|
||||||
bucket_name: str
|
bucket_name: NonEmptyString
|
||||||
transcript_id: str
|
transcript_id: NonEmptyString
|
||||||
room_id: str | None = None
|
room_id: NonEmptyString | None = None
|
||||||
|
|
||||||
|
|
||||||
# Get hatchet client and define workflow
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
diarization_pipeline = hatchet.workflow(
|
diarization_pipeline = hatchet.workflow(
|
||||||
@@ -120,9 +120,7 @@ async def set_workflow_error_status(transcript_id: str) -> bool:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
await set_status_and_broadcast(transcript_id, "error")
|
||||||
|
|
||||||
await transcripts_controller.set_status(transcript_id, "error")
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] Set transcript status to error",
|
"[Hatchet] Set transcript status to error",
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
@@ -181,9 +179,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
|||||||
)
|
)
|
||||||
if set_error_status:
|
if set_error_status:
|
||||||
await set_workflow_error_status(input.transcript_id)
|
await set_workflow_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, step_name, "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -203,34 +198,18 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
|||||||
ctx.log(f"get_recording: recording_id={input.recording_id}")
|
ctx.log(f"get_recording: recording_id={input.recording_id}")
|
||||||
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
# Set transcript status to "processing" at workflow start (broadcasts to WebSocket)
|
||||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set transcript status to "processing" at workflow start
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
await transcripts_controller.set_status(input.transcript_id, "processing")
|
await set_status_and_broadcast(input.transcript_id, "processing")
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] Set transcript status to processing",
|
"[Hatchet] Set transcript status to processing",
|
||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not input.recording_id:
|
|
||||||
# No recording_id in reprocess path - return minimal data
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
return RecordingResult(
|
|
||||||
id=None,
|
|
||||||
mtg_session_id=None,
|
|
||||||
room_name=input.room_name,
|
|
||||||
duration=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not settings.DAILY_API_KEY:
|
if not settings.DAILY_API_KEY:
|
||||||
raise ValueError("DAILY_API_KEY not configured")
|
raise ValueError("DAILY_API_KEY not configured")
|
||||||
|
|
||||||
@@ -247,14 +226,9 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
|||||||
duration=recording.duration,
|
duration=recording.duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return RecordingResult(
|
return RecordingResult(
|
||||||
id=recording.id,
|
id=recording.id,
|
||||||
mtg_session_id=recording.mtgSessionId,
|
mtg_session_id=recording.mtgSessionId,
|
||||||
room_name=recording.room_name,
|
|
||||||
duration=recording.duration,
|
duration=recording.duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -268,14 +242,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
|
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
|
||||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
recording_data = _to_dict(ctx.task_output(get_recording))
|
recording_data = _to_dict(ctx.task_output(get_recording))
|
||||||
mtg_session_id = recording_data.get("mtg_session_id")
|
mtg_session_id = recording_data.get("mtg_session_id")
|
||||||
|
|
||||||
# Get transcript and reset events/topics/participants
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptParticipant,
|
TranscriptParticipant,
|
||||||
@@ -284,7 +253,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
# Reset events/topics/participants
|
|
||||||
# Note: title NOT cleared - preserves existing titles
|
# Note: title NOT cleared - preserves existing titles
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript,
|
transcript,
|
||||||
@@ -296,12 +264,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not mtg_session_id or not settings.DAILY_API_KEY:
|
if not mtg_session_id or not settings.DAILY_API_KEY:
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id,
|
|
||||||
"get_participants",
|
|
||||||
"completed",
|
|
||||||
ctx.workflow_run_id,
|
|
||||||
)
|
|
||||||
return ParticipantsResult(
|
return ParticipantsResult(
|
||||||
participants=[],
|
participants=[],
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
@@ -309,7 +271,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
target_language=transcript.target_language if transcript else "en",
|
target_language=transcript.target_language if transcript else "en",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch participants from Daily API
|
|
||||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||||
participants = await client.get_meeting_participants(mtg_session_id)
|
participants = await client.get_meeting_participants(mtg_session_id)
|
||||||
|
|
||||||
@@ -321,11 +282,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
if p.user_id:
|
if p.user_id:
|
||||||
id_to_user_id[p.participant_id] = p.user_id
|
id_to_user_id[p.participant_id] = p.user_id
|
||||||
|
|
||||||
# Get track keys and filter for cam-audio tracks
|
|
||||||
track_keys = [t["s3_key"] for t in input.tracks]
|
track_keys = [t["s3_key"] for t in input.tracks]
|
||||||
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
||||||
|
|
||||||
# Update participants in database
|
|
||||||
participants_list = []
|
participants_list = []
|
||||||
for idx, key in enumerate(cam_audio_keys):
|
for idx, key in enumerate(cam_audio_keys):
|
||||||
try:
|
try:
|
||||||
@@ -361,10 +320,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
participant_count=len(participants_list),
|
participant_count=len(participants_list),
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return ParticipantsResult(
|
return ParticipantsResult(
|
||||||
participants=participants_list,
|
participants=participants_list,
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
@@ -389,7 +344,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
participants_data = _to_dict(ctx.task_output(get_participants))
|
participants_data = _to_dict(ctx.task_output(get_participants))
|
||||||
source_language = participants_data.get("source_language", "en")
|
source_language = participants_data.get("source_language", "en")
|
||||||
|
|
||||||
# Spawn child workflows for each track with correct language
|
|
||||||
child_coroutines = [
|
child_coroutines = [
|
||||||
track_workflow.aio_run(
|
track_workflow.aio_run(
|
||||||
TrackInput(
|
TrackInput(
|
||||||
@@ -403,7 +357,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
for i, track in enumerate(input.tracks)
|
for i, track in enumerate(input.tracks)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Wait for all child workflows to complete
|
|
||||||
results = await asyncio.gather(*child_coroutines)
|
results = await asyncio.gather(*child_coroutines)
|
||||||
|
|
||||||
# Get target_language for later use in detect_topics
|
# Get target_language for later use in detect_topics
|
||||||
@@ -428,13 +381,11 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
PaddedTrackInfo(key=padded_key, bucket_name=bucket_name)
|
PaddedTrackInfo(key=padded_key, bucket_name=bucket_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Track padded files for cleanup
|
|
||||||
track_index = pad_result.get("track_index")
|
track_index = pad_result.get("track_index")
|
||||||
if pad_result.get("size", 0) > 0 and track_index is not None:
|
if pad_result.get("size", 0) > 0 and track_index is not None:
|
||||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
||||||
created_padded_files.add(storage_path)
|
created_padded_files.add(storage_path)
|
||||||
|
|
||||||
# Merge all words and sort by start time
|
|
||||||
all_words = [word for words in track_words for word in words]
|
all_words = [word for words in track_words for word in words]
|
||||||
all_words.sort(key=lambda w: w.get("start", 0))
|
all_words.sort(key=lambda w: w.get("start", 0))
|
||||||
|
|
||||||
@@ -466,10 +417,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
padded_tracks_data = track_data.get("padded_tracks", [])
|
padded_tracks_data = track_data.get("padded_tracks", [])
|
||||||
|
|
||||||
@@ -503,7 +450,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
if not valid_urls:
|
if not valid_urls:
|
||||||
raise ValueError("No valid padded tracks to mixdown")
|
raise ValueError("No valid padded tracks to mixdown")
|
||||||
|
|
||||||
# Determine target sample rate from first track
|
|
||||||
target_sample_rate = None
|
target_sample_rate = None
|
||||||
for url in valid_urls:
|
for url in valid_urls:
|
||||||
container = None
|
container = None
|
||||||
@@ -551,12 +497,10 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
fmt.link_to(sink)
|
fmt.link_to(sink)
|
||||||
graph.configure()
|
graph.configure()
|
||||||
|
|
||||||
# Create temp output file
|
|
||||||
output_path = tempfile.mktemp(suffix=".mp3")
|
output_path = tempfile.mktemp(suffix=".mp3")
|
||||||
containers = []
|
containers = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Open all containers
|
|
||||||
for url in valid_urls:
|
for url in valid_urls:
|
||||||
try:
|
try:
|
||||||
c = av.open(
|
c = av.open(
|
||||||
@@ -644,7 +588,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Upload mixed file to storage
|
|
||||||
file_size = Path(output_path).stat().st_size
|
file_size = Path(output_path).stat().st_size
|
||||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||||
|
|
||||||
@@ -653,7 +596,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
|
|
||||||
Path(output_path).unlink(missing_ok=True)
|
Path(output_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
# Update transcript with audio_location
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
@@ -670,10 +612,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
size=file_size,
|
size=file_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return MixdownResult(
|
return MixdownResult(
|
||||||
audio_key=storage_path,
|
audio_key=storage_path,
|
||||||
duration=duration_ms[0],
|
duration=duration_ms[0],
|
||||||
@@ -689,10 +627,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
||||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptWaveform,
|
TranscriptWaveform,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
@@ -740,18 +674,16 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
with open(temp_path, "wb") as f:
|
with open(temp_path, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
# Generate waveform
|
|
||||||
waveform = get_audio_waveform(
|
waveform = get_audio_waveform(
|
||||||
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
|
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save waveform to database via event
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
if transcript:
|
if transcript:
|
||||||
waveform_data = TranscriptWaveform(waveform=waveform)
|
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||||
await transcripts_controller.append_event(
|
await append_event_and_broadcast(
|
||||||
transcript=transcript, event="WAVEFORM", data=waveform_data
|
input.transcript_id, transcript, "WAVEFORM", waveform_data
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
@@ -759,10 +691,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
|
|
||||||
logger.info("[Hatchet] generate_waveform complete")
|
logger.info("[Hatchet] generate_waveform complete")
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return WaveformResult(waveform_generated=True)
|
return WaveformResult(waveform_generated=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -775,10 +703,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
ctx.log("detect_topics: analyzing transcript for topics")
|
ctx.log("detect_topics: analyzing transcript for topics")
|
||||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
words = track_data.get("all_words", [])
|
words = track_data.get("all_words", [])
|
||||||
target_language = track_data.get("target_language", "en")
|
target_language = track_data.get("target_language", "en")
|
||||||
@@ -791,7 +715,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert word dicts to Word objects
|
|
||||||
word_objects = [Word(**w) for w in words]
|
word_objects = [Word(**w) for w in words]
|
||||||
transcript_type = TranscriptType(words=word_objects)
|
transcript_type = TranscriptType(words=word_objects)
|
||||||
|
|
||||||
@@ -800,7 +723,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that upserts topics to DB
|
|
||||||
async def on_topic_callback(data):
|
async def on_topic_callback(data):
|
||||||
topic = TranscriptTopic(
|
topic = TranscriptTopic(
|
||||||
title=data.title,
|
title=data.title,
|
||||||
@@ -812,8 +734,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
topic.id = data.id
|
topic.id = data.id
|
||||||
await transcripts_controller.upsert_topic(transcript, topic)
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
await transcripts_controller.append_event(
|
await append_event_and_broadcast(
|
||||||
transcript=transcript, event="TOPIC", data=topic
|
input.transcript_id, transcript, "TOPIC", topic
|
||||||
)
|
)
|
||||||
|
|
||||||
topics = await topic_processing.detect_topics(
|
topics = await topic_processing.detect_topics(
|
||||||
@@ -828,10 +750,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
ctx.log(f"detect_topics complete: found {len(topics_list)} topics")
|
ctx.log(f"detect_topics complete: found {len(topics_list)} topics")
|
||||||
logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list))
|
logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list))
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return TopicsResult(topics=topics_list)
|
return TopicsResult(topics=topics_list)
|
||||||
|
|
||||||
|
|
||||||
@@ -844,10 +762,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
ctx.log("generate_title: generating title from topics")
|
ctx.log("generate_title: generating title from topics")
|
||||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
@@ -864,7 +778,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that updates title in DB
|
|
||||||
async def on_title_callback(data):
|
async def on_title_callback(data):
|
||||||
nonlocal title_result
|
nonlocal title_result
|
||||||
title_result = data.title
|
title_result = data.title
|
||||||
@@ -874,8 +787,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
transcript,
|
transcript,
|
||||||
{"title": final_title.title},
|
{"title": final_title.title},
|
||||||
)
|
)
|
||||||
await transcripts_controller.append_event(
|
await append_event_and_broadcast(
|
||||||
transcript=transcript, event="FINAL_TITLE", data=final_title
|
input.transcript_id, transcript, "FINAL_TITLE", final_title
|
||||||
)
|
)
|
||||||
|
|
||||||
await topic_processing.generate_title(
|
await topic_processing.generate_title(
|
||||||
@@ -888,10 +801,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
ctx.log(f"generate_title complete: '{title_result}'")
|
ctx.log(f"generate_title complete: '{title_result}'")
|
||||||
logger.info("[Hatchet] generate_title complete", title=title_result)
|
logger.info("[Hatchet] generate_title complete", title=title_result)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return TitleResult(title=title_result)
|
return TitleResult(title=title_result)
|
||||||
|
|
||||||
|
|
||||||
@@ -904,10 +813,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
ctx.log("generate_summary: generating long and short summaries")
|
ctx.log("generate_summary: generating long and short summaries")
|
||||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
@@ -926,7 +831,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
# Callback that updates long_summary in DB
|
|
||||||
async def on_long_summary_callback(data):
|
async def on_long_summary_callback(data):
|
||||||
nonlocal summary_result
|
nonlocal summary_result
|
||||||
summary_result = data.long_summary
|
summary_result = data.long_summary
|
||||||
@@ -937,13 +841,13 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
transcript,
|
transcript,
|
||||||
{"long_summary": final_long_summary.long_summary},
|
{"long_summary": final_long_summary.long_summary},
|
||||||
)
|
)
|
||||||
await transcripts_controller.append_event(
|
await append_event_and_broadcast(
|
||||||
transcript=transcript,
|
input.transcript_id,
|
||||||
event="FINAL_LONG_SUMMARY",
|
transcript,
|
||||||
data=final_long_summary,
|
"FINAL_LONG_SUMMARY",
|
||||||
|
final_long_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callback that updates short_summary in DB
|
|
||||||
async def on_short_summary_callback(data):
|
async def on_short_summary_callback(data):
|
||||||
nonlocal short_summary_result
|
nonlocal short_summary_result
|
||||||
short_summary_result = data.short_summary
|
short_summary_result = data.short_summary
|
||||||
@@ -954,10 +858,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
transcript,
|
transcript,
|
||||||
{"short_summary": final_short_summary.short_summary},
|
{"short_summary": final_short_summary.short_summary},
|
||||||
)
|
)
|
||||||
await transcripts_controller.append_event(
|
await append_event_and_broadcast(
|
||||||
transcript=transcript,
|
input.transcript_id,
|
||||||
event="FINAL_SHORT_SUMMARY",
|
transcript,
|
||||||
data=final_short_summary,
|
"FINAL_SHORT_SUMMARY",
|
||||||
|
final_short_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
await topic_processing.generate_summaries(
|
await topic_processing.generate_summaries(
|
||||||
@@ -972,10 +877,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
ctx.log("generate_summary complete")
|
ctx.log("generate_summary complete")
|
||||||
logger.info("[Hatchet] generate_summary complete")
|
logger.info("[Hatchet] generate_summary complete")
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
|
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
|
||||||
|
|
||||||
|
|
||||||
@@ -994,10 +895,6 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
||||||
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||||
|
|
||||||
@@ -1006,6 +903,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
TranscriptDuration,
|
||||||
TranscriptText,
|
TranscriptText,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
@@ -1018,17 +916,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
if transcript is None:
|
if transcript is None:
|
||||||
raise ValueError(f"Transcript {input.transcript_id} not found in database")
|
raise ValueError(f"Transcript {input.transcript_id} not found in database")
|
||||||
|
|
||||||
# Convert words back to Word objects for storage
|
|
||||||
word_objects = [Word(**w) for w in all_words]
|
word_objects = [Word(**w) for w in all_words]
|
||||||
|
|
||||||
# Create merged transcript for TRANSCRIPT event
|
|
||||||
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
||||||
|
|
||||||
# Emit TRANSCRIPT event
|
await append_event_and_broadcast(
|
||||||
await transcripts_controller.append_event(
|
input.transcript_id,
|
||||||
transcript=transcript,
|
transcript,
|
||||||
event="TRANSCRIPT",
|
"TRANSCRIPT",
|
||||||
data=TranscriptText(
|
TranscriptText(
|
||||||
text=merged_transcript.text,
|
text=merged_transcript.text,
|
||||||
translation=merged_transcript.translation,
|
translation=merged_transcript.translation,
|
||||||
),
|
),
|
||||||
@@ -1044,18 +939,18 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set status to "ended"
|
duration_data = TranscriptDuration(duration=duration)
|
||||||
await transcripts_controller.set_status(input.transcript_id, "ended")
|
await append_event_and_broadcast(
|
||||||
|
input.transcript_id, transcript, "DURATION", duration_data
|
||||||
|
)
|
||||||
|
|
||||||
|
await set_status_and_broadcast(input.transcript_id, "ended")
|
||||||
|
|
||||||
ctx.log(
|
ctx.log(
|
||||||
f"finalize complete: transcript {input.transcript_id} status set to 'ended'"
|
f"finalize complete: transcript {input.transcript_id} status set to 'ended'"
|
||||||
)
|
)
|
||||||
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return FinalizeResult(status="COMPLETED")
|
return FinalizeResult(status="COMPLETED")
|
||||||
|
|
||||||
|
|
||||||
@@ -1067,10 +962,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
|||||||
"""Check and handle consent requirements."""
|
"""Check and handle consent requirements."""
|
||||||
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.meetings import meetings_controller # noqa: PLC0415
|
from reflector.db.meetings import meetings_controller # noqa: PLC0415
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
@@ -1087,10 +978,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
|||||||
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return ConsentResult(consent_checked=True)
|
return ConsentResult(consent_checked=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -1102,15 +989,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
|||||||
"""Post notification to Zulip."""
|
"""Post notification to Zulip."""
|
||||||
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not settings.ZULIP_REALM:
|
if not settings.ZULIP_REALM:
|
||||||
logger.info("[Hatchet] post_zulip skipped (Zulip not configured)")
|
logger.info("[Hatchet] post_zulip skipped (Zulip not configured)")
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
return ZulipResult(zulip_message_id=None, skipped=True)
|
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
@@ -1123,10 +1003,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
|||||||
else:
|
else:
|
||||||
message_id = None
|
message_id = None
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return ZulipResult(zulip_message_id=message_id)
|
return ZulipResult(zulip_message_id=message_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -1138,15 +1014,8 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
|||||||
"""Send completion webhook to external service."""
|
"""Send completion webhook to external service."""
|
||||||
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not input.room_id:
|
if not input.room_id:
|
||||||
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
return WebhookResult(webhook_sent=False, skipped=True)
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
@@ -1174,17 +1043,6 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
|||||||
"[Hatchet] send_webhook complete", status_code=response.status_code
|
"[Hatchet] send_webhook complete", status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id,
|
|
||||||
"send_webhook",
|
|
||||||
"completed",
|
|
||||||
ctx.workflow_run_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return WebhookResult(webhook_sent=False, skipped=True)
|
return WebhookResult(webhook_sent=False, skipped=True)
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class RecordingResult(BaseModel):
|
|||||||
|
|
||||||
id: str | None
|
id: str | None
|
||||||
mtg_session_id: str | None
|
mtg_session_id: str | None
|
||||||
room_name: str | None
|
|
||||||
duration: float
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,9 +26,13 @@ from hatchet_sdk import Context
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
|
||||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
from reflector.utils.audio_constants import (
|
||||||
|
OPUS_DEFAULT_BIT_RATE,
|
||||||
|
OPUS_STANDARD_SAMPLE_RATE,
|
||||||
|
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(output) -> dict:
|
def _to_dict(output) -> dict:
|
||||||
@@ -38,12 +42,6 @@ def _to_dict(output) -> dict:
|
|||||||
return output.model_dump()
|
return output.model_dump()
|
||||||
|
|
||||||
|
|
||||||
# Audio constants matching existing pipeline
|
|
||||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
|
||||||
OPUS_DEFAULT_BIT_RATE = 64000
|
|
||||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
|
||||||
|
|
||||||
|
|
||||||
class TrackInput(BaseModel):
|
class TrackInput(BaseModel):
|
||||||
"""Input for individual track processing."""
|
"""Input for individual track processing."""
|
||||||
|
|
||||||
@@ -193,10 +191,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "pad_track", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||||
from reflector.settings import settings # noqa: PLC0415
|
from reflector.settings import settings # noqa: PLC0415
|
||||||
@@ -229,9 +223,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
f"Track {input.track_index} requires no padding",
|
f"Track {input.track_index} requires no padding",
|
||||||
track_index=input.track_index,
|
track_index=input.track_index,
|
||||||
)
|
)
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
return PadTrackResult(
|
return PadTrackResult(
|
||||||
padded_key=input.s3_key,
|
padded_key=input.s3_key,
|
||||||
bucket_name=input.bucket_name,
|
bucket_name=input.bucket_name,
|
||||||
@@ -275,10 +266,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
padded_key=storage_path,
|
padded_key=storage_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return S3 key (not presigned URL) - consumer tasks presign on demand
|
# Return S3 key (not presigned URL) - consumer tasks presign on demand
|
||||||
# This avoids stale URLs when workflow is replayed
|
# This avoids stale URLs when workflow is replayed
|
||||||
return PadTrackResult(
|
return PadTrackResult(
|
||||||
@@ -290,9 +277,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "pad_track", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -308,10 +292,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
|||||||
language=input.language,
|
language=input.language,
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "transcribe_track", "in_progress", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pad_result = _to_dict(ctx.task_output(pad_track))
|
pad_result = _to_dict(ctx.task_output(pad_track))
|
||||||
padded_key = pad_result.get("padded_key")
|
padded_key = pad_result.get("padded_key")
|
||||||
@@ -360,10 +340,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
|||||||
word_count=len(words),
|
word_count=len(words),
|
||||||
)
|
)
|
||||||
|
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return TranscribeTrackResult(
|
return TranscribeTrackResult(
|
||||||
words=words,
|
words=words,
|
||||||
track_index=input.track_index,
|
track_index=input.track_index,
|
||||||
@@ -371,7 +347,4 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||||
await emit_progress_async(
|
|
||||||
input.transcript_id, "transcribe_track", "failed", ctx.workflow_run_id
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
|||||||
from reflector.processors.types import TitleSummary
|
from reflector.processors.types import TitleSummary
|
||||||
from reflector.processors.types import Transcript as TranscriptType
|
from reflector.processors.types import Transcript as TranscriptType
|
||||||
from reflector.storage import Storage, get_transcripts_storage
|
from reflector.storage import Storage, get_transcripts_storage
|
||||||
|
from reflector.utils.audio_constants import (
|
||||||
|
OPUS_DEFAULT_BIT_RATE,
|
||||||
|
OPUS_STANDARD_SAMPLE_RATE,
|
||||||
|
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
)
|
||||||
from reflector.utils.daily import (
|
from reflector.utils.daily import (
|
||||||
filter_cam_audio_tracks,
|
filter_cam_audio_tracks,
|
||||||
parse_daily_recording_filename,
|
parse_daily_recording_filename,
|
||||||
@@ -39,13 +44,6 @@ from reflector.utils.daily import (
|
|||||||
from reflector.utils.string import NonEmptyString
|
from reflector.utils.string import NonEmptyString
|
||||||
from reflector.video_platforms.factory import create_platform_client
|
from reflector.video_platforms.factory import create_platform_client
|
||||||
|
|
||||||
# Audio encoding constants
|
|
||||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
|
||||||
OPUS_DEFAULT_BIT_RATE = 128000
|
|
||||||
|
|
||||||
# Storage operation constants
|
|
||||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainMultitrack(PipelineMainBase):
|
class PipelineMainMultitrack(PipelineMainBase):
|
||||||
def __init__(self, transcript_id: str):
|
def __init__(self, transcript_id: str):
|
||||||
|
|||||||
@@ -251,7 +251,6 @@ async def dispatch_transcript_processing(
|
|||||||
workflow_name="DiarizationPipeline",
|
workflow_name="DiarizationPipeline",
|
||||||
input_data={
|
input_data={
|
||||||
"recording_id": config.recording_id,
|
"recording_id": config.recording_id,
|
||||||
"room_name": None,
|
|
||||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||||
"bucket_name": config.bucket_name,
|
"bucket_name": config.bucket_name,
|
||||||
"transcript_id": config.transcript_id,
|
"transcript_id": config.transcript_id,
|
||||||
|
|||||||
15
server/reflector/utils/audio_constants.py
Normal file
15
server/reflector/utils/audio_constants.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Shared audio processing constants.
|
||||||
|
|
||||||
|
Used by both Hatchet workflows and Celery pipelines for consistent audio encoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Opus codec settings
|
||||||
|
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||||
|
OPUS_DEFAULT_BIT_RATE = 128000 # 128kbps for good speech quality
|
||||||
|
|
||||||
|
# S3 presigned URL expiration
|
||||||
|
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours
|
||||||
|
|
||||||
|
# Waveform visualization
|
||||||
|
WAVEFORM_SEGMENTS = 255
|
||||||
@@ -303,7 +303,6 @@ async def _process_multitrack_recording_inner(
|
|||||||
workflow_name="DiarizationPipeline",
|
workflow_name="DiarizationPipeline",
|
||||||
input_data={
|
input_data={
|
||||||
"recording_id": recording_id,
|
"recording_id": recording_id,
|
||||||
"room_name": daily_room_name,
|
|
||||||
"tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)],
|
"tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)],
|
||||||
"bucket_name": bucket_name,
|
"bucket_name": bucket_name,
|
||||||
"transcript_id": transcript.id,
|
"transcript_id": transcript.id,
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for Hatchet progress emission.
|
|
||||||
|
|
||||||
Only tests that catch real bugs - error handling and step completeness.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emit_progress_async_handles_exception():
|
|
||||||
"""Test that emit_progress_async catches exceptions gracefully.
|
|
||||||
|
|
||||||
Critical: Progress emission must NEVER crash the pipeline.
|
|
||||||
WebSocket errors should be silently caught.
|
|
||||||
"""
|
|
||||||
from reflector.hatchet.progress import emit_progress_async
|
|
||||||
|
|
||||||
with patch("reflector.hatchet.progress.get_ws_manager") as mock_get_ws:
|
|
||||||
mock_ws = MagicMock()
|
|
||||||
mock_ws.send_json = AsyncMock(side_effect=Exception("WebSocket error"))
|
|
||||||
mock_get_ws.return_value = mock_ws
|
|
||||||
|
|
||||||
# Should not raise - exceptions are caught
|
|
||||||
await emit_progress_async(
|
|
||||||
transcript_id="test-transcript-123",
|
|
||||||
step="finalize",
|
|
||||||
status="completed",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pipeline_steps_mapping_complete():
|
|
||||||
"""Test the PIPELINE_STEPS mapping includes all expected steps.
|
|
||||||
|
|
||||||
Useful: Catches when someone adds a new pipeline step but forgets
|
|
||||||
to add it to the progress mapping, resulting in missing UI updates.
|
|
||||||
"""
|
|
||||||
from reflector.hatchet.progress import PIPELINE_STEPS, TOTAL_STEPS
|
|
||||||
|
|
||||||
expected_steps = [
|
|
||||||
"get_recording",
|
|
||||||
"get_participants",
|
|
||||||
"pad_track",
|
|
||||||
"mixdown_tracks",
|
|
||||||
"generate_waveform",
|
|
||||||
"transcribe_track",
|
|
||||||
"merge_transcripts",
|
|
||||||
"detect_topics",
|
|
||||||
"generate_title",
|
|
||||||
"generate_summary",
|
|
||||||
"finalize",
|
|
||||||
"cleanup_consent",
|
|
||||||
"post_zulip",
|
|
||||||
"send_webhook",
|
|
||||||
]
|
|
||||||
|
|
||||||
for step in expected_steps:
|
|
||||||
assert step in PIPELINE_STEPS, f"Missing step in PIPELINE_STEPS: {step}"
|
|
||||||
assert 1 <= PIPELINE_STEPS[step] <= TOTAL_STEPS
|
|
||||||
Reference in New Issue
Block a user