diff --git a/server/.gitignore b/server/.gitignore index 75accda0..8042ce84 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -180,3 +180,4 @@ reflector.sqlite3 data/ dump.rdb + diff --git a/server/reflector/tools/process_with_diarization.py b/server/reflector/tools/process_with_diarization.py new file mode 100644 index 00000000..78726cc1 --- /dev/null +++ b/server/reflector/tools/process_with_diarization.py @@ -0,0 +1,314 @@ +""" +@vibe-generated +Process audio file with diarization support +=========================================== + +Extended version of process.py that includes speaker diarization. +This tool processes audio files locally without requiring the full server infrastructure. +""" + +import asyncio +import tempfile +from pathlib import Path +from typing import List +import uuid + +import av +from reflector.logger import logger +from reflector.processors import ( + AudioChunkerProcessor, + AudioMergeProcessor, + AudioTranscriptAutoProcessor, + AudioFileWriterProcessor, + Pipeline, + PipelineEvent, + TranscriptFinalSummaryProcessor, + TranscriptFinalTitleProcessor, + TranscriptLinerProcessor, + TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, +) +from reflector.processors.base import BroadcastProcessor, Processor +from reflector.processors.types import ( + AudioDiarizationInput, + TitleSummary, + TitleSummaryWithId, +) + + +class TopicCollectorProcessor(Processor): + """Collect topics for diarization""" + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = TitleSummary + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.topics: List[TitleSummaryWithId] = [] + self._topic_id = 0 + + async def _push(self, data: TitleSummary): + # Convert to TitleSummaryWithId and collect + self._topic_id += 1 + topic_with_id = TitleSummaryWithId( + id=str(self._topic_id), + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + duration=data.duration, + transcript=data.transcript, + ) + self.topics.append(topic_with_id) + + # Pass through the original topic + await self.emit(data) + + def get_topics(self) -> List[TitleSummaryWithId]: + return self.topics + + +async def process_audio_file_with_diarization( + filename, + event_callback, + only_transcript=False, + source_language="en", + target_language="en", + enable_diarization=True, + diarization_backend="modal", +): + # Create temp file for audio if diarization is enabled + audio_temp_path = None + if enable_diarization: + audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + audio_temp_path = audio_temp_file.name + audio_temp_file.close() + + # Create processor for collecting topics + topic_collector = TopicCollectorProcessor() + + # Build pipeline for audio processing + processors = [] + + # Add audio file writer at the beginning if diarization is enabled + if enable_diarization: + processors.append(AudioFileWriterProcessor(audio_temp_path)) + + # Add the rest of the processors + processors += [ + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + ] + + processors += [ + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(), + ] + + if not only_transcript: + processors += [ + TranscriptTopicDetectorProcessor.as_threaded(), + # Collect topics for diarization + topic_collector, + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(), + TranscriptFinalSummaryProcessor.as_threaded(), + ], + ), + ] + + # Create main pipeline + pipeline = Pipeline(*processors) + pipeline.set_pref("audio:source_language", source_language) + pipeline.set_pref("audio:target_language", target_language) + pipeline.describe() + pipeline.on(event_callback) + + # Start processing audio + logger.info(f"Opening {filename}") + container = av.open(filename) + try: + logger.info("Start pushing audio into the pipeline") + for frame in container.decode(audio=0): + await pipeline.push(frame) + finally: + logger.info("Flushing the pipeline") + await pipeline.flush() + + # Run diarization if enabled and we have topics + if enable_diarization and not only_transcript and audio_temp_path: + topics = topic_collector.get_topics() + + if topics: + logger.info(f"Starting diarization with {len(topics)} topics") + + try: + # Import diarization processor + from reflector.processors import AudioDiarizationAutoProcessor + + # Create diarization processor + diarization_processor = AudioDiarizationAutoProcessor( + name=diarization_backend + ) + diarization_processor.on(event_callback) + + # For Modal backend, we need to upload the file to S3 first + if diarization_backend == "modal": + from reflector.storage import get_transcripts_storage + from reflector.utils.s3_temp_file import S3TemporaryFile + from datetime import datetime + + storage = get_transcripts_storage() + + # Generate a unique filename in evaluation folder + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav" + + # Use context manager for automatic cleanup + async with S3TemporaryFile(storage, audio_filename) as s3_file: + # Read and upload the audio file + with open(audio_temp_path, "rb") as f: + audio_data = f.read() + + audio_url = await s3_file.upload(audio_data) + logger.info(f"Uploaded audio to S3: {audio_filename}") + + # Create diarization input with S3 URL + diarization_input = AudioDiarizationInput( + audio_url=audio_url, topics=topics + ) + + # Run diarization + await diarization_processor.push(diarization_input) + await diarization_processor.flush() + + logger.info("Diarization complete") + # File will be automatically cleaned up when exiting the context + else: + # For local backend, use local file path + audio_url = audio_temp_path + + # Create diarization input + diarization_input = AudioDiarizationInput( + audio_url=audio_url, topics=topics + ) + + # Run diarization + await diarization_processor.push(diarization_input) + await diarization_processor.flush() + + logger.info("Diarization complete") + + except ImportError as e: + logger.error(f"Failed to import diarization dependencies: {e}") + logger.error( + "Install with: uv pip install pyannote.audio torch torchaudio" + ) + logger.error( + "And set HF_TOKEN environment variable for pyannote models" + ) + raise SystemExit(1) + except Exception as e: + logger.error(f"Diarization failed: {e}") + raise SystemExit(1) + else: + logger.warning("Skipping diarization: no topics available") + + # Clean up temp file + if audio_temp_path: + try: + Path(audio_temp_path).unlink() + except Exception as e: + logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}") + + logger.info("All done!") + + +if __name__ == "__main__": + import argparse + import os + + parser = argparse.ArgumentParser( + description="Process audio files with optional speaker diarization" + ) + parser.add_argument("source", help="Source file (mp3, wav, mp4...)") + parser.add_argument( + "--only-transcript", + "-t", + action="store_true", + help="Only generate transcript without topics/summaries", + ) + parser.add_argument( + "--source-language", default="en", help="Source language code (default: en)" + ) + parser.add_argument( + "--target-language", default="en", help="Target language code (default: en)" + ) + parser.add_argument("--output", "-o", help="Output file (output.jsonl)") + parser.add_argument( + "--enable-diarization", + "-d", + action="store_true", + help="Enable speaker diarization", + ) + parser.add_argument( + "--diarization-backend", + default="modal", + choices=["modal"], + help="Diarization backend to use (default: modal)", + ) + args = parser.parse_args() + + # Set REDIS_HOST to localhost if not provided + if "REDIS_HOST" not in os.environ: + os.environ["REDIS_HOST"] = "localhost" + logger.info("REDIS_HOST not set, defaulting to localhost") + + output_fd = None + if args.output: + output_fd = open(args.output, "w") + + async def event_callback(event: PipelineEvent): + processor = event.processor + data = event.data + + # Ignore internal processors + if processor in ( + "AudioChunkerProcessor", + "AudioMergeProcessor", + "AudioFileWriterProcessor", + "TopicCollectorProcessor", + "BroadcastProcessor", + ): + return + + # If diarization is enabled, skip the original topic events from the pipeline + # The diarization processor will emit the same topics but with speaker info + if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization: + return + + # Log all events + logger.info(f"Event: {processor} - {type(data).__name__}") + + # Write to output + if output_fd: + output_fd.write(event.model_dump_json()) + output_fd.write("\n") + output_fd.flush() + + asyncio.run( + process_audio_file_with_diarization( + args.source, + event_callback, + only_transcript=args.only_transcript, + source_language=args.source_language, + target_language=args.target_language, + enable_diarization=args.enable_diarization, + diarization_backend=args.diarization_backend, + ) + ) + + if output_fd: + output_fd.close() + logger.info(f"Output written to {args.output}") diff --git a/server/reflector/tools/test_diarization.py b/server/reflector/tools/test_diarization.py new file mode 100644 index 00000000..20383939 --- /dev/null +++ b/server/reflector/tools/test_diarization.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +@vibe-generated +Test script for the diarization CLI tool +========================================= + +This script helps test the diarization functionality with sample audio files. +""" + +import asyncio +import json +import sys +from pathlib import Path + +from reflector.logger import logger + + +async def test_diarization(audio_file: str): + """Test the diarization functionality""" + + # Import the processing function + from process_with_diarization import process_audio_file_with_diarization + + # Collect events + events = [] + + async def event_callback(event): + events.append({ + "processor": event.processor, + "data": event.data + }) + logger.info(f"Event from {event.processor}") + + # Process the audio file + logger.info(f"Processing audio file: {audio_file}") + + try: + await process_audio_file_with_diarization( + audio_file, + event_callback, + only_transcript=False, + source_language="en", + target_language="en", + enable_diarization=True, + diarization_backend="modal", + ) + + # Analyze results + logger.info(f"Processing complete. Received {len(events)} events") + + # Look for diarization results + diarized_topics = [] + for event in events: + if "TitleSummary" in event["processor"]: + # Check if words have speaker information + if hasattr(event["data"], "transcript") and event["data"].transcript: + words = event["data"].transcript.words + if words and hasattr(words[0], "speaker"): + speakers = set(w.speaker for w in words if hasattr(w, "speaker")) + logger.info(f"Found {len(speakers)} speakers in topic: {event['data'].title}") + diarized_topics.append(event["data"]) + + if diarized_topics: + logger.info(f"Successfully diarized {len(diarized_topics)} topics") + + # Print sample output + sample_topic = diarized_topics[0] + logger.info("Sample diarized output:") + for i, word in enumerate(sample_topic.transcript.words[:10]): + logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}") + else: + logger.warning("No diarization results found in output") + + return events + + except Exception as e: + logger.error(f"Error during processing: {e}") + raise + + +def main(): + if len(sys.argv) < 2: + print("Usage: python test_diarization.py ") + sys.exit(1) + + audio_file = sys.argv[1] + if not Path(audio_file).exists(): + print(f"Error: Audio file '{audio_file}' not found") + sys.exit(1) + + # Run the test + asyncio.run(test_diarization(audio_file)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/server/reflector/utils/s3_temp_file.py b/server/reflector/utils/s3_temp_file.py new file mode 100644 index 00000000..17b12aaa --- /dev/null +++ b/server/reflector/utils/s3_temp_file.py @@ -0,0 +1,149 @@ +""" +@vibe-generated +S3 Temporary File Context Manager + +Provides automatic cleanup of S3 files with retry logic and proper error handling. +""" + +from typing import Optional +from reflector.storage.base import Storage +from reflector.logger import logger +from reflector.utils.retry import retry + + +class S3TemporaryFile: + """ + Async context manager for temporary S3 files with automatic cleanup. + + Ensures that uploaded files are deleted even if exceptions occur during processing. + Uses retry logic for all S3 operations to handle transient failures. + + Example: + async with S3TemporaryFile(storage, "temp/audio.wav") as s3_file: + url = await s3_file.upload(audio_data) + # Use url for processing + # File is automatically cleaned up here + """ + + def __init__(self, storage: Storage, filepath: str): + """ + Initialize the temporary file context. + + Args: + storage: Storage instance for S3 operations + filepath: S3 key/path for the temporary file + """ + self.storage = storage + self.filepath = filepath + self.uploaded = False + self._url: Optional[str] = None + + async def __aenter__(self): + """Enter the context manager.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager and clean up the file. + + Cleanup is attempted even if an exception occurred during processing. + Cleanup failures are logged but don't raise exceptions. + """ + if self.uploaded: + try: + await self._delete_with_retry() + logger.info(f"Successfully cleaned up S3 file: {self.filepath}") + except Exception as e: + # Log the error but don't raise - we don't want cleanup failures + # to mask the original exception + logger.warning( + f"Failed to cleanup S3 file {self.filepath} after retries: {e}" + ) + return False # Don't suppress exceptions + + async def upload(self, data: bytes) -> str: + """ + Upload data to S3 and return the public URL. + + Args: + data: File data to upload + + Returns: + Public URL for the uploaded file + + Raises: + Exception: If upload or URL generation fails after retries + """ + await self._upload_with_retry(data) + self.uploaded = True + self._url = await self._get_url_with_retry() + return self._url + + @property + def url(self) -> Optional[str]: + """Get the URL of the uploaded file, if available.""" + return self._url + + async def _upload_with_retry(self, data: bytes): + """Upload file to S3 with retry logic.""" + + async def upload(): + await self.storage.put_file(self.filepath, data) + logger.debug(f"Successfully uploaded file to S3: {self.filepath}") + return True # Return something to indicate success + + await retry(upload)( + retry_attempts=3, + retry_timeout=30.0, + retry_backoff_interval=0.5, + retry_backoff_max=5.0, + ) + + async def _get_url_with_retry(self) -> str: + """Get public URL for the file with retry logic.""" + + async def get_url(): + url = await self.storage.get_file_url(self.filepath) + logger.debug(f"Generated public URL for S3 file: {self.filepath}") + return url + + return await retry(get_url)( + retry_attempts=3, + retry_timeout=30.0, + retry_backoff_interval=0.5, + retry_backoff_max=5.0, + ) + + async def _delete_with_retry(self): + """Delete file from S3 with retry logic.""" + + async def delete(): + await self.storage.delete_file(self.filepath) + logger.debug(f"Successfully deleted S3 file: {self.filepath}") + return True # Return something to indicate success + + await retry(delete)( + retry_attempts=3, + retry_timeout=30.0, + retry_backoff_interval=0.5, + retry_backoff_max=5.0, + ) + + +# Convenience function for simpler usage +async def temporary_s3_file(storage: Storage, filepath: str): + """ + Create a temporary S3 file context manager. + + This is a convenience wrapper around S3TemporaryFile for simpler usage. + + Args: + storage: Storage instance for S3 operations + filepath: S3 key/path for the temporary file + + Example: + async with temporary_s3_file(storage, "temp/audio.wav") as s3_file: + url = await s3_file.upload(audio_data) + # Use url for processing + """ + return S3TemporaryFile(storage, filepath) diff --git a/server/tests/test_s3_temp_file.py b/server/tests/test_s3_temp_file.py new file mode 100644 index 00000000..b212fbbd --- /dev/null +++ b/server/tests/test_s3_temp_file.py @@ -0,0 +1,129 @@ +""" +@vibe-generated +Tests for S3 temporary file context manager. +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from reflector.utils.s3_temp_file import S3TemporaryFile + + +@pytest.mark.asyncio +async def test_successful_upload_and_cleanup(): + """Test that file is uploaded and cleaned up on success.""" + # Mock storage + mock_storage = Mock() + mock_storage.put_file = AsyncMock() + mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav") + mock_storage.delete_file = AsyncMock() + + # Use context manager + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + url = await s3_file.upload(b"test data") + assert url == "https://example.com/file.wav" + assert s3_file.url == "https://example.com/file.wav" + + # Verify operations + mock_storage.put_file.assert_called_once_with("test/file.wav", b"test data") + mock_storage.get_file_url.assert_called_once_with("test/file.wav") + mock_storage.delete_file.assert_called_once_with("test/file.wav") + + +@pytest.mark.asyncio +async def test_cleanup_on_exception(): + """Test that cleanup happens even when an exception occurs.""" + # Mock storage + mock_storage = Mock() + mock_storage.put_file = AsyncMock() + mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav") + mock_storage.delete_file = AsyncMock() + + # Use context manager with exception + with pytest.raises(ValueError): + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + await s3_file.upload(b"test data") + raise ValueError("Simulated error during processing") + + # Verify cleanup still happened + mock_storage.delete_file.assert_called_once_with("test/file.wav") + + +@pytest.mark.asyncio +async def test_no_cleanup_if_not_uploaded(): + """Test that cleanup is skipped if file was never uploaded.""" + # Mock storage + mock_storage = Mock() + mock_storage.delete_file = AsyncMock() + + # Use context manager without uploading + async with S3TemporaryFile(mock_storage, "test/file.wav"): + pass # Don't upload anything + + # Verify no cleanup attempted + mock_storage.delete_file.assert_not_called() + + +@pytest.mark.asyncio +async def test_cleanup_failure_is_logged_not_raised(): + """Test that cleanup failures are logged but don't raise exceptions.""" + # Mock storage + mock_storage = Mock() + mock_storage.put_file = AsyncMock() + mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav") + mock_storage.delete_file = AsyncMock(side_effect=Exception("Delete failed")) + + # Use context manager - should not raise + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + await s3_file.upload(b"test data") + + # Verify delete was attempted (3 times due to retry) + assert mock_storage.delete_file.call_count == 3 + + +@pytest.mark.asyncio +async def test_upload_retry_on_failure(): + """Test that upload is retried on failure.""" + # Mock storage with failures then success + mock_storage = Mock() + mock_storage.put_file = AsyncMock( + side_effect=[Exception("Network error"), None] # Fail once, then succeed + ) + mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav") + mock_storage.delete_file = AsyncMock() + + # Use context manager + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + url = await s3_file.upload(b"test data") + assert url == "https://example.com/file.wav" + + # Verify upload was retried + assert mock_storage.put_file.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_retry_on_failure(): + """Test that delete is retried on failure.""" + # Mock storage + mock_storage = Mock() + mock_storage.put_file = AsyncMock() + mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav") + mock_storage.delete_file = AsyncMock( + side_effect=[Exception("Network error"), None] # Fail once, then succeed + ) + + # Use context manager + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + await s3_file.upload(b"test data") + + # Verify delete was retried + assert mock_storage.delete_file.call_count == 2 + + +@pytest.mark.asyncio +async def test_properties_before_upload(): + """Test that properties work correctly before upload.""" + mock_storage = Mock() + + async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file: + assert s3_file.url is None + assert s3_file.uploaded is False