mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
hatchet no-mistake
This commit is contained in:
@@ -0,0 +1,28 @@
|
|||||||
|
"""add workflow_run_id to transcript
|
||||||
|
|
||||||
|
Revision ID: 0f943fede0e0
|
||||||
|
Revises: a326252ac554
|
||||||
|
Create Date: 2025-12-16 01:54:13.855106
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0f943fede0e0"
|
||||||
|
down_revision: Union[str, None] = "a326252ac554"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column("workflow_run_id", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.drop_column("workflow_run_id")
|
||||||
@@ -83,6 +83,8 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||||
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||||
|
# Hatchet workflow run ID for resumption of failed workflows
|
||||||
|
sqlalchemy.Column("workflow_run_id", sqlalchemy.String),
|
||||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||||
@@ -227,6 +229,7 @@ class Transcript(BaseModel):
|
|||||||
zulip_message_id: int | None = None
|
zulip_message_id: int | None = None
|
||||||
audio_deleted: bool | None = None
|
audio_deleted: bool | None = None
|
||||||
webvtt: str | None = None
|
webvtt: str | None = None
|
||||||
|
workflow_run_id: str | None = None # Hatchet workflow run ID for resumption
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def serialize_datetime(self, dt: datetime) -> str:
|
def serialize_datetime(self, dt: datetime) -> str:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from hatchet_sdk import Hatchet
|
from hatchet_sdk import Hatchet
|
||||||
|
|
||||||
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -35,9 +36,44 @@ class HatchetClientManager:
|
|||||||
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
# SDK v1.21+ returns V1WorkflowRunDetails with run.metadata.id
|
||||||
return result.run.metadata.id
|
return result.run.metadata.id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_workflow_run_status(cls, workflow_run_id: str) -> str:
|
||||||
|
"""Get workflow run status."""
|
||||||
|
client = cls.get_client()
|
||||||
|
status = await client.runs.aio_get_status(workflow_run_id)
|
||||||
|
return str(status)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def cancel_workflow(cls, workflow_run_id: str) -> None:
|
||||||
|
"""Cancel a workflow."""
|
||||||
|
client = cls.get_client()
|
||||||
|
await client.runs.aio_cancel(workflow_run_id)
|
||||||
|
logger.info("[Hatchet] Cancelled workflow", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def replay_workflow(cls, workflow_run_id: str) -> None:
|
||||||
|
"""Replay a failed workflow."""
|
||||||
|
client = cls.get_client()
|
||||||
|
await client.runs.aio_replay(workflow_run_id)
|
||||||
|
logger.info("[Hatchet] Replaying workflow", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def can_replay(cls, workflow_run_id: str) -> bool:
|
||||||
|
"""Check if workflow can be replayed (is FAILED)."""
|
||||||
|
try:
|
||||||
|
status = await cls.get_workflow_run_status(workflow_run_id)
|
||||||
|
return "FAILED" in status
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"[Hatchet] Failed to check replay status",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
|
async def get_workflow_status(cls, workflow_run_id: str) -> dict:
|
||||||
"""Get the current status of a workflow run."""
|
"""Get the full workflow run details as dict."""
|
||||||
client = cls.get_client()
|
client = cls.get_client()
|
||||||
run = await client.runs.aio_get(workflow_run_id)
|
run = await client.runs.aio_get(workflow_run_id)
|
||||||
return run.to_dict()
|
return run.to_dict()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -166,6 +166,7 @@ class SummaryBuilder:
|
|||||||
self.model_name: str = llm.model_name
|
self.model_name: str = llm.model_name
|
||||||
self.logger = logger or structlog.get_logger()
|
self.logger = logger or structlog.get_logger()
|
||||||
self.participant_instructions: str | None = None
|
self.participant_instructions: str | None = None
|
||||||
|
self._logged_participant_instructions: bool = False
|
||||||
if filename:
|
if filename:
|
||||||
self.read_transcript_from_file(filename)
|
self.read_transcript_from_file(filename)
|
||||||
|
|
||||||
@@ -208,7 +209,9 @@ class SummaryBuilder:
|
|||||||
def _enhance_prompt_with_participants(self, prompt: str) -> str:
|
def _enhance_prompt_with_participants(self, prompt: str) -> str:
|
||||||
"""Add participant instructions to any prompt if participants are known."""
|
"""Add participant instructions to any prompt if participants are known."""
|
||||||
if self.participant_instructions:
|
if self.participant_instructions:
|
||||||
self.logger.debug("Adding participant instructions to prompt")
|
if not self._logged_participant_instructions:
|
||||||
|
self.logger.debug("Adding participant instructions to prompts")
|
||||||
|
self._logged_participant_instructions = True
|
||||||
return f"{prompt}\n\n{self.participant_instructions}"
|
return f"{prompt}\n\n{self.participant_instructions}"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ async def validate_transcript_for_processing(
|
|||||||
if transcript.status == "idle":
|
if transcript.status == "idle":
|
||||||
return ValidationNotReady(detail="Recording is not ready for processing")
|
return ValidationNotReady(detail="Recording is not ready for processing")
|
||||||
|
|
||||||
|
# Check Celery tasks
|
||||||
if task_is_scheduled_or_active(
|
if task_is_scheduled_or_active(
|
||||||
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
||||||
transcript_id=transcript.id,
|
transcript_id=transcript.id,
|
||||||
@@ -111,6 +112,23 @@ async def validate_transcript_for_processing(
|
|||||||
):
|
):
|
||||||
return ValidationAlreadyScheduled(detail="already running")
|
return ValidationAlreadyScheduled(detail="already running")
|
||||||
|
|
||||||
|
# Check Hatchet workflows (if enabled)
|
||||||
|
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = await HatchetClientManager.get_workflow_run_status(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
# If workflow is running or queued, don't allow new processing
|
||||||
|
if "RUNNING" in status or "QUEUED" in status:
|
||||||
|
return ValidationAlreadyScheduled(
|
||||||
|
detail="Hatchet workflow already running"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If we can't get status, allow processing (workflow might be gone)
|
||||||
|
pass
|
||||||
|
|
||||||
return ValidationOk(
|
return ValidationOk(
|
||||||
recording_id=transcript.recording_id, transcript_id=transcript.id
|
recording_id=transcript.recording_id, transcript_id=transcript.id
|
||||||
)
|
)
|
||||||
@@ -155,7 +173,9 @@ async def prepare_transcript_processing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | None:
|
def dispatch_transcript_processing(
|
||||||
|
config: ProcessingConfig, force: bool = False
|
||||||
|
) -> AsyncResult | None:
|
||||||
if isinstance(config, MultitrackProcessingConfig):
|
if isinstance(config, MultitrackProcessingConfig):
|
||||||
# Start durable workflow if enabled (Hatchet or Conductor)
|
# Start durable workflow if enabled (Hatchet or Conductor)
|
||||||
durable_started = False
|
durable_started = False
|
||||||
@@ -163,18 +183,69 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
|
|||||||
if settings.HATCHET_ENABLED:
|
if settings.HATCHET_ENABLED:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def _start_hatchet():
|
import databases
|
||||||
return await HatchetClientManager.start_workflow(
|
|
||||||
workflow_name="DiarizationPipeline",
|
from reflector.db import _database_context
|
||||||
input_data={
|
from reflector.db.transcripts import transcripts_controller
|
||||||
"recording_id": config.recording_id,
|
|
||||||
"room_name": None, # Not available in reprocess path
|
async def _handle_hatchet():
|
||||||
"tracks": [{"s3_key": k} for k in config.track_keys],
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
"bucket_name": config.bucket_name,
|
_database_context.set(db)
|
||||||
"transcript_id": config.transcript_id,
|
await db.connect()
|
||||||
"room_id": config.room_id,
|
|
||||||
},
|
try:
|
||||||
)
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
config.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if transcript and transcript.workflow_run_id and not force:
|
||||||
|
can_replay = await HatchetClientManager.can_replay(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
if can_replay:
|
||||||
|
await HatchetClientManager.replay_workflow(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Replaying Hatchet workflow",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
return transcript.workflow_run_id
|
||||||
|
|
||||||
|
# Force: cancel old workflow if exists
|
||||||
|
if force and transcript and transcript.workflow_run_id:
|
||||||
|
await HatchetClientManager.cancel_workflow(
|
||||||
|
transcript.workflow_run_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Cancelled old workflow (--force)",
|
||||||
|
workflow_id=transcript.workflow_run_id,
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"workflow_run_id": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_id = await HatchetClientManager.start_workflow(
|
||||||
|
workflow_name="DiarizationPipeline",
|
||||||
|
input_data={
|
||||||
|
"recording_id": config.recording_id,
|
||||||
|
"room_name": None,
|
||||||
|
"tracks": [{"s3_key": k} for k in config.track_keys],
|
||||||
|
"bucket_name": config.bucket_name,
|
||||||
|
"transcript_id": config.transcript_id,
|
||||||
|
"room_id": config.room_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript, {"workflow_run_id": workflow_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_id
|
||||||
|
finally:
|
||||||
|
await db.disconnect()
|
||||||
|
_database_context.set(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@@ -182,19 +253,14 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult | No
|
|||||||
loop = None
|
loop = None
|
||||||
|
|
||||||
if loop and loop.is_running():
|
if loop and loop.is_running():
|
||||||
# Already in async context
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
workflow_id = pool.submit(asyncio.run, _start_hatchet()).result()
|
workflow_id = pool.submit(asyncio.run, _handle_hatchet()).result()
|
||||||
else:
|
else:
|
||||||
workflow_id = asyncio.run(_start_hatchet())
|
workflow_id = asyncio.run(_handle_hatchet())
|
||||||
|
|
||||||
logger.info(
|
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||||
"Started Hatchet workflow (reprocess)",
|
|
||||||
workflow_id=workflow_id,
|
|
||||||
transcript_id=config.transcript_id,
|
|
||||||
)
|
|
||||||
durable_started = True
|
durable_started = True
|
||||||
|
|
||||||
elif settings.CONDUCTOR_ENABLED:
|
elif settings.CONDUCTOR_ENABLED:
|
||||||
|
|||||||
@@ -34,21 +34,25 @@ async def process_transcript_inner(
|
|||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
on_validation: Callable[[ValidationResult], None],
|
on_validation: Callable[[ValidationResult], None],
|
||||||
on_preprocess: Callable[[PrepareResult], None],
|
on_preprocess: Callable[[PrepareResult], None],
|
||||||
|
force: bool = False,
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
validation = await validate_transcript_for_processing(transcript)
|
validation = await validate_transcript_for_processing(transcript)
|
||||||
on_validation(validation)
|
on_validation(validation)
|
||||||
config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
|
config = await prepare_transcript_processing(validation, room_id=transcript.room_id)
|
||||||
on_preprocess(config)
|
on_preprocess(config)
|
||||||
return dispatch_transcript_processing(config)
|
return dispatch_transcript_processing(config, force=force)
|
||||||
|
|
||||||
|
|
||||||
async def process_transcript(transcript_id: str, sync: bool = False) -> None:
|
async def process_transcript(
|
||||||
|
transcript_id: str, sync: bool = False, force: bool = False
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Process a transcript by ID, auto-detecting multitrack vs file pipeline.
|
Process a transcript by ID, auto-detecting multitrack vs file pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transcript_id: The transcript UUID
|
transcript_id: The transcript UUID
|
||||||
sync: If True, wait for task completion. If False, dispatch and exit.
|
sync: If True, wait for task completion. If False, dispatch and exit.
|
||||||
|
force: If True, cancel old workflow and start new (latest code). If False, replay failed workflow.
|
||||||
"""
|
"""
|
||||||
from reflector.db import get_database
|
from reflector.db import get_database
|
||||||
|
|
||||||
@@ -82,7 +86,10 @@ async def process_transcript(transcript_id: str, sync: bool = False) -> None:
|
|||||||
print(f"Dispatching file pipeline", file=sys.stderr)
|
print(f"Dispatching file pipeline", file=sys.stderr)
|
||||||
|
|
||||||
result = await process_transcript_inner(
|
result = await process_transcript_inner(
|
||||||
transcript, on_validation=on_validation, on_preprocess=on_preprocess
|
transcript,
|
||||||
|
on_validation=on_validation,
|
||||||
|
on_preprocess=on_preprocess,
|
||||||
|
force=force,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sync:
|
if sync:
|
||||||
@@ -118,9 +125,16 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Wait for task completion instead of just dispatching",
|
help="Wait for task completion instead of just dispatching",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Cancel old workflow and start new (uses latest code instead of replaying)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
asyncio.run(process_transcript(args.transcript_id, sync=args.sync))
|
asyncio.run(
|
||||||
|
process_transcript(args.transcript_id, sync=args.sync, force=args.force)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user