mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
hatchet no-mistake
This commit is contained in:
@@ -0,0 +1,28 @@
|
|||||||
|
"""add workflow_run_id to transcript
|
||||||
|
|
||||||
|
Revision ID: 0f943fede0e0
|
||||||
|
Revises: a326252ac554
|
||||||
|
Create Date: 2025-12-16 01:54:13.855106
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0f943fede0e0"
|
||||||
|
down_revision: Union[str, None] = "a326252ac554"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column("workflow_run_id", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.drop_column("workflow_run_id")
|
||||||
@@ -83,6 +83,8 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||||
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||||
|
# Hatchet workflow run ID for resumption of failed workflows
|
||||||
|
sqlalchemy.Column("workflow_run_id", sqlalchemy.String),
|
||||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||||
@@ -227,6 +229,7 @@ class Transcript(BaseModel):
|
|||||||
zulip_message_id: int | None = None
|
zulip_message_id: int | None = None
|
||||||
audio_deleted: bool | None = None
|
audio_deleted: bool | None = None
|
||||||
webvtt: str | None = None
|
webvtt: str | None = None
|
||||||
|
workflow_run_id: str | None = None # Hatchet workflow run ID for resumption
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def serialize_datetime(self, dt: datetime) -> str:
|
def serialize_datetime(self, dt: datetime) -> str:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from hatchet_sdk import Hatchet
|
from hatchet_sdk import Hatchet
|
||||||
|
|
||||||
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -35,9 +36,44 @@ class HatchetClientManager:
|
|||||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||||
return result.run.metadata.id
|
return result.run.metadata.id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_workflow_run_status(cls, workflow_run_id: str) -> str:
|
||||||
|
"""Get workflow run status."""
|
||||||
|
client = cls.get_client()
|
||||||
|
status = await client.runs.aio_get_status(workflow_run_id)
|
||||||
|
return str(status)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def cancel_workflow(cls, workflow_run_id: str) -> None:
|
||||||
|
"""Cancel a workflow."""
|
||||||
|
client = cls.get_client()
|
||||||
|
await client.runs.aio_cancel(workflow_run_id)
|
||||||
|
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def replay_workflow(cls, workflow_run_id: str) -> None:
|
||||||
|
"""Replay a failed workflow."""
|
||||||
|
client = cls.get_client()
|
||||||
|
await client.runs.aio_replay(workflow_run_id)
|
||||||
|
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def can_replay(cls, workflow_run_id: str) -> bool:
|
||||||
|
"""Check if workflow can be replayed (is FAILED)."""
|
||||||
|
try:
|
||||||
|
status = await cls.get_workflow_run_status(workflow_run_id)
|
||||||
|
return "FAILED" in status
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet] Failed to check replay status",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
|
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
|
||||||
"""Get the current status of a workflow run."""
|
"""Get the full workflow run details as dict."""
|
||||||
client = cls.get_client()
|
client = cls.get_client()
|
||||||
run = await client.runs.aio_get(workflow_run_id)
|
run = await client.runs.aio_get(workflow_run_id)
|
||||||
return run.to_dict()
|
return run.to_dict()
|
||||||
|
|||||||
@@ -71,6 +71,28 @@ async def _close_db_connection(db):
|
|||||||
_database_context.set(None)
|
_database_context.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_error_status(transcript_id: str):
|
||||||
|
"""Set transcript status to 'error' on workflow failure (matches Celery line 790)."""
|
||||||
|
try:
|
||||||
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
await transcripts_controller.set_status(transcript_id, "error")
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] Set transcript status to error",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] Failed to set error status",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_storage():
|
def _get_storage():
|
||||||
"""Create fresh storage instance."""
|
"""Create fresh storage instance."""
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -98,6 +120,21 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
input.transcript_id, "get_recording", "in_progress", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set transcript status to "processing" at workflow start (matches Celery behavior)
|
||||||
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.set_status(input.transcript_id, "processing")
|
||||||
|
logger.info(
|
||||||
|
"[Hatchet] Set transcript status to processing",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -140,6 +177,7 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] get_recording failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
|
input.transcript_id, "get_recording", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -150,7 +188,10 @@ async def get_recording(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
parents=[get_recording], execution_timeout=timedelta(seconds=60), retries=3
|
||||||
)
|
)
|
||||||
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Fetch participant list from Daily.co API."""
|
"""Fetch participant list from Daily.co API and update transcript in database.
|
||||||
|
|
||||||
|
Matches Celery's update_participants_from_daily() behavior.
|
||||||
|
"""
|
||||||
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -163,38 +204,118 @@ async def get_participants(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.daily import (
|
||||||
|
filter_cam_audio_tracks,
|
||||||
|
parse_daily_recording_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get transcript and reset events/topics/participants (matches Celery line 599-607)
|
||||||
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
TranscriptParticipant,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
# Reset events/topics/participants (matches Celery line 599-607)
|
||||||
|
# Note: title NOT cleared - Celery preserves existing titles
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"events": [],
|
||||||
|
"topics": [],
|
||||||
|
"participants": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if not mtg_session_id or not settings.DAILY_API_KEY:
|
if not mtg_session_id or not settings.DAILY_API_KEY:
|
||||||
# Return empty participants if no session ID
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id,
|
input.transcript_id,
|
||||||
"get_participants",
|
"get_participants",
|
||||||
"completed",
|
"completed",
|
||||||
ctx.workflow_run_id,
|
ctx.workflow_run_id,
|
||||||
)
|
)
|
||||||
return {"participants": [], "num_tracks": len(input.tracks)}
|
return {
|
||||||
|
"participants": [],
|
||||||
|
"num_tracks": len(input.tracks),
|
||||||
|
"source_language": transcript.source_language
|
||||||
|
if transcript
|
||||||
|
else "en",
|
||||||
|
"target_language": transcript.target_language
|
||||||
|
if transcript
|
||||||
|
else "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Fetch participants from Daily API
|
||||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||||
participants = await client.get_meeting_participants(mtg_session_id)
|
participants = await client.get_meeting_participants(mtg_session_id)
|
||||||
|
|
||||||
participants_list = [
|
id_to_name = {}
|
||||||
{"participant_id": p.participant_id, "user_name": p.user_name}
|
id_to_user_id = {}
|
||||||
for p in participants.data
|
for p in participants.data:
|
||||||
]
|
if p.user_name:
|
||||||
|
id_to_name[p.participant_id] = p.user_name
|
||||||
|
if p.user_id:
|
||||||
|
id_to_user_id[p.participant_id] = p.user_id
|
||||||
|
|
||||||
|
# Get track keys and filter for cam-audio tracks
|
||||||
|
track_keys = [t["s3_key"] for t in input.tracks]
|
||||||
|
cam_audio_keys = filter_cam_audio_tracks(track_keys)
|
||||||
|
|
||||||
|
# Update participants in database (matches Celery lines 568-590)
|
||||||
|
participants_list = []
|
||||||
|
for idx, key in enumerate(cam_audio_keys):
|
||||||
|
try:
|
||||||
|
parsed = parse_daily_recording_filename(key)
|
||||||
|
participant_id = parsed.participant_id
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to parse Daily recording filename",
|
||||||
|
error=str(e),
|
||||||
|
key=key,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
default_name = f"Speaker {idx}"
|
||||||
|
name = id_to_name.get(participant_id, default_name)
|
||||||
|
user_id = id_to_user_id.get(participant_id)
|
||||||
|
|
||||||
|
participant = TranscriptParticipant(
|
||||||
|
id=participant_id, speaker=idx, name=name, user_id=user_id
|
||||||
|
)
|
||||||
|
await transcripts_controller.upsert_participant(transcript, participant)
|
||||||
|
participants_list.append(
|
||||||
|
{
|
||||||
|
"participant_id": participant_id,
|
||||||
|
"user_name": name,
|
||||||
|
"speaker": idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] get_participants complete",
|
"[Hatchet] get_participants complete",
|
||||||
participant_count=len(participants_list),
|
participant_count=len(participants_list),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
input.transcript_id, "get_participants", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"participants": participants_list, "num_tracks": len(input.tracks)}
|
return {
|
||||||
|
"participants": participants_list,
|
||||||
|
"num_tracks": len(input.tracks),
|
||||||
|
"source_language": transcript.source_language if transcript else "en",
|
||||||
|
"target_language": transcript.target_language if transcript else "en",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] get_participants failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
|
input.transcript_id, "get_participants", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -215,7 +336,12 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Spawn child workflows for each track
|
try:
|
||||||
|
# Get source_language from get_participants (matches Celery: uses transcript.source_language)
|
||||||
|
participants_data = ctx.task_output(get_participants)
|
||||||
|
source_language = participants_data.get("source_language", "en")
|
||||||
|
|
||||||
|
# Spawn child workflows for each track with correct language
|
||||||
child_coroutines = [
|
child_coroutines = [
|
||||||
track_workflow.aio_run(
|
track_workflow.aio_run(
|
||||||
TrackInput(
|
TrackInput(
|
||||||
@@ -223,6 +349,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
s3_key=track["s3_key"],
|
s3_key=track["s3_key"],
|
||||||
bucket_name=input.bucket_name,
|
bucket_name=input.bucket_name,
|
||||||
transcript_id=input.transcript_id,
|
transcript_id=input.transcript_id,
|
||||||
|
language=source_language,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for i, track in enumerate(input.tracks)
|
for i, track in enumerate(input.tracks)
|
||||||
@@ -231,9 +358,13 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
# Wait for all child workflows to complete
|
# Wait for all child workflows to complete
|
||||||
results = await asyncio.gather(*child_coroutines)
|
results = await asyncio.gather(*child_coroutines)
|
||||||
|
|
||||||
|
# Get target_language for later use in detect_topics
|
||||||
|
target_language = participants_data.get("target_language", "en")
|
||||||
|
|
||||||
# Collect all track results
|
# Collect all track results
|
||||||
all_words = []
|
all_words = []
|
||||||
padded_urls = []
|
padded_urls = []
|
||||||
|
created_padded_files = set()
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
transcribe_result = result.get("transcribe_track", {})
|
transcribe_result = result.get("transcribe_track", {})
|
||||||
@@ -242,9 +373,19 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
pad_result = result.get("pad_track", {})
|
pad_result = result.get("pad_track", {})
|
||||||
padded_urls.append(pad_result.get("padded_url"))
|
padded_urls.append(pad_result.get("padded_url"))
|
||||||
|
|
||||||
|
# Track padded files for cleanup (matches Celery line 636-637)
|
||||||
|
track_index = pad_result.get("track_index")
|
||||||
|
if pad_result.get("size", 0) > 0 and track_index is not None:
|
||||||
|
# File was created (size > 0 means padding was applied)
|
||||||
|
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{track_index}.webm"
|
||||||
|
created_padded_files.add(storage_path)
|
||||||
|
|
||||||
# Sort words by start time
|
# Sort words by start time
|
||||||
all_words.sort(key=lambda w: w.get("start", 0))
|
all_words.sort(key=lambda w: w.get("start", 0))
|
||||||
|
|
||||||
|
# NOTE: Cleanup of padded S3 files moved to generate_waveform (after mixdown completes)
|
||||||
|
# Mixdown needs the padded files, so we can't delete them here
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] process_tracks complete",
|
"[Hatchet] process_tracks complete",
|
||||||
num_tracks=len(input.tracks),
|
num_tracks=len(input.tracks),
|
||||||
@@ -256,14 +397,26 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
"padded_urls": padded_urls,
|
"padded_urls": padded_urls,
|
||||||
"word_count": len(all_words),
|
"word_count": len(all_words),
|
||||||
"num_tracks": len(input.tracks),
|
"num_tracks": len(input.tracks),
|
||||||
|
"target_language": target_language,
|
||||||
|
"created_padded_files": list(
|
||||||
|
created_padded_files
|
||||||
|
), # For cleanup after mixdown
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[Hatchet] process_tracks failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
|
await emit_progress_async(
|
||||||
|
input.transcript_id, "process_tracks", "failed", ctx.workflow_run_id
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(
|
@diarization_pipeline.task(
|
||||||
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[process_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Mix all padded tracks into single audio file."""
|
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
||||||
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -279,93 +432,198 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
storage = _get_storage()
|
storage = _get_storage()
|
||||||
|
|
||||||
# Download all tracks and mix
|
# Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph
|
||||||
temp_inputs = []
|
from fractions import Fraction
|
||||||
|
|
||||||
|
from av.audio.resampler import AudioResampler
|
||||||
|
|
||||||
|
from reflector.processors import AudioFileWriterProcessor
|
||||||
|
|
||||||
|
valid_urls = [url for url in padded_urls if url]
|
||||||
|
if not valid_urls:
|
||||||
|
raise ValueError("No valid padded tracks to mixdown")
|
||||||
|
|
||||||
|
# Determine target sample rate from first track
|
||||||
|
target_sample_rate = None
|
||||||
|
for url in valid_urls:
|
||||||
try:
|
try:
|
||||||
for i, url in enumerate(padded_urls):
|
container = av.open(url)
|
||||||
if not url:
|
for frame in container.decode(audio=0):
|
||||||
continue
|
target_sample_rate = frame.sample_rate
|
||||||
temp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False)
|
break
|
||||||
temp_inputs.append(temp_input.name)
|
|
||||||
|
|
||||||
# Download track
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(url)
|
|
||||||
response.raise_for_status()
|
|
||||||
with open(temp_input.name, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
# Mix using PyAV amix filter
|
|
||||||
if len(temp_inputs) == 0:
|
|
||||||
raise ValueError("No valid tracks to mixdown")
|
|
||||||
|
|
||||||
output_path = tempfile.mktemp(suffix=".mp3")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use ffmpeg-style mixing via PyAV
|
|
||||||
containers = [av.open(path) for path in temp_inputs]
|
|
||||||
|
|
||||||
# Get the longest duration
|
|
||||||
max_duration = 0.0
|
|
||||||
for container in containers:
|
|
||||||
if container.duration:
|
|
||||||
duration = float(container.duration * av.time_base)
|
|
||||||
max_duration = max(max_duration, duration)
|
|
||||||
|
|
||||||
# Close containers for now
|
|
||||||
for container in containers:
|
|
||||||
container.close()
|
container.close()
|
||||||
|
if target_sample_rate:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
# Use subprocess for mixing (simpler than complex PyAV graph)
|
if not target_sample_rate:
|
||||||
import subprocess
|
raise ValueError("No decodable audio frames in any track")
|
||||||
|
|
||||||
# Build ffmpeg command
|
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
|
||||||
cmd = ["ffmpeg", "-y"]
|
graph = av.filter.Graph()
|
||||||
for path in temp_inputs:
|
inputs = []
|
||||||
cmd.extend(["-i", path])
|
|
||||||
|
|
||||||
# Build filter for N inputs
|
for idx, url in enumerate(valid_urls):
|
||||||
n = len(temp_inputs)
|
args = (
|
||||||
filter_str = f"amix=inputs={n}:duration=longest:normalize=0"
|
f"time_base=1/{target_sample_rate}:"
|
||||||
cmd.extend(["-filter_complex", filter_str])
|
f"sample_rate={target_sample_rate}:"
|
||||||
cmd.extend(["-ac", "2", "-ar", "48000", "-b:a", "128k", output_path])
|
f"sample_fmt=s32:"
|
||||||
|
f"channel_layout=stereo"
|
||||||
|
)
|
||||||
|
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||||
|
inputs.append(in_ctx)
|
||||||
|
|
||||||
subprocess.run(cmd, check=True, capture_output=True)
|
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||||
|
fmt = graph.add(
|
||||||
|
"aformat",
|
||||||
|
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
||||||
|
name="fmt",
|
||||||
|
)
|
||||||
|
sink = graph.add("abuffersink", name="out")
|
||||||
|
|
||||||
# Upload mixed file
|
for idx, in_ctx in enumerate(inputs):
|
||||||
|
in_ctx.link_to(mixer, 0, idx)
|
||||||
|
mixer.link_to(fmt)
|
||||||
|
fmt.link_to(sink)
|
||||||
|
graph.configure()
|
||||||
|
|
||||||
|
# Create temp output file
|
||||||
|
output_path = tempfile.mktemp(suffix=".mp3")
|
||||||
|
containers = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Open all containers
|
||||||
|
for url in valid_urls:
|
||||||
|
try:
|
||||||
|
c = av.open(
|
||||||
|
url,
|
||||||
|
options={
|
||||||
|
"reconnect": "1",
|
||||||
|
"reconnect_streamed": "1",
|
||||||
|
"reconnect_delay_max": "5",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
containers.append(c)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet] mixdown: failed to open container",
|
||||||
|
url=url,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not containers:
|
||||||
|
raise ValueError("Could not open any track containers")
|
||||||
|
|
||||||
|
# Create AudioFileWriterProcessor for MP3 output with duration capture
|
||||||
|
duration_ms = [0.0] # Mutable container for callback capture
|
||||||
|
|
||||||
|
async def capture_duration(d):
|
||||||
|
duration_ms[0] = d
|
||||||
|
|
||||||
|
writer = AudioFileWriterProcessor(
|
||||||
|
path=output_path, on_duration=capture_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
decoders = [c.decode(audio=0) for c in containers]
|
||||||
|
active = [True] * len(decoders)
|
||||||
|
resamplers = [
|
||||||
|
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||||
|
for _ in decoders
|
||||||
|
]
|
||||||
|
|
||||||
|
while any(active):
|
||||||
|
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||||
|
if not is_active:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
frame = next(dec)
|
||||||
|
except StopIteration:
|
||||||
|
active[i] = False
|
||||||
|
inputs[i].push(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if frame.sample_rate != target_sample_rate:
|
||||||
|
continue
|
||||||
|
out_frames = resamplers[i].resample(frame) or []
|
||||||
|
for rf in out_frames:
|
||||||
|
rf.sample_rate = target_sample_rate
|
||||||
|
rf.time_base = Fraction(1, target_sample_rate)
|
||||||
|
inputs[i].push(rf)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
await writer.push(mixed)
|
||||||
|
|
||||||
|
# Flush remaining frames
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
await writer.push(mixed)
|
||||||
|
|
||||||
|
await writer.flush()
|
||||||
|
|
||||||
|
# Duration is captured via callback in milliseconds (from AudioFileWriterProcessor)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for c in containers:
|
||||||
|
try:
|
||||||
|
c.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Upload mixed file to correct path (matches Celery: {transcript.id}/audio.mp3)
|
||||||
file_size = Path(output_path).stat().st_size
|
file_size = Path(output_path).stat().st_size
|
||||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/mixed.mp3"
|
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||||
|
|
||||||
with open(output_path, "rb") as mixed_file:
|
with open(output_path, "rb") as mixed_file:
|
||||||
await storage.put_file(storage_path, mixed_file)
|
await storage.put_file(storage_path, mixed_file)
|
||||||
|
|
||||||
|
Path(output_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
# Update transcript with audio_location (matches Celery line 661)
|
||||||
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"audio_location": "storage"}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] mixdown_tracks uploaded",
|
"[Hatchet] mixdown_tracks uploaded",
|
||||||
key=storage_path,
|
key=storage_path,
|
||||||
size=file_size,
|
size=file_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
|
||||||
Path(output_path).unlink(missing_ok=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for path in temp_inputs:
|
|
||||||
Path(path).unlink(missing_ok=True)
|
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
input.transcript_id, "mixdown_tracks", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"audio_key": storage_path,
|
"audio_key": storage_path,
|
||||||
"duration": max_duration,
|
"duration": duration_ms[
|
||||||
"tracks_mixed": len(temp_inputs),
|
0
|
||||||
|
], # Duration in milliseconds from AudioFileWriterProcessor
|
||||||
|
"tracks_mixed": len(valid_urls),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] mixdown_tracks failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
|
input.transcript_id, "mixdown_tracks", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -376,7 +634,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=120), retries=3
|
||||||
)
|
)
|
||||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Generate audio waveform visualization."""
|
"""Generate audio waveform visualization using AudioWaveformProcessor (matches Celery)."""
|
||||||
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_waveform", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -384,6 +642,35 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.db.transcripts import TranscriptWaveform, transcripts_controller
|
||||||
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
|
||||||
|
# Cleanup temporary padded S3 files (matches Celery lines 710-725)
|
||||||
|
# Moved here from process_tracks because mixdown_tracks needs the padded files
|
||||||
|
track_data = ctx.task_output(process_tracks)
|
||||||
|
created_padded_files = track_data.get("created_padded_files", [])
|
||||||
|
if created_padded_files:
|
||||||
|
logger.info(
|
||||||
|
f"[Hatchet] Cleaning up {len(created_padded_files)} temporary S3 files"
|
||||||
|
)
|
||||||
|
storage = _get_storage()
|
||||||
|
cleanup_tasks = []
|
||||||
|
for storage_path in created_padded_files:
|
||||||
|
cleanup_tasks.append(storage.delete_file(storage_path))
|
||||||
|
|
||||||
|
cleanup_results = await asyncio.gather(
|
||||||
|
*cleanup_tasks, return_exceptions=True
|
||||||
|
)
|
||||||
|
for storage_path, result in zip(created_padded_files, cleanup_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet] Failed to cleanup temporary padded track",
|
||||||
|
storage_path=storage_path,
|
||||||
|
error=str(result),
|
||||||
|
)
|
||||||
|
|
||||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||||
audio_key = mixdown_data.get("audio_key")
|
audio_key = mixdown_data.get("audio_key")
|
||||||
|
|
||||||
@@ -394,18 +681,34 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
||||||
from reflector.pipelines.waveform_helpers import generate_waveform_data
|
# Download MP3 to temp file (AudioWaveformProcessor needs local file)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
waveform = await generate_waveform_data(audio_url)
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(audio_url, timeout=120)
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
# Store waveform
|
# Generate waveform (matches Celery: get_audio_waveform with 255 segments)
|
||||||
waveform_key = f"file_pipeline_hatchet/{input.transcript_id}/waveform.json"
|
waveform = get_audio_waveform(path=Path(temp_path), segments_count=255)
|
||||||
import json
|
|
||||||
|
|
||||||
waveform_bytes = json.dumps(waveform).encode()
|
# Save waveform to database via event (matches Celery on_waveform callback)
|
||||||
import io
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
if transcript:
|
||||||
|
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript, event="WAVEFORM", data=waveform_data
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
await storage.put_file(waveform_key, io.BytesIO(waveform_bytes))
|
finally:
|
||||||
|
Path(temp_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
logger.info("[Hatchet] generate_waveform complete")
|
logger.info("[Hatchet] generate_waveform complete")
|
||||||
|
|
||||||
@@ -413,10 +716,11 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_waveform", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"waveform_key": waveform_key}
|
return {"waveform_generated": True}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] generate_waveform failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
|
input.transcript_id, "generate_waveform", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -427,7 +731,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[mixdown_tracks], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Detect topics using LLM."""
|
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
||||||
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -437,26 +741,52 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
try:
|
try:
|
||||||
track_data = ctx.task_output(process_tracks)
|
track_data = ctx.task_output(process_tracks)
|
||||||
words = track_data.get("all_words", [])
|
words = track_data.get("all_words", [])
|
||||||
|
target_language = track_data.get("target_language", "en")
|
||||||
|
|
||||||
|
from reflector.db.transcripts import TranscriptTopic, transcripts_controller
|
||||||
from reflector.pipelines import topic_processing
|
from reflector.pipelines import topic_processing
|
||||||
|
from reflector.processors.types import (
|
||||||
|
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||||
|
)
|
||||||
from reflector.processors.types import Transcript as TranscriptType
|
from reflector.processors.types import Transcript as TranscriptType
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
# Convert word dicts to Word objects
|
# Convert word dicts to Word objects
|
||||||
word_objects = [Word(**w) for w in words]
|
word_objects = [Word(**w) for w in words]
|
||||||
transcript = TranscriptType(words=word_objects)
|
transcript_type = TranscriptType(words=word_objects)
|
||||||
|
|
||||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
|
||||||
async def noop_callback(t):
|
# Get DB connection for callbacks
|
||||||
pass
|
db = await _get_fresh_db_connection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
|
# Callback that upserts topics to DB (matches Celery on_topic)
|
||||||
|
async def on_topic_callback(data):
|
||||||
|
topic = TranscriptTopic(
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
transcript=data.transcript.text,
|
||||||
|
words=data.transcript.words,
|
||||||
|
)
|
||||||
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
|
topic.id = data.id
|
||||||
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript, event="TOPIC", data=topic
|
||||||
|
)
|
||||||
|
|
||||||
topics = await topic_processing.detect_topics(
|
topics = await topic_processing.detect_topics(
|
||||||
transcript,
|
transcript_type,
|
||||||
"en", # target_language
|
target_language,
|
||||||
on_topic_callback=noop_callback,
|
on_topic_callback=on_topic_callback,
|
||||||
empty_pipeline=empty_pipeline,
|
empty_pipeline=empty_pipeline,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
topics_list = [t.model_dump() for t in topics]
|
topics_list = [t.model_dump() for t in topics]
|
||||||
|
|
||||||
@@ -470,6 +800,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] detect_topics failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
|
input.transcript_id, "detect_topics", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -480,7 +811,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
parents=[detect_topics], execution_timeout=timedelta(seconds=120), retries=3
|
||||||
)
|
)
|
||||||
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Generate meeting title using LLM."""
|
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
||||||
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -491,23 +822,56 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
topics_data = ctx.task_output(detect_topics)
|
topics_data = ctx.task_output(detect_topics)
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
TranscriptFinalTitle,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
from reflector.pipelines import topic_processing
|
from reflector.pipelines import topic_processing
|
||||||
from reflector.processors.types import Topic
|
from reflector.processors.types import TitleSummary
|
||||||
|
|
||||||
topic_objects = [Topic(**t) for t in topics]
|
topic_objects = [TitleSummary(**t) for t in topics]
|
||||||
|
|
||||||
title = await topic_processing.generate_title(topic_objects)
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
title_result = None
|
||||||
|
|
||||||
logger.info("[Hatchet] generate_title complete", title=title)
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
|
# Callback that updates title in DB (matches Celery on_title)
|
||||||
|
async def on_title_callback(data):
|
||||||
|
nonlocal title_result
|
||||||
|
title_result = data.title
|
||||||
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
|
if not transcript.title:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{"title": final_title.title},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript, event="FINAL_TITLE", data=final_title
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_processing.generate_title(
|
||||||
|
topic_objects,
|
||||||
|
on_title_callback=on_title_callback,
|
||||||
|
empty_pipeline=empty_pipeline,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
|
logger.info("[Hatchet] generate_title complete", title=title_result)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_title", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"title": title}
|
return {"title": title_result}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] generate_title failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
|
input.transcript_id, "generate_title", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -518,7 +882,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
parents=[detect_topics], execution_timeout=timedelta(seconds=300), retries=3
|
||||||
)
|
)
|
||||||
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Generate meeting summary using LLM."""
|
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
||||||
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -526,23 +890,71 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
track_data = ctx.task_output(process_tracks)
|
|
||||||
topics_data = ctx.task_output(detect_topics)
|
topics_data = ctx.task_output(detect_topics)
|
||||||
|
|
||||||
words = track_data.get("all_words", [])
|
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_data.get("topics", [])
|
||||||
|
|
||||||
from reflector.pipelines import topic_processing
|
from reflector.db.transcripts import (
|
||||||
from reflector.processors.types import Topic, Word
|
TranscriptFinalLongSummary,
|
||||||
from reflector.processors.types import Transcript as TranscriptType
|
TranscriptFinalShortSummary,
|
||||||
|
transcripts_controller,
|
||||||
word_objects = [Word(**w) for w in words]
|
|
||||||
transcript = TranscriptType(words=word_objects)
|
|
||||||
topic_objects = [Topic(**t) for t in topics]
|
|
||||||
|
|
||||||
summary, short_summary = await topic_processing.generate_summary(
|
|
||||||
transcript, topic_objects
|
|
||||||
)
|
)
|
||||||
|
from reflector.pipelines import topic_processing
|
||||||
|
from reflector.processors.types import TitleSummary
|
||||||
|
|
||||||
|
topic_objects = [TitleSummary(**t) for t in topics]
|
||||||
|
|
||||||
|
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||||
|
summary_result = None
|
||||||
|
short_summary_result = None
|
||||||
|
|
||||||
|
db = await _get_fresh_db_connection()
|
||||||
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
|
|
||||||
|
# Callback that updates long_summary in DB (matches Celery on_long_summary)
|
||||||
|
async def on_long_summary_callback(data):
|
||||||
|
nonlocal summary_result
|
||||||
|
summary_result = data.long_summary
|
||||||
|
final_long_summary = TranscriptFinalLongSummary(
|
||||||
|
long_summary=data.long_summary
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{"long_summary": final_long_summary.long_summary},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_LONG_SUMMARY",
|
||||||
|
data=final_long_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callback that updates short_summary in DB (matches Celery on_short_summary)
|
||||||
|
async def on_short_summary_callback(data):
|
||||||
|
nonlocal short_summary_result
|
||||||
|
short_summary_result = data.short_summary
|
||||||
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
|
short_summary=data.short_summary
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{"short_summary": final_short_summary.short_summary},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_SHORT_SUMMARY",
|
||||||
|
data=final_short_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_processing.generate_summaries(
|
||||||
|
topic_objects,
|
||||||
|
transcript, # DB transcript for context
|
||||||
|
on_long_summary_callback=on_long_summary_callback,
|
||||||
|
on_short_summary_callback=on_short_summary_callback,
|
||||||
|
empty_pipeline=empty_pipeline,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _close_db_connection(db)
|
||||||
|
|
||||||
logger.info("[Hatchet] generate_summary complete")
|
logger.info("[Hatchet] generate_summary complete")
|
||||||
|
|
||||||
@@ -550,10 +962,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
input.transcript_id, "generate_summary", "completed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"summary": summary, "short_summary": short_summary}
|
return {"summary": summary_result, "short_summary": short_summary_result}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] generate_summary failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
|
input.transcript_id, "generate_summary", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
@@ -566,7 +979,11 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
retries=3,
|
retries=3,
|
||||||
)
|
)
|
||||||
async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
||||||
"""Finalize transcript status and update database."""
|
"""Finalize transcript: save words, emit TRANSCRIPT event, set status to 'ended'.
|
||||||
|
|
||||||
|
Matches Celery's on_transcript + set_status behavior.
|
||||||
|
Note: Title and summaries are already saved by their respective task callbacks.
|
||||||
|
"""
|
||||||
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
|
||||||
|
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
@@ -574,21 +991,17 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
title_data = ctx.task_output(generate_title)
|
|
||||||
summary_data = ctx.task_output(generate_summary)
|
|
||||||
mixdown_data = ctx.task_output(mixdown_tracks)
|
mixdown_data = ctx.task_output(mixdown_tracks)
|
||||||
track_data = ctx.task_output(process_tracks)
|
track_data = ctx.task_output(process_tracks)
|
||||||
|
|
||||||
title = title_data.get("title", "")
|
|
||||||
summary = summary_data.get("summary", "")
|
|
||||||
short_summary = summary_data.get("short_summary", "")
|
|
||||||
duration = mixdown_data.get("duration", 0)
|
duration = mixdown_data.get("duration", 0)
|
||||||
all_words = track_data.get("all_words", [])
|
all_words = track_data.get("all_words", [])
|
||||||
|
|
||||||
db = await _get_fresh_db_connection()
|
db = await _get_fresh_db_connection()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import TranscriptText, transcripts_controller
|
||||||
|
from reflector.processors.types import Transcript as TranscriptType
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||||
@@ -600,18 +1013,32 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
# Convert words back to Word objects for storage
|
# Convert words back to Word objects for storage
|
||||||
word_objects = [Word(**w) for w in all_words]
|
word_objects = [Word(**w) for w in all_words]
|
||||||
|
|
||||||
|
# Create merged transcript for TRANSCRIPT event (matches Celery line 734-736)
|
||||||
|
merged_transcript = TranscriptType(words=word_objects, translation=None)
|
||||||
|
|
||||||
|
# Emit TRANSCRIPT event (matches Celery on_transcript callback)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="TRANSCRIPT",
|
||||||
|
data=TranscriptText(
|
||||||
|
text=merged_transcript.text,
|
||||||
|
translation=merged_transcript.translation,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save duration and clear workflow_run_id (workflow completed successfully)
|
||||||
|
# Note: title/long_summary/short_summary already saved by their callbacks
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"status": "ended",
|
|
||||||
"title": title,
|
|
||||||
"long_summary": summary,
|
|
||||||
"short_summary": short_summary,
|
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"words": word_objects,
|
"workflow_run_id": None, # Clear on success - no need to resume
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set status to "ended" (matches Celery line 745)
|
||||||
|
await transcripts_controller.set_status(input.transcript_id, "ended")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[Hatchet] finalize complete", transcript_id=input.transcript_id
|
"[Hatchet] finalize complete", transcript_id=input.transcript_id
|
||||||
)
|
)
|
||||||
@@ -627,6 +1054,7 @@ async def finalize(input: PipelineInput, ctx: Context) -> dict:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
|
logger.error("[Hatchet] finalize failed", error=str(e), exc_info=True)
|
||||||
|
await _set_error_status(input.transcript_id)
|
||||||
await emit_progress_async(
|
await emit_progress_async(
|
||||||
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
|
input.transcript_id, "finalize", "failed", ctx.workflow_run_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ class SummaryBuilder:
|
|||||||
self.model_name: str = llm.model_name
|
self.model_name: str = llm.model_name
|
||||||
self.logger = logger or structlog.get_logger()
|
self.logger = logger or structlog.get_logger()
|
||||||
self.participant_instructions: str | None = None
|
self.participant_instructions: str | None = None
|
||||||
|
self._logged_participant_instructions: bool = False
|
||||||
if filename:
|
if filename:
|
||||||
self.read_transcript_from_file(filename)
|
self.read_transcript_from_file(filename)
|
||||||
|
|
||||||
@@ -208,7 +209,9 @@ class SummaryBuilder:
|
|||||||
def _enhance_prompt_with_participants(self, prompt: str) -> str:
|
def _enhance_prompt_with_participants(self, prompt: str) -> str:
|
||||||
"""Add participant instructions to any prompt if participants are known."""
|
"""Add participant instructions to any prompt if participants are known."""
|
||||||
if self.participant_instructions:
|
if self.participant_instructions:
|
||||||
self.logger.debug("Adding participant instructions to prompt")
|
if not self._logged_participant_instructions:
|
||||||
|
self.logger.debug("Adding participant instructions to prompts")
|
||||||
|
self._logged_participant_instructions = True
|
||||||
return f"{prompt}\n\n{self.participant_instructions}"
|
return f"{prompt}\n\n{self.participant_instructions}"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ async def validate_transcript_for_processing(
|
|||||||
if transcript.status == "idle":
|
if transcript.status == "idle":
|
||||||
return ValidationNotReady(detail="Recording is not ready for processing")
|
return ValidationNotReady(detail="Recording is not ready for processing")
|
||||||
|
|
||||||
|
# Check Celery tasks
|
||||||
if task_is_scheduled_or_active(
|
if task_is_scheduled_or_active(
|
||||||
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
||||||
transcript_id=transcript.id,
|
transcript_id=transcript.id,
|
||||||
@@ -111,6 +112,23 @@ async def validate_transcript_for_processing(
|
|||||||
):
|
):
|
||||||
return ValidationAlreadyScheduled(detail="already running")
|
return ValidationAlreadyScheduled(detail="already running")
|
||||||
|
|
||||||
|
# Check Hatchet workflows (if enabled)
|
||||||
|
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
# If workflow is running or queued, don't allow new processing
|
||||||
|
if "RUNNING" in status or "QUEUED" in status:
|
||||||
|
return ValidationAlreadyScheduled(
|
||||||
|
detail="Hatchet workflow already running"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If we can't get status, allow processing (workflow might be gone)
|
||||||
|
pass
|
||||||
|
|
||||||
return ValidationOk(
|
return ValidationOk(
|
||||||
recording_id=transcript.recording_id, transcript_id=transcript.id
|
recording_id=transcript.recording_id, transcript_id=transcript.id
|
||||||
)
|
)
|
||||||
@@ -155,7 +173,9 @@ async def prepare_transcript_processing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None:
|
def dispatch_transcript_processing(
|
||||||
|
config: ProcessingConfig, force: bool = False
|
||||||
|
) -> AsyncResult | None:
|
||||||
if isinstance(config, MultitrackProcessingConfig):
|
if isinstance(config, MultitrackProcessingConfig):
|
||||||
# Start durable workflow if enabled (Hatchet or Conductor)
|
# Start durable workflow if enabled (Hatchet or Conductor)
|
||||||
durable_started = False
|
durable_started = False
|
||||||
@@ -163,12 +183,53 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
|
|||||||
if settings.HATCHET_ENABLED:
|
if settings.HATCHET_ENABLED:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def _start_hatchet():
|
import databases
|
||||||
return await HatchetClientManager.start_workflow(
|
|
||||||
|
from reflector.db import _database_context
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
async def _handle_hatchet():
|
||||||
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
|
_database_context.set(db)
|
||||||
|
await db.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
config.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if transcript and transcript.workflow_run_id and not force:
|
||||||
|
can_replay = await HatchetClientManager.can_replay(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
if can_replay:
|
||||||
|
await HatchetClientManager.replay_workflow(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Replaying Hatchet workflow",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
return transcript.workflow_run_id
|
||||||
|
|
||||||
|
# Force: cancel old workflow if exists
|
||||||
|
if force and transcript and transcript.workflow_run_id:
|
||||||
|
await HatchetClientManager.cancel_workflow(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Cancelled old workflow (--force)",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"workflow_run_id": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
workflow_name="DiarizationPipeline",
|
workflow_name="DiarizationPipeline",
|
||||||
input_data={
|
input_data={
|
||||||
"recording_id": config.recording_id,
|
"recording_id": config.recording_id,
|
||||||
"room_name": None, # Not available in reprocess path
|
"room_name": None,
|
||||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||||
"bucket_name": config.bucket_name,
|
"bucket_name": config.bucket_name,
|
||||||
"transcript_id": config.transcript_id,
|
"transcript_id": config.transcript_id,
|
||||||
@@ -176,25 +237,30 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"workflow_run_id": workflow_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_id
|
||||||
|
finally:
|
||||||
|
await db.disconnect()
|
||||||
|
_database_context.set(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
loop = None
|
loop = None
|
||||||
|
|
||||||
if loop and loop.is_running():
|
if loop and loop.is_running():
|
||||||
# Already in async context
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result()
|
workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result()
|
||||||
else:
|
else:
|
||||||
workflow_id = asyncio.run(_start_hatchet())
|
workflow_id = asyncio.run(_handle_hatchet())
|
||||||
|
|
||||||
logger.info(
|
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||||
"Started Hatchet workflow (reprocess)",
|
|
||||||
workflow_id=workflow_id,
|
|
||||||
transcript_id=config.transcript_id,
|
|
||||||
)
|
|
||||||
durable_started = True
|
durable_started = True
|
||||||
|
|
||||||
elif settings.CONDUCTOR_ENABLED:
|
elif settings.CONDUCTOR_ENABLED:
|
||||||
|
|||||||
@@ -34,21 +34,25 @@ async def process_transcript_inner(
|
|||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
on_validation: Callable[[ValidationResult], None],
|
on_validation: Callable[[ValidationResult], None],
|
||||||
on_preprocess: Callable[[PrepareResult], None],
|
on_preprocess: Callable[[PrepareResult], None],
|
||||||
|
force: bool = False,
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
validation = await validate_transcript_for_processing(transcript)
|
validation = await validate_transcript_for_processing(transcript)
|
||||||
on_validation(validation)
|
on_validation(validation)
|
||||||
config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
|
config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
|
||||||
on_preprocess(config)
|
on_preprocess(config)
|
||||||
return dispatch_transcript_processing(config)
|
return dispatch_transcript_processing(config, force=force)
|
||||||
|
|
||||||
|
|
||||||
async def process_transcript(transcript_id: str, sync: bool = False) -> None:
|
async def process_transcript(
|
||||||
|
transcript_id: str, sync: bool = False, force: bool = False
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Process a transcript by ID, auto-detecting multitrack vs file pipeline.
|
Process a transcript by ID, auto-detecting multitrack vs file pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transcript_id: The transcript UUID
|
transcript_id: The transcript UUID
|
||||||
sync: If True, wait for task completion. If False, dispatch and exit.
|
sync: If True, wait for task completion. If False, dispatch and exit.
|
||||||
|
force: If True, cancel old workflow and start new (latest code). If False, replay failed workflow.
|
||||||
"""
|
"""
|
||||||
from reflector.db import get_database
|
from reflector.db import get_database
|
||||||
|
|
||||||
@@ -82,7 +86,10 @@ async def process_transcript(transcript_id: str, sync: bool = False) -> None:
|
|||||||
print(f"Dispatching file pipeline", file=sys.stderr)
|
print(f"Dispatching file pipeline", file=sys.stderr)
|
||||||
|
|
||||||
result = await process_transcript_inner(
|
result = await process_transcript_inner(
|
||||||
transcript, on_validation=on_validation, on_preprocess=on_preprocess
|
transcript,
|
||||||
|
on_validation=on_validation,
|
||||||
|
on_preprocess=on_preprocess,
|
||||||
|
force=force,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sync:
|
if sync:
|
||||||
@@ -118,9 +125,16 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Wait for task completion instead of just dispatching",
|
help="Wait for task completion instead of just dispatching",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Cancel old workflow and start new (uses latest code instead of replaying)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
asyncio.run(process_transcript(args.transcript_id, sync=args.sync))
|
asyncio.run(
|
||||||
|
process_transcript(args.transcript_id, sync=args.sync, force=args.force)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user