mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
fix: Igor/evaluation (#575)
* fix: impossible import error (#563) * evaluation cli - database events experiment * hallucinations * evaluation - unhallucinate * evaluation - unhallucinate * roll back reliability link * self reviewio * lint * self review * add file pipeline to cli * add file pipeline to cli + sorting * remove cli tests * remove ai comments * comments
This commit is contained in:
@@ -794,7 +794,7 @@ def pipeline_post(*, transcript_id: str):
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
|
||||
chain.delay()
|
||||
return chain.delay()
|
||||
|
||||
|
||||
@get_transcript
|
||||
|
||||
@@ -67,6 +67,9 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
# words come not in order
|
||||
words.sort(key=lambda w: w.start)
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
|
||||
@@ -1,294 +1,204 @@
|
||||
"""
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
task_pipeline_file_process as task_pipeline_file_process,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
pipeline_process as live_pipeline_process,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
||||
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
|
||||
serialized = []
|
||||
for topic in topics:
|
||||
topic_dict = topic.model_dump()
|
||||
serialized.append(topic_dict)
|
||||
return serialized
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
||||
"""Print debug info about speakers found in topics"""
|
||||
all_speakers = set()
|
||||
for topic_dict in serialized_topics:
|
||||
for word in topic_dict.get("words", []):
|
||||
all_speakers.add(word.get("speaker", 0))
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
print(
|
||||
f"Found {len(serialized_topics)} topics with speakers: {all_speakers}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
TranscriptId = str
|
||||
|
||||
|
||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||
# ideally we want to get rid of it at some point
|
||||
async def prepare_entry(
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> TranscriptId:
|
||||
file_path = Path(source_path)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
file_path.name,
|
||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
|
||||
)
|
||||
|
||||
# pipelines expect files as upload.*
|
||||
|
||||
extension = file_path.suffix
|
||||
upload_path = transcript.data_path / f"upload{extension}"
|
||||
upload_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_path, upload_path)
|
||||
logger.info(f"Copied {source_path} to {upload_path}")
|
||||
|
||||
# pipelines expect entity status "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
return transcript.id
|
||||
|
||||
|
||||
# same reason as prepare_entry
|
||||
async def extract_result_from_entry(
|
||||
transcript_id: TranscriptId, output_path: str
|
||||
) -> None:
|
||||
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
|
||||
# assert post_final_transcript.status == "ended"
|
||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||
topics = post_final_transcript.topics
|
||||
if not topics:
|
||||
raise RuntimeError(
|
||||
f"No topics found for transcript {transcript_id} after processing"
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
serialized_topics = serialize_topics(topics)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
if output_path:
|
||||
# Write to JSON file
|
||||
with open(output_path, "w") as f:
|
||||
for topic_dict in serialized_topics:
|
||||
json.dump(topic_dict, f)
|
||||
f.write("\n")
|
||||
print(f"Results written to {output_path}", file=sys.stderr)
|
||||
else:
|
||||
# Write to stdout as JSONL
|
||||
for topic_dict in serialized_topics:
|
||||
print(json.dumps(topic_dict))
|
||||
|
||||
debug_print_speakers(serialized_topics)
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="pyannote",
|
||||
async def process_live_pipeline(
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
"""Process transcript_id with transcription and diarization"""
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
print(f"Processing transcript_id {transcript_id}...", file=sys.stderr)
|
||||
await live_pipeline_process(transcript_id=transcript_id)
|
||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||
assert pre_final_transcript.status != "ended"
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
||||
result = live_pipeline_post(transcript_id=transcript_id)
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
# result.ready() blocks even without await; it mutates result also
|
||||
while not result.ready():
|
||||
print(f"Status: {result.state}")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
filename: str,
|
||||
event_callback,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
|
||||
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
||||
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
# Wait for the Celery task to complete
|
||||
while not result.ready():
|
||||
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
||||
time.sleep(2)
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
|
||||
async def process(
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
pipeline: Literal["live", "file"],
|
||||
output_path: str = None,
|
||||
):
|
||||
from reflector.db import get_database
|
||||
|
||||
database = get_database()
|
||||
# db connect is a part of ceremony
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
await database.connect()
|
||||
try:
|
||||
# Create a temporary transcript for processing
|
||||
transcript = await transcripts_controller.add(
|
||||
"",
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
)
|
||||
|
||||
# Process the file
|
||||
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||
await pipeline.process(Path(filename))
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
except ImportError as e:
|
||||
logger.error(f"File pipeline not available: {e}")
|
||||
logger.info("Falling back to stream pipeline")
|
||||
# Fall back to stream pipeline
|
||||
await process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
enable_diarization=enable_diarization,
|
||||
diarization_backend=diarization_backend,
|
||||
transcript_id = await prepare_entry(
|
||||
source_path,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
pipeline_handlers = {
|
||||
"live": process_live_pipeline,
|
||||
"file": process_file_pipeline,
|
||||
}
|
||||
|
||||
handler = pipeline_handlers.get(pipeline)
|
||||
if not handler:
|
||||
raise ValueError(f"Unknown pipeline type: {pipeline}")
|
||||
|
||||
await handler(transcript_id)
|
||||
|
||||
await extract_result_from_entry(transcript_id, output_path)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
description="Process audio files with speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use streaming pipeline (original frame-based processing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
"--pipeline",
|
||||
required=True,
|
||||
choices=["live", "file"],
|
||||
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
@@ -297,82 +207,14 @@ if __name__ == "__main__":
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="pyannote",
|
||||
choices=["pyannote", "modal"],
|
||||
help="Diarization backend to use (default: pyannote)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioDownscaleProcessor",
|
||||
"AudioChunkerAutoProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
if args.stream:
|
||||
# Use original streaming pipeline
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
asyncio.run(
|
||||
process(
|
||||
args.source,
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
args.pipeline,
|
||||
args.output,
|
||||
)
|
||||
else:
|
||||
# Use optimized file pipeline (default)
|
||||
asyncio.run(
|
||||
process_file_pipeline(
|
||||
args.source,
|
||||
event_callback,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
)
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file_with_diarization(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
processors += [
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="modal",
|
||||
choices=["modal"],
|
||||
help="Diarization backend to use (default: modal)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set REDIS_HOST to localhost if not provided
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioDownscaleProcessor",
|
||||
"AudioChunkerAutoProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file_with_diarization(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
@@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
@vibe-generated
|
||||
Test script for the diarization CLI tool
|
||||
=========================================
|
||||
|
||||
This script helps test the diarization functionality with sample audio files.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
async def test_diarization(audio_file: str):
|
||||
"""Test the diarization functionality"""
|
||||
|
||||
# Import the processing function
|
||||
from process_with_diarization import process_audio_file_with_diarization
|
||||
|
||||
# Collect events
|
||||
events = []
|
||||
|
||||
async def event_callback(event):
|
||||
events.append({"processor": event.processor, "data": event.data})
|
||||
logger.info(f"Event from {event.processor}")
|
||||
|
||||
# Process the audio file
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
|
||||
try:
|
||||
await process_audio_file_with_diarization(
|
||||
audio_file,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
)
|
||||
|
||||
# Analyze results
|
||||
logger.info(f"Processing complete. Received {len(events)} events")
|
||||
|
||||
# Look for diarization results
|
||||
diarized_topics = []
|
||||
for event in events:
|
||||
if "TitleSummary" in event["processor"]:
|
||||
# Check if words have speaker information
|
||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||
words = event["data"].transcript.words
|
||||
if words and hasattr(words[0], "speaker"):
|
||||
speakers = set(
|
||||
w.speaker for w in words if hasattr(w, "speaker")
|
||||
)
|
||||
logger.info(
|
||||
f"Found {len(speakers)} speakers in topic: {event['data'].title}"
|
||||
)
|
||||
diarized_topics.append(event["data"])
|
||||
|
||||
if diarized_topics:
|
||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||
|
||||
# Print sample output
|
||||
sample_topic = diarized_topics[0]
|
||||
logger.info("Sample diarized output:")
|
||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||
else:
|
||||
logger.warning("No diarization results found in output")
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during processing: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python test_diarization.py <audio_file>")
|
||||
sys.exit(1)
|
||||
|
||||
audio_file = sys.argv[1]
|
||||
if not Path(audio_file).exists():
|
||||
print(f"Error: Audio file '{audio_file}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_diarization(audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("enable_diarization", [False, True])
|
||||
async def test_basic_process(
|
||||
dummy_transcript,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
enable_diarization,
|
||||
dummy_diarization,
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.settings import settings
|
||||
from reflector.tools.process import process_audio_file
|
||||
|
||||
# LLM_BACKEND no longer exists in settings
|
||||
# settings.LLM_BACKEND = "test"
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
# event callback
|
||||
marks = {}
|
||||
|
||||
async def event_callback(event):
|
||||
if event.processor not in marks:
|
||||
marks[event.processor] = 0
|
||||
marks[event.processor] += 1
|
||||
|
||||
# invoke the process and capture events
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
if enable_diarization:
|
||||
# Test with diarization - may fail if pyannote.audio is not installed
|
||||
try:
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=True
|
||||
)
|
||||
except SystemExit:
|
||||
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
||||
else:
|
||||
# Test without diarization - should always work
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=False
|
||||
)
|
||||
|
||||
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
||||
|
||||
# validate the events
|
||||
# Each processor should be called for each audio segment processed
|
||||
# The final processors (Topic, Title, Summary) should be called once at the end
|
||||
assert marks["TranscriptLinerProcessor"] > 0
|
||||
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||
|
||||
if enable_diarization:
|
||||
assert marks["TestAudioDiarizationProcessor"] == 1
|
||||
Reference in New Issue
Block a user