mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
self-review round
This commit is contained in:
@@ -27,7 +27,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
|||||||
try:
|
try:
|
||||||
ws_manager = get_ws_manager()
|
ws_manager = get_ws_manager()
|
||||||
|
|
||||||
# Broadcast to transcript room
|
|
||||||
await ws_manager.send_json(
|
await ws_manager.send_json(
|
||||||
room_id=f"ts:{transcript_id}",
|
room_id=f"ts:{transcript_id}",
|
||||||
message=event.model_dump(mode="json"),
|
message=event.model_dump(mode="json"),
|
||||||
@@ -38,7 +37,6 @@ async def broadcast_event(transcript_id: str, event: TranscriptEvent) -> None:
|
|||||||
event_type=event.event,
|
event_type=event.event,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also broadcast to user room for certain events
|
|
||||||
if event.event in USER_ROOM_EVENTS:
|
if event.event in USER_ROOM_EVENTS:
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
if transcript and transcript.user_id:
|
if transcript and transcript.user_id:
|
||||||
|
|||||||
@@ -90,11 +90,6 @@ diarization_pipeline = hatchet.workflow(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def fresh_db_connection():
|
async def fresh_db_connection():
|
||||||
"""Context manager for database connections in Hatchet workers."""
|
"""Context manager for database connections in Hatchet workers."""
|
||||||
@@ -177,11 +172,6 @@ def with_error_handling(step_name: str, set_error_status: bool = True) -> Callab
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Pipeline Tasks
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
@diarization_pipeline.task(execution_timeout=timedelta(seconds=60), retries=3)
|
||||||
@with_error_handling("get_recording")
|
@with_error_handling("get_recording")
|
||||||
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||||
@@ -350,7 +340,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
|
|
||||||
results = await asyncio.gather(*child_coroutines)
|
results = await asyncio.gather(*child_coroutines)
|
||||||
|
|
||||||
# Get target_language for later use in detect_topics
|
|
||||||
target_language = participants_data.get("target_language", "en")
|
target_language = participants_data.get("target_language", "en")
|
||||||
|
|
||||||
# Collect results from each track (don't mutate lists while iterating)
|
# Collect results from each track (don't mutate lists while iterating)
|
||||||
@@ -440,13 +429,11 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
if not valid_urls:
|
if not valid_urls:
|
||||||
raise ValueError("No valid padded tracks to mixdown")
|
raise ValueError("No valid padded tracks to mixdown")
|
||||||
|
|
||||||
# Detect sample rate from tracks
|
|
||||||
target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger)
|
target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger)
|
||||||
if not target_sample_rate:
|
if not target_sample_rate:
|
||||||
logger.error("Mixdown failed - no decodable audio frames found")
|
logger.error("Mixdown failed - no decodable audio frames found")
|
||||||
raise ValueError("No decodable audio frames in any track")
|
raise ValueError("No decodable audio frames in any track")
|
||||||
|
|
||||||
# Create temp file and writer for MP3 output
|
|
||||||
output_path = tempfile.mktemp(suffix=".mp3")
|
output_path = tempfile.mktemp(suffix=".mp3")
|
||||||
duration_ms = [0.0] # Mutable container for callback capture
|
duration_ms = [0.0] # Mutable container for callback capture
|
||||||
|
|
||||||
@@ -455,7 +442,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
|
|
||||||
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
|
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
|
||||||
|
|
||||||
# Run mixdown using shared utility
|
|
||||||
await mixdown_tracks_pyav(
|
await mixdown_tracks_pyav(
|
||||||
valid_urls,
|
valid_urls,
|
||||||
writer,
|
writer,
|
||||||
@@ -465,7 +451,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
)
|
)
|
||||||
await writer.flush()
|
await writer.flush()
|
||||||
|
|
||||||
# Upload to storage
|
|
||||||
file_size = Path(output_path).stat().st_size
|
file_size = Path(output_path).stat().st_size
|
||||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
storage_path = f"{input.transcript_id}/audio.mp3"
|
||||||
|
|
||||||
@@ -474,7 +459,6 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
|
|
||||||
Path(output_path).unlink(missing_ok=True)
|
Path(output_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
# Update DB with audio location
|
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
|
|||||||
@@ -9,16 +9,16 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# ============================================================================
|
from reflector.utils.string import NonEmptyString
|
||||||
# Track Processing Results (track_processing.py)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class PadTrackResult(BaseModel):
|
class PadTrackResult(BaseModel):
|
||||||
"""Result from pad_track task."""
|
"""Result from pad_track task."""
|
||||||
|
|
||||||
padded_key: str # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
padded_key: NonEmptyString # S3 key (not presigned URL) - presign on demand to avoid stale URLs on replay
|
||||||
bucket_name: str | None # None means use default transcript storage bucket
|
bucket_name: (
|
||||||
|
NonEmptyString | None
|
||||||
|
) # None means use default transcript storage bucket
|
||||||
size: int
|
size: int
|
||||||
track_index: int
|
track_index: int
|
||||||
|
|
||||||
@@ -30,16 +30,11 @@ class TranscribeTrackResult(BaseModel):
|
|||||||
track_index: int
|
track_index: int
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Diarization Pipeline Results (diarization_pipeline.py)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class RecordingResult(BaseModel):
|
class RecordingResult(BaseModel):
|
||||||
"""Result from get_recording task."""
|
"""Result from get_recording task."""
|
||||||
|
|
||||||
id: str | None
|
id: NonEmptyString | None
|
||||||
mtg_session_id: str | None
|
mtg_session_id: NonEmptyString | None
|
||||||
duration: float
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
@@ -48,15 +43,15 @@ class ParticipantsResult(BaseModel):
|
|||||||
|
|
||||||
participants: list[dict[str, Any]]
|
participants: list[dict[str, Any]]
|
||||||
num_tracks: int
|
num_tracks: int
|
||||||
source_language: str
|
source_language: NonEmptyString
|
||||||
target_language: str
|
target_language: NonEmptyString
|
||||||
|
|
||||||
|
|
||||||
class PaddedTrackInfo(BaseModel):
|
class PaddedTrackInfo(BaseModel):
|
||||||
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
"""Info for a padded track - S3 key + bucket for on-demand presigning."""
|
||||||
|
|
||||||
key: str
|
key: NonEmptyString
|
||||||
bucket_name: str | None # None = use default storage bucket
|
bucket_name: NonEmptyString | None # None = use default storage bucket
|
||||||
|
|
||||||
|
|
||||||
class ProcessTracksResult(BaseModel):
|
class ProcessTracksResult(BaseModel):
|
||||||
@@ -66,14 +61,14 @@ class ProcessTracksResult(BaseModel):
|
|||||||
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
|
padded_tracks: list[PaddedTrackInfo] # S3 keys, not presigned URLs
|
||||||
word_count: int
|
word_count: int
|
||||||
num_tracks: int
|
num_tracks: int
|
||||||
target_language: str
|
target_language: NonEmptyString
|
||||||
created_padded_files: list[str]
|
created_padded_files: list[NonEmptyString]
|
||||||
|
|
||||||
|
|
||||||
class MixdownResult(BaseModel):
|
class MixdownResult(BaseModel):
|
||||||
"""Result from mixdown_tracks task."""
|
"""Result from mixdown_tracks task."""
|
||||||
|
|
||||||
audio_key: str
|
audio_key: NonEmptyString
|
||||||
duration: float
|
duration: float
|
||||||
tracks_mixed: int
|
tracks_mixed: int
|
||||||
|
|
||||||
@@ -106,7 +101,7 @@ class SummaryResult(BaseModel):
|
|||||||
class FinalizeResult(BaseModel):
|
class FinalizeResult(BaseModel):
|
||||||
"""Result from finalize task."""
|
"""Result from finalize task."""
|
||||||
|
|
||||||
status: str
|
status: NonEmptyString
|
||||||
|
|
||||||
|
|
||||||
class ConsentResult(BaseModel):
|
class ConsentResult(BaseModel):
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class TrackInput(BaseModel):
|
|||||||
language: str = "en"
|
language: str = "en"
|
||||||
|
|
||||||
|
|
||||||
# Get hatchet client and define workflow
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||||
@@ -76,7 +75,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get presigned URL for source file
|
|
||||||
source_url = await storage.get_file_url(
|
source_url = await storage.get_file_url(
|
||||||
input.s3_key,
|
input.s3_key,
|
||||||
operation="get_object",
|
operation="get_object",
|
||||||
@@ -84,7 +82,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
bucket=input.bucket_name,
|
bucket=input.bucket_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Open container and extract start time
|
|
||||||
with av.open(source_url) as in_container:
|
with av.open(source_url) as in_container:
|
||||||
start_time_seconds = extract_stream_start_time_from_container(
|
start_time_seconds = extract_stream_start_time_from_container(
|
||||||
in_container, input.track_index, logger=logger
|
in_container, input.track_index, logger=logger
|
||||||
@@ -103,7 +100,6 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
|||||||
track_index=input.track_index,
|
track_index=input.track_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create temp file for padded output
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as temp_file:
|
||||||
temp_path = temp_file.name
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
|||||||
@@ -97,13 +97,8 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract audio and write to transcript location
|
|
||||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||||
|
|
||||||
# Upload for processing
|
|
||||||
audio_url = await self.upload_audio(audio_path, transcript)
|
audio_url = await self.upload_audio(audio_path, transcript)
|
||||||
|
|
||||||
# Run parallel processing
|
|
||||||
await self.run_parallel_processing(
|
await self.run_parallel_processing(
|
||||||
audio_path,
|
audio_path,
|
||||||
audio_url,
|
audio_url,
|
||||||
@@ -197,7 +192,6 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
transcript_result = results[0]
|
transcript_result = results[0]
|
||||||
diarization_result = results[1]
|
diarization_result = results[1]
|
||||||
|
|
||||||
# Handle errors - raise any exception that occurred
|
|
||||||
self._handle_gather_exceptions(results, "parallel processing")
|
self._handle_gather_exceptions(results, "parallel processing")
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
@@ -212,7 +206,6 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
transcript=transcript_result, diarization=diarization_result or []
|
transcript=transcript_result, diarization=diarization_result or []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store result for retrieval
|
|
||||||
diarized_transcript: Transcript | None = None
|
diarized_transcript: Transcript | None = None
|
||||||
|
|
||||||
async def capture_result(transcript):
|
async def capture_result(transcript):
|
||||||
@@ -348,7 +341,6 @@ async def task_pipeline_file_process(*, transcript_id: str):
|
|||||||
try:
|
try:
|
||||||
await pipeline.set_status(transcript_id, "processing")
|
await pipeline.set_status(transcript_id, "processing")
|
||||||
|
|
||||||
# Find the file to process
|
|
||||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||||
if not audio_file:
|
if not audio_file:
|
||||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||||
|
|||||||
@@ -159,7 +159,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
with open(temp_path, "rb") as padded_file:
|
with open(temp_path, "rb") as padded_file:
|
||||||
await storage.put_file(storage_path, padded_file)
|
await storage.put_file(storage_path, padded_file)
|
||||||
finally:
|
finally:
|
||||||
# Clean up temp file
|
|
||||||
Path(temp_path).unlink(missing_ok=True)
|
Path(temp_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
padded_url = await storage.get_file_url(
|
padded_url = await storage.get_file_url(
|
||||||
@@ -196,7 +195,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
offsets_seconds: list[float] | None = None,
|
offsets_seconds: list[float] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs."""
|
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs."""
|
||||||
# Detect sample rate from tracks
|
|
||||||
target_sample_rate = detect_sample_rate_from_tracks(
|
target_sample_rate = detect_sample_rate_from_tracks(
|
||||||
track_urls, logger=self.logger
|
track_urls, logger=self.logger
|
||||||
)
|
)
|
||||||
@@ -204,7 +202,6 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
self.logger.error("Mixdown failed - no decodable audio frames found")
|
self.logger.error("Mixdown failed - no decodable audio frames found")
|
||||||
raise Exception("Mixdown failed: No decodable audio frames in any track")
|
raise Exception("Mixdown failed: No decodable audio frames in any track")
|
||||||
|
|
||||||
# Run mixdown using shared utility
|
|
||||||
await mixdown_tracks_pyav(
|
await mixdown_tracks_pyav(
|
||||||
track_urls,
|
track_urls,
|
||||||
writer,
|
writer,
|
||||||
|
|||||||
Reference in New Issue
Block a user