feat: multitrack cli (#735)

* multitrack cli prd

* prd/todo (no-mistakes)

* multitrack cli (no-mistakes)

* multitrack cli (no-mistakes)

* multitrack cli (no-mistakes)

* multitrack cli (no-mistakes)

* remove multitrack tests most worthless

* useless comments away

* useless comments away

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
Igor Monadical
2025-11-24 10:35:06 -05:00
committed by GitHub
parent 4287f8b8ae
commit 11731c9d38
3 changed files with 643 additions and 19 deletions

View File

@@ -0,0 +1,347 @@
import asyncio
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Protocol
import structlog
from celery.result import AsyncResult
from reflector.db import get_database
from reflector.db.transcripts import SourceKind, Transcript, transcripts_controller
from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process,
)
from reflector.storage import get_transcripts_storage
from reflector.tools.process import (
extract_result_from_entry,
parse_s3_url,
validate_s3_objects,
)
logger = structlog.get_logger(__name__)
DEFAULT_PROCESSING_TIMEOUT_SECONDS = 3600
MAX_ERROR_MESSAGE_LENGTH = 500
TASK_POLL_INTERVAL_SECONDS = 2
class StatusCallback(Protocol):
def __call__(self, state: str, elapsed_seconds: int) -> None: ...
@dataclass
class MultitrackTaskResult:
success: bool
transcript_id: str
error: Optional[str] = None
async def create_multitrack_transcript(
bucket_name: str,
track_keys: List[str],
source_language: str,
target_language: str,
user_id: Optional[str] = None,
) -> Transcript:
num_tracks = len(track_keys)
track_word = "track" if num_tracks == 1 else "tracks"
transcript_name = f"Multitrack ({num_tracks} {track_word})"
transcript = await transcripts_controller.add(
transcript_name,
source_kind=SourceKind.FILE,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
logger.info(
"Created multitrack transcript",
transcript_id=transcript.id,
name=transcript_name,
bucket=bucket_name,
num_tracks=len(track_keys),
)
return transcript
def submit_multitrack_task(
transcript_id: str, bucket_name: str, track_keys: List[str]
) -> AsyncResult:
result = task_pipeline_multitrack_process.delay(
transcript_id=transcript_id,
bucket_name=bucket_name,
track_keys=track_keys,
)
logger.info(
"Multitrack task submitted",
transcript_id=transcript_id,
task_id=result.id,
bucket=bucket_name,
num_tracks=len(track_keys),
)
return result
async def wait_for_task(
result: AsyncResult,
transcript_id: str,
timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS,
poll_interval: int = TASK_POLL_INTERVAL_SECONDS,
status_callback: Optional[StatusCallback] = None,
) -> MultitrackTaskResult:
start_time = time.time()
last_status = None
while not result.ready():
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
error_msg = (
f"Task {result.id} did not complete within {timeout_seconds}s "
f"for transcript {transcript_id}"
)
logger.error(
"Task timeout",
task_id=result.id,
transcript_id=transcript_id,
elapsed_seconds=elapsed,
)
raise TimeoutError(error_msg)
if result.state != last_status:
if status_callback:
status_callback(result.state, int(elapsed))
last_status = result.state
await asyncio.sleep(poll_interval)
if result.failed():
error_info = result.info
traceback_info = getattr(result, "traceback", None)
logger.error(
"Multitrack task failed",
transcript_id=transcript_id,
task_id=result.id,
error=str(error_info),
has_traceback=bool(traceback_info),
)
error_detail = str(error_info)
if traceback_info:
error_detail += f"\nTraceback:\n{traceback_info}"
return MultitrackTaskResult(
success=False, transcript_id=transcript_id, error=error_detail
)
logger.info(
"Multitrack task completed",
transcript_id=transcript_id,
task_id=result.id,
state=result.state,
)
return MultitrackTaskResult(success=True, transcript_id=transcript_id)
async def update_transcript_status(
transcript_id: str,
status: str,
error: Optional[str] = None,
max_error_length: int = MAX_ERROR_MESSAGE_LENGTH,
) -> None:
database = get_database()
connected = False
try:
await database.connect()
connected = True
transcript = await transcripts_controller.get_by_id(transcript_id)
if transcript:
update_data: Dict[str, Any] = {"status": status}
if error:
if len(error) > max_error_length:
error = error[: max_error_length - 3] + "..."
update_data["error"] = error
await transcripts_controller.update(transcript, update_data)
logger.info(
"Updated transcript status",
transcript_id=transcript_id,
status=status,
has_error=bool(error),
)
except Exception as e:
logger.warning(
"Failed to update transcript status",
transcript_id=transcript_id,
error=str(e),
)
finally:
if connected:
try:
await database.disconnect()
except Exception as e:
logger.warning(f"Database disconnect failed: {e}")
async def process_multitrack(
bucket_name: str,
track_keys: List[str],
source_language: str,
target_language: str,
user_id: Optional[str] = None,
timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS,
status_callback: Optional[StatusCallback] = None,
) -> MultitrackTaskResult:
"""High-level orchestration for multitrack processing."""
database = get_database()
transcript = None
connected = False
try:
await database.connect()
connected = True
transcript = await create_multitrack_transcript(
bucket_name=bucket_name,
track_keys=track_keys,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
result = submit_multitrack_task(
transcript_id=transcript.id, bucket_name=bucket_name, track_keys=track_keys
)
except Exception as e:
if transcript:
try:
await update_transcript_status(
transcript_id=transcript.id, status="failed", error=str(e)
)
except Exception as update_error:
logger.error(
"Failed to update transcript status after error",
original_error=str(e),
update_error=str(update_error),
transcript_id=transcript.id,
)
raise
finally:
if connected:
try:
await database.disconnect()
except Exception as e:
logger.warning(f"Database disconnect failed: {e}")
# Poll outside database connection
task_result = await wait_for_task(
result=result,
transcript_id=transcript.id,
timeout_seconds=timeout_seconds,
poll_interval=2,
status_callback=status_callback,
)
if not task_result.success:
await update_transcript_status(
transcript_id=transcript.id, status="failed", error=task_result.error
)
return task_result
def print_progress(message: str) -> None:
"""Print progress message to stderr for CLI visibility."""
print(f"{message}", file=sys.stderr)
def create_status_callback() -> StatusCallback:
"""Create callback for task status updates during polling."""
def callback(state: str, elapsed_seconds: int) -> None:
print_progress(
f"Multitrack pipeline status: {state} (elapsed: {elapsed_seconds}s)"
)
return callback
async def process_multitrack_cli(
s3_urls: List[str],
source_language: str,
target_language: str,
output_path: Optional[str] = None,
) -> None:
if not s3_urls:
raise ValueError("At least one track required for multitrack processing")
bucket_keys = []
for url in s3_urls:
try:
bucket, key = parse_s3_url(url)
bucket_keys.append((bucket, key))
except ValueError as e:
raise ValueError(f"Invalid S3 URL '{url}': {e}") from e
buckets = set(bucket for bucket, _ in bucket_keys)
if len(buckets) > 1:
raise ValueError(
f"All tracks must be in the same S3 bucket. "
f"Found {len(buckets)} different buckets: {sorted(buckets)}. "
f"Please upload all files to a single bucket."
)
primary_bucket = bucket_keys[0][0]
track_keys = [key for _, key in bucket_keys]
print_progress(
f"Starting multitrack CLI processing: "
f"bucket={primary_bucket}, num_tracks={len(track_keys)}, "
f"source_language={source_language}, target_language={target_language}"
)
storage = get_transcripts_storage()
await validate_s3_objects(storage, bucket_keys)
print_progress(f"S3 validation complete: {len(bucket_keys)} objects verified")
result = await process_multitrack(
bucket_name=primary_bucket,
track_keys=track_keys,
source_language=source_language,
target_language=target_language,
user_id=None,
timeout_seconds=3600,
status_callback=create_status_callback(),
)
if not result.success:
error_msg = (
f"Multitrack pipeline failed for transcript {result.transcript_id}\n"
)
if result.error:
error_msg += f"Error: {result.error}\n"
raise RuntimeError(error_msg)
print_progress(
f"Multitrack processing complete for transcript {result.transcript_id}"
)
database = get_database()
await database.connect()
try:
await extract_result_from_entry(result.transcript_id, output_path)
finally:
await database.disconnect()

View File

@@ -9,7 +9,10 @@ import shutil
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Tuple
from urllib.parse import unquote, urlparse
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.logger import logger
@@ -20,10 +23,119 @@ from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipelin
from reflector.pipelines.main_live_pipeline import (
pipeline_process as live_pipeline_process,
)
from reflector.storage import Storage
def validate_s3_bucket_name(bucket: str) -> None:
if not bucket:
raise ValueError("Bucket name cannot be empty")
if len(bucket) > 255: # Absolute max for any region
raise ValueError(f"Bucket name too long: {len(bucket)} characters (max 255)")
def validate_s3_key(key: str) -> None:
if not key:
raise ValueError("S3 key cannot be empty")
if len(key) > 1024:
raise ValueError(f"S3 key too long: {len(key)} characters (max 1024)")
def parse_s3_url(url: str) -> Tuple[str, str]:
parsed = urlparse(url)
if parsed.scheme == "s3":
bucket = parsed.netloc
key = parsed.path.lstrip("/")
if parsed.fragment:
logger.debug(
"URL fragment ignored (not part of S3 key)",
url=url,
fragment=parsed.fragment,
)
if not bucket or not key:
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
bucket = unquote(bucket)
key = unquote(key)
validate_s3_bucket_name(bucket)
validate_s3_key(key)
return bucket, key
elif parsed.scheme in ("http", "https"):
if ".s3." in parsed.netloc or parsed.netloc.endswith(".s3.amazonaws.com"):
bucket = parsed.netloc.split(".")[0]
key = parsed.path.lstrip("/")
if parsed.fragment:
logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment)
if not bucket or not key:
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
bucket = unquote(bucket)
key = unquote(key)
validate_s3_bucket_name(bucket)
validate_s3_key(key)
return bucket, key
elif parsed.netloc.startswith("s3.") and "amazonaws.com" in parsed.netloc:
path_parts = parsed.path.lstrip("/").split("/", 1)
if len(path_parts) != 2:
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
bucket, key = path_parts
if parsed.fragment:
logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment)
bucket = unquote(bucket)
key = unquote(key)
validate_s3_bucket_name(bucket)
validate_s3_key(key)
return bucket, key
else:
raise ValueError(f"Invalid S3 URL format: {url} (not recognized as S3 URL)")
else:
raise ValueError(f"Invalid S3 URL scheme: {url} (must be s3:// or https://)")
async def validate_s3_objects(
storage: Storage, bucket_keys: List[Tuple[str, str]]
) -> None:
async with storage.session.client("s3") as client:
async def check_object(bucket: str, key: str) -> None:
try:
await client.head_object(Bucket=bucket, Key=key)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code in ("404", "NoSuchKey"):
raise ValueError(f"S3 object not found: s3://{bucket}/{key}") from e
elif error_code in ("403", "Forbidden", "AccessDenied"):
raise ValueError(
f"Access denied for S3 object: s3://{bucket}/{key}. "
f"Check AWS credentials and permissions"
) from e
else:
raise ValueError(
f"S3 error {error_code} for s3://{bucket}/{key}: "
f"{e.response['Error'].get('Message', 'Unknown error')}"
) from e
except NoCredentialsError as e:
raise ValueError(
"AWS credentials not configured. Set AWS_ACCESS_KEY_ID and "
"AWS_SECRET_ACCESS_KEY environment variables"
) from e
except BotoCoreError as e:
raise ValueError(
f"AWS service error for s3://{bucket}/{key}: {str(e)}"
) from e
except Exception as e:
raise ValueError(
f"Unexpected error validating s3://{bucket}/{key}: {str(e)}"
) from e
await asyncio.gather(
*(check_object(bucket, key) for bucket, key in bucket_keys)
)
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
serialized = []
for topic in topics:
topic_dict = topic.model_dump()
@@ -32,7 +144,6 @@ def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
"""Print debug info about speakers found in topics"""
all_speakers = set()
for topic_dict in serialized_topics:
for word in topic_dict.get("words", []):
@@ -47,8 +158,6 @@ def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
TranscriptId = str
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
# ideally we want to get rid of it at some point
async def prepare_entry(
source_path: str,
source_language: str,
@@ -65,9 +174,7 @@ async def prepare_entry(
user_id=None,
)
logger.info(
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
)
logger.info(f"Created transcript {transcript.id} for {file_path.name}")
# pipelines expect files as upload.*
@@ -83,7 +190,6 @@ async def prepare_entry(
return transcript.id
# same reason as prepare_entry
async def extract_result_from_entry(
transcript_id: TranscriptId, output_path: str
) -> None:
@@ -193,13 +299,20 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process audio files with speaker diarization"
)
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument(
"source",
help="Source file (mp3, wav, mp4...) or comma-separated S3 URLs with --multitrack",
)
parser.add_argument(
"--pipeline",
required=True,
choices=["live", "file"],
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
)
parser.add_argument(
"--multitrack",
action="store_true",
help="Process multiple audio tracks from comma-separated S3 URLs",
)
parser.add_argument(
"--source-language", default="en", help="Source language code (default: en)"
)
@@ -209,12 +322,40 @@ if __name__ == "__main__":
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
args = parser.parse_args()
asyncio.run(
process(
args.source,
args.source_language,
args.target_language,
args.pipeline,
args.output,
if args.multitrack:
if not args.source:
parser.error("Source URLs required for multitrack processing")
s3_urls = [url.strip() for url in args.source.split(",") if url.strip()]
if not s3_urls:
parser.error("At least one S3 URL required for multitrack processing")
from reflector.tools.cli_multitrack import process_multitrack_cli
asyncio.run(
process_multitrack_cli(
s3_urls,
args.source_language,
args.target_language,
args.output,
)
)
else:
if not args.pipeline:
parser.error("--pipeline is required for single-track processing")
if "," in args.source:
parser.error(
"Multiple files detected. Use --multitrack flag for multitrack processing"
)
asyncio.run(
process(
args.source,
args.source_language,
args.target_language,
args.pipeline,
args.output,
)
)
)

View File

@@ -0,0 +1,136 @@
"""Tests for S3 URL parsing functionality in reflector.tools.process"""
import pytest
from reflector.tools.process import parse_s3_url
class TestParseS3URL:
"""Test cases for parse_s3_url function"""
def test_parse_s3_protocol(self):
"""Test parsing s3:// protocol URLs"""
bucket, key = parse_s3_url("s3://my-bucket/path/to/file.webm")
assert bucket == "my-bucket"
assert key == "path/to/file.webm"
def test_parse_s3_protocol_deep_path(self):
"""Test s3:// with deeply nested paths"""
bucket, key = parse_s3_url("s3://bucket-name/very/deep/path/to/audio.mp4")
assert bucket == "bucket-name"
assert key == "very/deep/path/to/audio.mp4"
def test_parse_https_subdomain_format(self):
"""Test parsing https://bucket.s3.amazonaws.com/key format"""
bucket, key = parse_s3_url("https://my-bucket.s3.amazonaws.com/path/file.webm")
assert bucket == "my-bucket"
assert key == "path/file.webm"
def test_parse_https_regional_subdomain(self):
"""Test parsing regional endpoint with subdomain"""
bucket, key = parse_s3_url(
"https://my-bucket.s3.us-west-2.amazonaws.com/path/file.webm"
)
assert bucket == "my-bucket"
assert key == "path/file.webm"
def test_parse_https_path_style(self):
"""Test parsing https://s3.amazonaws.com/bucket/key format"""
bucket, key = parse_s3_url("https://s3.amazonaws.com/my-bucket/path/file.webm")
assert bucket == "my-bucket"
assert key == "path/file.webm"
def test_parse_https_regional_path_style(self):
"""Test parsing regional endpoint with path style"""
bucket, key = parse_s3_url(
"https://s3.us-east-1.amazonaws.com/my-bucket/path/file.webm"
)
assert bucket == "my-bucket"
assert key == "path/file.webm"
def test_parse_url_encoded_keys(self):
"""Test parsing URL-encoded keys"""
bucket, key = parse_s3_url(
"s3://my-bucket/path%20with%20spaces/file%2Bname.webm"
)
assert bucket == "my-bucket"
assert key == "path with spaces/file+name.webm" # Should be decoded
def test_parse_url_encoded_https(self):
"""Test URL-encoded keys with HTTPS format"""
bucket, key = parse_s3_url(
"https://my-bucket.s3.amazonaws.com/file%20with%20spaces.webm"
)
assert bucket == "my-bucket"
assert key == "file with spaces.webm"
def test_invalid_url_no_scheme(self):
"""Test that URLs without scheme raise ValueError"""
with pytest.raises(ValueError, match="Invalid S3 URL scheme"):
parse_s3_url("my-bucket/path/file.webm")
def test_invalid_url_wrong_scheme(self):
"""Test that non-S3 schemes raise ValueError"""
with pytest.raises(ValueError, match="Invalid S3 URL scheme"):
parse_s3_url("ftp://my-bucket/path/file.webm")
def test_invalid_s3_missing_bucket(self):
"""Test s3:// URL without bucket raises ValueError"""
with pytest.raises(ValueError, match="missing bucket or key"):
parse_s3_url("s3:///path/file.webm")
def test_invalid_s3_missing_key(self):
"""Test s3:// URL without key raises ValueError"""
with pytest.raises(ValueError, match="missing bucket or key"):
parse_s3_url("s3://my-bucket/")
def test_invalid_s3_empty_key(self):
"""Test s3:// URL with empty key raises ValueError"""
with pytest.raises(ValueError, match="missing bucket or key"):
parse_s3_url("s3://my-bucket")
def test_invalid_https_not_s3(self):
"""Test HTTPS URL that's not S3 raises ValueError"""
with pytest.raises(ValueError, match="not recognized as S3 URL"):
parse_s3_url("https://example.com/path/file.webm")
def test_invalid_https_subdomain_missing_key(self):
"""Test HTTPS subdomain format without key raises ValueError"""
with pytest.raises(ValueError, match="missing bucket or key"):
parse_s3_url("https://my-bucket.s3.amazonaws.com/")
def test_invalid_https_path_style_missing_parts(self):
"""Test HTTPS path style with missing bucket/key raises ValueError"""
with pytest.raises(ValueError, match="missing bucket or key"):
parse_s3_url("https://s3.amazonaws.com/")
def test_bucket_with_dots(self):
"""Test parsing bucket names with dots"""
bucket, key = parse_s3_url("s3://my.bucket.name/path/file.webm")
assert bucket == "my.bucket.name"
assert key == "path/file.webm"
def test_bucket_with_hyphens(self):
"""Test parsing bucket names with hyphens"""
bucket, key = parse_s3_url("s3://my-bucket-name-123/path/file.webm")
assert bucket == "my-bucket-name-123"
assert key == "path/file.webm"
def test_key_with_special_chars(self):
"""Test keys with various special characters"""
# Note: # is treated as URL fragment separator, not part of key
bucket, key = parse_s3_url("s3://bucket/2024-01-01_12:00:00/file.webm")
assert bucket == "bucket"
assert key == "2024-01-01_12:00:00/file.webm"
def test_fragment_handling(self):
"""Test that URL fragments are properly ignored"""
bucket, key = parse_s3_url("s3://bucket/path/to/file.webm#fragment123")
assert bucket == "bucket"
assert key == "path/to/file.webm" # Fragment not included
def test_http_scheme_s3_url(self):
"""Test that HTTP (not HTTPS) S3 URLs are supported"""
bucket, key = parse_s3_url("http://my-bucket.s3.amazonaws.com/path/file.webm")
assert bucket == "my-bucket"
assert key == "path/file.webm"