mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
self-review round
This commit is contained in:
@@ -27,7 +27,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
||||
try:
|
||||
ws_manager = get_ws_manager()
|
||||
|
||||
# Broadcast to transcript room
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message=event.model_dump(mode="json"),
|
||||
@@ -38,7 +37,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
||||
event_type=event.event,
|
||||
)
|
||||
|
||||
# Also broadcast to user room for certain events
|
||||
if event.event in USER_ROOM_EVENTS:
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if transcript and transcript.user_id:
|
||||
|
||||
@@ -90,11 +90,6 @@ diarization_pipeline = hatchet.workflow(
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def fresh_db_connection():
|
||||
"""Context manager for database connections in Hatchet workers."""
|
||||
@@ -177,11 +172,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
||||
return decorator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Tasks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
||||
@with_error_handling("get_recording")
|
||||
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
@@ -350,7 +340,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
|
||||
results = await asyncio.gather(*child_coroutines)
|
||||
|
||||
# Get target_language for later use in detect_topics
|
||||
target_language = participants_data.get("target_language", "en")
|
||||
|
||||
# Collect results from each track (don't mutate lists while iterating)
|
||||
@@ -440,13 +429,11 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
if not valid_urls:
|
||||
raise ValueError("No valid padded tracks to mixdown")
|
||||
|
||||
# Detect sample rate from tracks
|
||||
target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger)
|
||||
if not target_sample_rate:
|
||||
logger.error("Mixdown failed - no decodable audio frames found")
|
||||
raise ValueError("No decodable audio frames in any track")
|
||||
|
||||
# Create temp file and writer for MP3 output
|
||||
output_path = tempfile.mktemp(suffix=".mp3")
|
||||
duration_ms = [0.0] # Mutable container for callback capture
|
||||
|
||||
@@ -455,7 +442,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
|
||||
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
|
||||
|
||||
# Run mixdown using shared utility
|
||||
await mixdown_tracks_pyav(
|
||||
valid_urls,
|
||||
writer,
|
||||
@@ -465,7 +451,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
)
|
||||
await writer.flush()
|
||||
|
||||
# Upload to storage
|
||||
file_size = Path(output_path).stat().st_size
|
||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||
|
||||
@@ -474,7 +459,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
# Update DB with audio location
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
|
||||
@@ -9,16 +9,16 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# ============================================================================
|
||||
# Track Processing Results (track_processing.py)
|
||||
# ============================================================================
|
||||
from reflector.utils.string import NonEmptyString
|
||||
|
||||
|
||||
class PadTrackResult(BaseModel):
|
||||
"""Result from pad_track task."""
|
||||
|
||||
padded_key: str # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
||||
bucket_name: str | None # None means use default transcript storage bucket
|
||||
padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
||||
bucket_name: (
|
||||
NonEmptyString | None
|
||||
) # None means use default transcript storage bucket
|
||||
size: int
|
||||
track_index: int
|
||||
|
||||
@@ -30,16 +30,11 @@ class TranscribeTrackResult(BaseModel):
|
||||
track_index: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Diarization Pipeline Results (diarization_pipeline.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RecordingResult(BaseModel):
|
||||
"""Result from get_recording task."""
|
||||
|
||||
id: str | None
|
||||
mtg_session_id: str | None
|
||||
id: NonEmptyString | None
|
||||
mtg_session_id: NonEmptyString | None
|
||||
duration: float
|
||||
|
||||
|
||||
@@ -48,15 +43,15 @@ class ParticipantsResult(BaseModel):
|
||||
|
||||
participants: list[dict[str, Any]]
|
||||
num_tracks: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
source_language: NonEmptyString
|
||||
target_language: NonEmptyString
|
||||
|
||||
|
||||
class PaddedTrackInfo(BaseModel):
|
||||
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
||||
|
||||
key: str
|
||||
bucket_name: str | None # None = use default storage bucket
|
||||
key: NonEmptyString
|
||||
bucket_name: NonEmptyString | None # None = use default storage bucket
|
||||
|
||||
|
||||
class ProcessTracksResult(BaseModel):
|
||||
@@ -66,14 +61,14 @@ class ProcessTracksResult(BaseModel):
|
||||
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
|
||||
word_count: int
|
||||
num_tracks: int
|
||||
target_language: str
|
||||
created_padded_files: list[str]
|
||||
target_language: NonEmptyString
|
||||
created_padded_files: list[NonEmptyString]
|
||||
|
||||
|
||||
class MixdownResult(BaseModel):
|
||||
"""Result from mixdown_tracks task."""
|
||||
|
||||
audio_key: str
|
||||
audio_key: NonEmptyString
|
||||
duration: float
|
||||
tracks_mixed: int
|
||||
|
||||
@@ -106,7 +101,7 @@ class SummaryResult(BaseModel):
|
||||
class FinalizeResult(BaseModel):
|
||||
"""Result from finalize task."""
|
||||
|
||||
status: str
|
||||
status: NonEmptyString
|
||||
|
||||
|
||||
class ConsentResult(BaseModel):
|
||||
|
||||
@@ -43,7 +43,6 @@ class TrackInput(BaseModel):
|
||||
language: str = "en"
|
||||
|
||||
|
||||
# Get hatchet client and define workflow
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||
@@ -76,7 +75,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
# Get presigned URL for source file
|
||||
source_url = await storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
@@ -84,7 +82,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
bucket=input.bucket_name,
|
||||
)
|
||||
|
||||
# Open container and extract start time
|
||||
with av.open(source_url) as in_container:
|
||||
start_time_seconds = extract_stream_start_time_from_container(
|
||||
in_container, input.track_index, logger=logger
|
||||
@@ -103,7 +100,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
track_index=input.track_index,
|
||||
)
|
||||
|
||||
# Create temp file for padded output
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
|
||||
@@ -97,13 +97,8 @@ class PipelineMainFile(PipelineMainBase):
|
||||
},
|
||||
)
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
|
||||
# Upload for processing
|
||||
audio_url = await self.upload_audio(audio_path, transcript)
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
audio_path,
|
||||
audio_url,
|
||||
@@ -197,7 +192,6 @@ class PipelineMainFile(PipelineMainBase):
|
||||
transcript_result = results[0]
|
||||
diarization_result = results[1]
|
||||
|
||||
# Handle errors - raise any exception that occurred
|
||||
self._handle_gather_exceptions(results, "parallel processing")
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
@@ -212,7 +206,6 @@ class PipelineMainFile(PipelineMainBase):
|
||||
transcript=transcript_result, diarization=diarization_result or []
|
||||
)
|
||||
|
||||
# Store result for retrieval
|
||||
diarized_transcript: Transcript | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
@@ -348,7 +341,6 @@ async def task_pipeline_file_process(*, transcript_id: str):
|
||||
try:
|
||||
await pipeline.set_status(transcript_id, "processing")
|
||||
|
||||
# Find the file to process
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
|
||||
@@ -159,7 +159,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
||||
with open(temp_path, "rb") as padded_file:
|
||||
await storage.put_file(storage_path, padded_file)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
padded_url = await storage.get_file_url(
|
||||
@@ -196,7 +195,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
||||
offsets_seconds: list[float] | None = None,
|
||||
) -> None:
|
||||
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs."""
|
||||
# Detect sample rate from tracks
|
||||
target_sample_rate = detect_sample_rate_from_tracks(
|
||||
track_urls, logger=self.logger
|
||||
)
|
||||
@@ -204,7 +202,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
||||
self.logger.error("Mixdown failed - no decodable audio frames found")
|
||||
raise Exception("Mixdown failed: No decodable audio frames in any track")
|
||||
|
||||
# Run mixdown using shared utility
|
||||
await mixdown_tracks_pyav(
|
||||
track_urls,
|
||||
writer,
|
||||
|
||||
Reference in New Issue
Block a user