mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: Diarization cli (#509)
* diarisation cli * feat: s3 upload for modal diarisation cli call * chore: cleanup * chore: s3 cleanup improvement * chore: lint * chore: cleanup * chore: cleanup * chore: cleanup * chore: cleanup
This commit is contained in:
1
server/.gitignore
vendored
1
server/.gitignore
vendored
@@ -180,3 +180,4 @@ reflector.sqlite3
|
|||||||
data/
|
data/
|
||||||
|
|
||||||
dump.rdb
|
dump.rdb
|
||||||
|
|
||||||
|
|||||||
314
server/reflector/tools/process_with_diarization.py
Normal file
314
server/reflector/tools/process_with_diarization.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
"""
|
||||||
|
@vibe-generated
|
||||||
|
Process audio file with diarization support
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Extended version of process.py that includes speaker diarization.
|
||||||
|
This tool processes audio files locally without requiring the full server infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import av
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.processors import (
|
||||||
|
AudioChunkerProcessor,
|
||||||
|
AudioMergeProcessor,
|
||||||
|
AudioTranscriptAutoProcessor,
|
||||||
|
AudioFileWriterProcessor,
|
||||||
|
Pipeline,
|
||||||
|
PipelineEvent,
|
||||||
|
TranscriptFinalSummaryProcessor,
|
||||||
|
TranscriptFinalTitleProcessor,
|
||||||
|
TranscriptLinerProcessor,
|
||||||
|
TranscriptTopicDetectorProcessor,
|
||||||
|
TranscriptTranslatorProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.base import BroadcastProcessor, Processor
|
||||||
|
from reflector.processors.types import (
|
||||||
|
AudioDiarizationInput,
|
||||||
|
TitleSummary,
|
||||||
|
TitleSummaryWithId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopicCollectorProcessor(Processor):
|
||||||
|
"""Collect topics for diarization"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TitleSummary
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.topics: List[TitleSummaryWithId] = []
|
||||||
|
self._topic_id = 0
|
||||||
|
|
||||||
|
async def _push(self, data: TitleSummary):
|
||||||
|
# Convert to TitleSummaryWithId and collect
|
||||||
|
self._topic_id += 1
|
||||||
|
topic_with_id = TitleSummaryWithId(
|
||||||
|
id=str(self._topic_id),
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
duration=data.duration,
|
||||||
|
transcript=data.transcript,
|
||||||
|
)
|
||||||
|
self.topics.append(topic_with_id)
|
||||||
|
|
||||||
|
# Pass through the original topic
|
||||||
|
await self.emit(data)
|
||||||
|
|
||||||
|
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||||
|
return self.topics
|
||||||
|
|
||||||
|
|
||||||
|
async def process_audio_file_with_diarization(
|
||||||
|
filename,
|
||||||
|
event_callback,
|
||||||
|
only_transcript=False,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="modal",
|
||||||
|
):
|
||||||
|
# Create temp file for audio if diarization is enabled
|
||||||
|
audio_temp_path = None
|
||||||
|
if enable_diarization:
|
||||||
|
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
audio_temp_path = audio_temp_file.name
|
||||||
|
audio_temp_file.close()
|
||||||
|
|
||||||
|
# Create processor for collecting topics
|
||||||
|
topic_collector = TopicCollectorProcessor()
|
||||||
|
|
||||||
|
# Build pipeline for audio processing
|
||||||
|
processors = []
|
||||||
|
|
||||||
|
# Add audio file writer at the beginning if diarization is enabled
|
||||||
|
if enable_diarization:
|
||||||
|
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||||
|
|
||||||
|
# Add the rest of the processors
|
||||||
|
processors += [
|
||||||
|
AudioChunkerProcessor(),
|
||||||
|
AudioMergeProcessor(),
|
||||||
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
|
]
|
||||||
|
|
||||||
|
processors += [
|
||||||
|
TranscriptLinerProcessor(),
|
||||||
|
TranscriptTranslatorProcessor.as_threaded(),
|
||||||
|
]
|
||||||
|
|
||||||
|
if not only_transcript:
|
||||||
|
processors += [
|
||||||
|
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||||
|
# Collect topics for diarization
|
||||||
|
topic_collector,
|
||||||
|
BroadcastProcessor(
|
||||||
|
processors=[
|
||||||
|
TranscriptFinalTitleProcessor.as_threaded(),
|
||||||
|
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create main pipeline
|
||||||
|
pipeline = Pipeline(*processors)
|
||||||
|
pipeline.set_pref("audio:source_language", source_language)
|
||||||
|
pipeline.set_pref("audio:target_language", target_language)
|
||||||
|
pipeline.describe()
|
||||||
|
pipeline.on(event_callback)
|
||||||
|
|
||||||
|
# Start processing audio
|
||||||
|
logger.info(f"Opening {filename}")
|
||||||
|
container = av.open(filename)
|
||||||
|
try:
|
||||||
|
logger.info("Start pushing audio into the pipeline")
|
||||||
|
for frame in container.decode(audio=0):
|
||||||
|
await pipeline.push(frame)
|
||||||
|
finally:
|
||||||
|
logger.info("Flushing the pipeline")
|
||||||
|
await pipeline.flush()
|
||||||
|
|
||||||
|
# Run diarization if enabled and we have topics
|
||||||
|
if enable_diarization and not only_transcript and audio_temp_path:
|
||||||
|
topics = topic_collector.get_topics()
|
||||||
|
|
||||||
|
if topics:
|
||||||
|
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import diarization processor
|
||||||
|
from reflector.processors import AudioDiarizationAutoProcessor
|
||||||
|
|
||||||
|
# Create diarization processor
|
||||||
|
diarization_processor = AudioDiarizationAutoProcessor(
|
||||||
|
name=diarization_backend
|
||||||
|
)
|
||||||
|
diarization_processor.on(event_callback)
|
||||||
|
|
||||||
|
# For Modal backend, we need to upload the file to S3 first
|
||||||
|
if diarization_backend == "modal":
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
# Generate a unique filename in evaluation folder
|
||||||
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
|
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||||
|
|
||||||
|
# Use context manager for automatic cleanup
|
||||||
|
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||||
|
# Read and upload the audio file
|
||||||
|
with open(audio_temp_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
audio_url = await s3_file.upload(audio_data)
|
||||||
|
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||||
|
|
||||||
|
# Create diarization input with S3 URL
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
# File will be automatically cleaned up when exiting the context
|
||||||
|
else:
|
||||||
|
# For local backend, use local file path
|
||||||
|
audio_url = audio_temp_path
|
||||||
|
|
||||||
|
# Create diarization input
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||||
|
logger.error(
|
||||||
|
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"And set HF_TOKEN environment variable for pyannote models"
|
||||||
|
)
|
||||||
|
raise SystemExit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Diarization failed: {e}")
|
||||||
|
raise SystemExit(1)
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping diarization: no topics available")
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if audio_temp_path:
|
||||||
|
try:
|
||||||
|
Path(audio_temp_path).unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||||
|
|
||||||
|
logger.info("All done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Process audio files with optional speaker diarization"
|
||||||
|
)
|
||||||
|
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--only-transcript",
|
||||||
|
"-t",
|
||||||
|
action="store_true",
|
||||||
|
help="Only generate transcript without topics/summaries",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-language", default="en", help="Source language code (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-language", default="en", help="Target language code (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-diarization",
|
||||||
|
"-d",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable speaker diarization",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
default="modal",
|
||||||
|
choices=["modal"],
|
||||||
|
help="Diarization backend to use (default: modal)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set REDIS_HOST to localhost if not provided
|
||||||
|
if "REDIS_HOST" not in os.environ:
|
||||||
|
os.environ["REDIS_HOST"] = "localhost"
|
||||||
|
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||||
|
|
||||||
|
output_fd = None
|
||||||
|
if args.output:
|
||||||
|
output_fd = open(args.output, "w")
|
||||||
|
|
||||||
|
async def event_callback(event: PipelineEvent):
|
||||||
|
processor = event.processor
|
||||||
|
data = event.data
|
||||||
|
|
||||||
|
# Ignore internal processors
|
||||||
|
if processor in (
|
||||||
|
"AudioChunkerProcessor",
|
||||||
|
"AudioMergeProcessor",
|
||||||
|
"AudioFileWriterProcessor",
|
||||||
|
"TopicCollectorProcessor",
|
||||||
|
"BroadcastProcessor",
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# If diarization is enabled, skip the original topic events from the pipeline
|
||||||
|
# The diarization processor will emit the same topics but with speaker info
|
||||||
|
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log all events
|
||||||
|
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||||
|
|
||||||
|
# Write to output
|
||||||
|
if output_fd:
|
||||||
|
output_fd.write(event.model_dump_json())
|
||||||
|
output_fd.write("\n")
|
||||||
|
output_fd.flush()
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
process_audio_file_with_diarization(
|
||||||
|
args.source,
|
||||||
|
event_callback,
|
||||||
|
only_transcript=args.only_transcript,
|
||||||
|
source_language=args.source_language,
|
||||||
|
target_language=args.target_language,
|
||||||
|
enable_diarization=args.enable_diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_fd:
|
||||||
|
output_fd.close()
|
||||||
|
logger.info(f"Output written to {args.output}")
|
||||||
96
server/reflector/tools/test_diarization.py
Normal file
96
server/reflector/tools/test_diarization.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
@vibe-generated
|
||||||
|
Test script for the diarization CLI tool
|
||||||
|
=========================================
|
||||||
|
|
||||||
|
This script helps test the diarization functionality with sample audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from reflector.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
async def test_diarization(audio_file: str):
|
||||||
|
"""Test the diarization functionality"""
|
||||||
|
|
||||||
|
# Import the processing function
|
||||||
|
from process_with_diarization import process_audio_file_with_diarization
|
||||||
|
|
||||||
|
# Collect events
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def event_callback(event):
|
||||||
|
events.append({
|
||||||
|
"processor": event.processor,
|
||||||
|
"data": event.data
|
||||||
|
})
|
||||||
|
logger.info(f"Event from {event.processor}")
|
||||||
|
|
||||||
|
# Process the audio file
|
||||||
|
logger.info(f"Processing audio file: {audio_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await process_audio_file_with_diarization(
|
||||||
|
audio_file,
|
||||||
|
event_callback,
|
||||||
|
only_transcript=False,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="modal",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Analyze results
|
||||||
|
logger.info(f"Processing complete. Received {len(events)} events")
|
||||||
|
|
||||||
|
# Look for diarization results
|
||||||
|
diarized_topics = []
|
||||||
|
for event in events:
|
||||||
|
if "TitleSummary" in event["processor"]:
|
||||||
|
# Check if words have speaker information
|
||||||
|
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||||
|
words = event["data"].transcript.words
|
||||||
|
if words and hasattr(words[0], "speaker"):
|
||||||
|
speakers = set(w.speaker for w in words if hasattr(w, "speaker"))
|
||||||
|
logger.info(f"Found {len(speakers)} speakers in topic: {event['data'].title}")
|
||||||
|
diarized_topics.append(event["data"])
|
||||||
|
|
||||||
|
if diarized_topics:
|
||||||
|
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||||
|
|
||||||
|
# Print sample output
|
||||||
|
sample_topic = diarized_topics[0]
|
||||||
|
logger.info("Sample diarized output:")
|
||||||
|
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||||
|
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||||
|
else:
|
||||||
|
logger.warning("No diarization results found in output")
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during processing: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python test_diarization.py <audio_file>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
audio_file = sys.argv[1]
|
||||||
|
if not Path(audio_file).exists():
|
||||||
|
print(f"Error: Audio file '{audio_file}' not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
asyncio.run(test_diarization(audio_file))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
149
server/reflector/utils/s3_temp_file.py
Normal file
149
server/reflector/utils/s3_temp_file.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
@vibe-generated
|
||||||
|
S3 Temporary File Context Manager
|
||||||
|
|
||||||
|
Provides automatic cleanup of S3 files with retry logic and proper error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from reflector.storage.base import Storage
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.utils.retry import retry
|
||||||
|
|
||||||
|
|
||||||
|
class S3TemporaryFile:
|
||||||
|
"""
|
||||||
|
Async context manager for temporary S3 files with automatic cleanup.
|
||||||
|
|
||||||
|
Ensures that uploaded files are deleted even if exceptions occur during processing.
|
||||||
|
Uses retry logic for all S3 operations to handle transient failures.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async with S3TemporaryFile(storage, "temp/audio.wav") as s3_file:
|
||||||
|
url = await s3_file.upload(audio_data)
|
||||||
|
# Use url for processing
|
||||||
|
# File is automatically cleaned up here
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage: Storage, filepath: str):
|
||||||
|
"""
|
||||||
|
Initialize the temporary file context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for S3 operations
|
||||||
|
filepath: S3 key/path for the temporary file
|
||||||
|
"""
|
||||||
|
self.storage = storage
|
||||||
|
self.filepath = filepath
|
||||||
|
self.uploaded = False
|
||||||
|
self._url: Optional[str] = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Enter the context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""
|
||||||
|
Exit the context manager and clean up the file.
|
||||||
|
|
||||||
|
Cleanup is attempted even if an exception occurred during processing.
|
||||||
|
Cleanup failures are logged but don't raise exceptions.
|
||||||
|
"""
|
||||||
|
if self.uploaded:
|
||||||
|
try:
|
||||||
|
await self._delete_with_retry()
|
||||||
|
logger.info(f"Successfully cleaned up S3 file: {self.filepath}")
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but don't raise - we don't want cleanup failures
|
||||||
|
# to mask the original exception
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to cleanup S3 file {self.filepath} after retries: {e}"
|
||||||
|
)
|
||||||
|
return False # Don't suppress exceptions
|
||||||
|
|
||||||
|
async def upload(self, data: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Upload data to S3 and return the public URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: File data to upload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Public URL for the uploaded file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If upload or URL generation fails after retries
|
||||||
|
"""
|
||||||
|
await self._upload_with_retry(data)
|
||||||
|
self.uploaded = True
|
||||||
|
self._url = await self._get_url_with_retry()
|
||||||
|
return self._url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> Optional[str]:
|
||||||
|
"""Get the URL of the uploaded file, if available."""
|
||||||
|
return self._url
|
||||||
|
|
||||||
|
async def _upload_with_retry(self, data: bytes):
|
||||||
|
"""Upload file to S3 with retry logic."""
|
||||||
|
|
||||||
|
async def upload():
|
||||||
|
await self.storage.put_file(self.filepath, data)
|
||||||
|
logger.debug(f"Successfully uploaded file to S3: {self.filepath}")
|
||||||
|
return True # Return something to indicate success
|
||||||
|
|
||||||
|
await retry(upload)(
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_timeout=30.0,
|
||||||
|
retry_backoff_interval=0.5,
|
||||||
|
retry_backoff_max=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_url_with_retry(self) -> str:
|
||||||
|
"""Get public URL for the file with retry logic."""
|
||||||
|
|
||||||
|
async def get_url():
|
||||||
|
url = await self.storage.get_file_url(self.filepath)
|
||||||
|
logger.debug(f"Generated public URL for S3 file: {self.filepath}")
|
||||||
|
return url
|
||||||
|
|
||||||
|
return await retry(get_url)(
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_timeout=30.0,
|
||||||
|
retry_backoff_interval=0.5,
|
||||||
|
retry_backoff_max=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _delete_with_retry(self):
|
||||||
|
"""Delete file from S3 with retry logic."""
|
||||||
|
|
||||||
|
async def delete():
|
||||||
|
await self.storage.delete_file(self.filepath)
|
||||||
|
logger.debug(f"Successfully deleted S3 file: {self.filepath}")
|
||||||
|
return True # Return something to indicate success
|
||||||
|
|
||||||
|
await retry(delete)(
|
||||||
|
retry_attempts=3,
|
||||||
|
retry_timeout=30.0,
|
||||||
|
retry_backoff_interval=0.5,
|
||||||
|
retry_backoff_max=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for simpler usage
|
||||||
|
async def temporary_s3_file(storage: Storage, filepath: str):
|
||||||
|
"""
|
||||||
|
Create a temporary S3 file context manager.
|
||||||
|
|
||||||
|
This is a convenience wrapper around S3TemporaryFile for simpler usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage instance for S3 operations
|
||||||
|
filepath: S3 key/path for the temporary file
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async with temporary_s3_file(storage, "temp/audio.wav") as s3_file:
|
||||||
|
url = await s3_file.upload(audio_data)
|
||||||
|
# Use url for processing
|
||||||
|
"""
|
||||||
|
return S3TemporaryFile(storage, filepath)
|
||||||
129
server/tests/test_s3_temp_file.py
Normal file
129
server/tests/test_s3_temp_file.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
@vibe-generated
|
||||||
|
Tests for S3 temporary file context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock
|
||||||
|
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_upload_and_cleanup():
|
||||||
|
"""Test that file is uploaded and cleaned up on success."""
|
||||||
|
# Mock storage
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.put_file = AsyncMock()
|
||||||
|
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||||
|
mock_storage.delete_file = AsyncMock()
|
||||||
|
|
||||||
|
# Use context manager
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
url = await s3_file.upload(b"test data")
|
||||||
|
assert url == "https://example.com/file.wav"
|
||||||
|
assert s3_file.url == "https://example.com/file.wav"
|
||||||
|
|
||||||
|
# Verify operations
|
||||||
|
mock_storage.put_file.assert_called_once_with("test/file.wav", b"test data")
|
||||||
|
mock_storage.get_file_url.assert_called_once_with("test/file.wav")
|
||||||
|
mock_storage.delete_file.assert_called_once_with("test/file.wav")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_on_exception():
|
||||||
|
"""Test that cleanup happens even when an exception occurs."""
|
||||||
|
# Mock storage
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.put_file = AsyncMock()
|
||||||
|
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||||
|
mock_storage.delete_file = AsyncMock()
|
||||||
|
|
||||||
|
# Use context manager with exception
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
await s3_file.upload(b"test data")
|
||||||
|
raise ValueError("Simulated error during processing")
|
||||||
|
|
||||||
|
# Verify cleanup still happened
|
||||||
|
mock_storage.delete_file.assert_called_once_with("test/file.wav")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_cleanup_if_not_uploaded():
|
||||||
|
"""Test that cleanup is skipped if file was never uploaded."""
|
||||||
|
# Mock storage
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.delete_file = AsyncMock()
|
||||||
|
|
||||||
|
# Use context manager without uploading
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav"):
|
||||||
|
pass # Don't upload anything
|
||||||
|
|
||||||
|
# Verify no cleanup attempted
|
||||||
|
mock_storage.delete_file.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_failure_is_logged_not_raised():
|
||||||
|
"""Test that cleanup failures are logged but don't raise exceptions."""
|
||||||
|
# Mock storage
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.put_file = AsyncMock()
|
||||||
|
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||||
|
mock_storage.delete_file = AsyncMock(side_effect=Exception("Delete failed"))
|
||||||
|
|
||||||
|
# Use context manager - should not raise
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
await s3_file.upload(b"test data")
|
||||||
|
|
||||||
|
# Verify delete was attempted (3 times due to retry)
|
||||||
|
assert mock_storage.delete_file.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_retry_on_failure():
|
||||||
|
"""Test that upload is retried on failure."""
|
||||||
|
# Mock storage with failures then success
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.put_file = AsyncMock(
|
||||||
|
side_effect=[Exception("Network error"), None] # Fail once, then succeed
|
||||||
|
)
|
||||||
|
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||||
|
mock_storage.delete_file = AsyncMock()
|
||||||
|
|
||||||
|
# Use context manager
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
url = await s3_file.upload(b"test data")
|
||||||
|
assert url == "https://example.com/file.wav"
|
||||||
|
|
||||||
|
# Verify upload was retried
|
||||||
|
assert mock_storage.put_file.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_retry_on_failure():
|
||||||
|
"""Test that delete is retried on failure."""
|
||||||
|
# Mock storage
|
||||||
|
mock_storage = Mock()
|
||||||
|
mock_storage.put_file = AsyncMock()
|
||||||
|
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||||
|
mock_storage.delete_file = AsyncMock(
|
||||||
|
side_effect=[Exception("Network error"), None] # Fail once, then succeed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use context manager
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
await s3_file.upload(b"test data")
|
||||||
|
|
||||||
|
# Verify delete was retried
|
||||||
|
assert mock_storage.delete_file.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_properties_before_upload():
|
||||||
|
"""Test that properties work correctly before upload."""
|
||||||
|
mock_storage = Mock()
|
||||||
|
|
||||||
|
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||||
|
assert s3_file.url is None
|
||||||
|
assert s3_file.uploaded is False
|
||||||
Reference in New Issue
Block a user