feat: Diarization cli (#509)

* diarisation cli

* feat: s3 upload for modal diarisation cli call

* chore: cleanup

* chore: s3 cleanup improvement

* chore: lint

* chore: cleanup

* chore: cleanup

* chore: cleanup

* chore: cleanup
This commit is contained in:
Igor Loskutov
2025-07-25 16:24:06 -04:00
committed by GitHub
parent 2289a1a231
commit 27b43d85ab
5 changed files with 689 additions and 0 deletions

1
server/.gitignore vendored
View File

@@ -180,3 +180,4 @@ reflector.sqlite3
data/
dump.rdb

View File

@@ -0,0 +1,314 @@
"""
@vibe-generated
Process audio file with diarization support
===========================================
Extended version of process.py that includes speaker diarization.
This tool processes audio files locally without requiring the full server infrastructure.
"""
import asyncio
import tempfile
from pathlib import Path
from typing import List
import uuid
import av
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
Pipeline,
PipelineEvent,
TranscriptFinalSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor, Processor
from reflector.processors.types import (
AudioDiarizationInput,
TitleSummary,
TitleSummaryWithId,
)
class TopicCollectorProcessor(Processor):
"""Collect topics for diarization"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.topics: List[TitleSummaryWithId] = []
self._topic_id = 0
async def _push(self, data: TitleSummary):
# Convert to TitleSummaryWithId and collect
self._topic_id += 1
topic_with_id = TitleSummaryWithId(
id=str(self._topic_id),
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
duration=data.duration,
transcript=data.transcript,
)
self.topics.append(topic_with_id)
# Pass through the original topic
await self.emit(data)
def get_topics(self) -> List[TitleSummaryWithId]:
return self.topics
async def process_audio_file_with_diarization(
filename,
event_callback,
only_transcript=False,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="modal",
):
# Create temp file for audio if diarization is enabled
audio_temp_path = None
if enable_diarization:
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_temp_path = audio_temp_file.name
audio_temp_file.close()
# Create processor for collecting topics
topic_collector = TopicCollectorProcessor()
# Build pipeline for audio processing
processors = []
# Add audio file writer at the beginning if diarization is enabled
if enable_diarization:
processors.append(AudioFileWriterProcessor(audio_temp_path))
# Add the rest of the processors
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
]
processors += [
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(),
]
if not only_transcript:
processors += [
TranscriptTopicDetectorProcessor.as_threaded(),
# Collect topics for diarization
topic_collector,
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(),
TranscriptFinalSummaryProcessor.as_threaded(),
],
),
]
# Create main pipeline
pipeline = Pipeline(*processors)
pipeline.set_pref("audio:source_language", source_language)
pipeline.set_pref("audio:target_language", target_language)
pipeline.describe()
pipeline.on(event_callback)
# Start processing audio
logger.info(f"Opening {filename}")
container = av.open(filename)
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
await pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
await pipeline.flush()
# Run diarization if enabled and we have topics
if enable_diarization and not only_transcript and audio_temp_path:
topics = topic_collector.get_topics()
if topics:
logger.info(f"Starting diarization with {len(topics)} topics")
try:
# Import diarization processor
from reflector.processors import AudioDiarizationAutoProcessor
# Create diarization processor
diarization_processor = AudioDiarizationAutoProcessor(
name=diarization_backend
)
diarization_processor.on(event_callback)
# For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal":
from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile
from datetime import datetime
storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup
async with S3TemporaryFile(storage, audio_filename) as s3_file:
# Read and upload the audio file
with open(audio_temp_path, "rb") as f:
audio_data = f.read()
audio_url = await s3_file.upload(audio_data)
logger.info(f"Uploaded audio to S3: {audio_filename}")
# Create diarization input with S3 URL
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
# File will be automatically cleaned up when exiting the context
else:
# For local backend, use local file path
audio_url = audio_temp_path
# Create diarization input
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
except ImportError as e:
logger.error(f"Failed to import diarization dependencies: {e}")
logger.error(
"Install with: uv pip install pyannote.audio torch torchaudio"
)
logger.error(
"And set HF_TOKEN environment variable for pyannote models"
)
raise SystemExit(1)
except Exception as e:
logger.error(f"Diarization failed: {e}")
raise SystemExit(1)
else:
logger.warning("Skipping diarization: no topics available")
# Clean up temp file
if audio_temp_path:
try:
Path(audio_temp_path).unlink()
except Exception as e:
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
logger.info("All done!")
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser(
description="Process audio files with optional speaker diarization"
)
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument(
"--only-transcript",
"-t",
action="store_true",
help="Only generate transcript without topics/summaries",
)
parser.add_argument(
"--source-language", default="en", help="Source language code (default: en)"
)
parser.add_argument(
"--target-language", default="en", help="Target language code (default: en)"
)
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
parser.add_argument(
"--enable-diarization",
"-d",
action="store_true",
help="Enable speaker diarization",
)
parser.add_argument(
"--diarization-backend",
default="modal",
choices=["modal"],
help="Diarization backend to use (default: modal)",
)
args = parser.parse_args()
# Set REDIS_HOST to localhost if not provided
if "REDIS_HOST" not in os.environ:
os.environ["REDIS_HOST"] = "localhost"
logger.info("REDIS_HOST not set, defaulting to localhost")
output_fd = None
if args.output:
output_fd = open(args.output, "w")
async def event_callback(event: PipelineEvent):
processor = event.processor
data = event.data
# Ignore internal processors
if processor in (
"AudioChunkerProcessor",
"AudioMergeProcessor",
"AudioFileWriterProcessor",
"TopicCollectorProcessor",
"BroadcastProcessor",
):
return
# If diarization is enabled, skip the original topic events from the pipeline
# The diarization processor will emit the same topics but with speaker info
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
return
# Log all events
logger.info(f"Event: {processor} - {type(data).__name__}")
# Write to output
if output_fd:
output_fd.write(event.model_dump_json())
output_fd.write("\n")
output_fd.flush()
asyncio.run(
process_audio_file_with_diarization(
args.source,
event_callback,
only_transcript=args.only_transcript,
source_language=args.source_language,
target_language=args.target_language,
enable_diarization=args.enable_diarization,
diarization_backend=args.diarization_backend,
)
)
if output_fd:
output_fd.close()
logger.info(f"Output written to {args.output}")

View File

@@ -0,0 +1,96 @@
#!/usr/bin/env python3
"""
@vibe-generated
Test script for the diarization CLI tool
=========================================
This script helps test the diarization functionality with sample audio files.
"""
import asyncio
import json
import sys
from pathlib import Path
from reflector.logger import logger
async def test_diarization(audio_file: str):
"""Test the diarization functionality"""
# Import the processing function
from process_with_diarization import process_audio_file_with_diarization
# Collect events
events = []
async def event_callback(event):
events.append({
"processor": event.processor,
"data": event.data
})
logger.info(f"Event from {event.processor}")
# Process the audio file
logger.info(f"Processing audio file: {audio_file}")
try:
await process_audio_file_with_diarization(
audio_file,
event_callback,
only_transcript=False,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="modal",
)
# Analyze results
logger.info(f"Processing complete. Received {len(events)} events")
# Look for diarization results
diarized_topics = []
for event in events:
if "TitleSummary" in event["processor"]:
# Check if words have speaker information
if hasattr(event["data"], "transcript") and event["data"].transcript:
words = event["data"].transcript.words
if words and hasattr(words[0], "speaker"):
speakers = set(w.speaker for w in words if hasattr(w, "speaker"))
logger.info(f"Found {len(speakers)} speakers in topic: {event['data'].title}")
diarized_topics.append(event["data"])
if diarized_topics:
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
# Print sample output
sample_topic = diarized_topics[0]
logger.info("Sample diarized output:")
for i, word in enumerate(sample_topic.transcript.words[:10]):
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
else:
logger.warning("No diarization results found in output")
return events
except Exception as e:
logger.error(f"Error during processing: {e}")
raise
def main():
if len(sys.argv) < 2:
print("Usage: python test_diarization.py <audio_file>")
sys.exit(1)
audio_file = sys.argv[1]
if not Path(audio_file).exists():
print(f"Error: Audio file '{audio_file}' not found")
sys.exit(1)
# Run the test
asyncio.run(test_diarization(audio_file))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,149 @@
"""
@vibe-generated
S3 Temporary File Context Manager
Provides automatic cleanup of S3 files with retry logic and proper error handling.
"""
from typing import Optional
from reflector.storage.base import Storage
from reflector.logger import logger
from reflector.utils.retry import retry
class S3TemporaryFile:
"""
Async context manager for temporary S3 files with automatic cleanup.
Ensures that uploaded files are deleted even if exceptions occur during processing.
Uses retry logic for all S3 operations to handle transient failures.
Example:
async with S3TemporaryFile(storage, "temp/audio.wav") as s3_file:
url = await s3_file.upload(audio_data)
# Use url for processing
# File is automatically cleaned up here
"""
def __init__(self, storage: Storage, filepath: str):
"""
Initialize the temporary file context.
Args:
storage: Storage instance for S3 operations
filepath: S3 key/path for the temporary file
"""
self.storage = storage
self.filepath = filepath
self.uploaded = False
self._url: Optional[str] = None
async def __aenter__(self):
"""Enter the context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Exit the context manager and clean up the file.
Cleanup is attempted even if an exception occurred during processing.
Cleanup failures are logged but don't raise exceptions.
"""
if self.uploaded:
try:
await self._delete_with_retry()
logger.info(f"Successfully cleaned up S3 file: {self.filepath}")
except Exception as e:
# Log the error but don't raise - we don't want cleanup failures
# to mask the original exception
logger.warning(
f"Failed to cleanup S3 file {self.filepath} after retries: {e}"
)
return False # Don't suppress exceptions
async def upload(self, data: bytes) -> str:
"""
Upload data to S3 and return the public URL.
Args:
data: File data to upload
Returns:
Public URL for the uploaded file
Raises:
Exception: If upload or URL generation fails after retries
"""
await self._upload_with_retry(data)
self.uploaded = True
self._url = await self._get_url_with_retry()
return self._url
@property
def url(self) -> Optional[str]:
"""Get the URL of the uploaded file, if available."""
return self._url
async def _upload_with_retry(self, data: bytes):
"""Upload file to S3 with retry logic."""
async def upload():
await self.storage.put_file(self.filepath, data)
logger.debug(f"Successfully uploaded file to S3: {self.filepath}")
return True # Return something to indicate success
await retry(upload)(
retry_attempts=3,
retry_timeout=30.0,
retry_backoff_interval=0.5,
retry_backoff_max=5.0,
)
async def _get_url_with_retry(self) -> str:
"""Get public URL for the file with retry logic."""
async def get_url():
url = await self.storage.get_file_url(self.filepath)
logger.debug(f"Generated public URL for S3 file: {self.filepath}")
return url
return await retry(get_url)(
retry_attempts=3,
retry_timeout=30.0,
retry_backoff_interval=0.5,
retry_backoff_max=5.0,
)
async def _delete_with_retry(self):
"""Delete file from S3 with retry logic."""
async def delete():
await self.storage.delete_file(self.filepath)
logger.debug(f"Successfully deleted S3 file: {self.filepath}")
return True # Return something to indicate success
await retry(delete)(
retry_attempts=3,
retry_timeout=30.0,
retry_backoff_interval=0.5,
retry_backoff_max=5.0,
)
# Convenience function for simpler usage
async def temporary_s3_file(storage: Storage, filepath: str):
"""
Create a temporary S3 file context manager.
This is a convenience wrapper around S3TemporaryFile for simpler usage.
Args:
storage: Storage instance for S3 operations
filepath: S3 key/path for the temporary file
Example:
async with temporary_s3_file(storage, "temp/audio.wav") as s3_file:
url = await s3_file.upload(audio_data)
# Use url for processing
"""
return S3TemporaryFile(storage, filepath)

View File

@@ -0,0 +1,129 @@
"""
@vibe-generated
Tests for S3 temporary file context manager.
"""
import pytest
from unittest.mock import Mock, AsyncMock
from reflector.utils.s3_temp_file import S3TemporaryFile
@pytest.mark.asyncio
async def test_successful_upload_and_cleanup():
"""Test that file is uploaded and cleaned up on success."""
# Mock storage
mock_storage = Mock()
mock_storage.put_file = AsyncMock()
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
mock_storage.delete_file = AsyncMock()
# Use context manager
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
url = await s3_file.upload(b"test data")
assert url == "https://example.com/file.wav"
assert s3_file.url == "https://example.com/file.wav"
# Verify operations
mock_storage.put_file.assert_called_once_with("test/file.wav", b"test data")
mock_storage.get_file_url.assert_called_once_with("test/file.wav")
mock_storage.delete_file.assert_called_once_with("test/file.wav")
@pytest.mark.asyncio
async def test_cleanup_on_exception():
"""Test that cleanup happens even when an exception occurs."""
# Mock storage
mock_storage = Mock()
mock_storage.put_file = AsyncMock()
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
mock_storage.delete_file = AsyncMock()
# Use context manager with exception
with pytest.raises(ValueError):
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
await s3_file.upload(b"test data")
raise ValueError("Simulated error during processing")
# Verify cleanup still happened
mock_storage.delete_file.assert_called_once_with("test/file.wav")
@pytest.mark.asyncio
async def test_no_cleanup_if_not_uploaded():
"""Test that cleanup is skipped if file was never uploaded."""
# Mock storage
mock_storage = Mock()
mock_storage.delete_file = AsyncMock()
# Use context manager without uploading
async with S3TemporaryFile(mock_storage, "test/file.wav"):
pass # Don't upload anything
# Verify no cleanup attempted
mock_storage.delete_file.assert_not_called()
@pytest.mark.asyncio
async def test_cleanup_failure_is_logged_not_raised():
"""Test that cleanup failures are logged but don't raise exceptions."""
# Mock storage
mock_storage = Mock()
mock_storage.put_file = AsyncMock()
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
mock_storage.delete_file = AsyncMock(side_effect=Exception("Delete failed"))
# Use context manager - should not raise
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
await s3_file.upload(b"test data")
# Verify delete was attempted (3 times due to retry)
assert mock_storage.delete_file.call_count == 3
@pytest.mark.asyncio
async def test_upload_retry_on_failure():
"""Test that upload is retried on failure."""
# Mock storage with failures then success
mock_storage = Mock()
mock_storage.put_file = AsyncMock(
side_effect=[Exception("Network error"), None] # Fail once, then succeed
)
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
mock_storage.delete_file = AsyncMock()
# Use context manager
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
url = await s3_file.upload(b"test data")
assert url == "https://example.com/file.wav"
# Verify upload was retried
assert mock_storage.put_file.call_count == 2
@pytest.mark.asyncio
async def test_delete_retry_on_failure():
"""Test that delete is retried on failure."""
# Mock storage
mock_storage = Mock()
mock_storage.put_file = AsyncMock()
mock_storage.get_file_url = AsyncMock(return_value="https://example.com/file.wav")
mock_storage.delete_file = AsyncMock(
side_effect=[Exception("Network error"), None] # Fail once, then succeed
)
# Use context manager
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
await s3_file.upload(b"test data")
# Verify delete was retried
assert mock_storage.delete_file.call_count == 2
@pytest.mark.asyncio
async def test_properties_before_upload():
"""Test that properties work correctly before upload."""
mock_storage = Mock()
async with S3TemporaryFile(mock_storage, "test/file.wav") as s3_file:
assert s3_file.url is None
assert s3_file.uploaded is False