mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
self-review round
This commit is contained in:
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
Hatchet workflow utilities.
|
|
||||||
|
|
||||||
Shared helpers for Hatchet task implementations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(output) -> dict:
|
|
||||||
"""Convert task output to dict, handling both dict and Pydantic model returns.
|
|
||||||
|
|
||||||
Hatchet SDK can return task outputs as either raw dicts or Pydantic models
|
|
||||||
depending on serialization context. This normalizes the output for consistent
|
|
||||||
downstream processing.
|
|
||||||
"""
|
|
||||||
if isinstance(output, dict):
|
|
||||||
return output
|
|
||||||
return output.model_dump()
|
|
||||||
@@ -27,7 +27,6 @@ from reflector.hatchet.broadcast import (
|
|||||||
set_status_and_broadcast,
|
set_status_and_broadcast,
|
||||||
)
|
)
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.utils import to_dict
|
|
||||||
from reflector.hatchet.workflows.models import (
|
from reflector.hatchet.workflows.models import (
|
||||||
ConsentResult,
|
ConsentResult,
|
||||||
FinalizeResult,
|
FinalizeResult,
|
||||||
@@ -69,7 +68,7 @@ from reflector.utils.daily import (
|
|||||||
filter_cam_audio_tracks,
|
filter_cam_audio_tracks,
|
||||||
parse_daily_recording_filename,
|
parse_daily_recording_filename,
|
||||||
)
|
)
|
||||||
from reflector.utils.string import NonEmptyString
|
from reflector.utils.string import NonEmptyString, assert_non_none_and_non_empty
|
||||||
from reflector.zulip import post_transcript_notification
|
from reflector.zulip import post_transcript_notification
|
||||||
|
|
||||||
|
|
||||||
@@ -212,8 +211,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
"""Fetch participant list from Daily.co API and update transcript in database."""
|
"""Fetch participant list from Daily.co API and update transcript in database."""
|
||||||
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
|
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
|
||||||
|
|
||||||
recording_data = to_dict(ctx.task_output(get_recording))
|
recording = ctx.task_output(get_recording)
|
||||||
mtg_session_id = recording_data.get("mtg_session_id")
|
mtg_session_id = recording.mtg_session_id
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
@@ -233,15 +232,14 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not mtg_session_id or not settings.DAILY_API_KEY:
|
mtg_session_id = assert_non_none_and_non_empty(
|
||||||
return ParticipantsResult(
|
mtg_session_id, "mtg_session_id is required"
|
||||||
participants=[],
|
)
|
||||||
num_tracks=len(input.tracks),
|
daily_api_key = assert_non_none_and_non_empty(
|
||||||
source_language=transcript.source_language if transcript else "en",
|
settings.DAILY_API_KEY, "DAILY_API_KEY is required"
|
||||||
target_language=transcript.target_language if transcript else "en",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
async with DailyApiClient(api_key=daily_api_key) as client:
|
||||||
participants = await client.get_meeting_participants(mtg_session_id)
|
participants = await client.get_meeting_participants(mtg_session_id)
|
||||||
|
|
||||||
id_to_name = {}
|
id_to_name = {}
|
||||||
@@ -302,8 +300,8 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
"""Spawn child workflows for each track (dynamic fan-out)."""
|
"""Spawn child workflows for each track (dynamic fan-out)."""
|
||||||
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
|
ctx.log(f"process_tracks: spawning {len(input.tracks)} track workflows")
|
||||||
|
|
||||||
participants_data = to_dict(ctx.task_output(get_participants))
|
participants_result = ctx.task_output(get_participants)
|
||||||
source_language = participants_data.get("source_language", "en")
|
source_language = participants_result.source_language
|
||||||
|
|
||||||
child_coroutines = [
|
child_coroutines = [
|
||||||
track_workflow.aio_run(
|
track_workflow.aio_run(
|
||||||
@@ -320,7 +318,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
|
|
||||||
results = await asyncio.gather(*child_coroutines)
|
results = await asyncio.gather(*child_coroutines)
|
||||||
|
|
||||||
target_language = participants_data.get("target_language", "en")
|
target_language = participants_result.target_language
|
||||||
|
|
||||||
# Collect results from each track (don't mutate lists while iterating)
|
# Collect results from each track (don't mutate lists while iterating)
|
||||||
track_words = []
|
track_words = []
|
||||||
@@ -371,31 +369,23 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
||||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||||
|
|
||||||
track_data = to_dict(ctx.task_output(process_tracks))
|
track_result = ctx.task_output(process_tracks)
|
||||||
padded_tracks_data = track_data.get("padded_tracks", [])
|
padded_tracks = track_result.padded_tracks
|
||||||
|
|
||||||
if not padded_tracks_data:
|
if not padded_tracks:
|
||||||
raise ValueError("No padded tracks to mixdown")
|
raise ValueError("No padded tracks to mixdown")
|
||||||
|
|
||||||
storage = _spawn_storage()
|
storage = _spawn_storage()
|
||||||
|
|
||||||
# Presign URLs on demand (avoids stale URLs on workflow replay)
|
# Presign URLs on demand (avoids stale URLs on workflow replay)
|
||||||
padded_urls = []
|
padded_urls = []
|
||||||
for track_info in padded_tracks_data:
|
for track_info in padded_tracks:
|
||||||
# Handle both dict (from to_dict) and PaddedTrackInfo
|
if track_info.key:
|
||||||
if isinstance(track_info, dict):
|
|
||||||
key = track_info.get("key")
|
|
||||||
bucket = track_info.get("bucket_name")
|
|
||||||
else:
|
|
||||||
key = track_info.key
|
|
||||||
bucket = track_info.bucket_name
|
|
||||||
|
|
||||||
if key:
|
|
||||||
url = await storage.get_file_url(
|
url = await storage.get_file_url(
|
||||||
key,
|
track_info.key,
|
||||||
operation="get_object",
|
operation="get_object",
|
||||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
bucket=bucket,
|
bucket=track_info.bucket_name,
|
||||||
)
|
)
|
||||||
padded_urls.append(url)
|
padded_urls.append(url)
|
||||||
|
|
||||||
@@ -465,8 +455,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Cleanup temporary padded S3 files (deferred until after mixdown)
|
# Cleanup temporary padded S3 files (deferred until after mixdown)
|
||||||
track_data = to_dict(ctx.task_output(process_tracks))
|
track_result = ctx.task_output(process_tracks)
|
||||||
created_padded_files = track_data.get("created_padded_files", [])
|
created_padded_files = track_result.created_padded_files
|
||||||
if created_padded_files:
|
if created_padded_files:
|
||||||
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
|
ctx.log(f"Cleaning up {len(created_padded_files)} temporary S3 files")
|
||||||
storage = _spawn_storage()
|
storage = _spawn_storage()
|
||||||
@@ -483,8 +473,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
|||||||
error=str(result),
|
error=str(result),
|
||||||
)
|
)
|
||||||
|
|
||||||
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
|
mixdown_result = ctx.task_output(mixdown_tracks)
|
||||||
audio_key = mixdown_data.get("audio_key")
|
audio_key = mixdown_result.audio_key
|
||||||
|
|
||||||
storage = _spawn_storage()
|
storage = _spawn_storage()
|
||||||
audio_url = await storage.get_file_url(
|
audio_url = await storage.get_file_url(
|
||||||
@@ -532,9 +522,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
|||||||
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
"""Detect topics using LLM and save to database (matches Celery on_topic callback)."""
|
||||||
ctx.log("detect_topics: analyzing transcript for topics")
|
ctx.log("detect_topics: analyzing transcript for topics")
|
||||||
|
|
||||||
track_data = to_dict(ctx.task_output(process_tracks))
|
track_result = ctx.task_output(process_tracks)
|
||||||
words = track_data.get("all_words", [])
|
words = track_result.all_words
|
||||||
target_language = track_data.get("target_language", "en")
|
target_language = track_result.target_language
|
||||||
|
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
@@ -589,8 +579,8 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
|||||||
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
"""Generate meeting title using LLM and save to database (matches Celery on_title callback)."""
|
||||||
ctx.log("generate_title: generating title from topics")
|
ctx.log("generate_title: generating title from topics")
|
||||||
|
|
||||||
topics_data = to_dict(ctx.task_output(detect_topics))
|
topics_result = ctx.task_output(detect_topics)
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_result.topics
|
||||||
|
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptFinalTitle,
|
TranscriptFinalTitle,
|
||||||
@@ -638,8 +628,8 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
|
|||||||
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
"""Generate meeting summary using LLM and save to database (matches Celery callbacks)."""
|
||||||
ctx.log("generate_summary: generating long and short summaries")
|
ctx.log("generate_summary: generating long and short summaries")
|
||||||
|
|
||||||
topics_data = to_dict(ctx.task_output(detect_topics))
|
topics_result = ctx.task_output(detect_topics)
|
||||||
topics = topics_data.get("topics", [])
|
topics = topics_result.topics
|
||||||
|
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
TranscriptFinalLongSummary,
|
TranscriptFinalLongSummary,
|
||||||
@@ -718,11 +708,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
|||||||
"""
|
"""
|
||||||
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
ctx.log("finalize: saving transcript and setting status to 'ended'")
|
||||||
|
|
||||||
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
|
mixdown_result = ctx.task_output(mixdown_tracks)
|
||||||
track_data = to_dict(ctx.task_output(process_tracks))
|
track_result = ctx.task_output(process_tracks)
|
||||||
|
|
||||||
duration = mixdown_data.get("duration", 0)
|
duration = mixdown_result.duration
|
||||||
all_words = track_data.get("all_words", [])
|
all_words = track_result.all_words
|
||||||
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from hatchet_sdk import Context
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
from reflector.hatchet.utils import to_dict
|
|
||||||
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
|
||||||
@@ -166,9 +165,9 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pad_result = to_dict(ctx.task_output(pad_track))
|
pad_result = ctx.task_output(pad_track)
|
||||||
padded_key = pad_result.get("padded_key")
|
padded_key = pad_result.padded_key
|
||||||
bucket_name = pad_result.get("bucket_name")
|
bucket_name = pad_result.bucket_name
|
||||||
|
|
||||||
if not padded_key:
|
if not padded_key:
|
||||||
raise ValueError("Missing padded_key from pad_track")
|
raise ValueError("Missing padded_key from pad_track")
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from typing import Annotated, TypeVar
|
|||||||
|
|
||||||
from pydantic import Field, TypeAdapter, constr
|
from pydantic import Field, TypeAdapter, constr
|
||||||
|
|
||||||
|
from reflector.utils.common import assert_not_none
|
||||||
|
|
||||||
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
|
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
|
||||||
NonEmptyString = Annotated[
|
NonEmptyString = Annotated[
|
||||||
NonEmptyStringBase,
|
NonEmptyStringBase,
|
||||||
@@ -30,3 +32,11 @@ def assert_equal[T](s1: T, s2: T) -> T:
|
|||||||
if s1 != s2:
|
if s1 != s2:
|
||||||
raise ValueError(f"assert_equal: {s1} != {s2}")
|
raise ValueError(f"assert_equal: {s1} != {s2}")
|
||||||
return s1
|
return s1
|
||||||
|
|
||||||
|
|
||||||
|
def assert_non_none_and_non_empty(
|
||||||
|
value: str | None, error: str | None = None
|
||||||
|
) -> NonEmptyString:
|
||||||
|
return parse_non_empty_string(
|
||||||
|
assert_not_none(value, error or "Value is None"), error
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user