mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: Igor/evaluation (#575)
* fix: impossible import error (#563) * evaluation cli - database events experiment * hallucinations * evaluation - unhallucinate * evaluation - unhallucinate * roll back reliability link * self reviewio * lint * self review * add file pipeline to cli * add file pipeline to cli + sorting * remove cli tests * remove ai comments * comments
This commit is contained in:
@@ -794,7 +794,7 @@ def pipeline_post(*, transcript_id: str):
|
|||||||
chain_final_summaries,
|
chain_final_summaries,
|
||||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||||
|
|
||||||
chain.delay()
|
return chain.delay()
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
@get_transcript
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
|||||||
for word_info in result.get("words", [])
|
for word_info in result.get("words", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# words come not in order
|
||||||
|
words.sort(key=lambda w: w.start)
|
||||||
|
|
||||||
return Transcript(words=words)
|
return Transcript(words=words)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,294 +1,204 @@
|
|||||||
"""
|
"""
|
||||||
Process audio file with diarization support
|
Process audio file with diarization support
|
||||||
===========================================
|
|
||||||
|
|
||||||
Extended version of process.py that includes speaker diarization.
|
|
||||||
This tool processes audio files locally without requiring the full server infrastructure.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import json
|
||||||
import uuid
|
import shutil
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Any, Dict, List, Literal
|
||||||
|
|
||||||
import av
|
|
||||||
|
|
||||||
|
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.pipelines.main_file_pipeline import (
|
||||||
AudioChunkerAutoProcessor,
|
task_pipeline_file_process as task_pipeline_file_process,
|
||||||
AudioDownscaleProcessor,
|
|
||||||
AudioFileWriterProcessor,
|
|
||||||
AudioMergeProcessor,
|
|
||||||
AudioTranscriptAutoProcessor,
|
|
||||||
Pipeline,
|
|
||||||
PipelineEvent,
|
|
||||||
TranscriptFinalSummaryProcessor,
|
|
||||||
TranscriptFinalTitleProcessor,
|
|
||||||
TranscriptLinerProcessor,
|
|
||||||
TranscriptTopicDetectorProcessor,
|
|
||||||
TranscriptTranslatorAutoProcessor,
|
|
||||||
)
|
)
|
||||||
from reflector.processors.base import BroadcastProcessor, Processor
|
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||||
from reflector.processors.types import (
|
from reflector.pipelines.main_live_pipeline import (
|
||||||
AudioDiarizationInput,
|
pipeline_process as live_pipeline_process,
|
||||||
TitleSummary,
|
|
||||||
TitleSummaryWithId,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TopicCollectorProcessor(Processor):
|
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
||||||
"""Collect topics for diarization"""
|
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
|
||||||
|
serialized = []
|
||||||
INPUT_TYPE = TitleSummary
|
for topic in topics:
|
||||||
OUTPUT_TYPE = TitleSummary
|
topic_dict = topic.model_dump()
|
||||||
|
serialized.append(topic_dict)
|
||||||
def __init__(self, **kwargs):
|
return serialized
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.topics: List[TitleSummaryWithId] = []
|
|
||||||
self._topic_id = 0
|
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
|
||||||
# Convert to TitleSummaryWithId and collect
|
|
||||||
self._topic_id += 1
|
|
||||||
topic_with_id = TitleSummaryWithId(
|
|
||||||
id=str(self._topic_id),
|
|
||||||
title=data.title,
|
|
||||||
summary=data.summary,
|
|
||||||
timestamp=data.timestamp,
|
|
||||||
duration=data.duration,
|
|
||||||
transcript=data.transcript,
|
|
||||||
)
|
|
||||||
self.topics.append(topic_with_id)
|
|
||||||
|
|
||||||
# Pass through the original topic
|
|
||||||
await self.emit(data)
|
|
||||||
|
|
||||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
|
||||||
return self.topics
|
|
||||||
|
|
||||||
|
|
||||||
async def process_audio_file(
|
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
||||||
filename,
|
"""Print debug info about speakers found in topics"""
|
||||||
event_callback,
|
all_speakers = set()
|
||||||
only_transcript=False,
|
for topic_dict in serialized_topics:
|
||||||
source_language="en",
|
for word in topic_dict.get("words", []):
|
||||||
target_language="en",
|
all_speakers.add(word.get("speaker", 0))
|
||||||
enable_diarization=True,
|
|
||||||
diarization_backend="pyannote",
|
|
||||||
):
|
|
||||||
# Create temp file for audio if diarization is enabled
|
|
||||||
audio_temp_path = None
|
|
||||||
if enable_diarization:
|
|
||||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
||||||
audio_temp_path = audio_temp_file.name
|
|
||||||
audio_temp_file.close()
|
|
||||||
|
|
||||||
# Create processor for collecting topics
|
print(
|
||||||
topic_collector = TopicCollectorProcessor()
|
f"Found {len(serialized_topics)} topics with speakers: {all_speakers}",
|
||||||
|
file=sys.stderr,
|
||||||
# Build pipeline for audio processing
|
|
||||||
processors = []
|
|
||||||
|
|
||||||
# Add audio file writer at the beginning if diarization is enabled
|
|
||||||
if enable_diarization:
|
|
||||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
|
||||||
|
|
||||||
# Add the rest of the processors
|
|
||||||
processors += [
|
|
||||||
AudioDownscaleProcessor(),
|
|
||||||
AudioChunkerAutoProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
|
||||||
TranscriptLinerProcessor(),
|
|
||||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
|
||||||
]
|
|
||||||
|
|
||||||
if not only_transcript:
|
|
||||||
processors += [
|
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
|
||||||
# Collect topics for diarization
|
|
||||||
topic_collector,
|
|
||||||
BroadcastProcessor(
|
|
||||||
processors=[
|
|
||||||
TranscriptFinalTitleProcessor.as_threaded(),
|
|
||||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create main pipeline
|
|
||||||
pipeline = Pipeline(*processors)
|
|
||||||
pipeline.set_pref("audio:source_language", source_language)
|
|
||||||
pipeline.set_pref("audio:target_language", target_language)
|
|
||||||
pipeline.describe()
|
|
||||||
pipeline.on(event_callback)
|
|
||||||
|
|
||||||
# Start processing audio
|
|
||||||
logger.info(f"Opening {filename}")
|
|
||||||
container = av.open(filename)
|
|
||||||
try:
|
|
||||||
logger.info("Start pushing audio into the pipeline")
|
|
||||||
for frame in container.decode(audio=0):
|
|
||||||
await pipeline.push(frame)
|
|
||||||
finally:
|
|
||||||
logger.info("Flushing the pipeline")
|
|
||||||
await pipeline.flush()
|
|
||||||
|
|
||||||
# Run diarization if enabled and we have topics
|
|
||||||
if enable_diarization and not only_transcript and audio_temp_path:
|
|
||||||
topics = topic_collector.get_topics()
|
|
||||||
|
|
||||||
if topics:
|
|
||||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.processors import AudioDiarizationAutoProcessor
|
|
||||||
|
|
||||||
diarization_processor = AudioDiarizationAutoProcessor(
|
|
||||||
name=diarization_backend
|
|
||||||
)
|
)
|
||||||
|
|
||||||
diarization_processor.set_pipeline(pipeline)
|
|
||||||
|
|
||||||
# For Modal backend, we need to upload the file to S3 first
|
TranscriptId = str
|
||||||
if diarization_backend == "modal":
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from reflector.storage import get_transcripts_storage
|
|
||||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
|
||||||
|
|
||||||
storage = get_transcripts_storage()
|
|
||||||
|
|
||||||
# Generate a unique filename in evaluation folder
|
|
||||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
||||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
|
||||||
|
|
||||||
# Use context manager for automatic cleanup
|
|
||||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
|
||||||
# Read and upload the audio file
|
|
||||||
with open(audio_temp_path, "rb") as f:
|
|
||||||
audio_data = f.read()
|
|
||||||
|
|
||||||
audio_url = await s3_file.upload(audio_data)
|
|
||||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
|
||||||
|
|
||||||
# Create diarization input with S3 URL
|
|
||||||
diarization_input = AudioDiarizationInput(
|
|
||||||
audio_url=audio_url, topics=topics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run diarization
|
|
||||||
await diarization_processor.push(diarization_input)
|
|
||||||
await diarization_processor.flush()
|
|
||||||
|
|
||||||
logger.info("Diarization complete")
|
|
||||||
# File will be automatically cleaned up when exiting the context
|
|
||||||
else:
|
|
||||||
# For local backend, use local file path
|
|
||||||
audio_url = audio_temp_path
|
|
||||||
|
|
||||||
# Create diarization input
|
|
||||||
diarization_input = AudioDiarizationInput(
|
|
||||||
audio_url=audio_url, topics=topics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run diarization
|
|
||||||
await diarization_processor.push(diarization_input)
|
|
||||||
await diarization_processor.flush()
|
|
||||||
|
|
||||||
logger.info("Diarization complete")
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
|
||||||
logger.error(
|
|
||||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
"And set HF_TOKEN environment variable for pyannote models"
|
|
||||||
)
|
|
||||||
raise SystemExit(1)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Diarization failed: {e}")
|
|
||||||
raise SystemExit(1)
|
|
||||||
else:
|
|
||||||
logger.warning("Skipping diarization: no topics available")
|
|
||||||
|
|
||||||
# Clean up temp file
|
|
||||||
if audio_temp_path:
|
|
||||||
try:
|
|
||||||
Path(audio_temp_path).unlink()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
|
||||||
|
|
||||||
logger.info("All done!")
|
|
||||||
|
|
||||||
|
|
||||||
async def process_file_pipeline(
|
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||||
filename: str,
|
# ideally we want to get rid of it at some point
|
||||||
event_callback,
|
async def prepare_entry(
|
||||||
source_language="en",
|
source_path: str,
|
||||||
target_language="en",
|
source_language: str,
|
||||||
enable_diarization=True,
|
target_language: str,
|
||||||
diarization_backend="modal",
|
) -> TranscriptId:
|
||||||
):
|
file_path = Path(source_path)
|
||||||
"""Process audio/video file using the optimized file pipeline"""
|
|
||||||
try:
|
|
||||||
from reflector.db import database
|
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
|
||||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
|
||||||
|
|
||||||
await database.connect()
|
|
||||||
try:
|
|
||||||
# Create a temporary transcript for processing
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
"",
|
file_path.name,
|
||||||
|
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
source_language=source_language,
|
source_language=source_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the file
|
logger.info(
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
|
||||||
await pipeline.process(Path(filename))
|
)
|
||||||
|
|
||||||
|
# pipelines expect files as upload.*
|
||||||
|
|
||||||
|
extension = file_path.suffix
|
||||||
|
upload_path = transcript.data_path / f"upload{extension}"
|
||||||
|
upload_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(source_path, upload_path)
|
||||||
|
logger.info(f"Copied {source_path} to {upload_path}")
|
||||||
|
|
||||||
|
# pipelines expect entity status "uploaded"
|
||||||
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
|
return transcript.id
|
||||||
|
|
||||||
|
|
||||||
|
# same reason as prepare_entry
|
||||||
|
async def extract_result_from_entry(
|
||||||
|
transcript_id: TranscriptId, output_path: str
|
||||||
|
) -> None:
|
||||||
|
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
|
||||||
|
# assert post_final_transcript.status == "ended"
|
||||||
|
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||||
|
topics = post_final_transcript.topics
|
||||||
|
if not topics:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No topics found for transcript {transcript_id} after processing"
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_topics = serialize_topics(topics)
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
# Write to JSON file
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
for topic_dict in serialized_topics:
|
||||||
|
json.dump(topic_dict, f)
|
||||||
|
f.write("\n")
|
||||||
|
print(f"Results written to {output_path}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
# Write to stdout as JSONL
|
||||||
|
for topic_dict in serialized_topics:
|
||||||
|
print(json.dumps(topic_dict))
|
||||||
|
|
||||||
|
debug_print_speakers(serialized_topics)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_live_pipeline(
|
||||||
|
transcript_id: TranscriptId,
|
||||||
|
):
|
||||||
|
"""Process transcript_id with transcription and diarization"""
|
||||||
|
|
||||||
|
print(f"Processing transcript_id {transcript_id}...", file=sys.stderr)
|
||||||
|
await live_pipeline_process(transcript_id=transcript_id)
|
||||||
|
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||||
|
|
||||||
|
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
|
||||||
|
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||||
|
assert pre_final_transcript.status != "ended"
|
||||||
|
|
||||||
|
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
||||||
|
result = live_pipeline_post(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
# result.ready() blocks even without await; it mutates result also
|
||||||
|
while not result.ready():
|
||||||
|
print(f"Status: {result.state}")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_file_pipeline(
|
||||||
|
transcript_id: TranscriptId,
|
||||||
|
):
|
||||||
|
"""Process audio/video file using the optimized file pipeline"""
|
||||||
|
|
||||||
|
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
||||||
|
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
# Wait for the Celery task to complete
|
||||||
|
while not result.ready():
|
||||||
|
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
logger.info("File pipeline processing complete")
|
logger.info("File pipeline processing complete")
|
||||||
|
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
source_path: str,
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
pipeline: Literal["live", "file"],
|
||||||
|
output_path: str = None,
|
||||||
|
):
|
||||||
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
database = get_database()
|
||||||
|
# db connect is a part of ceremony
|
||||||
|
await database.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
transcript_id = await prepare_entry(
|
||||||
|
source_path,
|
||||||
|
source_language,
|
||||||
|
target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_handlers = {
|
||||||
|
"live": process_live_pipeline,
|
||||||
|
"file": process_file_pipeline,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = pipeline_handlers.get(pipeline)
|
||||||
|
if not handler:
|
||||||
|
raise ValueError(f"Unknown pipeline type: {pipeline}")
|
||||||
|
|
||||||
|
await handler(transcript_id)
|
||||||
|
|
||||||
|
await extract_result_from_entry(transcript_id, output_path)
|
||||||
finally:
|
finally:
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
except ImportError as e:
|
|
||||||
logger.error(f"File pipeline not available: {e}")
|
|
||||||
logger.info("Falling back to stream pipeline")
|
|
||||||
# Fall back to stream pipeline
|
|
||||||
await process_audio_file(
|
|
||||||
filename,
|
|
||||||
event_callback,
|
|
||||||
only_transcript=False,
|
|
||||||
source_language=source_language,
|
|
||||||
target_language=target_language,
|
|
||||||
enable_diarization=enable_diarization,
|
|
||||||
diarization_backend=diarization_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Process audio files with optional speaker diarization"
|
description="Process audio files with speaker diarization"
|
||||||
)
|
)
|
||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--stream",
|
"--pipeline",
|
||||||
action="store_true",
|
required=True,
|
||||||
help="Use streaming pipeline (original frame-based processing)",
|
choices=["live", "file"],
|
||||||
)
|
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
||||||
parser.add_argument(
|
|
||||||
"--only-transcript",
|
|
||||||
"-t",
|
|
||||||
action="store_true",
|
|
||||||
help="Only generate transcript without topics/summaries",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--source-language", default="en", help="Source language code (default: en)"
|
"--source-language", default="en", help="Source language code (default: en)"
|
||||||
@@ -297,82 +207,14 @@ if __name__ == "__main__":
|
|||||||
"--target-language", default="en", help="Target language code (default: en)"
|
"--target-language", default="en", help="Target language code (default: en)"
|
||||||
)
|
)
|
||||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||||
parser.add_argument(
|
|
||||||
"--enable-diarization",
|
|
||||||
"-d",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable speaker diarization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--diarization-backend",
|
|
||||||
default="pyannote",
|
|
||||||
choices=["pyannote", "modal"],
|
|
||||||
help="Diarization backend to use (default: pyannote)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if "REDIS_HOST" not in os.environ:
|
|
||||||
os.environ["REDIS_HOST"] = "localhost"
|
|
||||||
|
|
||||||
output_fd = None
|
|
||||||
if args.output:
|
|
||||||
output_fd = open(args.output, "w")
|
|
||||||
|
|
||||||
async def event_callback(event: PipelineEvent):
|
|
||||||
processor = event.processor
|
|
||||||
data = event.data
|
|
||||||
|
|
||||||
# Ignore internal processors
|
|
||||||
if processor in (
|
|
||||||
"AudioDownscaleProcessor",
|
|
||||||
"AudioChunkerAutoProcessor",
|
|
||||||
"AudioMergeProcessor",
|
|
||||||
"AudioFileWriterProcessor",
|
|
||||||
"TopicCollectorProcessor",
|
|
||||||
"BroadcastProcessor",
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# If diarization is enabled, skip the original topic events from the pipeline
|
|
||||||
# The diarization processor will emit the same topics but with speaker info
|
|
||||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Log all events
|
|
||||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
|
||||||
|
|
||||||
# Write to output
|
|
||||||
if output_fd:
|
|
||||||
output_fd.write(event.model_dump_json())
|
|
||||||
output_fd.write("\n")
|
|
||||||
output_fd.flush()
|
|
||||||
|
|
||||||
if args.stream:
|
|
||||||
# Use original streaming pipeline
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
process_audio_file(
|
process(
|
||||||
args.source,
|
args.source,
|
||||||
event_callback,
|
args.source_language,
|
||||||
only_transcript=args.only_transcript,
|
args.target_language,
|
||||||
source_language=args.source_language,
|
args.pipeline,
|
||||||
target_language=args.target_language,
|
args.output,
|
||||||
enable_diarization=args.enable_diarization,
|
|
||||||
diarization_backend=args.diarization_backend,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Use optimized file pipeline (default)
|
|
||||||
asyncio.run(
|
|
||||||
process_file_pipeline(
|
|
||||||
args.source,
|
|
||||||
event_callback,
|
|
||||||
source_language=args.source_language,
|
|
||||||
target_language=args.target_language,
|
|
||||||
enable_diarization=args.enable_diarization,
|
|
||||||
diarization_backend=args.diarization_backend,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_fd:
|
|
||||||
output_fd.close()
|
|
||||||
logger.info(f"Output written to {args.output}")
|
|
||||||
|
|||||||
@@ -1,318 +0,0 @@
|
|||||||
"""
|
|
||||||
@vibe-generated
|
|
||||||
Process audio file with diarization support
|
|
||||||
===========================================
|
|
||||||
|
|
||||||
Extended version of process.py that includes speaker diarization.
|
|
||||||
This tool processes audio files locally without requiring the full server infrastructure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import tempfile
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import av
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.processors import (
|
|
||||||
AudioChunkerAutoProcessor,
|
|
||||||
AudioDownscaleProcessor,
|
|
||||||
AudioFileWriterProcessor,
|
|
||||||
AudioMergeProcessor,
|
|
||||||
AudioTranscriptAutoProcessor,
|
|
||||||
Pipeline,
|
|
||||||
PipelineEvent,
|
|
||||||
TranscriptFinalSummaryProcessor,
|
|
||||||
TranscriptFinalTitleProcessor,
|
|
||||||
TranscriptLinerProcessor,
|
|
||||||
TranscriptTopicDetectorProcessor,
|
|
||||||
TranscriptTranslatorAutoProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.base import BroadcastProcessor, Processor
|
|
||||||
from reflector.processors.types import (
|
|
||||||
AudioDiarizationInput,
|
|
||||||
TitleSummary,
|
|
||||||
TitleSummaryWithId,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TopicCollectorProcessor(Processor):
|
|
||||||
"""Collect topics for diarization"""
|
|
||||||
|
|
||||||
INPUT_TYPE = TitleSummary
|
|
||||||
OUTPUT_TYPE = TitleSummary
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.topics: List[TitleSummaryWithId] = []
|
|
||||||
self._topic_id = 0
|
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
|
||||||
# Convert to TitleSummaryWithId and collect
|
|
||||||
self._topic_id += 1
|
|
||||||
topic_with_id = TitleSummaryWithId(
|
|
||||||
id=str(self._topic_id),
|
|
||||||
title=data.title,
|
|
||||||
summary=data.summary,
|
|
||||||
timestamp=data.timestamp,
|
|
||||||
duration=data.duration,
|
|
||||||
transcript=data.transcript,
|
|
||||||
)
|
|
||||||
self.topics.append(topic_with_id)
|
|
||||||
|
|
||||||
# Pass through the original topic
|
|
||||||
await self.emit(data)
|
|
||||||
|
|
||||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
|
||||||
return self.topics
|
|
||||||
|
|
||||||
|
|
||||||
async def process_audio_file_with_diarization(
|
|
||||||
filename,
|
|
||||||
event_callback,
|
|
||||||
only_transcript=False,
|
|
||||||
source_language="en",
|
|
||||||
target_language="en",
|
|
||||||
enable_diarization=True,
|
|
||||||
diarization_backend="modal",
|
|
||||||
):
|
|
||||||
# Create temp file for audio if diarization is enabled
|
|
||||||
audio_temp_path = None
|
|
||||||
if enable_diarization:
|
|
||||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
||||||
audio_temp_path = audio_temp_file.name
|
|
||||||
audio_temp_file.close()
|
|
||||||
|
|
||||||
# Create processor for collecting topics
|
|
||||||
topic_collector = TopicCollectorProcessor()
|
|
||||||
|
|
||||||
# Build pipeline for audio processing
|
|
||||||
processors = []
|
|
||||||
|
|
||||||
# Add audio file writer at the beginning if diarization is enabled
|
|
||||||
if enable_diarization:
|
|
||||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
|
||||||
|
|
||||||
# Add the rest of the processors
|
|
||||||
processors += [
|
|
||||||
AudioDownscaleProcessor(),
|
|
||||||
AudioChunkerAutoProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
|
||||||
]
|
|
||||||
|
|
||||||
processors += [
|
|
||||||
TranscriptLinerProcessor(),
|
|
||||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
|
||||||
]
|
|
||||||
|
|
||||||
if not only_transcript:
|
|
||||||
processors += [
|
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
|
||||||
# Collect topics for diarization
|
|
||||||
topic_collector,
|
|
||||||
BroadcastProcessor(
|
|
||||||
processors=[
|
|
||||||
TranscriptFinalTitleProcessor.as_threaded(),
|
|
||||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create main pipeline
|
|
||||||
pipeline = Pipeline(*processors)
|
|
||||||
pipeline.set_pref("audio:source_language", source_language)
|
|
||||||
pipeline.set_pref("audio:target_language", target_language)
|
|
||||||
pipeline.describe()
|
|
||||||
pipeline.on(event_callback)
|
|
||||||
|
|
||||||
# Start processing audio
|
|
||||||
logger.info(f"Opening {filename}")
|
|
||||||
container = av.open(filename)
|
|
||||||
try:
|
|
||||||
logger.info("Start pushing audio into the pipeline")
|
|
||||||
for frame in container.decode(audio=0):
|
|
||||||
await pipeline.push(frame)
|
|
||||||
finally:
|
|
||||||
logger.info("Flushing the pipeline")
|
|
||||||
await pipeline.flush()
|
|
||||||
|
|
||||||
# Run diarization if enabled and we have topics
|
|
||||||
if enable_diarization and not only_transcript and audio_temp_path:
|
|
||||||
topics = topic_collector.get_topics()
|
|
||||||
|
|
||||||
if topics:
|
|
||||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from reflector.processors import AudioDiarizationAutoProcessor
|
|
||||||
|
|
||||||
diarization_processor = AudioDiarizationAutoProcessor(
|
|
||||||
name=diarization_backend
|
|
||||||
)
|
|
||||||
|
|
||||||
diarization_processor.set_pipeline(pipeline)
|
|
||||||
|
|
||||||
# For Modal backend, we need to upload the file to S3 first
|
|
||||||
if diarization_backend == "modal":
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from reflector.storage import get_transcripts_storage
|
|
||||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
|
||||||
|
|
||||||
storage = get_transcripts_storage()
|
|
||||||
|
|
||||||
# Generate a unique filename in evaluation folder
|
|
||||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
||||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
|
||||||
|
|
||||||
# Use context manager for automatic cleanup
|
|
||||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
|
||||||
# Read and upload the audio file
|
|
||||||
with open(audio_temp_path, "rb") as f:
|
|
||||||
audio_data = f.read()
|
|
||||||
|
|
||||||
audio_url = await s3_file.upload(audio_data)
|
|
||||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
|
||||||
|
|
||||||
# Create diarization input with S3 URL
|
|
||||||
diarization_input = AudioDiarizationInput(
|
|
||||||
audio_url=audio_url, topics=topics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run diarization
|
|
||||||
await diarization_processor.push(diarization_input)
|
|
||||||
await diarization_processor.flush()
|
|
||||||
|
|
||||||
logger.info("Diarization complete")
|
|
||||||
# File will be automatically cleaned up when exiting the context
|
|
||||||
else:
|
|
||||||
# For local backend, use local file path
|
|
||||||
audio_url = audio_temp_path
|
|
||||||
|
|
||||||
# Create diarization input
|
|
||||||
diarization_input = AudioDiarizationInput(
|
|
||||||
audio_url=audio_url, topics=topics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run diarization
|
|
||||||
await diarization_processor.push(diarization_input)
|
|
||||||
await diarization_processor.flush()
|
|
||||||
|
|
||||||
logger.info("Diarization complete")
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
|
||||||
logger.error(
|
|
||||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
"And set HF_TOKEN environment variable for pyannote models"
|
|
||||||
)
|
|
||||||
raise SystemExit(1)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Diarization failed: {e}")
|
|
||||||
raise SystemExit(1)
|
|
||||||
else:
|
|
||||||
logger.warning("Skipping diarization: no topics available")
|
|
||||||
|
|
||||||
# Clean up temp file
|
|
||||||
if audio_temp_path:
|
|
||||||
try:
|
|
||||||
Path(audio_temp_path).unlink()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
|
||||||
|
|
||||||
logger.info("All done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Process audio files with optional speaker diarization"
|
|
||||||
)
|
|
||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--only-transcript",
|
|
||||||
"-t",
|
|
||||||
action="store_true",
|
|
||||||
help="Only generate transcript without topics/summaries",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--source-language", default="en", help="Source language code (default: en)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target-language", default="en", help="Target language code (default: en)"
|
|
||||||
)
|
|
||||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-diarization",
|
|
||||||
"-d",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable speaker diarization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--diarization-backend",
|
|
||||||
default="modal",
|
|
||||||
choices=["modal"],
|
|
||||||
help="Diarization backend to use (default: modal)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Set REDIS_HOST to localhost if not provided
|
|
||||||
if "REDIS_HOST" not in os.environ:
|
|
||||||
os.environ["REDIS_HOST"] = "localhost"
|
|
||||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
|
||||||
|
|
||||||
output_fd = None
|
|
||||||
if args.output:
|
|
||||||
output_fd = open(args.output, "w")
|
|
||||||
|
|
||||||
async def event_callback(event: PipelineEvent):
|
|
||||||
processor = event.processor
|
|
||||||
data = event.data
|
|
||||||
|
|
||||||
# Ignore internal processors
|
|
||||||
if processor in (
|
|
||||||
"AudioDownscaleProcessor",
|
|
||||||
"AudioChunkerAutoProcessor",
|
|
||||||
"AudioMergeProcessor",
|
|
||||||
"AudioFileWriterProcessor",
|
|
||||||
"TopicCollectorProcessor",
|
|
||||||
"BroadcastProcessor",
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# If diarization is enabled, skip the original topic events from the pipeline
|
|
||||||
# The diarization processor will emit the same topics but with speaker info
|
|
||||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Log all events
|
|
||||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
|
||||||
|
|
||||||
# Write to output
|
|
||||||
if output_fd:
|
|
||||||
output_fd.write(event.model_dump_json())
|
|
||||||
output_fd.write("\n")
|
|
||||||
output_fd.flush()
|
|
||||||
|
|
||||||
asyncio.run(
|
|
||||||
process_audio_file_with_diarization(
|
|
||||||
args.source,
|
|
||||||
event_callback,
|
|
||||||
only_transcript=args.only_transcript,
|
|
||||||
source_language=args.source_language,
|
|
||||||
target_language=args.target_language,
|
|
||||||
enable_diarization=args.enable_diarization,
|
|
||||||
diarization_backend=args.diarization_backend,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_fd:
|
|
||||||
output_fd.close()
|
|
||||||
logger.info(f"Output written to {args.output}")
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
@vibe-generated
|
|
||||||
Test script for the diarization CLI tool
|
|
||||||
=========================================
|
|
||||||
|
|
||||||
This script helps test the diarization functionality with sample audio files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
async def test_diarization(audio_file: str):
|
|
||||||
"""Test the diarization functionality"""
|
|
||||||
|
|
||||||
# Import the processing function
|
|
||||||
from process_with_diarization import process_audio_file_with_diarization
|
|
||||||
|
|
||||||
# Collect events
|
|
||||||
events = []
|
|
||||||
|
|
||||||
async def event_callback(event):
|
|
||||||
events.append({"processor": event.processor, "data": event.data})
|
|
||||||
logger.info(f"Event from {event.processor}")
|
|
||||||
|
|
||||||
# Process the audio file
|
|
||||||
logger.info(f"Processing audio file: {audio_file}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await process_audio_file_with_diarization(
|
|
||||||
audio_file,
|
|
||||||
event_callback,
|
|
||||||
only_transcript=False,
|
|
||||||
source_language="en",
|
|
||||||
target_language="en",
|
|
||||||
enable_diarization=True,
|
|
||||||
diarization_backend="modal",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Analyze results
|
|
||||||
logger.info(f"Processing complete. Received {len(events)} events")
|
|
||||||
|
|
||||||
# Look for diarization results
|
|
||||||
diarized_topics = []
|
|
||||||
for event in events:
|
|
||||||
if "TitleSummary" in event["processor"]:
|
|
||||||
# Check if words have speaker information
|
|
||||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
|
||||||
words = event["data"].transcript.words
|
|
||||||
if words and hasattr(words[0], "speaker"):
|
|
||||||
speakers = set(
|
|
||||||
w.speaker for w in words if hasattr(w, "speaker")
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Found {len(speakers)} speakers in topic: {event['data'].title}"
|
|
||||||
)
|
|
||||||
diarized_topics.append(event["data"])
|
|
||||||
|
|
||||||
if diarized_topics:
|
|
||||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
|
||||||
|
|
||||||
# Print sample output
|
|
||||||
sample_topic = diarized_topics[0]
|
|
||||||
logger.info("Sample diarized output:")
|
|
||||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
|
||||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
|
||||||
else:
|
|
||||||
logger.warning("No diarization results found in output")
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during processing: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("Usage: python test_diarization.py <audio_file>")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
audio_file = sys.argv[1]
|
|
||||||
if not Path(audio_file).exists():
|
|
||||||
print(f"Error: Audio file '{audio_file}' not found")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Run the test
|
|
||||||
asyncio.run(test_diarization(audio_file))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("enable_diarization", [False, True])
|
|
||||||
async def test_basic_process(
|
|
||||||
dummy_transcript,
|
|
||||||
dummy_llm,
|
|
||||||
dummy_processors,
|
|
||||||
enable_diarization,
|
|
||||||
dummy_diarization,
|
|
||||||
):
|
|
||||||
# goal is to start the server, and send rtc audio to it
|
|
||||||
# validate the events received
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.tools.process import process_audio_file
|
|
||||||
|
|
||||||
# LLM_BACKEND no longer exists in settings
|
|
||||||
# settings.LLM_BACKEND = "test"
|
|
||||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
|
||||||
|
|
||||||
# event callback
|
|
||||||
marks = {}
|
|
||||||
|
|
||||||
async def event_callback(event):
|
|
||||||
if event.processor not in marks:
|
|
||||||
marks[event.processor] = 0
|
|
||||||
marks[event.processor] += 1
|
|
||||||
|
|
||||||
# invoke the process and capture events
|
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
|
||||||
|
|
||||||
if enable_diarization:
|
|
||||||
# Test with diarization - may fail if pyannote.audio is not installed
|
|
||||||
try:
|
|
||||||
await process_audio_file(
|
|
||||||
path.as_posix(), event_callback, enable_diarization=True
|
|
||||||
)
|
|
||||||
except SystemExit:
|
|
||||||
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
|
||||||
else:
|
|
||||||
# Test without diarization - should always work
|
|
||||||
await process_audio_file(
|
|
||||||
path.as_posix(), event_callback, enable_diarization=False
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
|
||||||
|
|
||||||
# validate the events
|
|
||||||
# Each processor should be called for each audio segment processed
|
|
||||||
# The final processors (Topic, Title, Summary) should be called once at the end
|
|
||||||
assert marks["TranscriptLinerProcessor"] > 0
|
|
||||||
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
|
||||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
|
||||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
|
||||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
|
||||||
|
|
||||||
if enable_diarization:
|
|
||||||
assert marks["TestAudioDiarizationProcessor"] == 1
|
|
||||||
Reference in New Issue
Block a user