hatchet no-mistake

This commit is contained in:
Igor Loskutov
2025-12-16 12:09:02 -05:00
parent c5498d26bf
commit 0f266eabdf
7 changed files with 780 additions and 202 deletions

View File

@@ -0,0 +1,28 @@
"""add workflow_run_id to transcript
Revision ID: 0f943fede0e0
Revises: a326252ac554
Create Date: 2025-12-16 01:54:13.855106
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0f943fede0e0"
down_revision: Union[str, None] = "a326252ac554"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.add_column(sa.Column("workflow_run_id", sa.String(), nullable=True))
def downgrade() -> None:
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.drop_column("workflow_run_id")

View File

@@ -83,6 +83,8 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text), sqlalchemy.Column("webvtt", sqlalchemy.Text),
# Hatchet workflow run ID for resumption of failed workflows
sqlalchemy.Column("workflow_run_id", sqlalchemy.String),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"), sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -227,6 +229,7 @@ class Transcript(BaseModel):
zulip_message_id: int | None = None zulip_message_id: int | None = None
audio_deleted: bool | None = None audio_deleted: bool | None = None
webvtt: str | None = None webvtt: str | None = None
workflow_run_id: str | None = None # Hatchet workflow run ID for resumption
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str: def serialize_datetime(self, dt: datetime) -> str:

View File

@@ -2,6 +2,7 @@
from hatchet_sdk import Hatchet from hatchet_sdk import Hatchet
from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
@@ -35,9 +36,44 @@ class HatchetClientManager:
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id # SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id return result.run.metadata.id
@classmethod
async def get_workflow_run_status(cls, workflow_run_id: str) -> str:
"""Get workflow run status."""
client = cls.get_client()
status = await client.runs.aio_get_status(workflow_run_id)
return str(status)
@classmethod
async def cancel_workflow(cls, workflow_run_id: str) -> None:
"""Cancel a workflow."""
client = cls.get_client()
await client.runs.aio_cancel(workflow_run_id)
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
@classmethod
async def replay_workflow(cls, workflow_run_id: str) -> None:
"""Replay a failed workflow."""
client = cls.get_client()
await client.runs.aio_replay(workflow_run_id)
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)
@classmethod
async def can_replay(cls, workflow_run_id: str) -> bool:
"""Check if workflow can be replayed (is FAILED)."""
try:
status = await cls.get_workflow_run_status(workflow_run_id)
return "FAILED" in status
except Exception as e:
logger.warning(
"[Hatchet] Failed to check replay status",
workflow_run_id=workflow_run_id,
error=str(e),
)
return False
@classmethod @classmethod
async def get_workflow_status(cls, workflow_run_id: str) -> dict: async def get_workflow_status(cls, workflow_run_id: str) -> dict:
"""Get the current status of a workflow run.""" """Get the full workflow run details as dict."""
client = cls.get_client() client = cls.get_client()
run = await client.runs.aio_get(workflow_run_id) run = await client.runs.aio_get(workflow_run_id)
return run.to_dict() return run.to_dict()

View File

@@ -71,6 +71,28 @@ async def _close_db_connection(db):
_database_context.set(None) _database_context.set(None)
async def _set_error_status(transcript_id: str):
"""Set transcript status to 'error' on workflow failure (matches Celery line 790)."""
try:
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import transcripts_controller
await transcripts_controller.set_status(transcript_id, "error")
logger.info(
"[Hatchet] Set transcript status to error",
transcript_id=transcript_id,
)
finally:
await _close_db_connection(db)
except Exception as e:
logger.error(
"[Hatchet] Failed to set error status",
transcript_id=transcript_id,
error=str(e),
)
def _get_storage(): def _get_storage():
"""Create fresh storage instance.""" """Create fresh storage instance."""
from reflector.settings import settings from reflector.settings import settings
@@ -98,6 +120,21 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
) )
# Set transcript status to "processing" at workflow start (matches Celery behavior)
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
await transcripts_controller.set_status(input.transcript_id, "processing")
logger.info(
"[Hatchet] Set transcript status to processing",
transcript_id=input.transcript_id,
)
finally:
await _close_db_connection(db)
try: try:
from reflector.dailyco_api.client import DailyApiClient from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings from reflector.settings import settings
@@ -140,6 +177,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
except Exception as e: except Exception as e:
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True) logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
) )
@@ -150,7 +188,10 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3 parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
) )
async def get_participants(input: PipelineInput, ctx: Context) -> dict: async def get_participants(input: PipelineInput, ctx: Context) -> dict:
"""Fetch participant list from Daily.co API.""" """Fetch participant list from Daily.co API and update transcript in database.
Matches Celery's update_participants_from_daily() behavior.
"""
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -163,38 +204,118 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
from reflector.dailyco_api.client import DailyApiClient from reflector.dailyco_api.client import DailyApiClient
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.daily import (
filter_cam_audio_tracks,
parse_daily_recording_filename,
)
# Get transcript and reset events/topics/participants (matches Celery line 599-607)
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import (
TranscriptParticipant,
transcripts_controller,
)
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
# Reset events/topics/participants (matches Celery line 599-607)
# Note: title NOT cleared - Celery preserves existing titles
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
"participants": [],
},
)
if not mtg_session_id or not settings.DAILY_API_KEY: if not mtg_session_id or not settings.DAILY_API_KEY:
# Return empty participants if no session ID
await emit_progress_async( await emit_progress_async(
input.transcript_id, input.transcript_id,
"get_participants", "get_participants",
"completed", "completed",
ctx.workflow_run_id, ctx.workflow_run_id,
) )
return {"participants": [], "num_tracks": len(input.tracks)} return {
"participants": [],
"num_tracks": len(input.tracks),
"source_language": transcript.source_language
if transcript
else "en",
"target_language": transcript.target_language
if transcript
else "en",
}
# Fetch participants from Daily API
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client: async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
participants = await client.get_meeting_participants(mtg_session_id) participants = await client.get_meeting_participants(mtg_session_id)
participants_list = [ id_to_name = {}
{"participant_id": p.participant_id, "user_name": p.user_name} id_to_user_id = {}
for p in participants.data for p in participants.data:
] if p.user_name:
id_to_name[p.participant_id] = p.user_name
if p.user_id:
id_to_user_id[p.participant_id] = p.user_id
# Get track keys and filter for cam-audio tracks
track_keys = [t["s3_key"] for t in input.tracks]
cam_audio_keys = filter_cam_audio_tracks(track_keys)
# Update participants in database (matches Celery lines 568-590)
participants_list = []
for idx, key in enumerate(cam_audio_keys):
try:
parsed = parse_daily_recording_filename(key)
participant_id = parsed.participant_id
except ValueError as e:
logger.error(
"Failed to parse Daily recording filename",
error=str(e),
key=key,
)
continue
default_name = f"Speaker {idx}"
name = id_to_name.get(participant_id, default_name)
user_id = id_to_user_id.get(participant_id)
participant = TranscriptParticipant(
id=participant_id, speaker=idx, name=name, user_id=user_id
)
await transcripts_controller.upsert_participant(transcript, participant)
participants_list.append(
{
"participant_id": participant_id,
"user_name": name,
"speaker": idx,
}
)
logger.info( logger.info(
"[Hatchet] get_participants complete", "[Hatchet] get_participants complete",
participant_count=len(participants_list), participant_count=len(participants_list),
) )
finally:
await _close_db_connection(db)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
) )
return {"participants": participants_list, "num_tracks": len(input.tracks)} return {
"participants": participants_list,
"num_tracks": len(input.tracks),
"source_language": transcript.source_language if transcript else "en",
"target_language": transcript.target_language if transcript else "en",
}
except Exception as e: except Exception as e:
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True) logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
) )
@@ -215,7 +336,12 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
transcript_id=input.transcript_id, transcript_id=input.transcript_id,
) )
# Spawn child workflows for each track try:
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
participants_data = ctx.task_output(get_participants)
source_language = participants_data.get("source_language", "en")
# Spawn child workflows for each track with correct language
child_coroutines = [ child_coroutines = [
track_workflow.aio_run( track_workflow.aio_run(
TrackInput( TrackInput(
@@ -223,6 +349,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
s3_key=track["s3_key"], s3_key=track["s3_key"],
bucket_name=input.bucket_name, bucket_name=input.bucket_name,
transcript_id=input.transcript_id, transcript_id=input.transcript_id,
language=source_language,
) )
) )
for i, track in enumerate(input.tracks) for i, track in enumerate(input.tracks)
@@ -231,9 +358,13 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
# Wait for all child workflows to complete # Wait for all child workflows to complete
results = await asyncio.gather(*child_coroutines) results = await asyncio.gather(*child_coroutines)
# Get target_language for later use in detect_topics
target_language = participants_data.get("target_language", "en")
# Collect all track results # Collect all track results
all_words = [] all_words = []
padded_urls = [] padded_urls = []
created_padded_files = set()
for result in results: for result in results:
transcribe_result = result.get("transcribe_track", {}) transcribe_result = result.get("transcribe_track", {})
@@ -242,9 +373,19 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
pad_result = result.get("pad_track", {}) pad_result = result.get("pad_track", {})
padded_urls.append(pad_result.get("padded_url")) padded_urls.append(pad_result.get("padded_url"))
# Track padded files for cleanup (matches Celery line 636-637)
track_index = pad_result.get("track_index")
if pad_result.get("size", 0) > 0 and track_index is not None:
# File was created (size > 0 means padding was applied)
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
created_padded_files.add(storage_path)
# Sort words by start time # Sort words by start time
all_words.sort(key=lambda w: w.get("start", 0)) all_words.sort(key=lambda w: w.get("start", 0))
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
# Mixdown needs the padded files, so we can't delete them here
logger.info( logger.info(
"[Hatchet] process_tracks complete", "[Hatchet] process_tracks complete",
num_tracks=len(input.tracks), num_tracks=len(input.tracks),
@@ -256,14 +397,26 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
"padded_urls": padded_urls, "padded_urls": padded_urls,
"word_count": len(all_words), "word_count": len(all_words),
"num_tracks": len(input.tracks), "num_tracks": len(input.tracks),
"target_language": target_language,
"created_padded_files": list(
created_padded_files
), # For cleanup after mixdown
} }
except Exception as e:
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async(
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
)
raise
@diarization_pipeline.task( @diarization_pipeline.task(
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3 parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
) )
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict: async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
"""Mix all padded tracks into single audio file.""" """Mix all padded tracks into single audio file using PyAV (same as Celery)."""
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -279,93 +432,198 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
storage = _get_storage() storage = _get_storage()
# Download all tracks and mix # Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph
temp_inputs = [] from fractions import Fraction
from av.audio.resampler import AudioResampler
from reflector.processors import AudioFileWriterProcessor
valid_urls = [url for url in padded_urls if url]
if not valid_urls:
raise ValueError("No valid padded tracks to mixdown")
# Determine target sample rate from first track
target_sample_rate = None
for url in valid_urls:
try: try:
for i, url in enumerate(padded_urls): container = av.open(url)
if not url: for frame in container.decode(audio=0):
continue target_sample_rate = frame.sample_rate
temp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False) break
temp_inputs.append(temp_input.name)
# Download track
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
with open(temp_input.name, "wb") as f:
f.write(response.content)
# Mix using PyAV amix filter
if len(temp_inputs) == 0:
raise ValueError("No valid tracks to mixdown")
output_path = tempfile.mktemp(suffix=".mp3")
try:
# Use ffmpeg-style mixing via PyAV
containers = [av.open(path) for path in temp_inputs]
# Get the longest duration
max_duration = 0.0
for container in containers:
if container.duration:
duration = float(container.duration * av.time_base)
max_duration = max(max_duration, duration)
# Close containers for now
for container in containers:
container.close() container.close()
if target_sample_rate:
break
except Exception:
continue
# Use subprocess for mixing (simpler than complex PyAV graph) if not target_sample_rate:
import subprocess raise ValueError("No decodable audio frames in any track")
# Build ffmpeg command # Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
cmd = ["ffmpeg", "-y"] graph = av.filter.Graph()
for path in temp_inputs: inputs = []
cmd.extend(["-i", path])
# Build filter for N inputs for idx, url in enumerate(valid_urls):
n = len(temp_inputs) args = (
filter_str = f"amix=inputs={n}:duration=longest:normalize=0" f"time_base=1/{target_sample_rate}:"
cmd.extend(["-filter_complex", filter_str]) f"sample_rate={target_sample_rate}:"
cmd.extend(["-ac", "2", "-ar", "48000", "-b:a", "128k", output_path]) f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
subprocess.run(cmd, check=True, capture_output=True) mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
# Upload mixed file for idx, in_ctx in enumerate(inputs):
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# Create temp output file
output_path = tempfile.mktemp(suffix=".mp3")
containers = []
try:
# Open all containers
for url in valid_urls:
try:
c = av.open(
url,
options={
"reconnect": "1",
"reconnect_streamed": "1",
"reconnect_delay_max": "5",
},
)
containers.append(c)
except Exception as e:
logger.warning(
"[Hatchet] mixdown: failed to open container",
url=url,
error=str(e),
)
if not containers:
raise ValueError("Could not open any track containers")
# Create AudioFileWriterProcessor for MP3 output with duration capture
duration_ms = [0.0] # Mutable container for callback capture
async def capture_duration(d):
duration_ms[0] = d
writer = AudioFileWriterProcessor(
path=output_path, on_duration=capture_duration
)
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
# Flush remaining frames
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
await writer.flush()
# Duration is captured via callback in milliseconds (from AudioFileWriterProcessor)
finally:
for c in containers:
try:
c.close()
except Exception:
pass
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3)
file_size = Path(output_path).stat().st_size file_size = Path(output_path).stat().st_size
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/mixed.mp3" storage_path = f"{input.transcript_id}/audio.mp3"
with open(output_path, "rb") as mixed_file: with open(output_path, "rb") as mixed_file:
await storage.put_file(storage_path, mixed_file) await storage.put_file(storage_path, mixed_file)
Path(output_path).unlink(missing_ok=True)
# Update transcript with audio_location (matches Celery line 661)
db = await _get_fresh_db_connection()
try:
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
await transcripts_controller.update(
transcript, {"audio_location": "storage"}
)
finally:
await _close_db_connection(db)
logger.info( logger.info(
"[Hatchet] mixdown_tracks uploaded", "[Hatchet] mixdown_tracks uploaded",
key=storage_path, key=storage_path,
size=file_size, size=file_size,
) )
finally:
Path(output_path).unlink(missing_ok=True)
finally:
for path in temp_inputs:
Path(path).unlink(missing_ok=True)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
) )
return { return {
"audio_key": storage_path, "audio_key": storage_path,
"duration": max_duration, "duration": duration_ms[
"tracks_mixed": len(temp_inputs), 0
], # Duration in milliseconds from AudioFileWriterProcessor
"tracks_mixed": len(valid_urls),
} }
except Exception as e: except Exception as e:
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True) logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
) )
@@ -376,7 +634,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3 parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
) )
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict: async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
"""Generate audio waveform visualization.""" """Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -384,6 +642,35 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
) )
try: try:
import httpx
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
from reflector.utils.audio_waveform import get_audio_waveform
# Cleanup temporary padded S3 files (matches Celery lines 710-725)
# Moved here from process_tracks because mixdown_tracks needs the padded files
track_data = ctx.task_output(process_tracks)
created_padded_files = track_data.get("created_padded_files", [])
if created_padded_files:
logger.info(
f"[Hatchet] Cleaning up {len(created_padded_files)} temporary S3 files"
)
storage = _get_storage()
cleanup_tasks = []
for storage_path in created_padded_files:
cleanup_tasks.append(storage.delete_file(storage_path))
cleanup_results = await asyncio.gather(
*cleanup_tasks, return_exceptions=True
)
for storage_path, result in zip(created_padded_files, cleanup_results):
if isinstance(result, Exception):
logger.warning(
"[Hatchet] Failed to cleanup temporary padded track",
storage_path=storage_path,
error=str(result),
)
mixdown_data = ctx.task_output(mixdown_tracks) mixdown_data = ctx.task_output(mixdown_tracks)
audio_key = mixdown_data.get("audio_key") audio_key = mixdown_data.get("audio_key")
@@ -394,18 +681,34 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS, expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
) )
from reflector.pipelines.waveform_helpers import generate_waveform_data # Download MP3 to temp file (AudioWaveformProcessor needs local file)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
temp_path = temp_file.name
waveform = await generate_waveform_data(audio_url) try:
async with httpx.AsyncClient() as client:
response = await client.get(audio_url, timeout=120)
response.raise_for_status()
with open(temp_path, "wb") as f:
f.write(response.content)
# Store waveform # Generate waveform (matches Celery: get_audio_waveform with 255 segments)
waveform_key = f"file_pipeline_hatchet/{input.transcript_id}/waveform.json" waveform = get_audio_waveform(path=Path(temp_path), segments_count=255)
import json
waveform_bytes = json.dumps(waveform).encode() # Save waveform to database via event (matches Celery on_waveform callback)
import io db = await _get_fresh_db_connection()
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if transcript:
waveform_data = TranscriptWaveform(waveform=waveform)
await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform_data
)
finally:
await _close_db_connection(db)
await storage.put_file(waveform_key, io.BytesIO(waveform_bytes)) finally:
Path(temp_path).unlink(missing_ok=True)
logger.info("[Hatchet] generate_waveform complete") logger.info("[Hatchet] generate_waveform complete")
@@ -413,10 +716,11 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
) )
return {"waveform_key": waveform_key} return {"waveform_generated": True}
except Exception as e: except Exception as e:
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True) logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
) )
@@ -427,7 +731,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3 parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
) )
async def detect_topics(input: PipelineInput, ctx: Context) -> dict: async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
"""Detect topics using LLM.""" """Detect topics using LLM and save to database (matches Celery on_topic callback)."""
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -437,26 +741,52 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
try: try:
track_data = ctx.task_output(process_tracks) track_data = ctx.task_output(process_tracks)
words = track_data.get("all_words", []) words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en")
from reflector.db.transcripts import TranscriptTopic, transcripts_controller
from reflector.pipelines import topic_processing from reflector.pipelines import topic_processing
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
)
from reflector.processors.types import Transcript as TranscriptType from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word from reflector.processors.types import Word
# Convert word dicts to Word objects # Convert word dicts to Word objects
word_objects = [Word(**w) for w in words] word_objects = [Word(**w) for w in words]
transcript = TranscriptType(words=word_objects) transcript_type = TranscriptType(words=word_objects)
empty_pipeline = topic_processing.EmptyPipeline(logger=logger) empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
async def noop_callback(t): # Get DB connection for callbacks
pass db = await _get_fresh_db_connection()
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that upserts topics to DB (matches Celery on_topic)
async def on_topic_callback(data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
transcript=data.transcript.text,
words=data.transcript.words,
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
await transcripts_controller.upsert_topic(transcript, topic)
await transcripts_controller.append_event(
transcript=transcript, event="TOPIC", data=topic
)
topics = await topic_processing.detect_topics( topics = await topic_processing.detect_topics(
transcript, transcript_type,
"en", # target_language target_language,
on_topic_callback=noop_callback, on_topic_callback=on_topic_callback,
empty_pipeline=empty_pipeline, empty_pipeline=empty_pipeline,
) )
finally:
await _close_db_connection(db)
topics_list = [t.model_dump() for t in topics] topics_list = [t.model_dump() for t in topics]
@@ -470,6 +800,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
except Exception as e: except Exception as e:
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True) logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
) )
@@ -480,7 +811,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3 parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
) )
async def generate_title(input: PipelineInput, ctx: Context) -> dict: async def generate_title(input: PipelineInput, ctx: Context) -> dict:
"""Generate meeting title using LLM.""" """Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -491,23 +822,56 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
topics_data = ctx.task_output(detect_topics) topics_data = ctx.task_output(detect_topics)
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.db.transcripts import (
TranscriptFinalTitle,
transcripts_controller,
)
from reflector.pipelines import topic_processing from reflector.pipelines import topic_processing
from reflector.processors.types import Topic from reflector.processors.types import TitleSummary
topic_objects = [Topic(**t) for t in topics] topic_objects = [TitleSummary(**t) for t in topics]
title = await topic_processing.generate_title(topic_objects) empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
title_result = None
logger.info("[Hatchet] generate_title complete", title=title) db = await _get_fresh_db_connection()
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates title in DB (matches Celery on_title)
async def on_title_callback(data):
nonlocal title_result
title_result = data.title
final_title = TranscriptFinalTitle(title=data.title)
if not transcript.title:
await transcripts_controller.update(
transcript,
{"title": final_title.title},
)
await transcripts_controller.append_event(
transcript=transcript, event="FINAL_TITLE", data=final_title
)
await topic_processing.generate_title(
topic_objects,
on_title_callback=on_title_callback,
empty_pipeline=empty_pipeline,
logger=logger,
)
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_title complete", title=title_result)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
) )
return {"title": title} return {"title": title_result}
except Exception as e: except Exception as e:
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True) logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
) )
@@ -518,7 +882,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3 parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
) )
async def generate_summary(input: PipelineInput, ctx: Context) -> dict: async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
"""Generate meeting summary using LLM.""" """Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -526,23 +890,71 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
) )
try: try:
track_data = ctx.task_output(process_tracks)
topics_data = ctx.task_output(detect_topics) topics_data = ctx.task_output(detect_topics)
words = track_data.get("all_words", [])
topics = topics_data.get("topics", []) topics = topics_data.get("topics", [])
from reflector.pipelines import topic_processing from reflector.db.transcripts import (
from reflector.processors.types import Topic, Word TranscriptFinalLongSummary,
from reflector.processors.types import Transcript as TranscriptType TranscriptFinalShortSummary,
transcripts_controller,
word_objects = [Word(**w) for w in words]
transcript = TranscriptType(words=word_objects)
topic_objects = [Topic(**t) for t in topics]
summary, short_summary = await topic_processing.generate_summary(
transcript, topic_objects
) )
from reflector.pipelines import topic_processing
from reflector.processors.types import TitleSummary
topic_objects = [TitleSummary(**t) for t in topics]
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
summary_result = None
short_summary_result = None
db = await _get_fresh_db_connection()
try:
transcript = await transcripts_controller.get_by_id(input.transcript_id)
# Callback that updates long_summary in DB (matches Celery on_long_summary)
async def on_long_summary_callback(data):
nonlocal summary_result
summary_result = data.long_summary
final_long_summary = TranscriptFinalLongSummary(
long_summary=data.long_summary
)
await transcripts_controller.update(
transcript,
{"long_summary": final_long_summary.long_summary},
)
await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
# Callback that updates short_summary in DB (matches Celery on_short_summary)
async def on_short_summary_callback(data):
nonlocal short_summary_result
short_summary_result = data.short_summary
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
await transcripts_controller.update(
transcript,
{"short_summary": final_short_summary.short_summary},
)
await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
await topic_processing.generate_summaries(
topic_objects,
transcript, # DB transcript for context
on_long_summary_callback=on_long_summary_callback,
on_short_summary_callback=on_short_summary_callback,
empty_pipeline=empty_pipeline,
logger=logger,
)
finally:
await _close_db_connection(db)
logger.info("[Hatchet] generate_summary complete") logger.info("[Hatchet] generate_summary complete")
@@ -550,10 +962,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
) )
return {"summary": summary, "short_summary": short_summary} return {"summary": summary_result, "short_summary": short_summary_result}
except Exception as e: except Exception as e:
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True) logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
) )
@@ -566,7 +979,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
retries=3, retries=3,
) )
async def finalize(input: PipelineInput, ctx: Context) -> dict: async def finalize(input: PipelineInput, ctx: Context) -> dict:
"""Finalize transcript status and update database.""" """Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
Matches Celery's on_transcript + set_status behavior.
Note: Title and summaries are already saved by their respective task callbacks.
"""
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id) logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
await emit_progress_async( await emit_progress_async(
@@ -574,21 +991,17 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
) )
try: try:
title_data = ctx.task_output(generate_title)
summary_data = ctx.task_output(generate_summary)
mixdown_data = ctx.task_output(mixdown_tracks) mixdown_data = ctx.task_output(mixdown_tracks)
track_data = ctx.task_output(process_tracks) track_data = ctx.task_output(process_tracks)
title = title_data.get("title", "")
summary = summary_data.get("summary", "")
short_summary = summary_data.get("short_summary", "")
duration = mixdown_data.get("duration", 0) duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", []) all_words = track_data.get("all_words", [])
db = await _get_fresh_db_connection() db = await _get_fresh_db_connection()
try: try:
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import TranscriptText, transcripts_controller
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import Word from reflector.processors.types import Word
transcript = await transcripts_controller.get_by_id(input.transcript_id) transcript = await transcripts_controller.get_by_id(input.transcript_id)
@@ -600,18 +1013,32 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
# Convert words back to Word objects for storage # Convert words back to Word objects for storage
word_objects = [Word(**w) for w in all_words] word_objects = [Word(**w) for w in all_words]
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736)
merged_transcript = TranscriptType(words=word_objects, translation=None)
# Emit TRANSCRIPT event (matches Celery on_transcript callback)
await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(
text=merged_transcript.text,
translation=merged_transcript.translation,
),
)
# Save duration and clear workflow_run_id (workflow completed successfully)
# Note: title/long_summary/short_summary already saved by their callbacks
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"status": "ended",
"title": title,
"long_summary": summary,
"short_summary": short_summary,
"duration": duration, "duration": duration,
"words": word_objects, "workflow_run_id": None, # Clear on success - no need to resume
}, },
) )
# Set status to "ended" (matches Celery line 745)
await transcripts_controller.set_status(input.transcript_id, "ended")
logger.info( logger.info(
"[Hatchet] finalize complete", transcript_id=input.transcript_id "[Hatchet] finalize complete", transcript_id=input.transcript_id
) )
@@ -627,6 +1054,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
except Exception as e: except Exception as e:
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True) logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
await _set_error_status(input.transcript_id)
await emit_progress_async( await emit_progress_async(
input.transcript_id, "finalize", "failed", ctx.workflow_run_id input.transcript_id, "finalize", "failed", ctx.workflow_run_id
) )

