diff --git a/server/reflector/tools/cli_multitrack.py b/server/reflector/tools/cli_multitrack.py new file mode 100644 index 00000000..aad5ab2f --- /dev/null +++ b/server/reflector/tools/cli_multitrack.py @@ -0,0 +1,347 @@ +import asyncio +import sys +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Protocol + +import structlog +from celery.result import AsyncResult + +from reflector.db import get_database +from reflector.db.transcripts import SourceKind, Transcript, transcripts_controller +from reflector.pipelines.main_multitrack_pipeline import ( + task_pipeline_multitrack_process, +) +from reflector.storage import get_transcripts_storage +from reflector.tools.process import ( + extract_result_from_entry, + parse_s3_url, + validate_s3_objects, +) + +logger = structlog.get_logger(__name__) + +DEFAULT_PROCESSING_TIMEOUT_SECONDS = 3600 + +MAX_ERROR_MESSAGE_LENGTH = 500 + +TASK_POLL_INTERVAL_SECONDS = 2 + + +class StatusCallback(Protocol): + def __call__(self, state: str, elapsed_seconds: int) -> None: ... + + +@dataclass +class MultitrackTaskResult: + success: bool + transcript_id: str + error: Optional[str] = None + + +async def create_multitrack_transcript( + bucket_name: str, + track_keys: List[str], + source_language: str, + target_language: str, + user_id: Optional[str] = None, +) -> Transcript: + num_tracks = len(track_keys) + track_word = "track" if num_tracks == 1 else "tracks" + transcript_name = f"Multitrack ({num_tracks} {track_word})" + + transcript = await transcripts_controller.add( + transcript_name, + source_kind=SourceKind.FILE, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) + + logger.info( + "Created multitrack transcript", + transcript_id=transcript.id, + name=transcript_name, + bucket=bucket_name, + num_tracks=len(track_keys), + ) + + return transcript + + +def submit_multitrack_task( + transcript_id: str, bucket_name: str, track_keys: List[str] +) -> AsyncResult: + result = task_pipeline_multitrack_process.delay( + transcript_id=transcript_id, + bucket_name=bucket_name, + track_keys=track_keys, + ) + + logger.info( + "Multitrack task submitted", + transcript_id=transcript_id, + task_id=result.id, + bucket=bucket_name, + num_tracks=len(track_keys), + ) + + return result + + +async def wait_for_task( + result: AsyncResult, + transcript_id: str, + timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS, + poll_interval: int = TASK_POLL_INTERVAL_SECONDS, + status_callback: Optional[StatusCallback] = None, +) -> MultitrackTaskResult: + start_time = time.time() + last_status = None + + while not result.ready(): + elapsed = time.time() - start_time + if elapsed > timeout_seconds: + error_msg = ( + f"Task {result.id} did not complete within {timeout_seconds}s " + f"for transcript {transcript_id}" + ) + logger.error( + "Task timeout", + task_id=result.id, + transcript_id=transcript_id, + elapsed_seconds=elapsed, + ) + raise TimeoutError(error_msg) + + if result.state != last_status: + if status_callback: + status_callback(result.state, int(elapsed)) + last_status = result.state + + await asyncio.sleep(poll_interval) + + if result.failed(): + error_info = result.info + traceback_info = getattr(result, "traceback", None) + + logger.error( + "Multitrack task failed", + transcript_id=transcript_id, + task_id=result.id, + error=str(error_info), + has_traceback=bool(traceback_info), + ) + + error_detail = str(error_info) + if traceback_info: + error_detail += f"\nTraceback:\n{traceback_info}" + + return MultitrackTaskResult( + success=False, transcript_id=transcript_id, error=error_detail + ) + + logger.info( + "Multitrack task completed", + transcript_id=transcript_id, + task_id=result.id, + state=result.state, + ) + + return MultitrackTaskResult(success=True, transcript_id=transcript_id) + + +async def update_transcript_status( + transcript_id: str, + status: str, + error: Optional[str] = None, + max_error_length: int = MAX_ERROR_MESSAGE_LENGTH, +) -> None: + database = get_database() + connected = False + + try: + await database.connect() + connected = True + + transcript = await transcripts_controller.get_by_id(transcript_id) + if transcript: + update_data: Dict[str, Any] = {"status": status} + + if error: + if len(error) > max_error_length: + error = error[: max_error_length - 3] + "..." + update_data["error"] = error + + await transcripts_controller.update(transcript, update_data) + + logger.info( + "Updated transcript status", + transcript_id=transcript_id, + status=status, + has_error=bool(error), + ) + except Exception as e: + logger.warning( + "Failed to update transcript status", + transcript_id=transcript_id, + error=str(e), + ) + finally: + if connected: + try: + await database.disconnect() + except Exception as e: + logger.warning(f"Database disconnect failed: {e}") + + +async def process_multitrack( + bucket_name: str, + track_keys: List[str], + source_language: str, + target_language: str, + user_id: Optional[str] = None, + timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS, + status_callback: Optional[StatusCallback] = None, +) -> MultitrackTaskResult: + """High-level orchestration for multitrack processing.""" + database = get_database() + transcript = None + connected = False + + try: + await database.connect() + connected = True + + transcript = await create_multitrack_transcript( + bucket_name=bucket_name, + track_keys=track_keys, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) + + result = submit_multitrack_task( + transcript_id=transcript.id, bucket_name=bucket_name, track_keys=track_keys + ) + + except Exception as e: + if transcript: + try: + await update_transcript_status( + transcript_id=transcript.id, status="failed", error=str(e) + ) + except Exception as update_error: + logger.error( + "Failed to update transcript status after error", + original_error=str(e), + update_error=str(update_error), + transcript_id=transcript.id, + ) + raise + finally: + if connected: + try: + await database.disconnect() + except Exception as e: + logger.warning(f"Database disconnect failed: {e}") + + # Poll outside database connection + task_result = await wait_for_task( + result=result, + transcript_id=transcript.id, + timeout_seconds=timeout_seconds, + poll_interval=2, + status_callback=status_callback, + ) + + if not task_result.success: + await update_transcript_status( + transcript_id=transcript.id, status="failed", error=task_result.error + ) + + return task_result + + +def print_progress(message: str) -> None: + """Print progress message to stderr for CLI visibility.""" + print(f"{message}", file=sys.stderr) + + +def create_status_callback() -> StatusCallback: + """Create callback for task status updates during polling.""" + + def callback(state: str, elapsed_seconds: int) -> None: + print_progress( + f"Multitrack pipeline status: {state} (elapsed: {elapsed_seconds}s)" + ) + + return callback + + +async def process_multitrack_cli( + s3_urls: List[str], + source_language: str, + target_language: str, + output_path: Optional[str] = None, +) -> None: + if not s3_urls: + raise ValueError("At least one track required for multitrack processing") + + bucket_keys = [] + for url in s3_urls: + try: + bucket, key = parse_s3_url(url) + bucket_keys.append((bucket, key)) + except ValueError as e: + raise ValueError(f"Invalid S3 URL '{url}': {e}") from e + + buckets = set(bucket for bucket, _ in bucket_keys) + if len(buckets) > 1: + raise ValueError( + f"All tracks must be in the same S3 bucket. " + f"Found {len(buckets)} different buckets: {sorted(buckets)}. " + f"Please upload all files to a single bucket." + ) + + primary_bucket = bucket_keys[0][0] + track_keys = [key for _, key in bucket_keys] + + print_progress( + f"Starting multitrack CLI processing: " + f"bucket={primary_bucket}, num_tracks={len(track_keys)}, " + f"source_language={source_language}, target_language={target_language}" + ) + + storage = get_transcripts_storage() + await validate_s3_objects(storage, bucket_keys) + print_progress(f"S3 validation complete: {len(bucket_keys)} objects verified") + + result = await process_multitrack( + bucket_name=primary_bucket, + track_keys=track_keys, + source_language=source_language, + target_language=target_language, + user_id=None, + timeout_seconds=3600, + status_callback=create_status_callback(), + ) + + if not result.success: + error_msg = ( + f"Multitrack pipeline failed for transcript {result.transcript_id}\n" + ) + if result.error: + error_msg += f"Error: {result.error}\n" + raise RuntimeError(error_msg) + + print_progress( + f"Multitrack processing complete for transcript {result.transcript_id}" + ) + + database = get_database() + await database.connect() + try: + await extract_result_from_entry(result.transcript_id, output_path) + finally: + await database.disconnect() diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index eb770f76..a3a74138 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -9,7 +9,10 @@ import shutil import sys import time from pathlib import Path -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Tuple +from urllib.parse import unquote, urlparse + +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller from reflector.logger import logger @@ -20,10 +23,119 @@ from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipelin from reflector.pipelines.main_live_pipeline import ( pipeline_process as live_pipeline_process, ) +from reflector.storage import Storage + + +def validate_s3_bucket_name(bucket: str) -> None: + if not bucket: + raise ValueError("Bucket name cannot be empty") + if len(bucket) > 255: # Absolute max for any region + raise ValueError(f"Bucket name too long: {len(bucket)} characters (max 255)") + + +def validate_s3_key(key: str) -> None: + if not key: + raise ValueError("S3 key cannot be empty") + if len(key) > 1024: + raise ValueError(f"S3 key too long: {len(key)} characters (max 1024)") + + +def parse_s3_url(url: str) -> Tuple[str, str]: + parsed = urlparse(url) + + if parsed.scheme == "s3": + bucket = parsed.netloc + key = parsed.path.lstrip("/") + if parsed.fragment: + logger.debug( + "URL fragment ignored (not part of S3 key)", + url=url, + fragment=parsed.fragment, + ) + if not bucket or not key: + raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)") + bucket = unquote(bucket) + key = unquote(key) + validate_s3_bucket_name(bucket) + validate_s3_key(key) + return bucket, key + + elif parsed.scheme in ("http", "https"): + if ".s3." in parsed.netloc or parsed.netloc.endswith(".s3.amazonaws.com"): + bucket = parsed.netloc.split(".")[0] + key = parsed.path.lstrip("/") + if parsed.fragment: + logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment) + if not bucket or not key: + raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)") + bucket = unquote(bucket) + key = unquote(key) + validate_s3_bucket_name(bucket) + validate_s3_key(key) + return bucket, key + + elif parsed.netloc.startswith("s3.") and "amazonaws.com" in parsed.netloc: + path_parts = parsed.path.lstrip("/").split("/", 1) + if len(path_parts) != 2: + raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)") + bucket, key = path_parts + if parsed.fragment: + logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment) + bucket = unquote(bucket) + key = unquote(key) + validate_s3_bucket_name(bucket) + validate_s3_key(key) + return bucket, key + + else: + raise ValueError(f"Invalid S3 URL format: {url} (not recognized as S3 URL)") + + else: + raise ValueError(f"Invalid S3 URL scheme: {url} (must be s3:// or https://)") + + +async def validate_s3_objects( + storage: Storage, bucket_keys: List[Tuple[str, str]] +) -> None: + async with storage.session.client("s3") as client: + + async def check_object(bucket: str, key: str) -> None: + try: + await client.head_object(Bucket=bucket, Key=key) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code in ("404", "NoSuchKey"): + raise ValueError(f"S3 object not found: s3://{bucket}/{key}") from e + elif error_code in ("403", "Forbidden", "AccessDenied"): + raise ValueError( + f"Access denied for S3 object: s3://{bucket}/{key}. " + f"Check AWS credentials and permissions" + ) from e + else: + raise ValueError( + f"S3 error {error_code} for s3://{bucket}/{key}: " + f"{e.response['Error'].get('Message', 'Unknown error')}" + ) from e + except NoCredentialsError as e: + raise ValueError( + "AWS credentials not configured. Set AWS_ACCESS_KEY_ID and " + "AWS_SECRET_ACCESS_KEY environment variables" + ) from e + except BotoCoreError as e: + raise ValueError( + f"AWS service error for s3://{bucket}/{key}: {str(e)}" + ) from e + except Exception as e: + raise ValueError( + f"Unexpected error validating s3://{bucket}/{key}: {str(e)}" + ) from e + + await asyncio.gather( + *(check_object(bucket, key) for bucket, key in bucket_keys) + ) def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]: - """Convert TranscriptTopic objects to JSON-serializable dicts""" serialized = [] for topic in topics: topic_dict = topic.model_dump() @@ -32,7 +144,6 @@ def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]: def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None: - """Print debug info about speakers found in topics""" all_speakers = set() for topic_dict in serialized_topics: for word in topic_dict.get("words", []): @@ -47,8 +158,6 @@ def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None: TranscriptId = str -# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system) -# ideally we want to get rid of it at some point async def prepare_entry( source_path: str, source_language: str, @@ -65,9 +174,7 @@ async def prepare_entry( user_id=None, ) - logger.info( - f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript" - ) + logger.info(f"Created transcript {transcript.id} for {file_path.name}") # pipelines expect files as upload.* @@ -83,7 +190,6 @@ async def prepare_entry( return transcript.id -# same reason as prepare_entry async def extract_result_from_entry( transcript_id: TranscriptId, output_path: str ) -> None: @@ -193,13 +299,20 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Process audio files with speaker diarization" ) - parser.add_argument("source", help="Source file (mp3, wav, mp4...)") + parser.add_argument( + "source", + help="Source file (mp3, wav, mp4...) or comma-separated S3 URLs with --multitrack", + ) parser.add_argument( "--pipeline", - required=True, choices=["live", "file"], help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)", ) + parser.add_argument( + "--multitrack", + action="store_true", + help="Process multiple audio tracks from comma-separated S3 URLs", + ) parser.add_argument( "--source-language", default="en", help="Source language code (default: en)" ) @@ -209,12 +322,40 @@ if __name__ == "__main__": parser.add_argument("--output", "-o", help="Output file (output.jsonl)") args = parser.parse_args() - asyncio.run( - process( - args.source, - args.source_language, - args.target_language, - args.pipeline, - args.output, + if args.multitrack: + if not args.source: + parser.error("Source URLs required for multitrack processing") + + s3_urls = [url.strip() for url in args.source.split(",") if url.strip()] + + if not s3_urls: + parser.error("At least one S3 URL required for multitrack processing") + + from reflector.tools.cli_multitrack import process_multitrack_cli + + asyncio.run( + process_multitrack_cli( + s3_urls, + args.source_language, + args.target_language, + args.output, + ) + ) + else: + if not args.pipeline: + parser.error("--pipeline is required for single-track processing") + + if "," in args.source: + parser.error( + "Multiple files detected. Use --multitrack flag for multitrack processing" + ) + + asyncio.run( + process( + args.source, + args.source_language, + args.target_language, + args.pipeline, + args.output, + ) ) - ) diff --git a/server/tests/test_s3_url_parser.py b/server/tests/test_s3_url_parser.py new file mode 100644 index 00000000..638f7c29 --- /dev/null +++ b/server/tests/test_s3_url_parser.py @@ -0,0 +1,136 @@ +"""Tests for S3 URL parsing functionality in reflector.tools.process""" + +import pytest + +from reflector.tools.process import parse_s3_url + + +class TestParseS3URL: + """Test cases for parse_s3_url function""" + + def test_parse_s3_protocol(self): + """Test parsing s3:// protocol URLs""" + bucket, key = parse_s3_url("s3://my-bucket/path/to/file.webm") + assert bucket == "my-bucket" + assert key == "path/to/file.webm" + + def test_parse_s3_protocol_deep_path(self): + """Test s3:// with deeply nested paths""" + bucket, key = parse_s3_url("s3://bucket-name/very/deep/path/to/audio.mp4") + assert bucket == "bucket-name" + assert key == "very/deep/path/to/audio.mp4" + + def test_parse_https_subdomain_format(self): + """Test parsing https://bucket.s3.amazonaws.com/key format""" + bucket, key = parse_s3_url("https://my-bucket.s3.amazonaws.com/path/file.webm") + assert bucket == "my-bucket" + assert key == "path/file.webm" + + def test_parse_https_regional_subdomain(self): + """Test parsing regional endpoint with subdomain""" + bucket, key = parse_s3_url( + "https://my-bucket.s3.us-west-2.amazonaws.com/path/file.webm" + ) + assert bucket == "my-bucket" + assert key == "path/file.webm" + + def test_parse_https_path_style(self): + """Test parsing https://s3.amazonaws.com/bucket/key format""" + bucket, key = parse_s3_url("https://s3.amazonaws.com/my-bucket/path/file.webm") + assert bucket == "my-bucket" + assert key == "path/file.webm" + + def test_parse_https_regional_path_style(self): + """Test parsing regional endpoint with path style""" + bucket, key = parse_s3_url( + "https://s3.us-east-1.amazonaws.com/my-bucket/path/file.webm" + ) + assert bucket == "my-bucket" + assert key == "path/file.webm" + + def test_parse_url_encoded_keys(self): + """Test parsing URL-encoded keys""" + bucket, key = parse_s3_url( + "s3://my-bucket/path%20with%20spaces/file%2Bname.webm" + ) + assert bucket == "my-bucket" + assert key == "path with spaces/file+name.webm" # Should be decoded + + def test_parse_url_encoded_https(self): + """Test URL-encoded keys with HTTPS format""" + bucket, key = parse_s3_url( + "https://my-bucket.s3.amazonaws.com/file%20with%20spaces.webm" + ) + assert bucket == "my-bucket" + assert key == "file with spaces.webm" + + def test_invalid_url_no_scheme(self): + """Test that URLs without scheme raise ValueError""" + with pytest.raises(ValueError, match="Invalid S3 URL scheme"): + parse_s3_url("my-bucket/path/file.webm") + + def test_invalid_url_wrong_scheme(self): + """Test that non-S3 schemes raise ValueError""" + with pytest.raises(ValueError, match="Invalid S3 URL scheme"): + parse_s3_url("ftp://my-bucket/path/file.webm") + + def test_invalid_s3_missing_bucket(self): + """Test s3:// URL without bucket raises ValueError""" + with pytest.raises(ValueError, match="missing bucket or key"): + parse_s3_url("s3:///path/file.webm") + + def test_invalid_s3_missing_key(self): + """Test s3:// URL without key raises ValueError""" + with pytest.raises(ValueError, match="missing bucket or key"): + parse_s3_url("s3://my-bucket/") + + def test_invalid_s3_empty_key(self): + """Test s3:// URL with empty key raises ValueError""" + with pytest.raises(ValueError, match="missing bucket or key"): + parse_s3_url("s3://my-bucket") + + def test_invalid_https_not_s3(self): + """Test HTTPS URL that's not S3 raises ValueError""" + with pytest.raises(ValueError, match="not recognized as S3 URL"): + parse_s3_url("https://example.com/path/file.webm") + + def test_invalid_https_subdomain_missing_key(self): + """Test HTTPS subdomain format without key raises ValueError""" + with pytest.raises(ValueError, match="missing bucket or key"): + parse_s3_url("https://my-bucket.s3.amazonaws.com/") + + def test_invalid_https_path_style_missing_parts(self): + """Test HTTPS path style with missing bucket/key raises ValueError""" + with pytest.raises(ValueError, match="missing bucket or key"): + parse_s3_url("https://s3.amazonaws.com/") + + def test_bucket_with_dots(self): + """Test parsing bucket names with dots""" + bucket, key = parse_s3_url("s3://my.bucket.name/path/file.webm") + assert bucket == "my.bucket.name" + assert key == "path/file.webm" + + def test_bucket_with_hyphens(self): + """Test parsing bucket names with hyphens""" + bucket, key = parse_s3_url("s3://my-bucket-name-123/path/file.webm") + assert bucket == "my-bucket-name-123" + assert key == "path/file.webm" + + def test_key_with_special_chars(self): + """Test keys with various special characters""" + # Note: # is treated as URL fragment separator, not part of key + bucket, key = parse_s3_url("s3://bucket/2024-01-01_12:00:00/file.webm") + assert bucket == "bucket" + assert key == "2024-01-01_12:00:00/file.webm" + + def test_fragment_handling(self): + """Test that URL fragments are properly ignored""" + bucket, key = parse_s3_url("s3://bucket/path/to/file.webm#fragment123") + assert bucket == "bucket" + assert key == "path/to/file.webm" # Fragment not included + + def test_http_scheme_s3_url(self): + """Test that HTTP (not HTTPS) S3 URLs are supported""" + bucket, key = parse_s3_url("http://my-bucket.s3.amazonaws.com/path/file.webm") + assert bucket == "my-bucket" + assert key == "path/file.webm"