self-review round

This commit is contained in:
Igor Loskutov
2025-12-17 15:25:29 -05:00
parent f7f2957fc9
commit cb41e9e779
6 changed files with 15 additions and 53 deletions

View File

@@ -27,7 +27,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
try:
ws_manager = get_ws_manager()
# Broadcast to transcript room
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message=event.model_dump(mode="json"),
@@ -38,7 +37,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
event_type=event.event,
)
# Also broadcast to user room for certain events
if event.event in USER_ROOM_EVENTS:
transcript = await transcripts_controller.get_by_id(transcript_id)
if transcript and transcript.user_id:

View File

@@ -90,11 +90,6 @@ diarization_pipeline = hatchet.workflow(
)
# ============================================================================
# Helper Functions
# ============================================================================
@asynccontextmanager
async def fresh_db_connection():
"""Context manager for database connections in Hatchet workers."""
@@ -177,11 +172,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
return decorator
# ============================================================================
# Pipeline Tasks
# ============================================================================
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
@with_error_handling("get_recording")
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
@@ -350,7 +340,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
results = await asyncio.gather(*child_coroutines)
# Get target_language for later use in detect_topics
target_language = participants_data.get("target_language", "en")
# Collect results from each track (don't mutate lists while iterating)
@@ -440,13 +429,11 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
if not valid_urls:
raise ValueError("No valid padded tracks to mixdown")
# Detect sample rate from tracks
target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger)
if not target_sample_rate:
logger.error("Mixdown failed - no decodable audio frames found")
raise ValueError("No decodable audio frames in any track")
# Create temp file and writer for MP3 output
output_path = tempfile.mktemp(suffix=".mp3")
duration_ms = [0.0] # Mutable container for callback capture
@@ -455,7 +442,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
# Run mixdown using shared utility
await mixdown_tracks_pyav(
valid_urls,
writer,
@@ -465,7 +451,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
)
await writer.flush()
# Upload to storage
file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3"
@@ -474,7 +459,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
Path(output_path).unlink(missing_ok=True)
# Update DB with audio location
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415

View File

@@ -9,16 +9,16 @@ from typing import Any
from pydantic import BaseModel
# ============================================================================
# Track Processing Results (track_processing.py)
# ============================================================================
from reflector.utils.string import NonEmptyString
class PadTrackResult(BaseModel):
"""Result from pad_track task."""
padded_key: str # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
bucket_name: str | None # None means use default transcript storage bucket
padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
bucket_name: (
NonEmptyString | None
) # None means use default transcript storage bucket
size: int
track_index: int
@@ -30,16 +30,11 @@ class TranscribeTrackResult(BaseModel):
track_index: int
# ============================================================================
# Diarization Pipeline Results (diarization_pipeline.py)
# ============================================================================
class RecordingResult(BaseModel):
"""Result from get_recording task."""
id: str | None
mtg_session_id: str | None
id: NonEmptyString | None
mtg_session_id: NonEmptyString | None
duration: float
@@ -48,15 +43,15 @@ class ParticipantsResult(BaseModel):
participants: list[dict[str, Any]]
num_tracks: int
source_language: str
target_language: str
source_language: NonEmptyString
target_language: NonEmptyString
class PaddedTrackInfo(BaseModel):
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
key: str
bucket_name: str | None # None = use default storage bucket
key: NonEmptyString
bucket_name: NonEmptyString | None # None = use default storage bucket
class ProcessTracksResult(BaseModel):
@@ -66,14 +61,14 @@ class ProcessTracksResult(BaseModel):
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
word_count: int
num_tracks: int
target_language: str
created_padded_files: list[str]
target_language: NonEmptyString
created_padded_files: list[NonEmptyString]
class MixdownResult(BaseModel):
"""Result from mixdown_tracks task."""
audio_key: str
audio_key: NonEmptyString
duration: float
tracks_mixed: int
@@ -106,7 +101,7 @@ class SummaryResult(BaseModel):
class FinalizeResult(BaseModel):
"""Result from finalize task."""
status: str
status: NonEmptyString
class ConsentResult(BaseModel):

View File

@@ -43,7 +43,6 @@ class TrackInput(BaseModel):
language: str = "en"
# Get hatchet client and define workflow
hatchet = HatchetClientManager.get_client()
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
@@ -76,7 +75,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
# Get presigned URL for source file
source_url = await storage.get_file_url(
input.s3_key,
operation="get_object",
@@ -84,7 +82,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
bucket=input.bucket_name,
)
# Open container and extract start time
with av.open(source_url) as in_container:
start_time_seconds = extract_stream_start_time_from_container(
in_container, input.track_index, logger=logger
@@ -103,7 +100,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
track_index=input.track_index,
)
# Create temp file for padded output
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
temp_path = temp_file.name

View File

@@ -97,13 +97,8 @@ class PipelineMainFile(PipelineMainBase):
},
)
# Extract audio and write to transcript location
audio_path = await self.extract_and_write_audio(file_path, transcript)
# Upload for processing
audio_url = await self.upload_audio(audio_path, transcript)
# Run parallel processing
await self.run_parallel_processing(
audio_path,
audio_url,
@@ -197,7 +192,6 @@ class PipelineMainFile(PipelineMainBase):
transcript_result = results[0]
diarization_result = results[1]
# Handle errors - raise any exception that occurred
self._handle_gather_exceptions(results, "parallel processing")
for result in results:
if isinstance(result, Exception):
@@ -212,7 +206,6 @@ class PipelineMainFile(PipelineMainBase):
transcript=transcript_result, diarization=diarization_result or []
)
# Store result for retrieval
diarized_transcript: Transcript | None = None
async def capture_result(transcript):
@@ -348,7 +341,6 @@ async def task_pipeline_file_process(*, transcript_id: str):
try:
await pipeline.set_status(transcript_id, "processing")
# Find the file to process
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)

View File

@@ -159,7 +159,6 @@ class PipelineMainMultitrack(PipelineMainBase):
with open(temp_path, "rb") as padded_file:
await storage.put_file(storage_path, padded_file)
finally:
# Clean up temp file
Path(temp_path).unlink(missing_ok=True)
padded_url = await storage.get_file_url(
@@ -196,7 +195,6 @@ class PipelineMainMultitrack(PipelineMainBase):
offsets_seconds: list[float] | None = None,
) -> None:
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs."""
# Detect sample rate from tracks
target_sample_rate = detect_sample_rate_from_tracks(
track_urls, logger=self.logger
)
@@ -204,7 +202,6 @@ class PipelineMainMultitrack(PipelineMainBase):
self.logger.error("Mixdown failed - no decodable audio frames found")
raise Exception("Mixdown failed: No decodable audio frames in any track")
# Run mixdown using shared utility
await mixdown_tracks_pyav(
track_urls,
writer,