View File

@@ -166,6 +166,7 @@ class SummaryBuilder:
self.model_name: str = llm.model_name self.model_name: str = llm.model_name
self.logger = logger or structlog.get_logger() self.logger = logger or structlog.get_logger()
self.participant_instructions: str | None = None self.participant_instructions: str | None = None
self._logged_participant_instructions: bool = False
if filename: if filename:
self.read_transcript_from_file(filename) self.read_transcript_from_file(filename)
@@ -208,7 +209,9 @@ class SummaryBuilder:
def _enhance_prompt_with_participants(self, prompt: str) -> str: def _enhance_prompt_with_participants(self, prompt: str) -> str:
"""Add participant instructions to any prompt if participants are known.""" """Add participant instructions to any prompt if participants are known."""
if self.participant_instructions: if self.participant_instructions:
self.logger.debug("Adding participant instructions to prompt") if not self._logged_participant_instructions:
self.logger.debug("Adding participant instructions to prompts")
self._logged_participant_instructions = True
return f"{prompt}\n\n{self.participant_instructions}" return f"{prompt}\n\n{self.participant_instructions}"
return prompt return prompt

View File

@@ -102,6 +102,7 @@ async def validate_transcript_for_processing(
if transcript.status == "idle": if transcript.status == "idle":
return ValidationNotReady(detail="Recording is not ready for processing") return ValidationNotReady(detail="Recording is not ready for processing")
# Check Celery tasks
if task_is_scheduled_or_active( if task_is_scheduled_or_active(
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process", "reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
transcript_id=transcript.id, transcript_id=transcript.id,
@@ -111,6 +112,23 @@ async def validate_transcript_for_processing(
): ):
return ValidationAlreadyScheduled(detail="already running") return ValidationAlreadyScheduled(detail="already running")
# Check Hatchet workflows (if enabled)
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
from reflector.hatchet.client import HatchetClientManager
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
)
# If workflow is running or queued, don't allow new processing
if "RUNNING" in status or "QUEUED" in status:
return ValidationAlreadyScheduled(
detail="Hatchet workflow already running"
)
except Exception:
# If we can't get status, allow processing (workflow might be gone)
pass
return ValidationOk( return ValidationOk(
recording_id=transcript.recording_id, transcript_id=transcript.id recording_id=transcript.recording_id, transcript_id=transcript.id
) )
@@ -155,7 +173,9 @@ async def prepare_transcript_processing(
) )
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None: def dispatch_transcript_processing(
config: ProcessingConfig, force: bool = False
) -> AsyncResult | None:
if isinstance(config, MultitrackProcessingConfig): if isinstance(config, MultitrackProcessingConfig):
# Start durable workflow if enabled (Hatchet or Conductor) # Start durable workflow if enabled (Hatchet or Conductor)
durable_started = False durable_started = False
@@ -163,12 +183,53 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
if settings.HATCHET_ENABLED: if settings.HATCHET_ENABLED:
import asyncio import asyncio
async def _start_hatchet(): import databases
return await HatchetClientManager.start_workflow(
from reflector.db import _database_context
from reflector.db.transcripts import transcripts_controller
async def _handle_hatchet():
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
await db.connect()
try:
transcript = await transcripts_controller.get_by_id(
config.transcript_id
)
if transcript and transcript.workflow_run_id and not force:
can_replay = await HatchetClientManager.can_replay(
transcript.workflow_run_id
)
if can_replay:
await HatchetClientManager.replay_workflow(
transcript.workflow_run_id
)
logger.info(
"Replaying Hatchet workflow",
workflow_id=transcript.workflow_run_id,
)
return transcript.workflow_run_id
# Force: cancel old workflow if exists
if force and transcript and transcript.workflow_run_id:
await HatchetClientManager.cancel_workflow(
transcript.workflow_run_id
)
logger.info(
"Cancelled old workflow (--force)",
workflow_id=transcript.workflow_run_id,
)
await transcripts_controller.update(
transcript, {"workflow_run_id": None}
)
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline", workflow_name="DiarizationPipeline",
input_data={ input_data={
"recording_id": config.recording_id, "recording_id": config.recording_id,
"room_name": None, # Not available in reprocess path "room_name": None,
"tracks": [{"s3_key": k} for k in config.track_keys], "tracks": [{"s3_key": k} for k in config.track_keys],
"bucket_name": config.bucket_name, "bucket_name": config.bucket_name,
"transcript_id": config.transcript_id, "transcript_id": config.transcript_id,
@@ -176,25 +237,30 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
}, },
) )
if transcript:
await transcripts_controller.update(
transcript, {"workflow_run_id": workflow_id}
)
return workflow_id
finally:
await db.disconnect()
_database_context.set(None)
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
loop = None loop = None
if loop and loop.is_running(): if loop and loop.is_running():
# Already in async context
import concurrent.futures import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result() workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result()
else: else:
workflow_id = asyncio.run(_start_hatchet()) workflow_id = asyncio.run(_handle_hatchet())
logger.info( logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
"Started Hatchet workflow (reprocess)",
workflow_id=workflow_id,
transcript_id=config.transcript_id,
)
durable_started = True durable_started = True
elif settings.CONDUCTOR_ENABLED: elif settings.CONDUCTOR_ENABLED:

