diff --git a/server/reflector/hatchet/broadcast.py b/server/reflector/hatchet/broadcast.py index 317b5dbb..e9439e61 100644 --- a/server/reflector/hatchet/broadcast.py +++ b/server/reflector/hatchet/broadcast.py @@ -27,7 +27,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None: try: ws_manager = get_ws_manager() - # Broadcast to transcript room await ws_manager.send_json( room_id=f"ts:{transcript_id}", message=event.model_dump(mode="json"), @@ -38,7 +37,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None: event_type=event.event, ) - # Also broadcast to user room for certain events if event.event in USER_ROOM_EVENTS: transcript = await transcripts_controller.get_by_id(transcript_id) if transcript and transcript.user_id: diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index f8121901..e5aba59a 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -90,11 +90,6 @@ diarization_pipeline = hatchet.workflow( ) -# ============================================================================ -# Helper Functions -# ============================================================================ - - @asynccontextmanager async def fresh_db_connection(): """Context manager for database connections in Hatchet workers.""" @@ -177,11 +172,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab return decorator -# ============================================================================ -# Pipeline Tasks -# ============================================================================ - - @diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3) @with_error_handling("get_recording") async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult: @@ -350,7 +340,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes results = await asyncio.gather(*child_coroutines) - # Get target_language for later use in detect_topics target_language = participants_data.get("target_language", "en") # Collect results from each track (don't mutate lists while iterating) @@ -440,13 +429,11 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: if not valid_urls: raise ValueError("No valid padded tracks to mixdown") - # Detect sample rate from tracks target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger) if not target_sample_rate: logger.error("Mixdown failed - no decodable audio frames found") raise ValueError("No decodable audio frames in any track") - # Create temp file and writer for MP3 output output_path = tempfile.mktemp(suffix=".mp3") duration_ms = [0.0] # Mutable container for callback capture @@ -455,7 +442,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration) - # Run mixdown using shared utility await mixdown_tracks_pyav( valid_urls, writer, @@ -465,7 +451,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: ) await writer.flush() - # Upload to storage file_size = Path(output_path).stat().st_size storage_path = f"{input.transcript_id}/audio.mp3" @@ -474,7 +459,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: Path(output_path).unlink(missing_ok=True) - # Update DB with audio location async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 diff --git a/server/reflector/hatchet/workflows/models.py b/server/reflector/hatchet/workflows/models.py index ad28edb4..13b34e07 100644 --- a/server/reflector/hatchet/workflows/models.py +++ b/server/reflector/hatchet/workflows/models.py @@ -9,16 +9,16 @@ from typing import Any from pydantic import BaseModel -# ============================================================================ -# Track Processing Results (track_processing.py) -# ============================================================================ +from reflector.utils.string import NonEmptyString class PadTrackResult(BaseModel): """Result from pad_track task.""" - padded_key: str # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay - bucket_name: str | None # None means use default transcript storage bucket + padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay + bucket_name: ( + NonEmptyString | None + ) # None means use default transcript storage bucket size: int track_index: int @@ -30,16 +30,11 @@ class TranscribeTrackResult(BaseModel): track_index: int -# ============================================================================ -# Diarization Pipeline Results (diarization_pipeline.py) -# ============================================================================ - - class RecordingResult(BaseModel): """Result from get_recording task.""" - id: str | None - mtg_session_id: str | None + id: NonEmptyString | None + mtg_session_id: NonEmptyString | None duration: float @@ -48,15 +43,15 @@ class ParticipantsResult(BaseModel): participants: list[dict[str, Any]] num_tracks: int - source_language: str - target_language: str + source_language: NonEmptyString + target_language: NonEmptyString class PaddedTrackInfo(BaseModel): """Info for a padded track - S3 key + bucket for on-demand presigning.""" - key: str - bucket_name: str | None # None = use default storage bucket + key: NonEmptyString + bucket_name: NonEmptyString | None # None = use default storage bucket class ProcessTracksResult(BaseModel): @@ -66,14 +61,14 @@ class ProcessTracksResult(BaseModel): padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs word_count: int num_tracks: int - target_language: str - created_padded_files: list[str] + target_language: NonEmptyString + created_padded_files: list[NonEmptyString] class MixdownResult(BaseModel): """Result from mixdown_tracks task.""" - audio_key: str + audio_key: NonEmptyString duration: float tracks_mixed: int @@ -106,7 +101,7 @@ class SummaryResult(BaseModel): class FinalizeResult(BaseModel): """Result from finalize task.""" - status: str + status: NonEmptyString class ConsentResult(BaseModel): diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index b5aaa87e..45873daf 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -43,7 +43,6 @@ class TrackInput(BaseModel): language: str = "en" -# Get hatchet client and define workflow hatchet = HatchetClientManager.get_client() track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput) @@ -76,7 +75,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, ) - # Get presigned URL for source file source_url = await storage.get_file_url( input.s3_key, operation="get_object", @@ -84,7 +82,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: bucket=input.bucket_name, ) - # Open container and extract start time with av.open(source_url) as in_container: start_time_seconds = extract_stream_start_time_from_container( in_container, input.track_index, logger=logger @@ -103,7 +100,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: track_index=input.track_index, ) - # Create temp file for padded output with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file: temp_path = temp_file.name diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index aff6e042..85719be5 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -97,13 +97,8 @@ class PipelineMainFile(PipelineMainBase): }, ) - # Extract audio and write to transcript location audio_path = await self.extract_and_write_audio(file_path, transcript) - - # Upload for processing audio_url = await self.upload_audio(audio_path, transcript) - - # Run parallel processing await self.run_parallel_processing( audio_path, audio_url, @@ -197,7 +192,6 @@ class PipelineMainFile(PipelineMainBase): transcript_result = results[0] diarization_result = results[1] - # Handle errors - raise any exception that occurred self._handle_gather_exceptions(results, "parallel processing") for result in results: if isinstance(result, Exception): @@ -212,7 +206,6 @@ class PipelineMainFile(PipelineMainBase): transcript=transcript_result, diarization=diarization_result or [] ) - # Store result for retrieval diarized_transcript: Transcript | None = None async def capture_result(transcript): @@ -348,7 +341,6 @@ async def task_pipeline_file_process(*, transcript_id: str): try: await pipeline.set_status(transcript_id, "processing") - # Find the file to process audio_file = next(transcript.data_path.glob("upload.*"), None) if not audio_file: audio_file = next(transcript.data_path.glob("audio.*"), None) diff --git a/server/reflector/pipelines/main_multitrack_pipeline.py b/server/reflector/pipelines/main_multitrack_pipeline.py index 72efbf5a..abfaac57 100644 --- a/server/reflector/pipelines/main_multitrack_pipeline.py +++ b/server/reflector/pipelines/main_multitrack_pipeline.py @@ -159,7 +159,6 @@ class PipelineMainMultitrack(PipelineMainBase): with open(temp_path, "rb") as padded_file: await storage.put_file(storage_path, padded_file) finally: - # Clean up temp file Path(temp_path).unlink(missing_ok=True) padded_url = await storage.get_file_url( @@ -196,7 +195,6 @@ class PipelineMainMultitrack(PipelineMainBase): offsets_seconds: list[float] | None = None, ) -> None: """Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs.""" - # Detect sample rate from tracks target_sample_rate = detect_sample_rate_from_tracks( track_urls, logger=self.logger ) @@ -204,7 +202,6 @@ class PipelineMainMultitrack(PipelineMainBase): self.logger.error("Mixdown failed - no decodable audio frames found") raise Exception("Mixdown failed: No decodable audio frames in any track") - # Run mixdown using shared utility await mixdown_tracks_pyav( track_urls, writer,