mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: Diarization cli (#509)
* diarisation cli * feat: s3 upload for modal diarisation cli call * chore: cleanup * chore: s3 cleanup improvement * chore: lint * chore: cleanup * chore: cleanup * chore: cleanup * chore: cleanup
This commit is contained in:
1
server/.gitignore
vendored
1
server/.gitignore
vendored
@@ -180,3 +180,4 @@ reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
|
||||
314
server/reflector/tools/process_with_diarization.py
Normal file
314
server/reflector/tools/process_with_diarization.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import uuid
|
||||
|
||||
import av
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file_with_diarization(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
processors += [
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
# Import diarization processor
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
# Create diarization processor
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
diarization_processor.on(event_callback)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
from datetime import datetime
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="modal",
|
||||
choices=["modal"],
|
||||
help="Diarization backend to use (default: modal)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set REDIS_HOST to localhost if not provided
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioChunkerProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file_with_diarization(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
96
server/reflector/tools/test_diarization.py
Normal file
96
server/reflector/tools/test_diarization.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
@vibe-generated
|
||||
Test script for the diarization CLI tool
|
||||
=========================================
|
||||
|
||||
This script helps test the diarization functionality with sample audio files.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
async def test_diarization(audio_file: str):
|
||||
"""Test the diarization functionality"""
|
||||
|
||||
# Import the processing function
|
||||
from process_with_diarization import process_audio_file_with_diarization
|
||||
|
||||
# Collect events
|
||||
events = []
|
||||
|
||||
async def event_callback(event):
|
||||
events.append({
|
||||
"processor": event.processor,
|
||||
"data": event.data
|
||||
})
|
||||
logger.info(f"Event from {event.processor}")
|
||||
|
||||
# Process the audio file
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
|
||||
try:
|
||||
await process_audio_file_with_diarization(
|
||||
audio_file,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
)
|
||||
|
||||
# Analyze results
|
||||
logger.info(f"Processing complete. Received {len(events)} events")
|
||||
|
||||
# Look for diarization results
|
||||
diarized_topics = []
|
||||
for event in events:
|
||||
if "TitleSummary" in event["processor"]:
|
||||
# Check if words have speaker information
|
||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||
words = event["data"].transcript.words
|
||||
if words and hasattr(words[0], "speaker"):
|
||||
speakers = set(w.speaker for w in words if hasattr(w, "speaker"))
|
||||
logger.info(f"Found {len(speakers)} speakers in topic: {event['data'].title}")
|
||||
diarized_topics.append(event["data"])
|
||||
|
||||
if diarized_topics:
|
||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||
|
||||
# Print sample output
|
||||
sample_topic = diarized_topics[0]
|
||||
logger.info("Sample diarized output:")
|
||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||
else:
|
||||
logger.warning("No diarization results found in output")
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during processing: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python test_diarization.py <audio_file>")
|
||||
sys.exit(1)
|
||||
|
||||
audio_file = sys.argv[1]
|
||||
if not Path(audio_file).exists():
|
||||
print(f"Error: Audio file '{audio_file}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_diarization(audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
149
server/reflector/utils/s3_temp_file.py
Normal file
149
server/reflector/utils/s3_temp_file.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
S3 Temporary File Context Manager
|
||||
|
||||
Provides automatic cleanup of S3 files with retry logic and proper error handling.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from reflector.storage.base import Storage
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class S3TemporaryFile:
|
||||
"""
|
||||
Async context manager for temporary S3 files with automatic cleanup.
|
||||
|
||||
Ensures that uploaded files are deleted even if exceptions occur during processing.
|
||||
Uses retry logic for all S3 operations to handle transient failures.
|
||||
|
||||
Example:
|
||||
async with S3TemporaryFile(storage, "temp/audio.wav") as s3_file:
|
||||
url = await s3_file.upload(audio_data)
|
||||
# Use url for processing
|
||||
# File is automatically cleaned up here
|
||||
"""
|
||||
|
||||
def __init__(self, storage: Storage, filepath: str):
|
||||
"""
|
||||
Initialize the temporary file context.
|
||||
|
||||
Args:
|
||||
storage: Storage instance for S3 operations
|
||||
filepath: S3 key/path for the temporary file
|
||||
"""
|
||||
self.storage = storage
|
||||
self.filepath = filepath
|
||||
self.uploaded = False
|
||||
self._url: Optional[str] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Exit the context manager and clean up the file.
|
||||
|
||||
Cleanup is attempted even if an exception occurred during processing.
|
||||
Cleanup failures are logged but don't raise exceptions.
|
||||
"""
|
||||
if self.uploaded:
|
||||
try:
|
||||
await self._delete_with_retry()
|
||||
logger.info(f"Successfully cleaned up S3 file: {self.filepath}")
|
||||
except Exception as e:
|
||||
# Log the error but don't raise - we don't want cleanup failures
|
||||
# to mask the original exception
|
||||
logger.warning(
|
||||
f"Failed to cleanup S3 file {self.filepath} after retries: {e}"
|
||||
)
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
async def upload(self, data: bytes) -> str:
|
||||
"""
|
||||
Upload data to S3 and return the public URL.
|
||||
|
||||
Args:
|
||||
data: File data to upload
|
||||
|
||||
Returns:
|
||||
Public URL for the uploaded file
|
||||
|
||||
Raises:
|
||||
Exception: If upload or URL generation fails after retries
|
||||
"""
|
||||
await self._upload_with_retry(data)
|
||||
self.uploaded = True
|
||||
self._url = await self._get_url_with_retry()
|
||||
return self._url
|
||||
|
||||
@property
|
||||
def url(self) -> Optional[str]:
|
||||
"""Get the URL of the uploaded file, if available."""
|
||||
return self._url
|
||||
|
||||
async def _upload_with_retry(self, data: bytes):
|
||||
"""Upload file to S3 with retry logic."""
|
||||
|
||||
async def upload():
|
||||
await self.storage.put_file(self.filepath, data)
|
||||
logger.debug(f"Successfully uploaded file to S3: {self.filepath}")
|
||||
return True # Return something to indicate success
|
||||
|
||||
await retry(upload)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
async def _get_url_with_retry(self) -> str:
|
||||
"""Get public URL for the file with retry logic."""
|
||||
|
||||
async def get_url():
|
||||
url = await self.storage.get_file_url(self.filepath)
|
||||
logger.debug(f"Generated public URL for S3 file: {self.filepath}")
|
||||
return url
|
||||
|
||||
return await retry(get_url)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
async def _delete_with_retry(self):
|
||||
"""Delete file from S3 with retry logic."""
|
||||
|
||||
async def delete():
|
||||
await self.storage.delete_file(self.filepath)
|
||||
logger.debug(f"Successfully deleted S3 file: {self.filepath}")
|
||||
return True # Return something to indicate success
|
||||
|
||||
await retry(delete)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
|
||||
# Convenience function for simpler usage
|
||||
async def temporary_s3_file(storage: Storage, filepath: str):
|
||||
"""
|
||||
Create a temporary S3 file context manager.
|
||||
|
||||
This is a convenience wrapper around S3TemporaryFile for simpler usage.
|
||||
|
||||
Args:
|
||||
storage: Storage instance for S3 operations
|
||||
filepath: S3 key/path for the temporary file
|
||||
|
||||
Example:
|
||||
async with temporary_s3_file(storage, "temp/audio.wav") as s3_file:
|
||||
url = await s3_file.upload(audio_data)
|
||||
# Use url for processing
|
||||
"""
|
||||
return S3TemporaryFile(storage, filepath)
|
||||
129
server/tests/test_s3_temp_file.py
Normal file
129
server/tests/test_s3_temp_file.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Tests for S3 temporary file context manager.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_upload_and_cleanup():
|
||||
"""Test that file is uploaded and cleaned up on success."""
|
||||
# Mock storage
|
||||
mock_storage = Mock()
|
||||
mock_storage.put_file = AsyncMock()
|
||||
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||
mock_storage.delete_file = AsyncMock()
|
||||
|
||||
# Use context manager
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
url = await s3_file.upload(b"test data")
|
||||
assert url == "https://example.com/file.wav"
|
||||
assert s3_file.url == "https://example.com/file.wav"
|
||||
|
||||
# Verify operations
|
||||
mock_storage.put_file.assert_called_once_with("test/file.wav", b"test data")
|
||||
mock_storage.get_file_url.assert_called_once_with("test/file.wav")
|
||||
mock_storage.delete_file.assert_called_once_with("test/file.wav")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_on_exception():
|
||||
"""Test that cleanup happens even when an exception occurs."""
|
||||
# Mock storage
|
||||
mock_storage = Mock()
|
||||
mock_storage.put_file = AsyncMock()
|
||||
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||
mock_storage.delete_file = AsyncMock()
|
||||
|
||||
# Use context manager with exception
|
||||
with pytest.raises(ValueError):
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
await s3_file.upload(b"test data")
|
||||
raise ValueError("Simulated error during processing")
|
||||
|
||||
# Verify cleanup still happened
|
||||
mock_storage.delete_file.assert_called_once_with("test/file.wav")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cleanup_if_not_uploaded():
|
||||
"""Test that cleanup is skipped if file was never uploaded."""
|
||||
# Mock storage
|
||||
mock_storage = Mock()
|
||||
mock_storage.delete_file = AsyncMock()
|
||||
|
||||
# Use context manager without uploading
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav"):
|
||||
pass # Don't upload anything
|
||||
|
||||
# Verify no cleanup attempted
|
||||
mock_storage.delete_file.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_failure_is_logged_not_raised():
|
||||
"""Test that cleanup failures are logged but don't raise exceptions."""
|
||||
# Mock storage
|
||||
mock_storage = Mock()
|
||||
mock_storage.put_file = AsyncMock()
|
||||
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||
mock_storage.delete_file = AsyncMock(side_effect=Exception("Delete failed"))
|
||||
|
||||
# Use context manager - should not raise
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
await s3_file.upload(b"test data")
|
||||
|
||||
# Verify delete was attempted (3 times due to retry)
|
||||
assert mock_storage.delete_file.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_retry_on_failure():
|
||||
"""Test that upload is retried on failure."""
|
||||
# Mock storage with failures then success
|
||||
mock_storage = Mock()
|
||||
mock_storage.put_file = AsyncMock(
|
||||
side_effect=[Exception("Network error"), None] # Fail once, then succeed
|
||||
)
|
||||
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||
mock_storage.delete_file = AsyncMock()
|
||||
|
||||
# Use context manager
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
url = await s3_file.upload(b"test data")
|
||||
assert url == "https://example.com/file.wav"
|
||||
|
||||
# Verify upload was retried
|
||||
assert mock_storage.put_file.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_retry_on_failure():
|
||||
"""Test that delete is retried on failure."""
|
||||
# Mock storage
|
||||
mock_storage = Mock()
|
||||
mock_storage.put_file = AsyncMock()
|
||||
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
|
||||
mock_storage.delete_file = AsyncMock(
|
||||
side_effect=[Exception("Network error"), None] # Fail once, then succeed
|
||||
)
|
||||
|
||||
# Use context manager
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
await s3_file.upload(b"test data")
|
||||
|
||||
# Verify delete was retried
|
||||
assert mock_storage.delete_file.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_properties_before_upload():
|
||||
"""Test that properties work correctly before upload."""
|
||||
mock_storage = Mock()
|
||||
|
||||
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
|
||||
assert s3_file.url is None
|
||||
assert s3_file.uploaded is False
|
||||
Reference in New Issue
Block a user