feat: Diarization cli (#509)

* diarisation cli

* feat: s3 upload for modal diarisation cli call

* chore: cleanup

* chore: s3 cleanup improvement

* chore: lint

* chore: cleanup

* chore: cleanup

* chore: cleanup

* chore: cleanup
This commit is contained in:
Igor Loskutov
2025-07-25 16:24:06 -04:00
committed by GitHub
parent 2289a1a231
commit 27b43d85ab
5 changed files with 689 additions and 0 deletions

View File

@@ -0,0 +1,314 @@
"""
@vibe-generated
Process audio file with diarization support
===========================================
Extended version of process.py that includes speaker diarization.
This tool processes audio files locally without requiring the full server infrastructure.
"""
import asyncio
import tempfile
from pathlib import Path
from typing import List
import uuid
import av
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
Pipeline,
PipelineEvent,
TranscriptFinalSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor, Processor
from reflector.processors.types import (
AudioDiarizationInput,
TitleSummary,
TitleSummaryWithId,
)
class TopicCollectorProcessor(Processor):
"""Collect topics for diarization"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.topics: List[TitleSummaryWithId] = []
self._topic_id = 0
async def _push(self, data: TitleSummary):
# Convert to TitleSummaryWithId and collect
self._topic_id += 1
topic_with_id = TitleSummaryWithId(
id=str(self._topic_id),
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
duration=data.duration,
transcript=data.transcript,
)
self.topics.append(topic_with_id)
# Pass through the original topic
await self.emit(data)
def get_topics(self) -> List[TitleSummaryWithId]:
return self.topics
async def process_audio_file_with_diarization(
filename,
event_callback,
only_transcript=False,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="modal",
):
# Create temp file for audio if diarization is enabled
audio_temp_path = None
if enable_diarization:
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_temp_path = audio_temp_file.name
audio_temp_file.close()
# Create processor for collecting topics
topic_collector = TopicCollectorProcessor()
# Build pipeline for audio processing
processors = []
# Add audio file writer at the beginning if diarization is enabled
if enable_diarization:
processors.append(AudioFileWriterProcessor(audio_temp_path))
# Add the rest of the processors
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
]
processors += [
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(),
]
if not only_transcript:
processors += [
TranscriptTopicDetectorProcessor.as_threaded(),
# Collect topics for diarization
topic_collector,
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(),
TranscriptFinalSummaryProcessor.as_threaded(),
],
),
]
# Create main pipeline
pipeline = Pipeline(*processors)
pipeline.set_pref("audio:source_language", source_language)
pipeline.set_pref("audio:target_language", target_language)
pipeline.describe()
pipeline.on(event_callback)
# Start processing audio
logger.info(f"Opening {filename}")
container = av.open(filename)
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
await pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
await pipeline.flush()
# Run diarization if enabled and we have topics
if enable_diarization and not only_transcript and audio_temp_path:
topics = topic_collector.get_topics()
if topics:
logger.info(f"Starting diarization with {len(topics)} topics")
try:
# Import diarization processor
from reflector.processors import AudioDiarizationAutoProcessor
# Create diarization processor
diarization_processor = AudioDiarizationAutoProcessor(
name=diarization_backend
)
diarization_processor.on(event_callback)
# For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal":
from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile
from datetime import datetime
storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup
async with S3TemporaryFile(storage, audio_filename) as s3_file:
# Read and upload the audio file
with open(audio_temp_path, "rb") as f:
audio_data = f.read()
audio_url = await s3_file.upload(audio_data)
logger.info(f"Uploaded audio to S3: {audio_filename}")
# Create diarization input with S3 URL
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
# File will be automatically cleaned up when exiting the context
else:
# For local backend, use local file path
audio_url = audio_temp_path
# Create diarization input
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
except ImportError as e:
logger.error(f"Failed to import diarization dependencies: {e}")
logger.error(
"Install with: uv pip install pyannote.audio torch torchaudio"
)
logger.error(
"And set HF_TOKEN environment variable for pyannote models"
)
raise SystemExit(1)
except Exception as e:
logger.error(f"Diarization failed: {e}")
raise SystemExit(1)
else:
logger.warning("Skipping diarization: no topics available")
# Clean up temp file
if audio_temp_path:
try:
Path(audio_temp_path).unlink()
except Exception as e:
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
logger.info("All done!")
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser(
description="Process audio files with optional speaker diarization"
)
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument(
"--only-transcript",
"-t",
action="store_true",
help="Only generate transcript without topics/summaries",
)
parser.add_argument(
"--source-language", default="en", help="Source language code (default: en)"
)
parser.add_argument(
"--target-language", default="en", help="Target language code (default: en)"
)
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
parser.add_argument(
"--enable-diarization",
"-d",
action="store_true",
help="Enable speaker diarization",
)
parser.add_argument(
"--diarization-backend",
default="modal",
choices=["modal"],
help="Diarization backend to use (default: modal)",
)
args = parser.parse_args()
# Set REDIS_HOST to localhost if not provided
if "REDIS_HOST" not in os.environ:
os.environ["REDIS_HOST"] = "localhost"
logger.info("REDIS_HOST not set, defaulting to localhost")
output_fd = None
if args.output:
output_fd = open(args.output, "w")
async def event_callback(event: PipelineEvent):
processor = event.processor
data = event.data
# Ignore internal processors
if processor in (
"AudioChunkerProcessor",
"AudioMergeProcessor",
"AudioFileWriterProcessor",
"TopicCollectorProcessor",
"BroadcastProcessor",
):
return
# If diarization is enabled, skip the original topic events from the pipeline
# The diarization processor will emit the same topics but with speaker info
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
return
# Log all events
logger.info(f"Event: {processor} - {type(data).__name__}")
# Write to output
if output_fd:
output_fd.write(event.model_dump_json())
output_fd.write("\n")
output_fd.flush()
asyncio.run(
process_audio_file_with_diarization(
args.source,
event_callback,
only_transcript=args.only_transcript,
source_language=args.source_language,
target_language=args.target_language,
enable_diarization=args.enable_diarization,
diarization_backend=args.diarization_backend,
)
)
if output_fd:
output_fd.close()
logger.info(f"Output written to {args.output}")

View File

@@ -0,0 +1,96 @@
#!/usr/bin/env python3
"""
@vibe-generated
Test script for the diarization CLI tool
=========================================
This script helps test the diarization functionality with sample audio files.
"""
import asyncio
import json
import sys
from pathlib import Path
from reflector.logger import logger
async def test_diarization(audio_file: str):
"""Test the diarization functionality"""
# Import the processing function
from process_with_diarization import process_audio_file_with_diarization
# Collect events
events = []
async def event_callback(event):
events.append({
"processor": event.processor,
"data": event.data
})
logger.info(f"Event from {event.processor}")
# Process the audio file
logger.info(f"Processing audio file: {audio_file}")
try:
await process_audio_file_with_diarization(
audio_file,
event_callback,
only_transcript=False,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="modal",
)
# Analyze results
logger.info(f"Processing complete. Received {len(events)} events")
# Look for diarization results
diarized_topics = []
for event in events:
if "TitleSummary" in event["processor"]:
# Check if words have speaker information
if hasattr(event["data"], "transcript") and event["data"].transcript:
words = event["data"].transcript.words
if words and hasattr(words[0], "speaker"):
speakers = set(w.speaker for w in words if hasattr(w, "speaker"))
logger.info(f"Found {len(speakers)} speakers in topic: {event['data'].title}")
diarized_topics.append(event["data"])
if diarized_topics:
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
# Print sample output
sample_topic = diarized_topics[0]
logger.info("Sample diarized output:")
for i, word in enumerate(sample_topic.transcript.words[:10]):
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
else:
logger.warning("No diarization results found in output")
return events
except Exception as e:
logger.error(f"Error during processing: {e}")
raise
def main():
if len(sys.argv) < 2:
print("Usage: python test_diarization.py <audio_file>")
sys.exit(1)
audio_file = sys.argv[1]
if not Path(audio_file).exists():
print(f"Error: Audio file '{audio_file}' not found")
sys.exit(1)
# Run the test
asyncio.run(test_diarization(audio_file))
if __name__ == "__main__":
main()