View File

@@ -34,21 +34,25 @@ async def process_transcript_inner(
transcript: Transcript, transcript: Transcript,
on_validation: Callable[[ValidationResult], None], on_validation: Callable[[ValidationResult], None],
on_preprocess: Callable[[PrepareResult], None], on_preprocess: Callable[[PrepareResult], None],
force: bool = False,
) -> AsyncResult: ) -> AsyncResult:
validation = await validate_transcript_for_processing(transcript) validation = await validate_transcript_for_processing(transcript)
on_validation(validation) on_validation(validation)
config = await prepare_transcript_processing(validation, room_id=transcript.room_id) config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
on_preprocess(config) on_preprocess(config)
return dispatch_transcript_processing(config) return dispatch_transcript_processing(config, force=force)
async def process_transcript(transcript_id: str, sync: bool = False) -> None: async def process_transcript(
transcript_id: str, sync: bool = False, force: bool = False
) -> None:
""" """
Process a transcript by ID, auto-detecting multitrack vs file pipeline. Process a transcript by ID, auto-detecting multitrack vs file pipeline.
Args: Args:
transcript_id: The transcript UUID transcript_id: The transcript UUID
sync: If True, wait for task completion. If False, dispatch and exit. sync: If True, wait for task completion. If False, dispatch and exit.
force: If True, cancel old workflow and start new (latest code). If False, replay failed workflow.
""" """
from reflector.db import get_database from reflector.db import get_database
@@ -82,7 +86,10 @@ async def process_transcript(transcript_id: str, sync: bool = False) -> None:
print(f"Dispatching file pipeline", file=sys.stderr) print(f"Dispatching file pipeline", file=sys.stderr)
result = await process_transcript_inner( result = await process_transcript_inner(
transcript, on_validation=on_validation, on_preprocess=on_preprocess transcript,
on_validation=on_validation,
on_preprocess=on_preprocess,
force=force,
) )
if sync: if sync:
@@ -118,9 +125,16 @@ def main():
action="store_true", action="store_true",
help="Wait for task completion instead of just dispatching", help="Wait for task completion instead of just dispatching",
) )
parser.add_argument(
"--force",
action="store_true",
help="Cancel old workflow and start new (uses latest code instead of replaying)",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(process_transcript(args.transcript_id, sync=args.sync)) asyncio.run(
process_transcript(args.transcript_id, sync=args.sync, force=args.force)
)
if __name__ == "__main__": if __name__ == "__main__":