mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
self-review round
This commit is contained in:
@@ -183,16 +183,6 @@ class TranscriptEvent(BaseModel):
|
||||
data: dict
|
||||
|
||||
|
||||
class PipelineProgressData(BaseModel):
|
||||
"""Data payload for PIPELINE_PROGRESS WebSocket events."""
|
||||
|
||||
workflow_id: str | None = None
|
||||
current_step: str
|
||||
step_index: int
|
||||
total_steps: int
|
||||
step_status: Literal["pending", "in_progress", "completed", "failed"]
|
||||
|
||||
|
||||
class TranscriptParticipant(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Hatchet workflow orchestration for Reflector."""
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
|
||||
__all__ = ["HatchetClientManager", "emit_progress_async"]
|
||||
__all__ = ["HatchetClientManager"]
|
||||
|
||||
82
server/reflector/hatchet/broadcast.py
Normal file
82
server/reflector/hatchet/broadcast.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""WebSocket broadcasting helpers for Hatchet workflows.
|
||||
|
||||
Provides WebSocket broadcasting for Hatchet that matches Celery's @broadcast_to_sockets
|
||||
decorator behavior. Events are broadcast to transcript rooms and user rooms.
|
||||
"""
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
from reflector.logger import logger
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
# Events that should also be sent to user room (matches Celery behavior)
|
||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
|
||||
|
||||
|
||||
async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
||||
"""Broadcast a TranscriptEvent to WebSocket subscribers.
|
||||
|
||||
Fire-and-forget: errors are logged but don't interrupt workflow execution.
|
||||
"""
|
||||
try:
|
||||
ws_manager = get_ws_manager()
|
||||
|
||||
# Broadcast to transcript room
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message=event.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
# Also broadcast to user room for certain events
|
||||
if event.event in USER_ROOM_EVENTS:
|
||||
# Deferred import to avoid circular dependency
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if transcript and transcript.user_id:
|
||||
await ws_manager.send_json(
|
||||
room_id=f"user:{transcript.user_id}",
|
||||
message={
|
||||
"event": f"TRANSCRIPT_{event.event}",
|
||||
"data": {"id": transcript_id, **event.data},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Hatchet Broadcast] Failed to broadcast event",
|
||||
error=str(e),
|
||||
transcript_id=transcript_id,
|
||||
event=event.event,
|
||||
)
|
||||
|
||||
|
||||
async def set_status_and_broadcast(transcript_id: str, status: str) -> None:
|
||||
"""Set transcript status and broadcast to WebSocket.
|
||||
|
||||
Wrapper around transcripts_controller.set_status that adds WebSocket broadcasting.
|
||||
"""
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
event = await transcripts_controller.set_status(transcript_id, status)
|
||||
if event:
|
||||
await broadcast_event(transcript_id, event)
|
||||
|
||||
|
||||
async def append_event_and_broadcast(
|
||||
transcript_id: str,
|
||||
transcript, # Transcript model
|
||||
event_name: str,
|
||||
data, # Pydantic model
|
||||
) -> TranscriptEvent:
|
||||
"""Append event to transcript and broadcast to WebSocket.
|
||||
|
||||
Wrapper around transcripts_controller.append_event that adds WebSocket broadcasting.
|
||||
"""
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
event = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event=event_name,
|
||||
data=data,
|
||||
)
|
||||
await broadcast_event(transcript_id, event)
|
||||
return event
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Progress event emission for Hatchet workers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from reflector.db.transcripts import PipelineProgressData
|
||||
from reflector.logger import logger
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
# Step mapping for progress tracking
|
||||
PIPELINE_STEPS = {
|
||||
"get_recording": 1,
|
||||
"get_participants": 2,
|
||||
"pad_track": 3, # Fork tasks share same step
|
||||
"mixdown_tracks": 4,
|
||||
"generate_waveform": 5,
|
||||
"transcribe_track": 6, # Fork tasks share same step
|
||||
"merge_transcripts": 7,
|
||||
"detect_topics": 8,
|
||||
"generate_title": 9, # Fork tasks share same step
|
||||
"generate_summary": 9, # Fork tasks share same step
|
||||
"finalize": 10,
|
||||
"cleanup_consent": 11,
|
||||
"post_zulip": 12,
|
||||
"send_webhook": 13,
|
||||
}
|
||||
|
||||
TOTAL_STEPS = 13
|
||||
|
||||
|
||||
async def _emit_progress_async(
|
||||
transcript_id: str,
|
||||
step: str,
|
||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
||||
workflow_id: str | None = None,
|
||||
) -> None:
|
||||
"""Async implementation of progress emission."""
|
||||
ws_manager = get_ws_manager()
|
||||
step_index = PIPELINE_STEPS.get(step, 0)
|
||||
|
||||
data = PipelineProgressData(
|
||||
workflow_id=workflow_id,
|
||||
current_step=step,
|
||||
step_index=step_index,
|
||||
total_steps=TOTAL_STEPS,
|
||||
step_status=status,
|
||||
)
|
||||
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "PIPELINE_PROGRESS",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[Hatchet Progress] Emitted",
|
||||
transcript_id=transcript_id,
|
||||
step=step,
|
||||
status=status,
|
||||
step_index=step_index,
|
||||
)
|
||||
|
||||
|
||||
async def emit_progress_async(
|
||||
transcript_id: str,
|
||||
step: str,
|
||||
status: Literal["pending", "in_progress", "completed", "failed"],
|
||||
workflow_id: str | None = None,
|
||||
) -> None:
|
||||
"""Async version of emit_progress for use in async Hatchet tasks."""
|
||||
try:
|
||||
await _emit_progress_async(transcript_id, step, status, workflow_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Hatchet Progress] Failed to emit progress event",
|
||||
error=str(e),
|
||||
transcript_id=transcript_id,
|
||||
step=step,
|
||||
)
|
||||
@@ -25,8 +25,11 @@ from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
from reflector.hatchet.broadcast import (
|
||||
append_event_and_broadcast,
|
||||
set_status_and_broadcast,
|
||||
)
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.models import (
|
||||
ConsentResult,
|
||||
FinalizeResult,
|
||||
@@ -55,32 +58,29 @@ from reflector.processors.types import (
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
from reflector.utils.audio_constants import (
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
WAVEFORM_SEGMENTS,
|
||||
)
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from reflector.utils.daily import (
|
||||
filter_cam_audio_tracks,
|
||||
parse_daily_recording_filename,
|
||||
)
|
||||
from reflector.utils.string import NonEmptyString
|
||||
from reflector.zulip import post_transcript_notification
|
||||
|
||||
# Audio constants
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||
WAVEFORM_SEGMENTS = 255
|
||||
|
||||
|
||||
class PipelineInput(BaseModel):
|
||||
"""Input to trigger the diarization pipeline."""
|
||||
|
||||
recording_id: str | None
|
||||
room_name: str | None
|
||||
recording_id: NonEmptyString
|
||||
tracks: list[dict] # List of {"s3_key": str}
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
room_id: str | None = None
|
||||
bucket_name: NonEmptyString
|
||||
transcript_id: NonEmptyString
|
||||
room_id: NonEmptyString | None = None
|
||||
|
||||
|
||||
# Get hatchet client and define workflow
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
diarization_pipeline = hatchet.workflow(
|
||||
@@ -120,9 +120,7 @@ async def set_workflow_error_status(transcript_id: str) -> bool:
|
||||
"""
|
||||
try:
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
await transcripts_controller.set_status(transcript_id, "error")
|
||||
await set_status_and_broadcast(transcript_id, "error")
|
||||
logger.info(
|
||||
"[Hatchet] Set transcript status to error",
|
||||
transcript_id=transcript_id,
|
||||
@@ -181,9 +179,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
||||
)
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, step_name, "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
@@ -203,34 +198,18 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
ctx.log(f"get_recording: recording_id={input.recording_id}")
|
||||
logger.info("[Hatchet] get_recording", recording_id=input.recording_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
# Set transcript status to "processing" at workflow start
|
||||
# Set transcript status to "processing" at workflow start (broadcasts to WebSocket)
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
await transcripts_controller.set_status(input.transcript_id, "processing")
|
||||
await set_status_and_broadcast(input.transcript_id, "processing")
|
||||
logger.info(
|
||||
"[Hatchet] Set transcript status to processing",
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
if not input.recording_id:
|
||||
# No recording_id in reprocess path - return minimal data
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return RecordingResult(
|
||||
id=None,
|
||||
mtg_session_id=None,
|
||||
room_name=input.room_name,
|
||||
duration=0,
|
||||
)
|
||||
|
||||
if not settings.DAILY_API_KEY:
|
||||
raise ValueError("DAILY_API_KEY not configured")
|
||||
|
||||
@@ -247,14 +226,9 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
duration=recording.duration,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_recording", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return RecordingResult(
|
||||
id=recording.id,
|
||||
mtg_session_id=recording.mtgSessionId,
|
||||
room_name=recording.room_name,
|
||||
duration=recording.duration,
|
||||
)
|
||||
|
||||
@@ -268,14 +242,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
|
||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
recording_data = _to_dict(ctx.task_output(get_recording))
|
||||
mtg_session_id = recording_data.get("mtg_session_id")
|
||||
|
||||
# Get transcript and reset events/topics/participants
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptParticipant,
|
||||
@@ -284,7 +253,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
# Reset events/topics/participants
|
||||
# Note: title NOT cleared - preserves existing titles
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
@@ -296,12 +264,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
)
|
||||
|
||||
if not mtg_session_id or not settings.DAILY_API_KEY:
|
||||
await emit_progress_async(
|
||||
input.transcript_id,
|
||||
"get_participants",
|
||||
"completed",
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
return ParticipantsResult(
|
||||
participants=[],
|
||||
num_tracks=len(input.tracks),
|
||||
@@ -309,7 +271,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
target_language=transcript.target_language if transcript else "en",
|
||||
)
|
||||
|
||||
# Fetch participants from Daily API
|
||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||
participants = await client.get_meeting_participants(mtg_session_id)
|
||||
|
||||
@@ -321,11 +282,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
if p.user_id:
|
||||
id_to_user_id[p.participant_id] = p.user_id
|
||||
|
||||
# Get track keys and filter for cam-audio tracks
|
||||
track_keys = [t["s3_key"] for t in input.tracks]
|
||||
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
||||
|
||||
# Update participants in database
|
||||
participants_list = []
|
||||
for idx, key in enumerate(cam_audio_keys):
|
||||
try:
|
||||
@@ -361,10 +320,6 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
participant_count=len(participants_list),
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return ParticipantsResult(
|
||||
participants=participants_list,
|
||||
num_tracks=len(input.tracks),
|
||||
@@ -389,7 +344,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
participants_data = _to_dict(ctx.task_output(get_participants))
|
||||
source_language = participants_data.get("source_language", "en")
|
||||
|
||||
# Spawn child workflows for each track with correct language
|
||||
child_coroutines = [
|
||||
track_workflow.aio_run(
|
||||
TrackInput(
|
||||
@@ -403,7 +357,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
for i, track in enumerate(input.tracks)
|
||||
]
|
||||
|
||||
# Wait for all child workflows to complete
|
||||
results = await asyncio.gather(*child_coroutines)
|
||||
|
||||
# Get target_language for later use in detect_topics
|
||||
@@ -428,13 +381,11 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
PaddedTrackInfo(key=padded_key, bucket_name=bucket_name)
|
||||
)
|
||||
|
||||
# Track padded files for cleanup
|
||||
track_index = pad_result.get("track_index")
|
||||
if pad_result.get("size", 0) > 0 and track_index is not None:
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
||||
created_padded_files.add(storage_path)
|
||||
|
||||
# Merge all words and sort by start time
|
||||
all_words = [word for words in track_words for word in words]
|
||||
all_words.sort(key=lambda w: w.get("start", 0))
|
||||
|
||||
@@ -466,10 +417,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
padded_tracks_data = track_data.get("padded_tracks", [])
|
||||
|
||||
@@ -503,7 +450,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
if not valid_urls:
|
||||
raise ValueError("No valid padded tracks to mixdown")
|
||||
|
||||
# Determine target sample rate from first track
|
||||
target_sample_rate = None
|
||||
for url in valid_urls:
|
||||
container = None
|
||||
@@ -551,12 +497,10 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
fmt.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
# Create temp output file
|
||||
output_path = tempfile.mktemp(suffix=".mp3")
|
||||
containers = []
|
||||
|
||||
try:
|
||||
# Open all containers
|
||||
for url in valid_urls:
|
||||
try:
|
||||
c = av.open(
|
||||
@@ -644,7 +588,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Upload mixed file to storage
|
||||
file_size = Path(output_path).stat().st_size
|
||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||
|
||||
@@ -653,7 +596,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
# Update transcript with audio_location
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
@@ -670,10 +612,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return MixdownResult(
|
||||
audio_key=storage_path,
|
||||
duration=duration_ms[0],
|
||||
@@ -689,10 +627,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
||||
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptWaveform,
|
||||
transcripts_controller,
|
||||
@@ -740,18 +674,16 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Generate waveform
|
||||
waveform = get_audio_waveform(
|
||||
path=Path(temp_path), segments_count=WAVEFORM_SEGMENTS
|
||||
)
|
||||
|
||||
# Save waveform to database via event
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform_data
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id, transcript, "WAVEFORM", waveform_data
|
||||
)
|
||||
|
||||
finally:
|
||||
@@ -759,10 +691,6 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
||||
|
||||
logger.info("[Hatchet] generate_waveform complete")
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return WaveformResult(waveform_generated=True)
|
||||
|
||||
|
||||
@@ -775,10 +703,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
ctx.log("detect_topics: analyzing transcript for topics")
|
||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
words = track_data.get("all_words", [])
|
||||
target_language = track_data.get("target_language", "en")
|
||||
@@ -791,7 +715,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||
)
|
||||
|
||||
# Convert word dicts to Word objects
|
||||
word_objects = [Word(**w) for w in words]
|
||||
transcript_type = TranscriptType(words=word_objects)
|
||||
|
||||
@@ -800,7 +723,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that upserts topics to DB
|
||||
async def on_topic_callback(data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
@@ -812,8 +734,8 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript, event="TOPIC", data=topic
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id, transcript, "TOPIC", topic
|
||||
)
|
||||
|
||||
topics = await topic_processing.detect_topics(
|
||||
@@ -828,10 +750,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
ctx.log(f"detect_topics complete: found {len(topics_list)} topics")
|
||||
logger.info("[Hatchet] detect_topics complete", topic_count=len(topics_list))
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "detect_topics", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return TopicsResult(topics=topics_list)
|
||||
|
||||
|
||||
@@ -844,10 +762,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
ctx.log("generate_title: generating title from topics")
|
||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
@@ -864,7 +778,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that updates title in DB
|
||||
async def on_title_callback(data):
|
||||
nonlocal title_result
|
||||
title_result = data.title
|
||||
@@ -874,8 +787,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
transcript,
|
||||
{"title": final_title.title},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript, event="FINAL_TITLE", data=final_title
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id, transcript, "FINAL_TITLE", final_title
|
||||
)
|
||||
|
||||
await topic_processing.generate_title(
|
||||
@@ -888,10 +801,6 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
ctx.log(f"generate_title complete: '{title_result}'")
|
||||
logger.info("[Hatchet] generate_title complete", title=title_result)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return TitleResult(title=title_result)
|
||||
|
||||
|
||||
@@ -904,10 +813,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
ctx.log("generate_summary: generating long and short summaries")
|
||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
topics_data = _to_dict(ctx.task_output(detect_topics))
|
||||
topics = topics_data.get("topics", [])
|
||||
|
||||
@@ -926,7 +831,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
|
||||
# Callback that updates long_summary in DB
|
||||
async def on_long_summary_callback(data):
|
||||
nonlocal summary_result
|
||||
summary_result = data.long_summary
|
||||
@@ -937,13 +841,13 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
transcript,
|
||||
{"long_summary": final_long_summary.long_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"FINAL_LONG_SUMMARY",
|
||||
final_long_summary,
|
||||
)
|
||||
|
||||
# Callback that updates short_summary in DB
|
||||
async def on_short_summary_callback(data):
|
||||
nonlocal short_summary_result
|
||||
short_summary_result = data.short_summary
|
||||
@@ -954,10 +858,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
transcript,
|
||||
{"short_summary": final_short_summary.short_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"FINAL_SHORT_SUMMARY",
|
||||
final_short_summary,
|
||||
)
|
||||
|
||||
await topic_processing.generate_summaries(
|
||||
@@ -972,10 +877,6 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
||||
ctx.log("generate_summary complete")
|
||||
logger.info("[Hatchet] generate_summary complete")
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return SummaryResult(summary=summary_result, short_summary=short_summary_result)
|
||||
|
||||
|
||||
@@ -994,10 +895,6 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
||||
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
|
||||
track_data = _to_dict(ctx.task_output(process_tracks))
|
||||
|
||||
@@ -1006,6 +903,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptDuration,
|
||||
TranscriptText,
|
||||
transcripts_controller,
|
||||
)
|
||||
@@ -1018,17 +916,14 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
if transcript is None:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found in database")
|
||||
|
||||
# Convert words back to Word objects for storage
|
||||
word_objects = [Word(**w) for w in all_words]
|
||||
|
||||
# Create merged transcript for TRANSCRIPT event
|
||||
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
||||
|
||||
# Emit TRANSCRIPT event
|
||||
await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"TRANSCRIPT",
|
||||
TranscriptText(
|
||||
text=merged_transcript.text,
|
||||
translation=merged_transcript.translation,
|
||||
),
|
||||
@@ -1044,18 +939,18 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
},
|
||||
)
|
||||
|
||||
# Set status to "ended"
|
||||
await transcripts_controller.set_status(input.transcript_id, "ended")
|
||||
duration_data = TranscriptDuration(duration=duration)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id, transcript, "DURATION", duration_data
|
||||
)
|
||||
|
||||
await set_status_and_broadcast(input.transcript_id, "ended")
|
||||
|
||||
ctx.log(
|
||||
f"finalize complete: transcript {input.transcript_id} status set to 'ended'"
|
||||
)
|
||||
logger.info("[Hatchet] finalize complete", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "finalize", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return FinalizeResult(status="COMPLETED")
|
||||
|
||||
|
||||
@@ -1067,10 +962,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
"""Check and handle consent requirements."""
|
||||
logger.info("[Hatchet] cleanup_consent", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.meetings import meetings_controller # noqa: PLC0415
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
@@ -1087,10 +978,6 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
"[Hatchet] cleanup_consent complete", transcript_id=input.transcript_id
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "cleanup_consent", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return ConsentResult(consent_checked=True)
|
||||
|
||||
|
||||
@@ -1102,15 +989,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||
"""Post notification to Zulip."""
|
||||
logger.info("[Hatchet] post_zulip", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
if not settings.ZULIP_REALM:
|
||||
logger.info("[Hatchet] post_zulip skipped (Zulip not configured)")
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
@@ -1123,10 +1003,6 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "post_zulip", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return ZulipResult(zulip_message_id=message_id)
|
||||
|
||||
|
||||
@@ -1138,15 +1014,8 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||
"""Send completion webhook to external service."""
|
||||
logger.info("[Hatchet] send_webhook", transcript_id=input.transcript_id)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
if not input.room_id:
|
||||
logger.info("[Hatchet] send_webhook skipped (no room_id)")
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
@@ -1174,17 +1043,6 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||
"[Hatchet] send_webhook complete", status_code=response.status_code
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id,
|
||||
"send_webhook",
|
||||
"completed",
|
||||
ctx.workflow_run_id,
|
||||
)
|
||||
|
||||
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "send_webhook", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
@@ -40,7 +40,6 @@ class RecordingResult(BaseModel):
|
||||
|
||||
id: str | None
|
||||
mtg_session_id: str | None
|
||||
room_name: str | None
|
||||
duration: float
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,13 @@ from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.audio_constants import (
|
||||
OPUS_DEFAULT_BIT_RATE,
|
||||
OPUS_STANDARD_SAMPLE_RATE,
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _to_dict(output) -> dict:
|
||||
@@ -38,12 +42,6 @@ def _to_dict(output) -> dict:
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
# Audio constants matching existing pipeline
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 64000
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200
|
||||
|
||||
|
||||
class TrackInput(BaseModel):
|
||||
"""Input for individual track processing."""
|
||||
|
||||
@@ -193,10 +191,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
@@ -229,9 +223,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
f"Track {input.track_index} requires no padding",
|
||||
track_index=input.track_index,
|
||||
)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
return PadTrackResult(
|
||||
padded_key=input.s3_key,
|
||||
bucket_name=input.bucket_name,
|
||||
@@ -275,10 +266,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
padded_key=storage_path,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
# Return S3 key (not presigned URL) - consumer tasks presign on demand
|
||||
# This avoids stale URLs when workflow is replayed
|
||||
return PadTrackResult(
|
||||
@@ -290,9 +277,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] pad_track failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "pad_track", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -308,10 +292,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
||||
language=input.language,
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "in_progress", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
pad_result = _to_dict(ctx.task_output(pad_track))
|
||||
padded_key = pad_result.get("padded_key")
|
||||
@@ -360,10 +340,6 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
||||
word_count=len(words),
|
||||
)
|
||||
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "completed", ctx.workflow_run_id
|
||||
)
|
||||
|
||||
return TranscribeTrackResult(
|
||||
words=words,
|
||||
track_index=input.track_index,
|
||||
@@ -371,7 +347,4 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Hatchet] transcribe_track failed", error=str(e), exc_info=True)
|
||||
await emit_progress_async(
|
||||
input.transcript_id, "transcribe_track", "failed", ctx.workflow_run_id
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -32,6 +32,11 @@ from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.types import TitleSummary
|
||||
from reflector.processors.types import Transcript as TranscriptType
|
||||
from reflector.storage import Storage, get_transcripts_storage
|
||||
from reflector.utils.audio_constants import (
|
||||
OPUS_DEFAULT_BIT_RATE,
|
||||
OPUS_STANDARD_SAMPLE_RATE,
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
from reflector.utils.daily import (
|
||||
filter_cam_audio_tracks,
|
||||
parse_daily_recording_filename,
|
||||
@@ -39,13 +44,6 @@ from reflector.utils.daily import (
|
||||
from reflector.utils.string import NonEmptyString
|
||||
from reflector.video_platforms.factory import create_platform_client
|
||||
|
||||
# Audio encoding constants
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 128000
|
||||
|
||||
# Storage operation constants
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours
|
||||
|
||||
|
||||
class PipelineMainMultitrack(PipelineMainBase):
|
||||
def __init__(self, transcript_id: str):
|
||||
|
||||
@@ -251,7 +251,6 @@ async def dispatch_transcript_processing(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
"recording_id": config.recording_id,
|
||||
"room_name": None,
|
||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||
"bucket_name": config.bucket_name,
|
||||
"transcript_id": config.transcript_id,
|
||||
|
||||
15
server/reflector/utils/audio_constants.py
Normal file
15
server/reflector/utils/audio_constants.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Shared audio processing constants.
|
||||
|
||||
Used by both Hatchet workflows and Celery pipelines for consistent audio encoding.
|
||||
"""
|
||||
|
||||
# Opus codec settings
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 128000 # 128kbps for good speech quality
|
||||
|
||||
# S3 presigned URL expiration
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours
|
||||
|
||||
# Waveform visualization
|
||||
WAVEFORM_SEGMENTS = 255
|
||||
@@ -303,7 +303,6 @@ async def _process_multitrack_recording_inner(
|
||||
workflow_name="DiarizationPipeline",
|
||||
input_data={
|
||||
"recording_id": recording_id,
|
||||
"room_name": daily_room_name,
|
||||
"tracks": [{"s3_key": k} for k in filter_cam_audio_tracks(track_keys)],
|
||||
"bucket_name": bucket_name,
|
||||
"transcript_id": transcript.id,
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
Tests for Hatchet progress emission.
|
||||
|
||||
Only tests that catch real bugs - error handling and step completeness.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_progress_async_handles_exception():
|
||||
"""Test that emit_progress_async catches exceptions gracefully.
|
||||
|
||||
Critical: Progress emission must NEVER crash the pipeline.
|
||||
WebSocket errors should be silently caught.
|
||||
"""
|
||||
from reflector.hatchet.progress import emit_progress_async
|
||||
|
||||
with patch("reflector.hatchet.progress.get_ws_manager") as mock_get_ws:
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.send_json = AsyncMock(side_effect=Exception("WebSocket error"))
|
||||
mock_get_ws.return_value = mock_ws
|
||||
|
||||
# Should not raise - exceptions are caught
|
||||
await emit_progress_async(
|
||||
transcript_id="test-transcript-123",
|
||||
step="finalize",
|
||||
status="completed",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_steps_mapping_complete():
|
||||
"""Test the PIPELINE_STEPS mapping includes all expected steps.
|
||||
|
||||
Useful: Catches when someone adds a new pipeline step but forgets
|
||||
to add it to the progress mapping, resulting in missing UI updates.
|
||||
"""
|
||||
from reflector.hatchet.progress import PIPELINE_STEPS, TOTAL_STEPS
|
||||
|
||||
expected_steps = [
|
||||
"get_recording",
|
||||
"get_participants",
|
||||
"pad_track",
|
||||
"mixdown_tracks",
|
||||
"generate_waveform",
|
||||
"transcribe_track",
|
||||
"merge_transcripts",
|
||||
"detect_topics",
|
||||
"generate_title",
|
||||
"generate_summary",
|
||||
"finalize",
|
||||
"cleanup_consent",
|
||||
"post_zulip",
|
||||
"send_webhook",
|
||||
]
|
||||
|
||||
for step in expected_steps:
|
||||
assert step in PIPELINE_STEPS, f"Missing step in PIPELINE_STEPS: {step}"
|
||||
assert 1 <= PIPELINE_STEPS[step] <= TOTAL_STEPS
|
||||
Reference in New Issue
Block a user