mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: multitrack cli (#735)
* multitrack cli prd * prd/todo (no-mistakes) * multitrack cli (no-mistakes) * multitrack cli (no-mistakes) * multitrack cli (no-mistakes) * multitrack cli (no-mistakes) * remove multitrack tests most worthless * useless comments away * useless comments away --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
347
server/reflector/tools/cli_multitrack.py
Normal file
347
server/reflector/tools/cli_multitrack.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.db.transcripts import SourceKind, Transcript, transcripts_controller
|
||||||
|
from reflector.pipelines.main_multitrack_pipeline import (
|
||||||
|
task_pipeline_multitrack_process,
|
||||||
|
)
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.tools.process import (
|
||||||
|
extract_result_from_entry,
|
||||||
|
parse_s3_url,
|
||||||
|
validate_s3_objects,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_PROCESSING_TIMEOUT_SECONDS = 3600
|
||||||
|
|
||||||
|
MAX_ERROR_MESSAGE_LENGTH = 500
|
||||||
|
|
||||||
|
TASK_POLL_INTERVAL_SECONDS = 2
|
||||||
|
|
||||||
|
|
||||||
|
class StatusCallback(Protocol):
|
||||||
|
def __call__(self, state: str, elapsed_seconds: int) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultitrackTaskResult:
|
||||||
|
success: bool
|
||||||
|
transcript_id: str
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_multitrack_transcript(
|
||||||
|
bucket_name: str,
|
||||||
|
track_keys: List[str],
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
) -> Transcript:
|
||||||
|
num_tracks = len(track_keys)
|
||||||
|
track_word = "track" if num_tracks == 1 else "tracks"
|
||||||
|
transcript_name = f"Multitrack ({num_tracks} {track_word})"
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
transcript_name,
|
||||||
|
source_kind=SourceKind.FILE,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created multitrack transcript",
|
||||||
|
transcript_id=transcript.id,
|
||||||
|
name=transcript_name,
|
||||||
|
bucket=bucket_name,
|
||||||
|
num_tracks=len(track_keys),
|
||||||
|
)
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
|
def submit_multitrack_task(
|
||||||
|
transcript_id: str, bucket_name: str, track_keys: List[str]
|
||||||
|
) -> AsyncResult:
|
||||||
|
result = task_pipeline_multitrack_process.delay(
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
track_keys=track_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Multitrack task submitted",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
task_id=result.id,
|
||||||
|
bucket=bucket_name,
|
||||||
|
num_tracks=len(track_keys),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_task(
|
||||||
|
result: AsyncResult,
|
||||||
|
transcript_id: str,
|
||||||
|
timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS,
|
||||||
|
poll_interval: int = TASK_POLL_INTERVAL_SECONDS,
|
||||||
|
status_callback: Optional[StatusCallback] = None,
|
||||||
|
) -> MultitrackTaskResult:
|
||||||
|
start_time = time.time()
|
||||||
|
last_status = None
|
||||||
|
|
||||||
|
while not result.ready():
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > timeout_seconds:
|
||||||
|
error_msg = (
|
||||||
|
f"Task {result.id} did not complete within {timeout_seconds}s "
|
||||||
|
f"for transcript {transcript_id}"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"Task timeout",
|
||||||
|
task_id=result.id,
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
raise TimeoutError(error_msg)
|
||||||
|
|
||||||
|
if result.state != last_status:
|
||||||
|
if status_callback:
|
||||||
|
status_callback(result.state, int(elapsed))
|
||||||
|
last_status = result.state
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
if result.failed():
|
||||||
|
error_info = result.info
|
||||||
|
traceback_info = getattr(result, "traceback", None)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"Multitrack task failed",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
task_id=result.id,
|
||||||
|
error=str(error_info),
|
||||||
|
has_traceback=bool(traceback_info),
|
||||||
|
)
|
||||||
|
|
||||||
|
error_detail = str(error_info)
|
||||||
|
if traceback_info:
|
||||||
|
error_detail += f"\nTraceback:\n{traceback_info}"
|
||||||
|
|
||||||
|
return MultitrackTaskResult(
|
||||||
|
success=False, transcript_id=transcript_id, error=error_detail
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Multitrack task completed",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
task_id=result.id,
|
||||||
|
state=result.state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MultitrackTaskResult(success=True, transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_transcript_status(
|
||||||
|
transcript_id: str,
|
||||||
|
status: str,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
max_error_length: int = MAX_ERROR_MESSAGE_LENGTH,
|
||||||
|
) -> None:
|
||||||
|
database = get_database()
|
||||||
|
connected = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
await database.connect()
|
||||||
|
connected = True
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if transcript:
|
||||||
|
update_data: Dict[str, Any] = {"status": status}
|
||||||
|
|
||||||
|
if error:
|
||||||
|
if len(error) > max_error_length:
|
||||||
|
error = error[: max_error_length - 3] + "..."
|
||||||
|
update_data["error"] = error
|
||||||
|
|
||||||
|
await transcripts_controller.update(transcript, update_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Updated transcript status",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
status=status,
|
||||||
|
has_error=bool(error),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to update transcript status",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if connected:
|
||||||
|
try:
|
||||||
|
await database.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Database disconnect failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_multitrack(
|
||||||
|
bucket_name: str,
|
||||||
|
track_keys: List[str],
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
timeout_seconds: int = DEFAULT_PROCESSING_TIMEOUT_SECONDS,
|
||||||
|
status_callback: Optional[StatusCallback] = None,
|
||||||
|
) -> MultitrackTaskResult:
|
||||||
|
"""High-level orchestration for multitrack processing."""
|
||||||
|
database = get_database()
|
||||||
|
transcript = None
|
||||||
|
connected = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
await database.connect()
|
||||||
|
connected = True
|
||||||
|
|
||||||
|
transcript = await create_multitrack_transcript(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
track_keys=track_keys,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = submit_multitrack_task(
|
||||||
|
transcript_id=transcript.id, bucket_name=bucket_name, track_keys=track_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if transcript:
|
||||||
|
try:
|
||||||
|
await update_transcript_status(
|
||||||
|
transcript_id=transcript.id, status="failed", error=str(e)
|
||||||
|
)
|
||||||
|
except Exception as update_error:
|
||||||
|
logger.error(
|
||||||
|
"Failed to update transcript status after error",
|
||||||
|
original_error=str(e),
|
||||||
|
update_error=str(update_error),
|
||||||
|
transcript_id=transcript.id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if connected:
|
||||||
|
try:
|
||||||
|
await database.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Database disconnect failed: {e}")
|
||||||
|
|
||||||
|
# Poll outside database connection
|
||||||
|
task_result = await wait_for_task(
|
||||||
|
result=result,
|
||||||
|
transcript_id=transcript.id,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
poll_interval=2,
|
||||||
|
status_callback=status_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not task_result.success:
|
||||||
|
await update_transcript_status(
|
||||||
|
transcript_id=transcript.id, status="failed", error=task_result.error
|
||||||
|
)
|
||||||
|
|
||||||
|
return task_result
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(message: str) -> None:
|
||||||
|
"""Print progress message to stderr for CLI visibility."""
|
||||||
|
print(f"{message}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def create_status_callback() -> StatusCallback:
|
||||||
|
"""Create callback for task status updates during polling."""
|
||||||
|
|
||||||
|
def callback(state: str, elapsed_seconds: int) -> None:
|
||||||
|
print_progress(
|
||||||
|
f"Multitrack pipeline status: {state} (elapsed: {elapsed_seconds}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
async def process_multitrack_cli(
|
||||||
|
s3_urls: List[str],
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
if not s3_urls:
|
||||||
|
raise ValueError("At least one track required for multitrack processing")
|
||||||
|
|
||||||
|
bucket_keys = []
|
||||||
|
for url in s3_urls:
|
||||||
|
try:
|
||||||
|
bucket, key = parse_s3_url(url)
|
||||||
|
bucket_keys.append((bucket, key))
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid S3 URL '{url}': {e}") from e
|
||||||
|
|
||||||
|
buckets = set(bucket for bucket, _ in bucket_keys)
|
||||||
|
if len(buckets) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"All tracks must be in the same S3 bucket. "
|
||||||
|
f"Found {len(buckets)} different buckets: {sorted(buckets)}. "
|
||||||
|
f"Please upload all files to a single bucket."
|
||||||
|
)
|
||||||
|
|
||||||
|
primary_bucket = bucket_keys[0][0]
|
||||||
|
track_keys = [key for _, key in bucket_keys]
|
||||||
|
|
||||||
|
print_progress(
|
||||||
|
f"Starting multitrack CLI processing: "
|
||||||
|
f"bucket={primary_bucket}, num_tracks={len(track_keys)}, "
|
||||||
|
f"source_language={source_language}, target_language={target_language}"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
await validate_s3_objects(storage, bucket_keys)
|
||||||
|
print_progress(f"S3 validation complete: {len(bucket_keys)} objects verified")
|
||||||
|
|
||||||
|
result = await process_multitrack(
|
||||||
|
bucket_name=primary_bucket,
|
||||||
|
track_keys=track_keys,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
user_id=None,
|
||||||
|
timeout_seconds=3600,
|
||||||
|
status_callback=create_status_callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
error_msg = (
|
||||||
|
f"Multitrack pipeline failed for transcript {result.transcript_id}\n"
|
||||||
|
)
|
||||||
|
if result.error:
|
||||||
|
error_msg += f"Error: {result.error}\n"
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
print_progress(
|
||||||
|
f"Multitrack processing complete for transcript {result.transcript_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
await extract_result_from_entry(result.transcript_id, output_path)
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
@@ -9,7 +9,10 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal
|
from typing import Any, Dict, List, Literal, Tuple
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||||
|
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
@@ -20,10 +23,119 @@ from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipelin
|
|||||||
from reflector.pipelines.main_live_pipeline import (
|
from reflector.pipelines.main_live_pipeline import (
|
||||||
pipeline_process as live_pipeline_process,
|
pipeline_process as live_pipeline_process,
|
||||||
)
|
)
|
||||||
|
from reflector.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
def validate_s3_bucket_name(bucket: str) -> None:
|
||||||
|
if not bucket:
|
||||||
|
raise ValueError("Bucket name cannot be empty")
|
||||||
|
if len(bucket) > 255: # Absolute max for any region
|
||||||
|
raise ValueError(f"Bucket name too long: {len(bucket)} characters (max 255)")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_s3_key(key: str) -> None:
|
||||||
|
if not key:
|
||||||
|
raise ValueError("S3 key cannot be empty")
|
||||||
|
if len(key) > 1024:
|
||||||
|
raise ValueError(f"S3 key too long: {len(key)} characters (max 1024)")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_s3_url(url: str) -> Tuple[str, str]:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.scheme == "s3":
|
||||||
|
bucket = parsed.netloc
|
||||||
|
key = parsed.path.lstrip("/")
|
||||||
|
if parsed.fragment:
|
||||||
|
logger.debug(
|
||||||
|
"URL fragment ignored (not part of S3 key)",
|
||||||
|
url=url,
|
||||||
|
fragment=parsed.fragment,
|
||||||
|
)
|
||||||
|
if not bucket or not key:
|
||||||
|
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
|
||||||
|
bucket = unquote(bucket)
|
||||||
|
key = unquote(key)
|
||||||
|
validate_s3_bucket_name(bucket)
|
||||||
|
validate_s3_key(key)
|
||||||
|
return bucket, key
|
||||||
|
|
||||||
|
elif parsed.scheme in ("http", "https"):
|
||||||
|
if ".s3." in parsed.netloc or parsed.netloc.endswith(".s3.amazonaws.com"):
|
||||||
|
bucket = parsed.netloc.split(".")[0]
|
||||||
|
key = parsed.path.lstrip("/")
|
||||||
|
if parsed.fragment:
|
||||||
|
logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment)
|
||||||
|
if not bucket or not key:
|
||||||
|
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
|
||||||
|
bucket = unquote(bucket)
|
||||||
|
key = unquote(key)
|
||||||
|
validate_s3_bucket_name(bucket)
|
||||||
|
validate_s3_key(key)
|
||||||
|
return bucket, key
|
||||||
|
|
||||||
|
elif parsed.netloc.startswith("s3.") and "amazonaws.com" in parsed.netloc:
|
||||||
|
path_parts = parsed.path.lstrip("/").split("/", 1)
|
||||||
|
if len(path_parts) != 2:
|
||||||
|
raise ValueError(f"Invalid S3 URL: {url} (missing bucket or key)")
|
||||||
|
bucket, key = path_parts
|
||||||
|
if parsed.fragment:
|
||||||
|
logger.debug("URL fragment ignored", url=url, fragment=parsed.fragment)
|
||||||
|
bucket = unquote(bucket)
|
||||||
|
key = unquote(key)
|
||||||
|
validate_s3_bucket_name(bucket)
|
||||||
|
validate_s3_key(key)
|
||||||
|
return bucket, key
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid S3 URL format: {url} (not recognized as S3 URL)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid S3 URL scheme: {url} (must be s3:// or https://)")
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_s3_objects(
|
||||||
|
storage: Storage, bucket_keys: List[Tuple[str, str]]
|
||||||
|
) -> None:
|
||||||
|
async with storage.session.client("s3") as client:
|
||||||
|
|
||||||
|
async def check_object(bucket: str, key: str) -> None:
|
||||||
|
try:
|
||||||
|
await client.head_object(Bucket=bucket, Key=key)
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code in ("404", "NoSuchKey"):
|
||||||
|
raise ValueError(f"S3 object not found: s3://{bucket}/{key}") from e
|
||||||
|
elif error_code in ("403", "Forbidden", "AccessDenied"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Access denied for S3 object: s3://{bucket}/{key}. "
|
||||||
|
f"Check AWS credentials and permissions"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"S3 error {error_code} for s3://{bucket}/{key}: "
|
||||||
|
f"{e.response['Error'].get('Message', 'Unknown error')}"
|
||||||
|
) from e
|
||||||
|
except NoCredentialsError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"AWS credentials not configured. Set AWS_ACCESS_KEY_ID and "
|
||||||
|
"AWS_SECRET_ACCESS_KEY environment variables"
|
||||||
|
) from e
|
||||||
|
except BotoCoreError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"AWS service error for s3://{bucket}/{key}: {str(e)}"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected error validating s3://{bucket}/{key}: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
*(check_object(bucket, key) for bucket, key in bucket_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
||||||
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
|
|
||||||
serialized = []
|
serialized = []
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
topic_dict = topic.model_dump()
|
topic_dict = topic.model_dump()
|
||||||
@@ -32,7 +144,6 @@ def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
||||||
"""Print debug info about speakers found in topics"""
|
|
||||||
all_speakers = set()
|
all_speakers = set()
|
||||||
for topic_dict in serialized_topics:
|
for topic_dict in serialized_topics:
|
||||||
for word in topic_dict.get("words", []):
|
for word in topic_dict.get("words", []):
|
||||||
@@ -47,8 +158,6 @@ def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
|||||||
TranscriptId = str
|
TranscriptId = str
|
||||||
|
|
||||||
|
|
||||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
|
||||||
# ideally we want to get rid of it at some point
|
|
||||||
async def prepare_entry(
|
async def prepare_entry(
|
||||||
source_path: str,
|
source_path: str,
|
||||||
source_language: str,
|
source_language: str,
|
||||||
@@ -65,9 +174,7 @@ async def prepare_entry(
|
|||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Created transcript {transcript.id} for {file_path.name}")
|
||||||
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
|
|
||||||
)
|
|
||||||
|
|
||||||
# pipelines expect files as upload.*
|
# pipelines expect files as upload.*
|
||||||
|
|
||||||
@@ -83,7 +190,6 @@ async def prepare_entry(
|
|||||||
return transcript.id
|
return transcript.id
|
||||||
|
|
||||||
|
|
||||||
# same reason as prepare_entry
|
|
||||||
async def extract_result_from_entry(
|
async def extract_result_from_entry(
|
||||||
transcript_id: TranscriptId, output_path: str
|
transcript_id: TranscriptId, output_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -193,13 +299,20 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Process audio files with speaker diarization"
|
description="Process audio files with speaker diarization"
|
||||||
)
|
)
|
||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
parser.add_argument(
|
||||||
|
"source",
|
||||||
|
help="Source file (mp3, wav, mp4...) or comma-separated S3 URLs with --multitrack",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pipeline",
|
"--pipeline",
|
||||||
required=True,
|
|
||||||
choices=["live", "file"],
|
choices=["live", "file"],
|
||||||
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multitrack",
|
||||||
|
action="store_true",
|
||||||
|
help="Process multiple audio tracks from comma-separated S3 URLs",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--source-language", default="en", help="Source language code (default: en)"
|
"--source-language", default="en", help="Source language code (default: en)"
|
||||||
)
|
)
|
||||||
@@ -209,6 +322,34 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.multitrack:
|
||||||
|
if not args.source:
|
||||||
|
parser.error("Source URLs required for multitrack processing")
|
||||||
|
|
||||||
|
s3_urls = [url.strip() for url in args.source.split(",") if url.strip()]
|
||||||
|
|
||||||
|
if not s3_urls:
|
||||||
|
parser.error("At least one S3 URL required for multitrack processing")
|
||||||
|
|
||||||
|
from reflector.tools.cli_multitrack import process_multitrack_cli
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
process_multitrack_cli(
|
||||||
|
s3_urls,
|
||||||
|
args.source_language,
|
||||||
|
args.target_language,
|
||||||
|
args.output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not args.pipeline:
|
||||||
|
parser.error("--pipeline is required for single-track processing")
|
||||||
|
|
||||||
|
if "," in args.source:
|
||||||
|
parser.error(
|
||||||
|
"Multiple files detected. Use --multitrack flag for multitrack processing"
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
process(
|
process(
|
||||||
args.source,
|
args.source,
|
||||||
|
|||||||
136
server/tests/test_s3_url_parser.py
Normal file
136
server/tests/test_s3_url_parser.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Tests for S3 URL parsing functionality in reflector.tools.process"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.tools.process import parse_s3_url
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseS3URL:
|
||||||
|
"""Test cases for parse_s3_url function"""
|
||||||
|
|
||||||
|
def test_parse_s3_protocol(self):
|
||||||
|
"""Test parsing s3:// protocol URLs"""
|
||||||
|
bucket, key = parse_s3_url("s3://my-bucket/path/to/file.webm")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/to/file.webm"
|
||||||
|
|
||||||
|
def test_parse_s3_protocol_deep_path(self):
|
||||||
|
"""Test s3:// with deeply nested paths"""
|
||||||
|
bucket, key = parse_s3_url("s3://bucket-name/very/deep/path/to/audio.mp4")
|
||||||
|
assert bucket == "bucket-name"
|
||||||
|
assert key == "very/deep/path/to/audio.mp4"
|
||||||
|
|
||||||
|
def test_parse_https_subdomain_format(self):
|
||||||
|
"""Test parsing https://bucket.s3.amazonaws.com/key format"""
|
||||||
|
bucket, key = parse_s3_url("https://my-bucket.s3.amazonaws.com/path/file.webm")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_parse_https_regional_subdomain(self):
|
||||||
|
"""Test parsing regional endpoint with subdomain"""
|
||||||
|
bucket, key = parse_s3_url(
|
||||||
|
"https://my-bucket.s3.us-west-2.amazonaws.com/path/file.webm"
|
||||||
|
)
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_parse_https_path_style(self):
|
||||||
|
"""Test parsing https://s3.amazonaws.com/bucket/key format"""
|
||||||
|
bucket, key = parse_s3_url("https://s3.amazonaws.com/my-bucket/path/file.webm")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_parse_https_regional_path_style(self):
|
||||||
|
"""Test parsing regional endpoint with path style"""
|
||||||
|
bucket, key = parse_s3_url(
|
||||||
|
"https://s3.us-east-1.amazonaws.com/my-bucket/path/file.webm"
|
||||||
|
)
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_parse_url_encoded_keys(self):
|
||||||
|
"""Test parsing URL-encoded keys"""
|
||||||
|
bucket, key = parse_s3_url(
|
||||||
|
"s3://my-bucket/path%20with%20spaces/file%2Bname.webm"
|
||||||
|
)
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path with spaces/file+name.webm" # Should be decoded
|
||||||
|
|
||||||
|
def test_parse_url_encoded_https(self):
|
||||||
|
"""Test URL-encoded keys with HTTPS format"""
|
||||||
|
bucket, key = parse_s3_url(
|
||||||
|
"https://my-bucket.s3.amazonaws.com/file%20with%20spaces.webm"
|
||||||
|
)
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "file with spaces.webm"
|
||||||
|
|
||||||
|
def test_invalid_url_no_scheme(self):
|
||||||
|
"""Test that URLs without scheme raise ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="Invalid S3 URL scheme"):
|
||||||
|
parse_s3_url("my-bucket/path/file.webm")
|
||||||
|
|
||||||
|
def test_invalid_url_wrong_scheme(self):
|
||||||
|
"""Test that non-S3 schemes raise ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="Invalid S3 URL scheme"):
|
||||||
|
parse_s3_url("ftp://my-bucket/path/file.webm")
|
||||||
|
|
||||||
|
def test_invalid_s3_missing_bucket(self):
|
||||||
|
"""Test s3:// URL without bucket raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="missing bucket or key"):
|
||||||
|
parse_s3_url("s3:///path/file.webm")
|
||||||
|
|
||||||
|
def test_invalid_s3_missing_key(self):
|
||||||
|
"""Test s3:// URL without key raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="missing bucket or key"):
|
||||||
|
parse_s3_url("s3://my-bucket/")
|
||||||
|
|
||||||
|
def test_invalid_s3_empty_key(self):
|
||||||
|
"""Test s3:// URL with empty key raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="missing bucket or key"):
|
||||||
|
parse_s3_url("s3://my-bucket")
|
||||||
|
|
||||||
|
def test_invalid_https_not_s3(self):
|
||||||
|
"""Test HTTPS URL that's not S3 raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="not recognized as S3 URL"):
|
||||||
|
parse_s3_url("https://example.com/path/file.webm")
|
||||||
|
|
||||||
|
def test_invalid_https_subdomain_missing_key(self):
|
||||||
|
"""Test HTTPS subdomain format without key raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="missing bucket or key"):
|
||||||
|
parse_s3_url("https://my-bucket.s3.amazonaws.com/")
|
||||||
|
|
||||||
|
def test_invalid_https_path_style_missing_parts(self):
|
||||||
|
"""Test HTTPS path style with missing bucket/key raises ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="missing bucket or key"):
|
||||||
|
parse_s3_url("https://s3.amazonaws.com/")
|
||||||
|
|
||||||
|
def test_bucket_with_dots(self):
|
||||||
|
"""Test parsing bucket names with dots"""
|
||||||
|
bucket, key = parse_s3_url("s3://my.bucket.name/path/file.webm")
|
||||||
|
assert bucket == "my.bucket.name"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_bucket_with_hyphens(self):
|
||||||
|
"""Test parsing bucket names with hyphens"""
|
||||||
|
bucket, key = parse_s3_url("s3://my-bucket-name-123/path/file.webm")
|
||||||
|
assert bucket == "my-bucket-name-123"
|
||||||
|
assert key == "path/file.webm"
|
||||||
|
|
||||||
|
def test_key_with_special_chars(self):
|
||||||
|
"""Test keys with various special characters"""
|
||||||
|
# Note: # is treated as URL fragment separator, not part of key
|
||||||
|
bucket, key = parse_s3_url("s3://bucket/2024-01-01_12:00:00/file.webm")
|
||||||
|
assert bucket == "bucket"
|
||||||
|
assert key == "2024-01-01_12:00:00/file.webm"
|
||||||
|
|
||||||
|
def test_fragment_handling(self):
|
||||||
|
"""Test that URL fragments are properly ignored"""
|
||||||
|
bucket, key = parse_s3_url("s3://bucket/path/to/file.webm#fragment123")
|
||||||
|
assert bucket == "bucket"
|
||||||
|
assert key == "path/to/file.webm" # Fragment not included
|
||||||
|
|
||||||
|
def test_http_scheme_s3_url(self):
|
||||||
|
"""Test that HTTP (not HTTPS) S3 URLs are supported"""
|
||||||
|
bucket, key = parse_s3_url("http://my-bucket.s3.amazonaws.com/path/file.webm")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert key == "path/file.webm"
|
||||||
Reference in New Issue
Block a user