hatchet no-mistake

This commit is contained in:
Igor Loskutov
2025-12-16 12:09:02 -05:00
parent c5498d26bf
commit 0f266eabdf
7 changed files with 780 additions and 202 deletions

View File

@@ -0,0 +1,28 @@
"""add workflow_run_id to transcript
Revision ID: 0f943fede0e0
Revises: a326252ac554
Create Date: 2025-12-16 01:54:13.855106
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0f943fede0e0"
down_revision: Union[str, None] = "a326252ac554"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.add_column(sa.Column("workflow_run_id", sa.String(), nullable=True))
def downgrade() -> None:
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.drop_column("workflow_run_id")

View File

@@ -83,6 +83,8 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text), sqlalchemy.Column("webvtt", sqlalchemy.Text),
# Hatchet workflow run ID for resumption of failed workflows
sqlalchemy.Column("workflow_run_id", sqlalchemy.String),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"), sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -227,6 +229,7 @@ class Transcript(BaseModel):
zulip_message_id: int | None = None zulip_message_id: int | None = None
audio_deleted: bool | None = None audio_deleted: bool | None = None
webvtt: str | None = None webvtt: str | None = None
workflow_run_id: str | None = None # Hatchet workflow run ID for resumption
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str: def serialize_datetime(self, dt: datetime) -> str:

View File

@@ -2,6 +2,7 @@
from hatchet_sdk import Hatchet from hatchet_sdk import Hatchet
from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
@@ -35,9 +36,44 @@ class HatchetClientManager:
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id # SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
return result.run.metadata.id return result.run.metadata.id
@classmethod
async def get_workflow_run_status(cls, workflow_run_id: str) -> str:
"""Get workflow run status."""
client = cls.get_client()
status = await client.runs.aio_get_status(workflow_run_id)
return str(status)
@classmethod
async def cancel_workflow(cls, workflow_run_id: str) -> None:
"""Cancel a workflow."""
client = cls.get_client()
await client.runs.aio_cancel(workflow_run_id)
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
@classmethod
async def replay_workflow(cls, workflow_run_id: str) -> None:
"""Replay a failed workflow."""
client = cls.get_client()
await client.runs.aio_replay(workflow_run_id)
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)
@classmethod
async def can_replay(cls, workflow_run_id: str) -> bool:
"""Check if workflow can be replayed (is FAILED)."""
try:
status = await cls.get_workflow_run_status(workflow_run_id)
return "FAILED" in status
except Exception as e:
logger.warning(
"[Hatchet] Failed to check replay status",
workflow_run_id=workflow_run_id,
error=str(e),
)
return False
@classmethod @classmethod
async def get_workflow_status(cls, workflow_run_id: str) -> dict: async def get_workflow_status(cls, workflow_run_id: str) -> dict:
"""Get the current status of a workflow run.""" """Get the full workflow run details as dict."""
client = cls.get_client() client = cls.get_client()
run = await client.runs.aio_get(workflow_run_id) run = await client.runs.aio_get(workflow_run_id)
return run.to_dict() return run.to_dict()

File diff suppressed because it is too large Load Diff

View File

@@ -166,6 +166,7 @@ class SummaryBuilder:
self.model_name: str = llm.model_name self.model_name: str = llm.model_name
self.logger = logger or structlog.get_logger() self.logger = logger or structlog.get_logger()
self.participant_instructions: str | None = None self.participant_instructions: str | None = None
self._logged_participant_instructions: bool = False
if filename: if filename:
self.read_transcript_from_file(filename) self.read_transcript_from_file(filename)
@@ -208,7 +209,9 @@ class SummaryBuilder:
def _enhance_prompt_with_participants(self, prompt: str) -> str: def _enhance_prompt_with_participants(self, prompt: str) -> str:
"""Add participant instructions to any prompt if participants are known.""" """Add participant instructions to any prompt if participants are known."""
if self.participant_instructions: if self.participant_instructions:
self.logger.debug("Adding participant instructions to prompt") if not self._logged_participant_instructions:
self.logger.debug("Adding participant instructions to prompts")
self._logged_participant_instructions = True
return f"{prompt}\n\n{self.participant_instructions}" return f"{prompt}\n\n{self.participant_instructions}"
return prompt return prompt

View File

@@ -102,6 +102,7 @@ async def validate_transcript_for_processing(
if transcript.status == "idle": if transcript.status == "idle":
return ValidationNotReady(detail="Recording is not ready for processing") return ValidationNotReady(detail="Recording is not ready for processing")
# Check Celery tasks
if task_is_scheduled_or_active( if task_is_scheduled_or_active(
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process", "reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
transcript_id=transcript.id, transcript_id=transcript.id,
@@ -111,6 +112,23 @@ async def validate_transcript_for_processing(
): ):
return ValidationAlreadyScheduled(detail="already running") return ValidationAlreadyScheduled(detail="already running")
# Check Hatchet workflows (if enabled)
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
from reflector.hatchet.client import HatchetClientManager
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
)
# If workflow is running or queued, don't allow new processing
if "RUNNING" in status or "QUEUED" in status:
return ValidationAlreadyScheduled(
detail="Hatchet workflow already running"
)
except Exception:
# If we can't get status, allow processing (workflow might be gone)
pass
return ValidationOk( return ValidationOk(
recording_id=transcript.recording_id, transcript_id=transcript.id recording_id=transcript.recording_id, transcript_id=transcript.id
) )
@@ -155,7 +173,9 @@ async def prepare_transcript_processing(
) )
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None: def dispatch_transcript_processing(
config: ProcessingConfig, force: bool = False
) -> AsyncResult | None:
if isinstance(config, MultitrackProcessingConfig): if isinstance(config, MultitrackProcessingConfig):
# Start durable workflow if enabled (Hatchet or Conductor) # Start durable workflow if enabled (Hatchet or Conductor)
durable_started = False durable_started = False
@@ -163,18 +183,69 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
if settings.HATCHET_ENABLED: if settings.HATCHET_ENABLED:
import asyncio import asyncio
async def _start_hatchet(): import databases
return await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline", from reflector.db import _database_context
input_data={ from reflector.db.transcripts import transcripts_controller
"recording_id": config.recording_id,
"room_name": None, # Not available in reprocess path async def _handle_hatchet():
"tracks": [{"s3_key": k} for k in config.track_keys], db = databases.Database(settings.DATABASE_URL)
"bucket_name": config.bucket_name, _database_context.set(db)
"transcript_id": config.transcript_id, await db.connect()
"room_id": config.room_id,
}, try:
) transcript = await transcripts_controller.get_by_id(
config.transcript_id
)
if transcript and transcript.workflow_run_id and not force:
can_replay = await HatchetClientManager.can_replay(
transcript.workflow_run_id
)
if can_replay:
await HatchetClientManager.replay_workflow(
transcript.workflow_run_id
)
logger.info(
"Replaying Hatchet workflow",
workflow_id=transcript.workflow_run_id,
)
return transcript.workflow_run_id
# Force: cancel old workflow if exists
if force and transcript and transcript.workflow_run_id:
await HatchetClientManager.cancel_workflow(
transcript.workflow_run_id
)
logger.info(
"Cancelled old workflow (--force)",
workflow_id=transcript.workflow_run_id,
)
await transcripts_controller.update(
transcript, {"workflow_run_id": None}
)
workflow_id = await HatchetClientManager.start_workflow(
workflow_name="DiarizationPipeline",
input_data={
"recording_id": config.recording_id,
"room_name": None,
"tracks": [{"s3_key": k} for k in config.track_keys],
"bucket_name": config.bucket_name,
"transcript_id": config.transcript_id,
"room_id": config.room_id,
},
)
if transcript:
await transcripts_controller.update(
transcript, {"workflow_run_id": workflow_id}
)
return workflow_id
finally:
await db.disconnect()
_database_context.set(None)
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@@ -182,19 +253,14 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
loop = None loop = None
if loop and loop.is_running(): if loop and loop.is_running():
# Already in async context
import concurrent.futures import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result() workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result()
else: else:
workflow_id = asyncio.run(_start_hatchet()) workflow_id = asyncio.run(_handle_hatchet())
logger.info( logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
"Started Hatchet workflow (reprocess)",
workflow_id=workflow_id,
transcript_id=config.transcript_id,
)
durable_started = True durable_started = True
elif settings.CONDUCTOR_ENABLED: elif settings.CONDUCTOR_ENABLED:

View File

@@ -34,21 +34,25 @@ async def process_transcript_inner(
transcript: Transcript, transcript: Transcript,
on_validation: Callable[[ValidationResult], None], on_validation: Callable[[ValidationResult], None],
on_preprocess: Callable[[PrepareResult], None], on_preprocess: Callable[[PrepareResult], None],
force: bool = False,
) -> AsyncResult: ) -> AsyncResult:
validation = await validate_transcript_for_processing(transcript) validation = await validate_transcript_for_processing(transcript)
on_validation(validation) on_validation(validation)
config = await prepare_transcript_processing(validation, room_id=transcript.room_id) config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
on_preprocess(config) on_preprocess(config)
return dispatch_transcript_processing(config) return dispatch_transcript_processing(config, force=force)
async def process_transcript(transcript_id: str, sync: bool = False) -> None: async def process_transcript(
transcript_id: str, sync: bool = False, force: bool = False
) -> None:
""" """
Process a transcript by ID, auto-detecting multitrack vs file pipeline. Process a transcript by ID, auto-detecting multitrack vs file pipeline.
Args: Args:
transcript_id: The transcript UUID transcript_id: The transcript UUID
sync: If True, wait for task completion. If False, dispatch and exit. sync: If True, wait for task completion. If False, dispatch and exit.
force: If True, cancel old workflow and start new (latest code). If False, replay failed workflow.
""" """
from reflector.db import get_database from reflector.db import get_database
@@ -82,7 +86,10 @@ async def process_transcript(transcript_id: str, sync: bool = False) -> None:
print(f"Dispatching file pipeline", file=sys.stderr) print(f"Dispatching file pipeline", file=sys.stderr)
result = await process_transcript_inner( result = await process_transcript_inner(
transcript, on_validation=on_validation, on_preprocess=on_preprocess transcript,
on_validation=on_validation,
on_preprocess=on_preprocess,
force=force,
) )
if sync: if sync:
@@ -118,9 +125,16 @@ def main():
action="store_true", action="store_true",
help="Wait for task completion instead of just dispatching", help="Wait for task completion instead of just dispatching",
) )
parser.add_argument(
"--force",
action="store_true",
help="Cancel old workflow and start new (uses latest code instead of replaying)",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(process_transcript(args.transcript_id, sync=args.sync)) asyncio.run(
process_transcript(args.transcript_id, sync=args.sync, force=args.force)
)
if __name__ == "__main__": if __name__ == "__main__":