mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: pipeline improvement with file processing, parakeet, silero-vad (#540)
* feat: improve pipeline threading, and transcriber (parakeet and silero vad) * refactor: remove whisperx, implement parakeet * refactor: make audio_chunker more smart and wait for speech, instead of fixed frame * refactor: make audio merge to always downscale the audio to 16k for transcription * refactor: make the audio transcript modal accepting batches * refactor: improve type safety and remove prometheus metrics - Add DiarizationSegment TypedDict for proper diarization typing - Replace List/Optional with modern Python list/| None syntax - Remove all Prometheus metrics from TranscriptDiarizationAssemblerProcessor - Add comprehensive file processing pipeline with parallel execution - Update processor imports and type annotations throughout - Implement optimized file pipeline as default in process.py tool * refactor: convert FileDiarizationProcessor I/O types to BaseModel Update FileDiarizationInput and FileDiarizationOutput to inherit from BaseModel instead of plain classes, following the standard pattern used by other processors in the codebase. * test: add tests for file transcript and diarization with pytest-recording * build: add pytest-recording * feat: add local pyannote for testing * fix: replace PyAV AudioResampler with torchaudio for reliable audio processing - Replace problematic PyAV AudioResampler that was causing ValueError: [Errno 22] Invalid argument - Use torchaudio.functional.resample for robust sample rate conversion - Optimize processing: skip conversion for already 16kHz mono audio - Add direct WAV writing with Python wave module for better performance - Consolidate duplicate downsample checks for cleaner code - Maintain list[av.AudioFrame] input interface - Required for Silero VAD which needs 16kHz mono audio * fix: replace PyAV AudioResampler with torchaudio solution - Resolves ValueError: [Errno 22] Invalid argument in AudioMergeProcessor - Replaces problematic PyAV AudioResampler with torchaudio.functional.resample - Optimizes processing to skip unnecessary conversions when audio is already 16kHz mono - Uses direct WAV writing with Python's wave module for better performance - Fixes test_basic_process to disable diarization (pyannote dependency not installed) - Updates test expectations to match actual processor behavior - Removes unused pydub dependency from pyproject.toml - Adds comprehensive TEST_ANALYSIS.md documenting test suite status * feat: add parameterized test for both diarization modes - Adds @pytest.mark.parametrize to test_basic_process with enable_diarization=[False, True] - Test with diarization=False always passes (tests core AudioMergeProcessor functionality) - Test with diarization=True gracefully skips when pyannote.audio is not installed - Provides comprehensive test coverage for both pipeline configurations * fix: resolve pipeline property naming conflict in AudioDiarizationPyannoteProcessor - Renames 'pipeline' property to 'diarization_pipeline' to avoid conflict with base Processor.pipeline attribute - Fixes AttributeError: 'property 'pipeline' object has no setter' when set_pipeline() is called - Updates property usage in _diarize method to use new name - Now correctly supports pipeline initialization for diarization processing * fix: add local for pyannote * test: add diarization test * fix: resample on audio merge now working * fix: correctly restore timestamp * fix: display exception in a threaded processor if that happen * Update pyproject.toml * ci: remove option * ci: update astral-sh/setup-uv * test: add monadical url for pytest-recording * refactor: remove previous version * build: move faster whisper to local dep * test: fix missing import * refactor: improve main_file_pipeline organization and error handling - Move all imports to the top of the file - Create unified EmptyPipeline class to replace duplicate mock pipeline code - Remove timeout and fallback logic - let processors handle their own retries - Fix error handling to raise any exception from parallel tasks - Add proper type hints and validation for captured results * fix: wrong function * fix: remove task_done * feat: add configurable file processing timeouts for modal processors - Add TRANSCRIPT_FILE_TIMEOUT setting (default: 600s) for file transcription - Add DIARIZATION_FILE_TIMEOUT setting (default: 600s) for file diarization - Replace hardcoded timeout=600 with configurable settings in modal processors - Allows customization of timeout values via environment variables * fix: use logger * fix: worker process meetings now use file pipeline * fix: topic not gathered * refactor: remove prepare(), pipeline now work * refactor: implement many review from Igor * test: add test for test_pipeline_main_file * refactor: remove doc * doc: add doc * ci: update build to use native arm64 builder * fix: merge fixes * refactor: changes from Igor review + add test (not by default) to test gpu modal part * ci: update to our own runner linux-amd64 * ci: try using suggested mode=min * fix: update diarizer for latest modal, and use volume * fix: modal file extension detection * fix: put the diarizer as A100
This commit is contained in:
76
.github/workflows/deploy.yml
vendored
76
.github/workflows/deploy.yml
vendored
@@ -8,18 +8,30 @@ env:
|
|||||||
ECR_REPOSITORY: reflector
|
ECR_REPOSITORY: reflector
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
build:
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- platform: linux/amd64
|
||||||
|
runner: linux-amd64
|
||||||
|
arch: amd64
|
||||||
|
- platform: linux/arm64
|
||||||
|
runner: linux-arm64
|
||||||
|
arch: arm64
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
deployments: write
|
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
uses: aws-actions/configure-aws-credentials@v4
|
||||||
with:
|
with:
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
@@ -27,21 +39,51 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Amazon ECR
|
- name: Login to Amazon ECR
|
||||||
id: login-ecr
|
id: login-ecr
|
||||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
uses: aws-actions/amazon-ecr-login@v2
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Build and push
|
- name: Build and push ${{ matrix.arch }}
|
||||||
id: docker_build
|
uses: docker/build-push-action@v5
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: ${{ matrix.platform }}
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||||
|
provenance: false
|
||||||
|
|
||||||
|
create-manifest:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [build]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
deployments: write
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v4
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
|
||||||
|
- name: Login to Amazon ECR
|
||||||
|
uses: aws-actions/amazon-ecr-login@v2
|
||||||
|
|
||||||
|
- name: Create and push multi-arch manifest
|
||||||
|
run: |
|
||||||
|
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||||
|
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||||
|
|
||||||
|
docker manifest create \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||||
|
|
||||||
|
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||||
|
|
||||||
|
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||||
|
|||||||
36
.github/workflows/test_server.yml
vendored
36
.github/workflows/test_server.yml
vendored
@@ -19,29 +19,39 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v3
|
uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
working-directory: server
|
working-directory: server
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: |
|
run: |
|
||||||
cd server
|
cd server
|
||||||
uv run -m pytest -v tests
|
uv run -m pytest -v tests
|
||||||
|
|
||||||
docker:
|
docker-amd64:
|
||||||
runs-on: ubuntu-latest
|
runs-on: linux-amd64
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build and push
|
- name: Build AMD64
|
||||||
id: docker_build
|
uses: docker/build-push-action@v6
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=min,scope=amd64
|
||||||
|
|
||||||
|
docker-arm64:
|
||||||
|
runs-on: linux-arm64
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Build ARM64
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: server
|
||||||
|
platforms: linux/arm64
|
||||||
|
cache-from: type=gha,scope=arm64
|
||||||
|
cache-to: type=gha,mode=min,scope=arm64
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ This repository hold an API for the GPU implementation of the Reflector API serv
|
|||||||
and use [Modal.com](https://modal.com)
|
and use [Modal.com](https://modal.com)
|
||||||
|
|
||||||
- `reflector_diarizer.py` - Diarization API
|
- `reflector_diarizer.py` - Diarization API
|
||||||
- `reflector_transcriber.py` - Transcription API
|
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||||
|
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||||
- `reflector_translator.py` - Translation API
|
- `reflector_translator.py` - Translation API
|
||||||
|
|
||||||
## Modal.com deployment
|
## Modal.com deployment
|
||||||
@@ -19,6 +20,10 @@ $ modal deploy reflector_transcriber.py
|
|||||||
...
|
...
|
||||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||||
|
|
||||||
|
$ modal deploy reflector_transcriber_parakeet.py
|
||||||
|
...
|
||||||
|
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
|
||||||
$ modal deploy reflector_llm.py
|
$ modal deploy reflector_llm.py
|
||||||
...
|
...
|
||||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||||
@@ -68,6 +73,86 @@ Authorization: bearer <REFLECTOR_APIKEY>
|
|||||||
|
|
||||||
### Transcription
|
### Transcription
|
||||||
|
|
||||||
|
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||||
|
|
||||||
|
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||||
|
|
||||||
|
**GPU Configuration:**
|
||||||
|
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||||
|
- Higher concurrency (max_inputs=10)
|
||||||
|
- Optimized for multiple small audio files
|
||||||
|
- Supports batch processing for efficiency
|
||||||
|
|
||||||
|
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||||
|
- Lower concurrency but more powerful processing
|
||||||
|
- Optimized for single large audio files
|
||||||
|
- VAD-based chunking for long-form audio
|
||||||
|
|
||||||
|
##### `/v1/audio/transcriptions` - Small file transcription
|
||||||
|
|
||||||
|
**request** (multipart/form-data)
|
||||||
|
- `file` or `files[]` - audio file(s) to transcribe
|
||||||
|
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||||
|
- `language` - language code (default: `en`)
|
||||||
|
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0}
|
||||||
|
],
|
||||||
|
"filename": "audio.mp3"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For multiple files with batch=true:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"filename": "audio1.mp3",
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [...]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "audio2.mp3",
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [...]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||||
|
|
||||||
|
**request** (application/json)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_file_url": "https://example.com/audio.mp3",
|
||||||
|
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 0.0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"text": "transcribed text from large file",
|
||||||
|
"words": [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||||
|
|
||||||
|
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||||
|
|
||||||
`POST /transcribe`
|
`POST /transcribe`
|
||||||
|
|
||||||
**request** (multipart/form-data)
|
**request** (multipart/form-data)
|
||||||
|
|||||||
@@ -4,14 +4,80 @@ Reflector GPU backend - diarizer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Mapping, NewType
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import modal.gpu
|
import modal
|
||||||
from modal import App, Image, Secret, asgi_app, enter, method
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||||
MODEL_DIR = "/root/diarization_models"
|
MODEL_DIR = "/root/diarization_models"
|
||||||
app = App(name="reflector-diarizer")
|
UPLOADS_PATH = "/uploads"
|
||||||
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
|
||||||
|
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||||
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||||
|
|
||||||
|
app = modal.App(name="reflector-diarizer")
|
||||||
|
|
||||||
|
# Volume for temporary file uploads
|
||||||
|
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
url_path = parsed_url.path
|
||||||
|
|
||||||
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
if url_path.lower().endswith(f".{ext}"):
|
||||||
|
return AudioFileExtension(ext)
|
||||||
|
|
||||||
|
content_type = headers.get("content-type", "").lower()
|
||||||
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||||
|
return AudioFileExtension("mp3")
|
||||||
|
if "audio/wav" in content_type:
|
||||||
|
return AudioFileExtension("wav")
|
||||||
|
if "audio/mp4" in content_type:
|
||||||
|
return AudioFileExtension("mp4")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported audio format for URL: {url}. "
|
||||||
|
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_to_volume(
|
||||||
|
audio_file_url: str,
|
||||||
|
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||||
|
import requests
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
print(f"Checking audio file at: {audio_file_url}")
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
|
||||||
|
print(f"Downloading audio file from: {audio_file_url}")
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail=f"Failed to download audio file: {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||||
|
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
print(f"File saved as: {unique_filename}")
|
||||||
|
return unique_filename, audio_suffix
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
def migrate_cache_llm():
|
||||||
@@ -39,7 +105,7 @@ def download_pyannote_audio():
|
|||||||
|
|
||||||
|
|
||||||
diarizer_image = (
|
diarizer_image = (
|
||||||
Image.debian_slim(python_version="3.10.8")
|
modal.Image.debian_slim(python_version="3.10.8")
|
||||||
.pip_install(
|
.pip_install(
|
||||||
"pyannote.audio==3.1.0",
|
"pyannote.audio==3.1.0",
|
||||||
"requests",
|
"requests",
|
||||||
@@ -55,7 +121,8 @@ diarizer_image = (
|
|||||||
"hf-transfer",
|
"hf-transfer",
|
||||||
)
|
)
|
||||||
.run_function(
|
.run_function(
|
||||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
download_pyannote_audio,
|
||||||
|
secrets=[modal.Secret.from_name("hf_token")],
|
||||||
)
|
)
|
||||||
.run_function(migrate_cache_llm)
|
.run_function(migrate_cache_llm)
|
||||||
.env(
|
.env(
|
||||||
@@ -70,44 +137,51 @@ diarizer_image = (
|
|||||||
|
|
||||||
|
|
||||||
@app.cls(
|
@app.cls(
|
||||||
gpu=modal.gpu.A100(size="40GB"),
|
gpu="A100",
|
||||||
timeout=60 * 30,
|
timeout=60 * 30,
|
||||||
scaledown_window=60,
|
|
||||||
allow_concurrent_inputs=1,
|
|
||||||
image=diarizer_image,
|
image=diarizer_image,
|
||||||
|
volumes={UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
secrets=[
|
||||||
|
modal.Secret.from_name("hf_token"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@modal.concurrent(max_inputs=1)
|
||||||
class Diarizer:
|
class Diarizer:
|
||||||
@enter()
|
@modal.enter(snap=True)
|
||||||
def enter(self):
|
def enter(self):
|
||||||
import torch
|
import torch
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
self.use_gpu = torch.cuda.is_available()
|
self.use_gpu = torch.cuda.is_available()
|
||||||
self.device = "cuda" if self.use_gpu else "cpu"
|
self.device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
PYANNOTE_MODEL_NAME,
|
||||||
|
cache_dir=MODEL_DIR,
|
||||||
|
use_auth_token=os.environ["HF_TOKEN"],
|
||||||
)
|
)
|
||||||
self.diarization_pipeline.to(torch.device(self.device))
|
self.diarization_pipeline.to(torch.device(self.device))
|
||||||
|
|
||||||
@method()
|
@modal.method()
|
||||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
upload_volume.reload()
|
||||||
fp.write(audio_data)
|
|
||||||
|
|
||||||
print("Diarizing audio")
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
waveform, sample_rate = torchaudio.load(fp.name)
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
print(f"Diarizing audio from: {file_path}")
|
||||||
|
waveform, sample_rate = torchaudio.load(file_path)
|
||||||
diarization = self.diarization_pipeline(
|
diarization = self.diarization_pipeline(
|
||||||
{"waveform": waveform, "sample_rate": sample_rate}
|
{"waveform": waveform, "sample_rate": sample_rate}
|
||||||
)
|
)
|
||||||
|
|
||||||
words = []
|
words = []
|
||||||
for diarization_segment, _, speaker in diarization.itertracks(
|
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
yield_label=True
|
|
||||||
):
|
|
||||||
words.append(
|
words.append(
|
||||||
{
|
{
|
||||||
"start": round(timestamp + diarization_segment.start, 3),
|
"start": round(timestamp + diarization_segment.start, 3),
|
||||||
@@ -127,17 +201,18 @@ class Diarizer:
|
|||||||
@app.function(
|
@app.function(
|
||||||
timeout=60 * 10,
|
timeout=60 * 10,
|
||||||
scaledown_window=60 * 3,
|
scaledown_window=60 * 3,
|
||||||
allow_concurrent_inputs=40,
|
|
||||||
secrets=[
|
secrets=[
|
||||||
Secret.from_name("reflector-gpu"),
|
modal.Secret.from_name("reflector-gpu"),
|
||||||
],
|
],
|
||||||
|
volumes={UPLOADS_PATH: upload_volume},
|
||||||
image=diarizer_image,
|
image=diarizer_image,
|
||||||
)
|
)
|
||||||
@asgi_app()
|
@modal.concurrent(max_inputs=40)
|
||||||
|
@modal.asgi_app()
|
||||||
def web():
|
def web():
|
||||||
import requests
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
diarizerstub = Diarizer()
|
diarizerstub = Diarizer()
|
||||||
|
|
||||||
@@ -153,35 +228,26 @@ def web():
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_audio_file(audio_file_url: str):
|
|
||||||
# Check if the audio file exists
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail="The audio file does not exist.",
|
|
||||||
)
|
|
||||||
|
|
||||||
class DiarizationResponse(BaseModel):
|
class DiarizationResponse(BaseModel):
|
||||||
result: dict
|
result: dict
|
||||||
|
|
||||||
@app.post(
|
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||||
)
|
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||||
def diarize(
|
|
||||||
audio_file_url: str, timestamp: float = 0.0
|
|
||||||
) -> HTTPException | DiarizationResponse:
|
|
||||||
# Currently the uploaded files are in mp3 format
|
|
||||||
audio_suffix = "mp3"
|
|
||||||
|
|
||||||
print("Downloading audio file")
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
print("Audio file downloaded successfully")
|
|
||||||
|
|
||||||
|
try:
|
||||||
func = diarizerstub.diarize.spawn(
|
func = diarizerstub.diarize.spawn(
|
||||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
filename=unique_filename, timestamp=timestamp
|
||||||
)
|
)
|
||||||
result = func.get()
|
result = func.get()
|
||||||
return result
|
return result
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
upload_volume.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cleaning up {unique_filename}: {e}")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from typing import Mapping, NewType
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import modal
|
||||||
|
|
||||||
|
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||||
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
SAMPLERATE = 16000
|
||||||
|
UPLOADS_PATH = "/uploads"
|
||||||
|
CACHE_PATH = "/cache"
|
||||||
|
VAD_CONFIG = {
|
||||||
|
"max_segment_duration": 30.0,
|
||||||
|
"batch_max_files": 10,
|
||||||
|
"batch_max_duration": 5.0,
|
||||||
|
"min_segment_duration": 0.02,
|
||||||
|
"silence_padding": 0.5,
|
||||||
|
"window_size": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||||
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||||
|
|
||||||
|
app = modal.App("reflector-transcriber-parakeet")
|
||||||
|
|
||||||
|
# Volume for caching model weights
|
||||||
|
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||||
|
# Volume for temporary file uploads
|
||||||
|
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||||
|
|
||||||
|
image = (
|
||||||
|
modal.Image.from_registry(
|
||||||
|
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||||
|
)
|
||||||
|
.env(
|
||||||
|
{
|
||||||
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||||
|
"HF_HOME": "/cache",
|
||||||
|
"DEBIAN_FRONTEND": "noninteractive",
|
||||||
|
"CXX": "g++",
|
||||||
|
"CC": "g++",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.apt_install("ffmpeg")
|
||||||
|
.pip_install(
|
||||||
|
"hf_transfer==0.1.9",
|
||||||
|
"huggingface_hub[hf-xet]==0.31.2",
|
||||||
|
"nemo_toolkit[asr]==2.3.0",
|
||||||
|
"cuda-python==12.8.0",
|
||||||
|
"fastapi==0.115.12",
|
||||||
|
"numpy<2",
|
||||||
|
"librosa==0.10.1",
|
||||||
|
"requests",
|
||||||
|
"silero-vad==5.1.0",
|
||||||
|
"torch",
|
||||||
|
)
|
||||||
|
.entrypoint([]) # silence chatty logs by container on start
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
url_path = parsed_url.path
|
||||||
|
|
||||||
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
if url_path.lower().endswith(f".{ext}"):
|
||||||
|
return AudioFileExtension(ext)
|
||||||
|
|
||||||
|
content_type = headers.get("content-type", "").lower()
|
||||||
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||||
|
return AudioFileExtension("mp3")
|
||||||
|
if "audio/wav" in content_type:
|
||||||
|
return AudioFileExtension("wav")
|
||||||
|
if "audio/mp4" in content_type:
|
||||||
|
return AudioFileExtension("mp4")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported audio format for URL: {url}. "
|
||||||
|
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_to_volume(
|
||||||
|
audio_file_url: str,
|
||||||
|
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||||
|
import requests
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||||
|
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
return unique_filename, audio_suffix
|
||||||
|
|
||||||
|
|
||||||
|
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||||
|
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||||
|
|
||||||
|
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||||
|
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||||
|
have to be of the same length
|
||||||
|
|
||||||
|
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
audio_duration = len(audio_array) / sample_rate
|
||||||
|
if audio_duration < 0.5:
|
||||||
|
silence_samples = int(sample_rate * 0.5)
|
||||||
|
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
return np.concatenate([audio_array, silence])
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
|
||||||
|
@app.cls(
|
||||||
|
gpu="A10G",
|
||||||
|
timeout=600,
|
||||||
|
scaledown_window=300,
|
||||||
|
image=image,
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
)
|
||||||
|
@modal.concurrent(max_inputs=10)
|
||||||
|
class TranscriberParakeetLive:
|
||||||
|
@modal.enter(snap=True)
|
||||||
|
def enter(self):
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
|
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
print(f"Model is on device: {device}")
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_segment(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
padded_audio = pad_audio(audio_array, sample_rate)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
with NoStdStreams():
|
||||||
|
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||||
|
|
||||||
|
text = output.text.strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(word_info["start"], 2),
|
||||||
|
"end": round(word_info["end"], 2),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"text": text, "words": words}
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_batch(
|
||||||
|
self,
|
||||||
|
filenames: list[str],
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
audio_arrays = []
|
||||||
|
|
||||||
|
# Load all audio files with padding
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
padded_audio = pad_audio(audio_array, sample_rate)
|
||||||
|
audio_arrays.append(padded_audio)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
with NoStdStreams():
|
||||||
|
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||||
|
|
||||||
|
# Process results for each file
|
||||||
|
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||||
|
text = output.text.strip()
|
||||||
|
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(word_info["start"], 2),
|
||||||
|
"end": round(word_info["end"], 2),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"filename": filename,
|
||||||
|
"text": text,
|
||||||
|
"words": words,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# L40S class for file transcription (bigger files)
|
||||||
|
@app.cls(
|
||||||
|
gpu="L40S",
|
||||||
|
timeout=900,
|
||||||
|
image=image,
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
)
|
||||||
|
class TranscriberParakeetFile:
|
||||||
|
@modal.enter(snap=True)
|
||||||
|
def enter(self):
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
import torch
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
|
||||||
|
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
print(f"Model is on device: {device}")
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
self.vad_model = load_silero_vad(onnx=False)
|
||||||
|
print("Silero VAD initialized")
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_segment(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
timestamp_offset: float = 0.0,
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from silero_vad import VADIterator
|
||||||
|
|
||||||
|
def load_and_convert_audio(file_path):
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
def vad_segment_generator(audio_array):
|
||||||
|
"""Generate speech segments using VAD with start/end sample indices"""
|
||||||
|
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||||
|
window_size = VAD_CONFIG["window_size"]
|
||||||
|
start = None
|
||||||
|
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(
|
||||||
|
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
speech_dict = vad_iterator(chunk)
|
||||||
|
if not speech_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "start" in speech_dict:
|
||||||
|
start = speech_dict["start"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "end" in speech_dict and start is not None:
|
||||||
|
end = speech_dict["end"]
|
||||||
|
start_time = start / float(SAMPLERATE)
|
||||||
|
end_time = end / float(SAMPLERATE)
|
||||||
|
|
||||||
|
# Extract the actual audio segment
|
||||||
|
audio_segment = audio_array[start:end]
|
||||||
|
|
||||||
|
yield (start_time, end_time, audio_segment)
|
||||||
|
start = None
|
||||||
|
|
||||||
|
vad_iterator.reset_states()
|
||||||
|
|
||||||
|
def vad_segment_filter(segments):
|
||||||
|
"""Filter VAD segments by duration and chunk large segments"""
|
||||||
|
min_dur = VAD_CONFIG["min_segment_duration"]
|
||||||
|
max_dur = VAD_CONFIG["max_segment_duration"]
|
||||||
|
|
||||||
|
for start_time, end_time, audio_segment in segments:
|
||||||
|
segment_duration = end_time - start_time
|
||||||
|
|
||||||
|
# Skip very small segments
|
||||||
|
if segment_duration < min_dur:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If segment is within max duration, yield as-is
|
||||||
|
if segment_duration <= max_dur:
|
||||||
|
yield (start_time, end_time, audio_segment)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Chunk large segments into smaller pieces
|
||||||
|
chunk_samples = int(max_dur * SAMPLERATE)
|
||||||
|
current_start = start_time
|
||||||
|
|
||||||
|
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
||||||
|
chunk_audio = audio_segment[
|
||||||
|
chunk_offset : chunk_offset + chunk_samples
|
||||||
|
]
|
||||||
|
if len(chunk_audio) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
||||||
|
chunk_end = current_start + chunk_duration
|
||||||
|
|
||||||
|
# Only yield chunks that meet minimum duration
|
||||||
|
if chunk_duration >= min_dur:
|
||||||
|
yield (current_start, chunk_end, chunk_audio)
|
||||||
|
|
||||||
|
current_start = chunk_end
|
||||||
|
|
||||||
|
def batch_segments(segments, max_files=10, max_duration=5.0):
|
||||||
|
batch = []
|
||||||
|
batch_duration = 0.0
|
||||||
|
|
||||||
|
for start_time, end_time, audio_segment in segments:
|
||||||
|
segment_duration = end_time - start_time
|
||||||
|
|
||||||
|
if segment_duration < VAD_CONFIG["silence_padding"]:
|
||||||
|
silence_samples = int(
|
||||||
|
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
||||||
|
)
|
||||||
|
padding = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
audio_segment = np.concatenate([audio_segment, padding])
|
||||||
|
segment_duration = VAD_CONFIG["silence_padding"]
|
||||||
|
|
||||||
|
batch.append((start_time, end_time, audio_segment))
|
||||||
|
batch_duration += segment_duration
|
||||||
|
|
||||||
|
if len(batch) >= max_files or batch_duration >= max_duration:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
batch_duration = 0.0
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def transcribe_batch(model, audio_segments):
|
||||||
|
with NoStdStreams():
|
||||||
|
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def emit_results(
|
||||||
|
results,
|
||||||
|
segments_info,
|
||||||
|
batch_index,
|
||||||
|
total_batches,
|
||||||
|
):
|
||||||
|
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||||
|
for i, (output, (start_time, end_time, _)) in enumerate(
|
||||||
|
zip(results, segments_info)
|
||||||
|
):
|
||||||
|
text = output.text.strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(
|
||||||
|
word_info["start"] + start_time + timestamp_offset, 2
|
||||||
|
),
|
||||||
|
"end": round(
|
||||||
|
word_info["end"] + start_time + timestamp_offset, 2
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
yield text, words
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array = load_and_convert_audio(file_path)
|
||||||
|
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||||
|
processed_duration = 0.0
|
||||||
|
|
||||||
|
all_text_parts = []
|
||||||
|
all_words = []
|
||||||
|
|
||||||
|
raw_segments = vad_segment_generator(audio_array)
|
||||||
|
filtered_segments = vad_segment_filter(raw_segments)
|
||||||
|
batches = batch_segments(
|
||||||
|
filtered_segments,
|
||||||
|
VAD_CONFIG["batch_max_files"],
|
||||||
|
VAD_CONFIG["batch_max_duration"],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_index = 0
|
||||||
|
total_batches = max(
|
||||||
|
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_index += 1
|
||||||
|
audio_segments = [seg[2] for seg in batch]
|
||||||
|
results = transcribe_batch(self.model, audio_segments)
|
||||||
|
|
||||||
|
for text, words in emit_results(
|
||||||
|
results,
|
||||||
|
batch,
|
||||||
|
batch_index,
|
||||||
|
total_batches,
|
||||||
|
):
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
all_text_parts.append(text)
|
||||||
|
all_words.extend(words)
|
||||||
|
|
||||||
|
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
||||||
|
|
||||||
|
combined_text = " ".join(all_text_parts)
|
||||||
|
return {"text": combined_text, "words": all_words}
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
scaledown_window=60,
|
||||||
|
timeout=600,
|
||||||
|
secrets=[
|
||||||
|
modal.Secret.from_name("reflector-gpu"),
|
||||||
|
],
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
@modal.concurrent(max_inputs=40)
|
||||||
|
@modal.asgi_app()
|
||||||
|
def web():
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import (
|
||||||
|
Body,
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
Form,
|
||||||
|
HTTPException,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
transcriber_live = TranscriberParakeetLive()
|
||||||
|
transcriber_file = TranscriberParakeetFile()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||||
|
return
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class TranscriptResponse(BaseModel):
|
||||||
|
result: dict
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||||
|
def transcribe(
|
||||||
|
file: UploadFile = None,
|
||||||
|
files: list[UploadFile] | None = None,
|
||||||
|
model: str = Form(MODEL_NAME),
|
||||||
|
language: str = Form("en"),
|
||||||
|
batch: bool = Form(False),
|
||||||
|
):
|
||||||
|
# Parakeet only supports English
|
||||||
|
if language != "en":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||||
|
)
|
||||||
|
# Handle both single file and multiple files
|
||||||
|
if not file and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||||
|
)
|
||||||
|
if batch and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Batch transcription requires 'files'"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_files = [file] if file else files
|
||||||
|
|
||||||
|
# Upload files to volume
|
||||||
|
uploaded_filenames = []
|
||||||
|
for upload_file in upload_files:
|
||||||
|
audio_suffix = upload_file.filename.split(".")[-1]
|
||||||
|
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
print(f"Writing file to: {file_path}")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
content = upload_file.file.read()
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
uploaded_filenames.append(unique_filename)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use A10G live transcriber for per-file transcription
|
||||||
|
if batch and len(upload_files) > 1:
|
||||||
|
# Use batch transcription
|
||||||
|
func = transcriber_live.transcribe_batch.spawn(
|
||||||
|
filenames=uploaded_filenames,
|
||||||
|
)
|
||||||
|
results = func.get()
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
# Per-file transcription
|
||||||
|
results = []
|
||||||
|
for filename in uploaded_filenames:
|
||||||
|
func = transcriber_live.transcribe_segment.spawn(
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
result["filename"] = filename
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return {"results": results} if len(results) > 1 else results[0]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for filename in uploaded_filenames:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting {filename}: {e}")
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||||
|
def transcribe_from_url(
|
||||||
|
audio_file_url: str = Body(
|
||||||
|
..., description="URL of the audio file to transcribe"
|
||||||
|
),
|
||||||
|
model: str = Body(MODEL_NAME),
|
||||||
|
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||||
|
timestamp_offset: float = Body(0.0),
|
||||||
|
):
|
||||||
|
# Parakeet only supports English
|
||||||
|
if language != "en":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||||
|
)
|
||||||
|
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
func = transcriber_file.transcribe_segment.spawn(
|
||||||
|
filename=unique_filename,
|
||||||
|
timestamp_offset=timestamp_offset,
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
upload_volume.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cleaning up {unique_filename}: {e}")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class NoStdStreams:
|
||||||
|
def __init__(self):
|
||||||
|
self.devnull = open(os.devnull, "w")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||||
|
self._stdout.flush()
|
||||||
|
self._stderr.flush()
|
||||||
|
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||||
|
self.devnull.close()
|
||||||
@@ -32,7 +32,6 @@ dependencies = [
|
|||||||
"redis>=5.0.1",
|
"redis>=5.0.1",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
"faster-whisper>=0.10.0",
|
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"jsonschema>=4.23.0",
|
"jsonschema>=4.23.0",
|
||||||
"openai>=1.59.7",
|
"openai>=1.59.7",
|
||||||
@@ -41,6 +40,7 @@ dependencies = [
|
|||||||
"llama-index-llms-openai-like>=0.4.0",
|
"llama-index-llms-openai-like>=0.4.0",
|
||||||
"pytest-env>=1.1.5",
|
"pytest-env>=1.1.5",
|
||||||
"webvtt-py>=0.5.0",
|
"webvtt-py>=0.5.0",
|
||||||
|
"silero-vad>=5.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -57,6 +57,7 @@ tests = [
|
|||||||
"httpx-ws>=0.4.1",
|
"httpx-ws>=0.4.1",
|
||||||
"pytest-httpx>=0.23.1",
|
"pytest-httpx>=0.23.1",
|
||||||
"pytest-celery>=0.0.0",
|
"pytest-celery>=0.0.0",
|
||||||
|
"pytest-recording>=0.13.4",
|
||||||
"pytest-docker>=3.2.3",
|
"pytest-docker>=3.2.3",
|
||||||
"asgi-lifespan>=2.1.0",
|
"asgi-lifespan>=2.1.0",
|
||||||
]
|
]
|
||||||
@@ -67,6 +68,10 @@ evaluation = [
|
|||||||
"tqdm>=4.66.0",
|
"tqdm>=4.66.0",
|
||||||
"pydantic>=2.1.1",
|
"pydantic>=2.1.1",
|
||||||
]
|
]
|
||||||
|
local = [
|
||||||
|
"pyannote-audio>=3.3.2",
|
||||||
|
"faster-whisper>=0.10.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
default-groups = [
|
default-groups = [
|
||||||
@@ -74,6 +79,7 @@ default-groups = [
|
|||||||
"tests",
|
"tests",
|
||||||
"aws",
|
"aws",
|
||||||
"evaluation",
|
"evaluation",
|
||||||
|
"local"
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@@ -94,6 +100,9 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
|
|||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
markers = [
|
||||||
|
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
|
|||||||
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
File-based processing pipeline
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Optimized pipeline for processing complete audio/video files.
|
||||||
|
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
import structlog
|
||||||
|
from celery import shared_task
|
||||||
|
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
Transcript,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
|
||||||
|
from reflector.processors import (
|
||||||
|
AudioFileWriterProcessor,
|
||||||
|
TranscriptFinalSummaryProcessor,
|
||||||
|
TranscriptFinalTitleProcessor,
|
||||||
|
TranscriptTopicDetectorProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationInput
|
||||||
|
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptInput
|
||||||
|
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||||
|
from reflector.processors.transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
DiarizationSegment,
|
||||||
|
TitleSummary,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyPipeline:
|
||||||
|
"""Empty pipeline for processors that need a pipeline reference"""
|
||||||
|
|
||||||
|
def __init__(self, logger: structlog.BoundLogger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def get_pref(self, k, d=None):
|
||||||
|
return d
|
||||||
|
|
||||||
|
async def emit(self, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainFile(PipelineMainBase):
|
||||||
|
"""
|
||||||
|
Optimized file processing pipeline.
|
||||||
|
Processes complete audio/video files with parallel execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger: structlog.BoundLogger = None
|
||||||
|
empty_pipeline = None
|
||||||
|
|
||||||
|
def __init__(self, transcript_id: str):
|
||||||
|
super().__init__(transcript_id=transcript_id)
|
||||||
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
|
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||||
|
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if not isinstance(result, Exception):
|
||||||
|
continue
|
||||||
|
self.logger.error(
|
||||||
|
f"Error in {operation} (task {i}): {result}",
|
||||||
|
transcript_id=self.transcript_id,
|
||||||
|
exc_info=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process(self, file_path: Path):
|
||||||
|
"""Main entry point for file processing"""
|
||||||
|
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
# Extract audio and write to transcript location
|
||||||
|
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||||
|
|
||||||
|
# Upload for processing
|
||||||
|
audio_url = await self.upload_audio(audio_path, transcript)
|
||||||
|
|
||||||
|
# Run parallel processing
|
||||||
|
await self.run_parallel_processing(
|
||||||
|
audio_path,
|
||||||
|
audio_url,
|
||||||
|
transcript.source_language,
|
||||||
|
transcript.target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("File pipeline complete")
|
||||||
|
|
||||||
|
async def extract_and_write_audio(
|
||||||
|
self, file_path: Path, transcript: Transcript
|
||||||
|
) -> Path:
|
||||||
|
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||||
|
self.logger.info(f"Processing audio file: {file_path}")
|
||||||
|
|
||||||
|
# Check if it's already audio-only
|
||||||
|
container = av.open(str(file_path))
|
||||||
|
has_video = len(container.streams.video) > 0
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||||
|
mp3_writer = AudioFileWriterProcessor(
|
||||||
|
path=transcript.audio_mp3_filename,
|
||||||
|
on_duration=self.on_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process audio frames and write to transcript location
|
||||||
|
input_container = av.open(str(file_path))
|
||||||
|
for frame in input_container.decode(audio=0):
|
||||||
|
await mp3_writer.push(frame)
|
||||||
|
|
||||||
|
await mp3_writer.flush()
|
||||||
|
input_container.close()
|
||||||
|
|
||||||
|
if has_video:
|
||||||
|
self.logger.info(
|
||||||
|
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger.info(
|
||||||
|
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return transcript.audio_mp3_filename
|
||||||
|
|
||||||
|
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||||
|
"""Upload audio to storage for processing"""
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
if not storage:
|
||||||
|
raise Exception(
|
||||||
|
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("Uploading audio to storage")
|
||||||
|
|
||||||
|
with open(audio_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||||
|
await storage.put_file(storage_path, audio_data)
|
||||||
|
|
||||||
|
audio_url = await storage.get_file_url(storage_path)
|
||||||
|
|
||||||
|
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||||
|
return audio_url
|
||||||
|
|
||||||
|
async def run_parallel_processing(
|
||||||
|
self,
|
||||||
|
audio_path: Path,
|
||||||
|
audio_url: str,
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
):
|
||||||
|
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||||
|
self.logger.info(
|
||||||
|
"Starting parallel processing", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: Parallel processing of independent tasks
|
||||||
|
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||||
|
diarization_task = self.diarize_file(audio_url)
|
||||||
|
waveform_task = self.generate_waveform(audio_path)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_result = results[0]
|
||||||
|
diarization_result = results[1]
|
||||||
|
|
||||||
|
# Handle errors - raise any exception that occurred
|
||||||
|
self._handle_gather_exceptions(results, "parallel processing")
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
|
||||||
|
# Phase 2: Assemble transcript with diarization
|
||||||
|
self.logger.info(
|
||||||
|
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
input_data = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript_result, diarization=diarization_result or []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
diarized_transcript: Transcript | None = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal diarized_transcript
|
||||||
|
diarized_transcript = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not diarized_transcript:
|
||||||
|
raise ValueError("No diarized transcript captured")
|
||||||
|
|
||||||
|
# Phase 3: Generate topics from diarized transcript
|
||||||
|
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||||
|
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||||
|
|
||||||
|
# Phase 4: Generate title and summaries in parallel
|
||||||
|
self.logger.info(
|
||||||
|
"Generating title and summaries", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
results = await asyncio.gather(
|
||||||
|
self.generate_title(topics),
|
||||||
|
self.generate_summaries(topics),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle_gather_exceptions(results, "title and summary generation")
|
||||||
|
|
||||||
|
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||||
|
"""Transcribe complete file"""
|
||||||
|
processor = FileTranscriptAutoProcessor()
|
||||||
|
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
result: TranscriptType | None = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal result
|
||||||
|
result = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise ValueError("No transcript captured")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||||
|
"""Get diarization for file"""
|
||||||
|
if not settings.DIARIZATION_BACKEND:
|
||||||
|
self.logger.info("Diarization disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = FileDiarizationAutoProcessor()
|
||||||
|
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
result = None
|
||||||
|
|
||||||
|
async def capture_result(diarization_output):
|
||||||
|
nonlocal result
|
||||||
|
result = diarization_output.diarization
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Diarization failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def generate_waveform(self, audio_path: Path):
|
||||||
|
"""Generate and save waveform"""
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
processor = AudioWaveformProcessor(
|
||||||
|
audio_path=audio_path,
|
||||||
|
waveform_path=transcript.audio_waveform_filename,
|
||||||
|
on_waveform=self.on_waveform,
|
||||||
|
)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
async def detect_topics(
|
||||||
|
self, transcript: TranscriptType, target_language: str
|
||||||
|
) -> list[TitleSummary]:
|
||||||
|
"""Detect topics from complete transcript"""
|
||||||
|
chunk_size = 300
|
||||||
|
topics: list[TitleSummary] = []
|
||||||
|
|
||||||
|
async def on_topic(topic: TitleSummary):
|
||||||
|
topics.append(topic)
|
||||||
|
return await self.on_topic(topic)
|
||||||
|
|
||||||
|
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||||
|
topic_detector.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for i in range(0, len(transcript.words), chunk_size):
|
||||||
|
chunk_words = transcript.words[i : i + chunk_size]
|
||||||
|
if not chunk_words:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_transcript = TranscriptType(
|
||||||
|
words=chunk_words, translation=transcript.translation
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_detector.push(chunk_transcript)
|
||||||
|
|
||||||
|
await topic_detector.flush()
|
||||||
|
return topics
|
||||||
|
|
||||||
|
async def generate_title(self, topics: list[TitleSummary]):
|
||||||
|
"""Generate title from topics"""
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for title generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||||
|
"""Generate long and short summaries from topics"""
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for summary generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
|
transcript=transcript,
|
||||||
|
callback=self.on_long_summary,
|
||||||
|
on_short_summary=self.on_short_summary,
|
||||||
|
)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_file_process(*, transcript_id: str):
|
||||||
|
"""Celery task for file pipeline processing"""
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
|
# Find the file to process
|
||||||
|
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||||
|
if not audio_file:
|
||||||
|
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||||
|
|
||||||
|
if not audio_file:
|
||||||
|
raise Exception("No audio file found to process")
|
||||||
|
|
||||||
|
# Run file pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||||
|
await pipeline.process(audio_file)
|
||||||
@@ -147,15 +147,18 @@ class StrValue(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||||
transcript_id: str
|
def __init__(self, transcript_id: str):
|
||||||
ws_room_id: str | None = None
|
super().__init__()
|
||||||
ws_manager: WebsocketManager | None = None
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
# prepare websocket
|
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self.transcript_id = transcript_id
|
||||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||||
self.ws_manager = get_ws_manager()
|
self._ws_manager = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ws_manager(self) -> WebsocketManager:
|
||||||
|
if self._ws_manager is None:
|
||||||
|
self._ws_manager = get_ws_manager()
|
||||||
|
return self._ws_manager
|
||||||
|
|
||||||
async def get_transcript(self) -> Transcript:
|
async def get_transcript(self) -> Transcript:
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
@@ -355,7 +358,6 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
@@ -376,6 +378,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
pipeline.logger.bind(transcript_id=transcript.id)
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
pipeline.logger.info("Pipeline main live created")
|
pipeline.logger.info("Pipeline main live created")
|
||||||
|
pipeline.describe()
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@@ -394,7 +397,6 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||||
)
|
)
|
||||||
@@ -435,8 +437,6 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
self.prepare()
|
|
||||||
|
|
||||||
# get transcript
|
# get transcript
|
||||||
self._transcript = transcript = await self.get_transcript()
|
self._transcript = transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
|||||||
@@ -18,22 +18,14 @@ During its lifecycle, it will emit the following status:
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import Pipeline
|
from reflector.processors import Pipeline
|
||||||
|
|
||||||
PipelineMessage = TypeVar("PipelineMessage")
|
PipelineMessage = TypeVar("PipelineMessage")
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
class PipelineRunner(Generic[PipelineMessage]):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
def __init__(self):
|
||||||
|
|
||||||
status: str = "idle"
|
|
||||||
pipeline: Pipeline | None = None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._task = None
|
self._task = None
|
||||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||||
self._ev_done = asyncio.Event()
|
self._ev_done = asyncio.Event()
|
||||||
@@ -42,6 +34,8 @@ class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
|||||||
runner=id(self),
|
runner=id(self),
|
||||||
runner_cls=self.__class__.__name__,
|
runner_cls=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
self.status = "idle"
|
||||||
|
self.pipeline: Pipeline | None = None
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ from .base import ( # noqa: F401
|
|||||||
Processor,
|
Processor,
|
||||||
ThreadedProcessor,
|
ThreadedProcessor,
|
||||||
)
|
)
|
||||||
|
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||||
|
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||||
|
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||||
|
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||||
|
from .transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||||
|
)
|
||||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||||
|
|||||||
@@ -1,28 +1,340 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from silero_vad import VADIterator, load_silero_vad
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerProcessor(Processor):
|
class AudioChunkerProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Assemble audio frames into chunks
|
Assemble audio frames into chunks with VAD-based speech detection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = av.AudioFrame
|
INPUT_TYPE = av.AudioFrame
|
||||||
OUTPUT_TYPE = list[av.AudioFrame]
|
OUTPUT_TYPE = list[av.AudioFrame]
|
||||||
|
|
||||||
def __init__(self, max_frames=256):
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_frames=256,
|
||||||
|
max_frames=1024,
|
||||||
|
vad_threshold=0.5,
|
||||||
|
use_onnx=False,
|
||||||
|
min_frames=2,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.frames: list[av.AudioFrame] = []
|
self.frames: list[av.AudioFrame] = []
|
||||||
|
self.block_frames = block_frames
|
||||||
self.max_frames = max_frames
|
self.max_frames = max_frames
|
||||||
|
self.vad_threshold = vad_threshold
|
||||||
|
self.min_frames = min_frames
|
||||||
|
|
||||||
|
# Initialize Silero VAD
|
||||||
|
self._init_vad(use_onnx)
|
||||||
|
|
||||||
|
def _init_vad(self, use_onnx=False):
|
||||||
|
"""Initialize Silero VAD model"""
|
||||||
|
try:
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||||
|
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||||
|
self.logger.info("Silero VAD initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||||
|
self.vad_model = None
|
||||||
|
self.vad_iterator = None
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
self.frames.append(data)
|
self.frames.append(data)
|
||||||
if len(self.frames) >= self.max_frames:
|
# print("timestamp", data.pts * data.time_base * 1000)
|
||||||
await self.flush()
|
|
||||||
|
# Check for speech segments every 32 frames (~1 second)
|
||||||
|
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||||
|
await self._process_block()
|
||||||
|
|
||||||
|
# Safety fallback - emit if we hit max frames
|
||||||
|
elif len(self.frames) >= self.max_frames:
|
||||||
|
self.logger.warning(
|
||||||
|
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
|
||||||
|
f"emitting first {self.max_frames // 2} frames"
|
||||||
|
)
|
||||||
|
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||||
|
self.frames = self.frames[self.max_frames // 2 :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_block(self):
|
||||||
|
# Need at least 32 frames for VAD detection (~1 second)
|
||||||
|
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Processing block with current buffer size
|
||||||
|
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert frames to numpy array for VAD
|
||||||
|
audio_array = self._frames_to_numpy(self.frames)
|
||||||
|
|
||||||
|
if audio_array is None:
|
||||||
|
# Fallback: emit all frames if conversion failed
|
||||||
|
frames_to_emit = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find complete speech segments in the buffer
|
||||||
|
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||||
|
|
||||||
|
if speech_end_frame is None or speech_end_frame <= 0:
|
||||||
|
# No speech found but buffer is getting large
|
||||||
|
if len(self.frames) > 512:
|
||||||
|
# Check if it's all silence and can be discarded
|
||||||
|
# No speech segment found, buffer at {len(self.frames)} frames
|
||||||
|
|
||||||
|
# Could emit silence or discard old frames here
|
||||||
|
# For now, keep first 256 frames and discard older silence
|
||||||
|
if len(self.frames) > 768:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||||
|
)
|
||||||
|
self.frames = self.frames[-256:]
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate segment timing information
|
||||||
|
frames_to_emit = self.frames[:speech_end_frame]
|
||||||
|
|
||||||
|
# Get timing from av.AudioFrame
|
||||||
|
if frames_to_emit:
|
||||||
|
first_frame = frames_to_emit[0]
|
||||||
|
last_frame = frames_to_emit[-1]
|
||||||
|
sample_rate = first_frame.sample_rate
|
||||||
|
|
||||||
|
# Calculate duration
|
||||||
|
total_samples = sum(f.samples for f in frames_to_emit)
|
||||||
|
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||||
|
|
||||||
|
# Get timestamps if available
|
||||||
|
start_time = (
|
||||||
|
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to HH:MM:SS format for logging
|
||||||
|
def format_time(seconds):
|
||||||
|
if not seconds:
|
||||||
|
return "00:00:00"
|
||||||
|
total_seconds = int(float(seconds))
|
||||||
|
hours = total_seconds // 3600
|
||||||
|
minutes = (total_seconds % 3600) // 60
|
||||||
|
secs = total_seconds % 60
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||||
|
|
||||||
|
start_formatted = format_time(start_time)
|
||||||
|
end_formatted = format_time(end_time)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
remaining_after = len(self.frames) - speech_end_frame
|
||||||
|
|
||||||
|
# Single structured log line
|
||||||
|
self.logger.info(
|
||||||
|
"Speech segment found",
|
||||||
|
start=start_formatted,
|
||||||
|
end=end_formatted,
|
||||||
|
frames=speech_end_frame,
|
||||||
|
duration=round(duration_seconds, 2),
|
||||||
|
buffer_before=len(self.frames),
|
||||||
|
remaining=remaining_after,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
self.frames = self.frames[speech_end_frame:]
|
||||||
|
|
||||||
|
# Filter out segments with too few frames
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in VAD processing: {e}")
|
||||||
|
# Fallback to simple chunking
|
||||||
|
if len(self.frames) >= self.block_frames:
|
||||||
|
frames_to_emit = self.frames[: self.block_frames]
|
||||||
|
self.frames = self.frames[self.block_frames :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||||
|
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||||
|
if not frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_frame = frames[0]
|
||||||
|
original_sample_rate = first_frame.sample_rate
|
||||||
|
|
||||||
|
audio_data = []
|
||||||
|
for frame in frames:
|
||||||
|
frame_array = frame.to_ndarray()
|
||||||
|
|
||||||
|
# Handle stereo -> mono conversion
|
||||||
|
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
|
||||||
|
frame_array = np.mean(frame_array, axis=0)
|
||||||
|
elif len(frame_array.shape) == 2:
|
||||||
|
frame_array = frame_array.flatten()
|
||||||
|
|
||||||
|
audio_data.append(frame_array)
|
||||||
|
|
||||||
|
if not audio_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
combined_audio = np.concatenate(audio_data)
|
||||||
|
|
||||||
|
# Resample from 48kHz to 16kHz if needed
|
||||||
|
if original_sample_rate != 16000:
|
||||||
|
combined_audio = self._resample_audio(
|
||||||
|
combined_audio, original_sample_rate, 16000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure float32 format
|
||||||
|
if combined_audio.dtype == np.int16:
|
||||||
|
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||||
|
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||||
|
elif combined_audio.dtype != np.float32:
|
||||||
|
combined_audio = combined_audio.astype(np.float32)
|
||||||
|
|
||||||
|
return combined_audio
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resample_audio(
|
||||||
|
self, audio: np.ndarray, from_sr: int, to_sr: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Simple linear resampling from from_sr to to_sr"""
|
||||||
|
if from_sr == to_sr:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Simple linear interpolation resampling
|
||||||
|
ratio = to_sr / from_sr
|
||||||
|
new_length = int(len(audio) * ratio)
|
||||||
|
|
||||||
|
# Create indices for interpolation
|
||||||
|
old_indices = np.linspace(0, len(audio) - 1, new_length)
|
||||||
|
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
|
||||||
|
|
||||||
|
return resampled.astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Resampling error", exc_info=e)
|
||||||
|
# Fallback: simple decimation/repetition
|
||||||
|
if from_sr > to_sr:
|
||||||
|
# Downsample by taking every nth sample
|
||||||
|
step = from_sr // to_sr
|
||||||
|
return audio[::step]
|
||||||
|
else:
|
||||||
|
# Upsample by repeating samples
|
||||||
|
repeat = to_sr // from_sr
|
||||||
|
return np.repeat(audio, repeat)
|
||||||
|
|
||||||
|
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||||
|
"""Find complete speech segments and return frame index at segment end"""
|
||||||
|
if self.vad_iterator is None or len(audio_array) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process audio in 512-sample windows for VAD
|
||||||
|
window_size = 512
|
||||||
|
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||||
|
|
||||||
|
# Track speech state
|
||||||
|
in_speech = False
|
||||||
|
speech_start = None
|
||||||
|
speech_end = None
|
||||||
|
silence_count = 0
|
||||||
|
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||||
|
|
||||||
|
# Detect if this window has speech
|
||||||
|
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||||
|
|
||||||
|
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||||
|
if speech_dict:
|
||||||
|
if not in_speech:
|
||||||
|
# Speech started
|
||||||
|
speech_start = i
|
||||||
|
in_speech = True
|
||||||
|
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||||
|
silence_count = 0 # Reset silence counter
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not in_speech:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We're in speech but found silence
|
||||||
|
silence_count += 1
|
||||||
|
if silence_count < min_silence_windows:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Found end of speech segment
|
||||||
|
speech_end = i - (min_silence_windows - 1) * window_size
|
||||||
|
# Debug: print(f"Speech END at sample {speech_end}")
|
||||||
|
|
||||||
|
# Convert sample position to frame index
|
||||||
|
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||||
|
# Account for resampling: we process at 16kHz but frames might be 48kHz
|
||||||
|
resample_ratio = 48000 / 16000 # 3x
|
||||||
|
actual_sample_pos = int(speech_end * resample_ratio)
|
||||||
|
frame_index = actual_sample_pos // samples_per_frame
|
||||||
|
|
||||||
|
# Ensure we don't exceed buffer
|
||||||
|
frame_index = min(frame_index, len(self.frames))
|
||||||
|
return frame_index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding speech segment: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
frames = self.frames[:]
|
frames = self.frames[:]
|
||||||
self.frames = []
|
self.frames = []
|
||||||
if frames:
|
if frames:
|
||||||
|
if len(frames) >= self.min_frames:
|
||||||
await self.emit(frames)
|
await self.emit(frames)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring flush segment with {len(frames)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import (
|
from reflector.processors.types import (
|
||||||
AudioDiarizationInput,
|
AudioDiarizationInput,
|
||||||
|
DiarizationSegment,
|
||||||
TitleSummary,
|
TitleSummary,
|
||||||
Word,
|
Word,
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def assign_speaker(cls, words: list[Word], diarization: list[dict]):
|
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||||
cls._diarization_remove_overlap(diarization)
|
cls._diarization_remove_overlap(diarization)
|
||||||
cls._diarization_remove_segment_without_words(words, diarization)
|
cls._diarization_remove_segment_without_words(words, diarization)
|
||||||
cls._diarization_merge_same_speaker(diarization)
|
cls._diarization_merge_same_speaker(diarization)
|
||||||
@@ -65,7 +66,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _diarization_remove_overlap(diarization: list[dict]):
|
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||||
"""
|
"""
|
||||||
Remove overlap in diarization results
|
Remove overlap in diarization results
|
||||||
|
|
||||||
@@ -92,7 +93,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _diarization_remove_segment_without_words(
|
def _diarization_remove_segment_without_words(
|
||||||
words: list[Word], diarization: list[dict]
|
words: list[Word], diarization: list[DiarizationSegment]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Remove diarization segments without words
|
Remove diarization segments without words
|
||||||
@@ -122,7 +123,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _diarization_merge_same_speaker(diarization: list[dict]):
|
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||||
"""
|
"""
|
||||||
Merge diarization contigous segments with the same speaker
|
Merge diarization contigous segments with the same speaker
|
||||||
|
|
||||||
@@ -140,7 +141,9 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _diarization_assign_speaker(cls, words: list[Word], diarization: list[dict]):
|
def _diarization_assign_speaker(
|
||||||
|
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Assign speaker to words based on diarization
|
Assign speaker to words based on diarization
|
||||||
|
|
||||||
@@ -148,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
word_idx = 0
|
word_idx = 0
|
||||||
last_speaker = None
|
last_speaker = 0
|
||||||
for d in diarization:
|
for d in diarization:
|
||||||
start = d["start"]
|
start = d["start"]
|
||||||
end = d["end"]
|
end = d["end"]
|
||||||
|
|||||||
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
|
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||||
|
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||||
|
"""Local diarization processor using pyannote.audio library"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||||
|
pyannote_auth_token: str | None = None,
|
||||||
|
device: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||||
|
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||||
|
self.model_name, use_auth_token=self.auth_token
|
||||||
|
)
|
||||||
|
self.diarization_pipeline.to(torch.device(self.device))
|
||||||
|
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||||
|
|
||||||
|
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||||
|
try:
|
||||||
|
# Load audio file (audio_url is assumed to be a local file path)
|
||||||
|
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||||
|
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||||
|
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||||
|
self.logger.info("Running speaker diarization")
|
||||||
|
diarization = self.diarization_pipeline(audio_input)
|
||||||
|
|
||||||
|
# Convert pyannote diarization output to our format
|
||||||
|
segments = []
|
||||||
|
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||||
|
speaker_id = 0
|
||||||
|
if speaker.startswith("SPEAKER_"):
|
||||||
|
try:
|
||||||
|
speaker_id = int(speaker.split("_")[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback to hash-based ID if parsing fails
|
||||||
|
speaker_id = hash(speaker) % 1000
|
||||||
|
|
||||||
|
segments.append(
|
||||||
|
{
|
||||||
|
"start": round(segment.start, 3),
|
||||||
|
"end": round(segment.end, 3),
|
||||||
|
"speaker": speaker_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||||
|
return segments
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Diarization failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||||
@@ -3,11 +3,24 @@ from time import monotonic_ns
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
from av.audio.resampler import AudioResampler
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import AudioFile
|
from reflector.processors.types import AudioFile
|
||||||
|
|
||||||
|
|
||||||
|
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||||
|
frame_copy = frame.from_ndarray(
|
||||||
|
frame.to_ndarray(),
|
||||||
|
format=frame.format.name,
|
||||||
|
layout=frame.layout.name,
|
||||||
|
)
|
||||||
|
frame_copy.sample_rate = frame.sample_rate
|
||||||
|
frame_copy.pts = frame.pts
|
||||||
|
frame_copy.time_base = frame.time_base
|
||||||
|
return frame_copy
|
||||||
|
|
||||||
|
|
||||||
class AudioMergeProcessor(Processor):
|
class AudioMergeProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Merge audio frame into a single file
|
Merge audio frame into a single file
|
||||||
@@ -16,37 +29,92 @@ class AudioMergeProcessor(Processor):
|
|||||||
INPUT_TYPE = list[av.AudioFrame]
|
INPUT_TYPE = list[av.AudioFrame]
|
||||||
OUTPUT_TYPE = AudioFile
|
OUTPUT_TYPE = AudioFile
|
||||||
|
|
||||||
|
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.downsample_to_16k_mono = downsample_to_16k_mono
|
||||||
|
|
||||||
async def _push(self, data: list[av.AudioFrame]):
|
async def _push(self, data: list[av.AudioFrame]):
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get audio information from first frame
|
# get audio information from first frame
|
||||||
frame = data[0]
|
frame = data[0]
|
||||||
channels = len(frame.layout.channels)
|
original_channels = len(frame.layout.channels)
|
||||||
sample_rate = frame.sample_rate
|
original_sample_rate = frame.sample_rate
|
||||||
sample_width = frame.format.bytes
|
original_sample_width = frame.format.bytes
|
||||||
|
|
||||||
|
# determine if we need processing
|
||||||
|
needs_processing = self.downsample_to_16k_mono and (
|
||||||
|
original_sample_rate != 16000 or original_channels != 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# determine output parameters
|
||||||
|
if self.downsample_to_16k_mono:
|
||||||
|
output_sample_rate = 16000
|
||||||
|
output_channels = 1
|
||||||
|
output_sample_width = 2 # 16-bit = 2 bytes
|
||||||
|
else:
|
||||||
|
output_sample_rate = original_sample_rate
|
||||||
|
output_channels = original_channels
|
||||||
|
output_sample_width = original_sample_width
|
||||||
|
|
||||||
# create audio file
|
# create audio file
|
||||||
uu = uuid4().hex
|
uu = uuid4().hex
|
||||||
fd = io.BytesIO()
|
fd = io.BytesIO()
|
||||||
|
|
||||||
|
if needs_processing:
|
||||||
|
# Process with PyAV resampler
|
||||||
out_container = av.open(fd, "w", format="wav")
|
out_container = av.open(fd, "w", format="wav")
|
||||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
|
||||||
|
out_stream.layout = "mono"
|
||||||
|
|
||||||
|
# Create resampler if needed
|
||||||
|
resampler = None
|
||||||
|
if original_sample_rate != 16000 or original_channels != 1:
|
||||||
|
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
||||||
|
|
||||||
for frame in data:
|
for frame in data:
|
||||||
|
if resampler:
|
||||||
|
# Resample and convert to mono
|
||||||
|
# XXX for an unknown reason, if we don't use a copy of the frame, we get
|
||||||
|
# Invalid Argumment from resample. Debugging indicate that when a previous processor
|
||||||
|
# already used the frame (like AudioFileWriter), it make it invalid argument here.
|
||||||
|
resampled_frames = resampler.resample(copy_frame(frame))
|
||||||
|
for resampled_frame in resampled_frames:
|
||||||
|
for packet in out_stream.encode(resampled_frame):
|
||||||
|
out_container.mux(packet)
|
||||||
|
else:
|
||||||
|
# Direct encoding without resampling
|
||||||
for packet in out_stream.encode(frame):
|
for packet in out_stream.encode(frame):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
|
|
||||||
|
# Flush the encoder
|
||||||
for packet in out_stream.encode(None):
|
for packet in out_stream.encode(None):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
out_container.close()
|
out_container.close()
|
||||||
|
else:
|
||||||
|
# Use PyAV for original frames (no processing needed)
|
||||||
|
out_container = av.open(fd, "w", format="wav")
|
||||||
|
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||||
|
out_stream.layout = "mono" if output_channels == 1 else frame.layout
|
||||||
|
|
||||||
|
for frame in data:
|
||||||
|
for packet in out_stream.encode(frame):
|
||||||
|
out_container.mux(packet)
|
||||||
|
|
||||||
|
for packet in out_stream.encode(None):
|
||||||
|
out_container.mux(packet)
|
||||||
|
out_container.close()
|
||||||
|
|
||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
|
|
||||||
# emit audio file
|
# emit audio file
|
||||||
audiofile = AudioFile(
|
audiofile = AudioFile(
|
||||||
name=f"{monotonic_ns()}-{uu}.wav",
|
name=f"{monotonic_ns()}-{uu}.wav",
|
||||||
fd=fd,
|
fd=fd,
|
||||||
sample_rate=sample_rate,
|
sample_rate=output_sample_rate,
|
||||||
channels=channels,
|
channels=output_channels,
|
||||||
sample_width=sample_width,
|
sample_width=output_sample_width,
|
||||||
timestamp=data[0].pts * data[0].time_base,
|
timestamp=data[0].pts * data[0].time_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ API will be a POST request to TRANSCRIPT_URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
@@ -21,7 +24,9 @@ from reflector.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
def __init__(
|
||||||
|
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not settings.TRANSCRIPT_URL:
|
if not settings.TRANSCRIPT_URL:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -30,6 +35,126 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||||
self.modal_api_key = modal_api_key
|
self.modal_api_key = modal_api_key
|
||||||
|
self.max_batch_duration = 10.0
|
||||||
|
self.max_batch_files = 15
|
||||||
|
self.batch_enabled = batch_enabled
|
||||||
|
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_duration(cls, audio_file: AudioFile) -> float:
|
||||||
|
"""Calculate audio duration in seconds from AudioFile metadata"""
|
||||||
|
# Duration = total_samples / sample_rate
|
||||||
|
# We need to estimate total samples from the file data
|
||||||
|
import wave
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to read as WAV file to get duration
|
||||||
|
audio_file.fd.seek(0)
|
||||||
|
with wave.open(audio_file.fd, "rb") as wav_file:
|
||||||
|
frames = wav_file.getnframes()
|
||||||
|
sample_rate = wav_file.getframerate()
|
||||||
|
duration = frames / sample_rate
|
||||||
|
return duration
|
||||||
|
except Exception:
|
||||||
|
# Fallback: estimate from file size and audio parameters
|
||||||
|
audio_file.fd.seek(0, 2) # Seek to end
|
||||||
|
file_size = audio_file.fd.tell()
|
||||||
|
audio_file.fd.seek(0) # Reset to beginning
|
||||||
|
|
||||||
|
# Estimate: file_size / (sample_rate * channels * sample_width)
|
||||||
|
bytes_per_second = (
|
||||||
|
audio_file.sample_rate
|
||||||
|
* audio_file.channels
|
||||||
|
* (audio_file.sample_width // 8)
|
||||||
|
)
|
||||||
|
estimated_duration = (
|
||||||
|
file_size / bytes_per_second if bytes_per_second > 0 else 0
|
||||||
|
)
|
||||||
|
return max(0, estimated_duration)
|
||||||
|
|
||||||
|
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
|
||||||
|
"""Group audio files into batches with maximum 30s total duration"""
|
||||||
|
batches = []
|
||||||
|
current_batch = []
|
||||||
|
current_duration = 0.0
|
||||||
|
|
||||||
|
for audio_file in audio_files:
|
||||||
|
duration = self._calculate_duration(audio_file)
|
||||||
|
|
||||||
|
# If adding this file exceeds max duration, start a new batch
|
||||||
|
if current_duration + duration > self.max_batch_duration and current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = [audio_file]
|
||||||
|
current_duration = duration
|
||||||
|
else:
|
||||||
|
current_batch.append(audio_file)
|
||||||
|
current_duration += duration
|
||||||
|
|
||||||
|
# Add the last batch if not empty
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
|
||||||
|
"""Transcribe a batch of audio files using the parakeet backend"""
|
||||||
|
if not audio_files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
|
||||||
|
|
||||||
|
# Prepare form data for batch request
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field("language", self.get_pref("audio:source_language", "en"))
|
||||||
|
data.add_field("batch", "true")
|
||||||
|
|
||||||
|
for i, audio_file in enumerate(audio_files):
|
||||||
|
audio_file.fd.seek(0)
|
||||||
|
data.add_field(
|
||||||
|
"files",
|
||||||
|
audio_file.fd,
|
||||||
|
filename=f"{audio_file.name}",
|
||||||
|
content_type="audio/wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make batch request
|
||||||
|
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.transcript_url}/audio/transcriptions",
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise Exception(
|
||||||
|
f"Batch transcription failed: {response.status} {error_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
|
||||||
|
# Process batch results
|
||||||
|
transcripts = []
|
||||||
|
results = result.get("results", [])
|
||||||
|
|
||||||
|
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
|
||||||
|
transcript = Transcript(
|
||||||
|
words=[
|
||||||
|
Word(
|
||||||
|
text=word_info["word"],
|
||||||
|
start=word_info["start"],
|
||||||
|
end=word_info["end"],
|
||||||
|
)
|
||||||
|
for word_info in file_result.get("words", [])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transcript.add_offset(audio_file.timestamp)
|
||||||
|
transcripts.append(transcript)
|
||||||
|
|
||||||
|
return transcripts
|
||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
async with AsyncOpenAI(
|
async with AsyncOpenAI(
|
||||||
@@ -62,5 +187,96 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
|
async def transcript_multiple(
|
||||||
|
self, audio_files: List[AudioFile]
|
||||||
|
) -> List[Transcript]:
|
||||||
|
"""Transcribe multiple audio files using batching"""
|
||||||
|
if len(audio_files) == 1:
|
||||||
|
# Single file, use existing method
|
||||||
|
return [await self._transcript(audio_files[0])]
|
||||||
|
|
||||||
|
# Create batches with max 30s duration each
|
||||||
|
batches = self._create_batches(audio_files)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Processing {len(audio_files)} files in {len(batches)} batches"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process all batches concurrently
|
||||||
|
all_transcripts = []
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_transcripts = await self._transcript_batch(batch)
|
||||||
|
all_transcripts.extend(batch_transcripts)
|
||||||
|
|
||||||
|
return all_transcripts
|
||||||
|
|
||||||
|
async def _push(self, data: AudioFile):
|
||||||
|
"""Override _push to support batching"""
|
||||||
|
if not self.batch_enabled:
|
||||||
|
# Use parent implementation for single file processing
|
||||||
|
return await super()._push(data)
|
||||||
|
|
||||||
|
# Add file to pending batch
|
||||||
|
self.pending_files.append(data)
|
||||||
|
self.logger.debug(
|
||||||
|
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate total duration of pending files
|
||||||
|
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
|
||||||
|
|
||||||
|
# Process batch if it reaches max duration or has multiple files ready for optimization
|
||||||
|
should_process_batch = (
|
||||||
|
total_duration >= self.max_batch_duration
|
||||||
|
or len(self.pending_files) >= self.max_batch_files
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_process_batch:
|
||||||
|
await self._process_pending_batch()
|
||||||
|
|
||||||
|
async def _process_pending_batch(self):
|
||||||
|
"""Process all pending files as batches"""
|
||||||
|
if not self.pending_files:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create batches respecting duration limit
|
||||||
|
batches = self._create_batches(self.pending_files)
|
||||||
|
|
||||||
|
# Process each batch
|
||||||
|
for batch in batches:
|
||||||
|
self.m_transcript_call.inc()
|
||||||
|
try:
|
||||||
|
with self.m_transcript.time():
|
||||||
|
# Use batch transcription
|
||||||
|
transcripts = await self._transcript_batch(batch)
|
||||||
|
|
||||||
|
self.m_transcript_success.inc()
|
||||||
|
|
||||||
|
# Emit each transcript
|
||||||
|
for transcript in transcripts:
|
||||||
|
if transcript:
|
||||||
|
await self.emit(transcript)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
self.m_transcript_failure.inc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Release audio files
|
||||||
|
for audio_file in batch:
|
||||||
|
audio_file.release()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clear pending files
|
||||||
|
self.pending_files.clear()
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
"""Process any remaining files when flushing"""
|
||||||
|
await self._process_pending_batch()
|
||||||
|
await super()._flush()
|
||||||
|
|
||||||
|
|
||||||
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
||||||
|
|||||||
@@ -241,14 +241,15 @@ class ThreadedProcessor(Processor):
|
|||||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue(maxsize=50)
|
||||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
self.task: asyncio.Task | None = None
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: "Pipeline"):
|
def set_pipeline(self, pipeline: "Pipeline"):
|
||||||
super().set_pipeline(pipeline)
|
super().set_pipeline(pipeline)
|
||||||
self.processor.set_pipeline(pipeline)
|
self.processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
async def loop(self):
|
async def loop(self):
|
||||||
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await self.queue.get()
|
data = await self.queue.get()
|
||||||
self.m_processor_queue.set(self.queue.qsize())
|
self.m_processor_queue.set(self.queue.qsize())
|
||||||
@@ -266,8 +267,19 @@ class ThreadedProcessor(Processor):
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||||
|
|
||||||
|
async def _ensure_task(self):
|
||||||
|
if self.task is None:
|
||||||
|
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||||
|
|
||||||
|
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||||
|
# to be running without having a chance to work on the task here.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def _push(self, data):
|
async def _push(self, data):
|
||||||
|
await self._ensure_task()
|
||||||
await self.queue.put(data)
|
await self.queue.put(data)
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
|
|||||||
33
server/reflector/processors/file_diarization.py
Normal file
33
server/reflector/processors/file_diarization.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import DiarizationSegment
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationInput(BaseModel):
|
||||||
|
"""Input for file diarization containing audio URL"""
|
||||||
|
|
||||||
|
audio_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationOutput(BaseModel):
|
||||||
|
"""Output for file diarization containing speaker segments"""
|
||||||
|
|
||||||
|
diarization: list[DiarizationSegment]
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Diarize complete audio files from URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = FileDiarizationInput
|
||||||
|
OUTPUT_TYPE = FileDiarizationOutput
|
||||||
|
|
||||||
|
async def _push(self, data: FileDiarizationInput):
|
||||||
|
result = await self._diarize(data)
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
|
||||||
|
async def _diarize(self, data: FileDiarizationInput):
|
||||||
|
raise NotImplementedError
|
||||||
33
server/reflector/processors/file_diarization_auto.py
Normal file
33
server/reflector/processors/file_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.DIARIZATION_BACKEND
|
||||||
|
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.file_diarization_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "DIARIZATION_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
57
server/reflector/processors/file_diarization_modal.py
Normal file
57
server/reflector/processors/file_diarization_modal.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
File diarization implementation using the GPU service from modal.com
|
||||||
|
|
||||||
|
API will be a POST request to DIARIZATION_URL:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /diarize?audio_file_url=...×tamp=0
|
||||||
|
Authorization: Bearer <modal_api_key>
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import (
|
||||||
|
FileDiarizationInput,
|
||||||
|
FileDiarizationOutput,
|
||||||
|
FileDiarizationProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||||
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not settings.DIARIZATION_URL:
|
||||||
|
raise Exception(
|
||||||
|
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||||
|
)
|
||||||
|
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||||
|
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key
|
||||||
|
|
||||||
|
async def _diarize(self, data: FileDiarizationInput):
|
||||||
|
"""Get speaker diarization for file"""
|
||||||
|
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.modal_api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.diarization_url,
|
||||||
|
headers=headers,
|
||||||
|
params={
|
||||||
|
"audio_file_url": data.audio_url,
|
||||||
|
"timestamp": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
diarization_data = response.json()["diarization"]
|
||||||
|
|
||||||
|
return FileDiarizationOutput(diarization=diarization_data)
|
||||||
|
|
||||||
|
|
||||||
|
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||||
65
server/reflector/processors/file_transcript.py
Normal file
65
server/reflector/processors/file_transcript.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import Transcript
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptInput:
|
||||||
|
"""Input for file transcription containing audio URL and language settings"""
|
||||||
|
|
||||||
|
def __init__(self, audio_url: str, language: str = "en"):
|
||||||
|
self.audio_url = audio_url
|
||||||
|
self.language = language
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Transcript complete audio files from URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = FileTranscriptInput
|
||||||
|
OUTPUT_TYPE = Transcript
|
||||||
|
|
||||||
|
m_transcript = Histogram(
|
||||||
|
"file_transcript",
|
||||||
|
"Time spent in FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_call = Counter(
|
||||||
|
"file_transcript_call",
|
||||||
|
"Number of calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_success = Counter(
|
||||||
|
"file_transcript_success",
|
||||||
|
"Number of successful calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_failure = Counter(
|
||||||
|
"file_transcript_failure",
|
||||||
|
"Number of failed calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
name = self.__class__.__name__
|
||||||
|
self.m_transcript = self.m_transcript.labels(name)
|
||||||
|
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||||
|
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||||
|
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _push(self, data: FileTranscriptInput):
|
||||||
|
try:
|
||||||
|
self.m_transcript_call.inc()
|
||||||
|
with self.m_transcript.time():
|
||||||
|
result = await self._transcript(data)
|
||||||
|
self.m_transcript_success.inc()
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
except Exception:
|
||||||
|
self.m_transcript_failure.inc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _transcript(self, data: FileTranscriptInput):
|
||||||
|
raise NotImplementedError
|
||||||
32
server/reflector/processors/file_transcript_auto.py
Normal file
32
server/reflector/processors/file_transcript_auto.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.TRANSCRIPT_BACKEND
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.file_transcript_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "TRANSCRIPT_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
74
server/reflector/processors/file_transcript_modal.py
Normal file
74
server/reflector/processors/file_transcript_modal.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
File transcription implementation using the GPU service from modal.com
|
||||||
|
|
||||||
|
API will be a POST request to TRANSCRIPT_URL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_file_url": "https://...",
|
||||||
|
"language": "en",
|
||||||
|
"model": "parakeet-tdt-0.6b-v2",
|
||||||
|
"batch": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.processors.file_transcript import (
|
||||||
|
FileTranscriptInput,
|
||||||
|
FileTranscriptProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||||
|
from reflector.processors.types import Transcript, Word
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||||
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not settings.TRANSCRIPT_URL:
|
||||||
|
raise Exception(
|
||||||
|
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||||
|
)
|
||||||
|
self.transcript_url = settings.TRANSCRIPT_URL
|
||||||
|
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key
|
||||||
|
|
||||||
|
async def _transcript(self, data: FileTranscriptInput):
|
||||||
|
"""Send full file to Modal for transcription"""
|
||||||
|
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||||
|
|
||||||
|
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.modal_api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"audio_file_url": data.audio_url,
|
||||||
|
"language": data.language,
|
||||||
|
"batch": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
words = [
|
||||||
|
Word(
|
||||||
|
text=word_info["word"],
|
||||||
|
start=word_info["start"],
|
||||||
|
end=word_info["end"],
|
||||||
|
)
|
||||||
|
for word_info in result.get("words", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return Transcript(words=words)
|
||||||
|
|
||||||
|
|
||||||
|
# Register with the auto processor
|
||||||
|
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Processor to assemble transcript with diarization results
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import DiarizationSegment, Transcript
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptDiarizationAssemblerInput:
|
||||||
|
"""Input containing transcript and diarization data"""
|
||||||
|
|
||||||
|
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||||
|
self.transcript = transcript
|
||||||
|
self.diarization = diarization
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Assemble transcript with diarization results by applying speaker assignments
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||||
|
OUTPUT_TYPE = Transcript
|
||||||
|
|
||||||
|
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||||
|
result = await self._assemble(data)
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
|
||||||
|
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||||
|
"""Apply diarization to transcript words"""
|
||||||
|
if not data.diarization:
|
||||||
|
self.logger.info(
|
||||||
|
"No diarization data provided, returning original transcript"
|
||||||
|
)
|
||||||
|
return data.transcript
|
||||||
|
|
||||||
|
# Reuse logic from AudioDiarizationProcessor
|
||||||
|
processor = AudioDiarizationProcessor()
|
||||||
|
words = data.transcript.words
|
||||||
|
processor.assign_speaker(words, data.diarization)
|
||||||
|
|
||||||
|
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||||
|
return data.transcript
|
||||||
@@ -2,13 +2,22 @@ import io
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated
|
from typing import Annotated, TypedDict
|
||||||
|
|
||||||
from profanityfilter import ProfanityFilter
|
from profanityfilter import ProfanityFilter
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from reflector.redis_cache import redis_cache
|
from reflector.redis_cache import redis_cache
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationSegment(TypedDict):
|
||||||
|
"""Type definition for diarization segment containing speaker information"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
speaker: int
|
||||||
|
|
||||||
|
|
||||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||||
|
|
||||||
profanity_filter = ProfanityFilter()
|
profanity_filter = ProfanityFilter()
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_BACKEND: str = "whisper"
|
TRANSCRIPT_BACKEND: str = "whisper"
|
||||||
TRANSCRIPT_URL: str | None = None
|
TRANSCRIPT_URL: str | None = None
|
||||||
TRANSCRIPT_TIMEOUT: int = 90
|
TRANSCRIPT_TIMEOUT: int = 90
|
||||||
|
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||||
|
|
||||||
# Audio Transcription: modal backend
|
# Audio Transcription: modal backend
|
||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
@@ -66,10 +67,14 @@ class Settings(BaseSettings):
|
|||||||
DIARIZATION_ENABLED: bool = True
|
DIARIZATION_ENABLED: bool = True
|
||||||
DIARIZATION_BACKEND: str = "modal"
|
DIARIZATION_BACKEND: str = "modal"
|
||||||
DIARIZATION_URL: str | None = None
|
DIARIZATION_URL: str | None = None
|
||||||
|
DIARIZATION_FILE_TIMEOUT: int = 600
|
||||||
|
|
||||||
# Diarization: modal backend
|
# Diarization: modal backend
|
||||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Diarization: local pyannote.audio
|
||||||
|
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
SENTRY_DSN: str | None = None
|
SENTRY_DSN: str | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,23 @@
|
|||||||
|
"""
|
||||||
|
Process audio file with diarization support
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Extended version of process.py that includes speaker diarization.
|
||||||
|
This tool processes audio files locally without requiring the full server infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerProcessor,
|
||||||
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
@@ -15,7 +28,43 @@ from reflector.processors import (
|
|||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorAutoProcessor,
|
TranscriptTranslatorAutoProcessor,
|
||||||
)
|
)
|
||||||
from reflector.processors.base import BroadcastProcessor
|
from reflector.processors.base import BroadcastProcessor, Processor
|
||||||
|
from reflector.processors.types import (
|
||||||
|
AudioDiarizationInput,
|
||||||
|
TitleSummary,
|
||||||
|
TitleSummaryWithId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopicCollectorProcessor(Processor):
|
||||||
|
"""Collect topics for diarization"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TitleSummary
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.topics: List[TitleSummaryWithId] = []
|
||||||
|
self._topic_id = 0
|
||||||
|
|
||||||
|
async def _push(self, data: TitleSummary):
|
||||||
|
# Convert to TitleSummaryWithId and collect
|
||||||
|
self._topic_id += 1
|
||||||
|
topic_with_id = TitleSummaryWithId(
|
||||||
|
id=str(self._topic_id),
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
duration=data.duration,
|
||||||
|
transcript=data.transcript,
|
||||||
|
)
|
||||||
|
self.topics.append(topic_with_id)
|
||||||
|
|
||||||
|
# Pass through the original topic
|
||||||
|
await self.emit(data)
|
||||||
|
|
||||||
|
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||||
|
return self.topics
|
||||||
|
|
||||||
|
|
||||||
async def process_audio_file(
|
async def process_audio_file(
|
||||||
@@ -24,18 +73,40 @@ async def process_audio_file(
|
|||||||
only_transcript=False,
|
only_transcript=False,
|
||||||
source_language="en",
|
source_language="en",
|
||||||
target_language="en",
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="pyannote",
|
||||||
):
|
):
|
||||||
# build pipeline for audio processing
|
# Create temp file for audio if diarization is enabled
|
||||||
processors = [
|
audio_temp_path = None
|
||||||
|
if enable_diarization:
|
||||||
|
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
audio_temp_path = audio_temp_file.name
|
||||||
|
audio_temp_file.close()
|
||||||
|
|
||||||
|
# Create processor for collecting topics
|
||||||
|
topic_collector = TopicCollectorProcessor()
|
||||||
|
|
||||||
|
# Build pipeline for audio processing
|
||||||
|
processors = []
|
||||||
|
|
||||||
|
# Add audio file writer at the beginning if diarization is enabled
|
||||||
|
if enable_diarization:
|
||||||
|
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||||
|
|
||||||
|
# Add the rest of the processors
|
||||||
|
processors += [
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not only_transcript:
|
if not only_transcript:
|
||||||
processors += [
|
processors += [
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||||
|
# Collect topics for diarization
|
||||||
|
topic_collector,
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalTitleProcessor.as_threaded(),
|
TranscriptFinalTitleProcessor.as_threaded(),
|
||||||
@@ -44,14 +115,14 @@ async def process_audio_file(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# transcription output
|
# Create main pipeline
|
||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
pipeline.set_pref("audio:source_language", source_language)
|
pipeline.set_pref("audio:source_language", source_language)
|
||||||
pipeline.set_pref("audio:target_language", target_language)
|
pipeline.set_pref("audio:target_language", target_language)
|
||||||
pipeline.describe()
|
pipeline.describe()
|
||||||
pipeline.on(event_callback)
|
pipeline.on(event_callback)
|
||||||
|
|
||||||
# start processing audio
|
# Start processing audio
|
||||||
logger.info(f"Opening {filename}")
|
logger.info(f"Opening {filename}")
|
||||||
container = av.open(filename)
|
container = av.open(filename)
|
||||||
try:
|
try:
|
||||||
@@ -62,34 +133,219 @@ async def process_audio_file(
|
|||||||
logger.info("Flushing the pipeline")
|
logger.info("Flushing the pipeline")
|
||||||
await pipeline.flush()
|
await pipeline.flush()
|
||||||
|
|
||||||
logger.info("All done !")
|
# Run diarization if enabled and we have topics
|
||||||
|
if enable_diarization and not only_transcript and audio_temp_path:
|
||||||
|
topics = topic_collector.get_topics()
|
||||||
|
|
||||||
|
if topics:
|
||||||
|
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reflector.processors import AudioDiarizationAutoProcessor
|
||||||
|
|
||||||
|
diarization_processor = AudioDiarizationAutoProcessor(
|
||||||
|
name=diarization_backend
|
||||||
|
)
|
||||||
|
|
||||||
|
diarization_processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
|
# For Modal backend, we need to upload the file to S3 first
|
||||||
|
if diarization_backend == "modal":
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
# Generate a unique filename in evaluation folder
|
||||||
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
|
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||||
|
|
||||||
|
# Use context manager for automatic cleanup
|
||||||
|
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||||
|
# Read and upload the audio file
|
||||||
|
with open(audio_temp_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
audio_url = await s3_file.upload(audio_data)
|
||||||
|
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||||
|
|
||||||
|
# Create diarization input with S3 URL
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
# File will be automatically cleaned up when exiting the context
|
||||||
|
else:
|
||||||
|
# For local backend, use local file path
|
||||||
|
audio_url = audio_temp_path
|
||||||
|
|
||||||
|
# Create diarization input
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||||
|
logger.error(
|
||||||
|
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"And set HF_TOKEN environment variable for pyannote models"
|
||||||
|
)
|
||||||
|
raise SystemExit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Diarization failed: {e}")
|
||||||
|
raise SystemExit(1)
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping diarization: no topics available")
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if audio_temp_path:
|
||||||
|
try:
|
||||||
|
Path(audio_temp_path).unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||||
|
|
||||||
|
logger.info("All done!")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_file_pipeline(
|
||||||
|
filename: str,
|
||||||
|
event_callback,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="modal",
|
||||||
|
):
|
||||||
|
"""Process audio/video file using the optimized file pipeline"""
|
||||||
|
try:
|
||||||
|
from reflector.db import database
|
||||||
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
# Create a temporary transcript for processing
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
"",
|
||||||
|
source_kind=SourceKind.FILE,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the file
|
||||||
|
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||||
|
await pipeline.process(Path(filename))
|
||||||
|
|
||||||
|
logger.info("File pipeline processing complete")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"File pipeline not available: {e}")
|
||||||
|
logger.info("Falling back to stream pipeline")
|
||||||
|
# Fall back to stream pipeline
|
||||||
|
await process_audio_file(
|
||||||
|
filename,
|
||||||
|
event_callback,
|
||||||
|
only_transcript=False,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
enable_diarization=enable_diarization,
|
||||||
|
diarization_backend=diarization_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Process audio files with optional speaker diarization"
|
||||||
|
)
|
||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
parser.add_argument(
|
||||||
parser.add_argument("--source-language", default="en")
|
"--stream",
|
||||||
parser.add_argument("--target-language", default="en")
|
action="store_true",
|
||||||
|
help="Use streaming pipeline (original frame-based processing)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only-transcript",
|
||||||
|
"-t",
|
||||||
|
action="store_true",
|
||||||
|
help="Only generate transcript without topics/summaries",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-language", default="en", help="Source language code (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-language", default="en", help="Target language code (default: en)"
|
||||||
|
)
|
||||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-diarization",
|
||||||
|
"-d",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable speaker diarization",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
default="pyannote",
|
||||||
|
choices=["pyannote", "modal"],
|
||||||
|
help="Diarization backend to use (default: pyannote)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if "REDIS_HOST" not in os.environ:
|
||||||
|
os.environ["REDIS_HOST"] = "localhost"
|
||||||
|
|
||||||
output_fd = None
|
output_fd = None
|
||||||
if args.output:
|
if args.output:
|
||||||
output_fd = open(args.output, "w")
|
output_fd = open(args.output, "w")
|
||||||
|
|
||||||
async def event_callback(event: PipelineEvent):
|
async def event_callback(event: PipelineEvent):
|
||||||
processor = event.processor
|
processor = event.processor
|
||||||
# ignore some processor
|
data = event.data
|
||||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
|
||||||
|
# Ignore internal processors
|
||||||
|
if processor in (
|
||||||
|
"AudioChunkerProcessor",
|
||||||
|
"AudioMergeProcessor",
|
||||||
|
"AudioFileWriterProcessor",
|
||||||
|
"TopicCollectorProcessor",
|
||||||
|
"BroadcastProcessor",
|
||||||
|
):
|
||||||
return
|
return
|
||||||
logger.info(f"Event: {event}")
|
|
||||||
|
# If diarization is enabled, skip the original topic events from the pipeline
|
||||||
|
# The diarization processor will emit the same topics but with speaker info
|
||||||
|
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log all events
|
||||||
|
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||||
|
|
||||||
|
# Write to output
|
||||||
if output_fd:
|
if output_fd:
|
||||||
output_fd.write(event.model_dump_json())
|
output_fd.write(event.model_dump_json())
|
||||||
output_fd.write("\n")
|
output_fd.write("\n")
|
||||||
|
output_fd.flush()
|
||||||
|
|
||||||
|
if args.stream:
|
||||||
|
# Use original streaming pipeline
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
process_audio_file(
|
process_audio_file(
|
||||||
args.source,
|
args.source,
|
||||||
@@ -97,6 +353,20 @@ if __name__ == "__main__":
|
|||||||
only_transcript=args.only_transcript,
|
only_transcript=args.only_transcript,
|
||||||
source_language=args.source_language,
|
source_language=args.source_language,
|
||||||
target_language=args.target_language,
|
target_language=args.target_language,
|
||||||
|
enable_diarization=args.enable_diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use optimized file pipeline (default)
|
||||||
|
asyncio.run(
|
||||||
|
process_file_pipeline(
|
||||||
|
args.source,
|
||||||
|
event_callback,
|
||||||
|
source_language=args.source_language,
|
||||||
|
target_language=args.target_language,
|
||||||
|
enable_diarization=args.enable_diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from reflector.db.meetings import meetings_controller
|
|||||||
from reflector.db.recordings import Recording, recordings_controller
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.whereby import get_room_sessions
|
from reflector.whereby import get_room_sessions
|
||||||
|
|
||||||
@@ -140,7 +141,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
task_pipeline_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: ''
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '0'
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-diarizer-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '220'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:25:34 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAVNEP6N7Y1Y7W3T98BCXK
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||||
|
"language": "en", "batch": true}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '136'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||||
|
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||||
|
With reflector, I can easily convert any audio into accurate transcription.
|
||||||
|
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '1933'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:26:59 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAWC7GAMKX4DSJ21WV31NG
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||||
|
"language": "en", "batch": true}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '136'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||||
|
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||||
|
With reflector, I can easily convert any audio into accurate transcription.
|
||||||
|
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '1933'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:27:02 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAYZ1AR2HE422VJVKBWX9Z
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
- request:
|
||||||
|
body: ''
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '0'
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-diarizer-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '220'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:27:18 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAZ1M34NQRJK03CCFK95D6
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -5,7 +5,29 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# Pytest-docker configuration
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def settings_configuration():
|
||||||
|
# theses settings are linked to monadical for pytest-recording
|
||||||
|
# if a fork is done, they have to provide their own url when cassettes needs to be updated
|
||||||
|
# modal api keys has to be defined by the user
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
settings.TRANSCRIPT_BACKEND = "modal"
|
||||||
|
settings.TRANSCRIPT_URL = (
|
||||||
|
"https://monadical-sas--reflector-transcriber-parakeet-web.modal.run"
|
||||||
|
)
|
||||||
|
settings.DIARIZATION_BACKEND = "modal"
|
||||||
|
settings.DIARIZATION_URL = "https://monadical-sas--reflector-diarizer-web.modal.run"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def vcr_config():
|
||||||
|
"""VCR configuration to filter sensitive headers"""
|
||||||
|
return {
|
||||||
|
"filter_headers": [("authorization", "DUMMY_API_KEY")],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def docker_compose_file(pytestconfig):
|
def docker_compose_file(pytestconfig):
|
||||||
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
version: '3.8'
|
version: "3.8"
|
||||||
services:
|
services:
|
||||||
postgres_test:
|
postgres_test:
|
||||||
image: postgres:15
|
image: postgres:17
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_DB: reflector_test
|
POSTGRES_DB: reflector_test
|
||||||
POSTGRES_USER: test_user
|
POSTGRES_USER: test_user
|
||||||
|
|||||||
330
server/tests/test_gpu_modal_transcript.py
Normal file
330
server/tests/test_gpu_modal_transcript.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
Tests for GPU Modal transcription endpoints.
|
||||||
|
|
||||||
|
These tests are marked with the "gpu-modal" group and will not run by default.
|
||||||
|
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
||||||
|
|
||||||
|
Required environment variables:
|
||||||
|
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
||||||
|
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
||||||
|
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||||
|
|
||||||
|
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
||||||
|
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||||
|
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||||
|
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
||||||
|
|
||||||
|
# Or with completely clean options:
|
||||||
|
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
||||||
|
|
||||||
|
Running Modal locally for testing:
|
||||||
|
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||||
|
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Test audio file URL for testing
|
||||||
|
TEST_AUDIO_URL = (
|
||||||
|
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_modal_transcript_url():
|
||||||
|
"""Get and validate the Modal transcript URL from environment."""
|
||||||
|
url = os.environ.get("TRANSCRIPT_URL")
|
||||||
|
if not url:
|
||||||
|
pytest.skip(
|
||||||
|
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_headers():
|
||||||
|
"""Get authentication headers if API key is available."""
|
||||||
|
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name():
|
||||||
|
"""Get the model name from environment or use default."""
|
||||||
|
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.gpu_modal
|
||||||
|
class TestGPUModalTranscript:
|
||||||
|
"""Test suite for GPU Modal transcription endpoints."""
|
||||||
|
|
||||||
|
def test_transcriptions_from_url(self):
|
||||||
|
"""Test the /v1/audio/transcriptions-from-url endpoint."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 0.0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert isinstance(result["text"], str)
|
||||||
|
assert isinstance(result["words"], list)
|
||||||
|
|
||||||
|
# Verify content is meaningful
|
||||||
|
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||||
|
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||||
|
|
||||||
|
# Verify word structure
|
||||||
|
for word in result["words"]:
|
||||||
|
assert "word" in word
|
||||||
|
assert "start" in word
|
||||||
|
assert "end" in word
|
||||||
|
assert isinstance(word["start"], (int, float))
|
||||||
|
assert isinstance(word["end"], (int, float))
|
||||||
|
assert word["start"] <= word["end"]
|
||||||
|
|
||||||
|
def test_transcriptions_single_file(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with a single file."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Download test audio file to upload
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(audio_response.content)
|
||||||
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Upload the file for transcription
|
||||||
|
with open(tmp_file_path, "rb") as f:
|
||||||
|
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for single file
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert "filename" in result
|
||||||
|
assert isinstance(result["text"], str)
|
||||||
|
assert isinstance(result["words"], list)
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
Path(tmp_file_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_multiple_files(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Create multiple test files (we'll use the same audio content for simplicity)
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
audio_content = audio_response.content
|
||||||
|
|
||||||
|
temp_files = []
|
||||||
|
try:
|
||||||
|
# Create 3 temporary files
|
||||||
|
for i in range(3):
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||||
|
tmp_file.write(audio_content)
|
||||||
|
tmp_file.close()
|
||||||
|
temp_files.append(tmp_file.name)
|
||||||
|
|
||||||
|
# Upload multiple files for transcription (non-batch)
|
||||||
|
files = [
|
||||||
|
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||||
|
for i, f in enumerate(temp_files)
|
||||||
|
]
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close file handles
|
||||||
|
for _, file_tuple in files:
|
||||||
|
file_tuple[1].close()
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for multiple files (non-batch)
|
||||||
|
assert "results" in result
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
for idx, file_result in enumerate(result["results"]):
|
||||||
|
assert "text" in file_result
|
||||||
|
assert "words" in file_result
|
||||||
|
assert "filename" in file_result
|
||||||
|
assert isinstance(file_result["text"], str)
|
||||||
|
assert isinstance(file_result["words"], list)
|
||||||
|
assert len(file_result["text"]) > 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for f in temp_files:
|
||||||
|
Path(f).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_multiple_files_batch(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Create multiple test files
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
audio_content = audio_response.content
|
||||||
|
|
||||||
|
temp_files = []
|
||||||
|
try:
|
||||||
|
# Create 3 temporary files
|
||||||
|
for i in range(3):
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||||
|
tmp_file.write(audio_content)
|
||||||
|
tmp_file.close()
|
||||||
|
temp_files.append(tmp_file.name)
|
||||||
|
|
||||||
|
# Upload multiple files for batch transcription
|
||||||
|
files = [
|
||||||
|
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||||
|
for i, f in enumerate(temp_files)
|
||||||
|
]
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close file handles
|
||||||
|
for _, file_tuple in files:
|
||||||
|
file_tuple[1].close()
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for batch mode
|
||||||
|
assert "results" in result
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
for idx, batch_result in enumerate(result["results"]):
|
||||||
|
assert "text" in batch_result
|
||||||
|
assert "words" in batch_result
|
||||||
|
assert "filename" in batch_result
|
||||||
|
assert isinstance(batch_result["text"], str)
|
||||||
|
assert isinstance(batch_result["words"], list)
|
||||||
|
assert len(batch_result["text"]) > 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for f in temp_files:
|
||||||
|
Path(f).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_error_handling(self):
|
||||||
|
"""Test error handling for invalid requests."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# Test with unsupported language
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "fr", # Parakeet only supports English
|
||||||
|
"timestamp_offset": 0.0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "only supports English" in response.text
|
||||||
|
|
||||||
|
def test_transcriptions_with_timestamp_offset(self):
|
||||||
|
"""Test transcription with timestamp offset parameter."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# Test with timestamp offset
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 10.0, # Add 10 second offset
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||||
|
|
||||||
|
# Verify that timestamps have been offset
|
||||||
|
for word in result["words"]:
|
||||||
|
# All timestamps should be >= 10.0 due to offset
|
||||||
|
assert (
|
||||||
|
word["start"] >= 10.0
|
||||||
|
), f"Word start time {word['start']} should be >= 10.0"
|
||||||
|
assert (
|
||||||
|
word["end"] >= 10.0
|
||||||
|
), f"Word end time {word['end']} should be >= 10.0"
|
||||||
633
server/tests/test_pipeline_main_file.py
Normal file
633
server/tests/test_pipeline_main_file.py
Normal file
@@ -0,0 +1,633 @@
|
|||||||
|
"""
|
||||||
|
Tests for PipelineMainFile - file-based processing pipeline
|
||||||
|
|
||||||
|
This test verifies the complete file processing pipeline without mocking much,
|
||||||
|
ensuring all processors are correctly invoked and the happy path works correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationOutput
|
||||||
|
from reflector.processors.types import (
|
||||||
|
DiarizationSegment,
|
||||||
|
TitleSummary,
|
||||||
|
Word,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_file_transcript():
|
||||||
|
"""Mock FileTranscriptAutoProcessor for file processing"""
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||||
|
|
||||||
|
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
||||||
|
async def _transcript(self, data):
|
||||||
|
return TranscriptType(
|
||||||
|
text="Hello world. How are you today?",
|
||||||
|
words=[
|
||||||
|
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
||||||
|
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
||||||
|
Word(start=0.6, end=1.0, text="world", speaker=0),
|
||||||
|
Word(start=1.0, end=1.1, text=".", speaker=0),
|
||||||
|
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
||||||
|
Word(start=1.2, end=1.5, text="How", speaker=0),
|
||||||
|
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
||||||
|
Word(start=1.6, end=1.8, text="are", speaker=0),
|
||||||
|
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
||||||
|
Word(start=1.9, end=2.1, text="you", speaker=0),
|
||||||
|
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
||||||
|
Word(start=2.2, end=2.5, text="today", speaker=0),
|
||||||
|
Word(start=2.5, end=2.6, text="?", speaker=0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
||||||
|
) as mock_auto:
|
||||||
|
mock_auto.return_value = TestFileTranscriptProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_file_diarization():
|
||||||
|
"""Mock FileDiarizationAutoProcessor for file processing"""
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||||
|
|
||||||
|
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
||||||
|
async def _diarize(self, data):
|
||||||
|
return FileDiarizationOutput(
|
||||||
|
diarization=[
|
||||||
|
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
||||||
|
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
||||||
|
) as mock_auto:
|
||||||
|
mock_auto.return_value = TestFileDiarizationProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_transcript_in_db(tmpdir):
|
||||||
|
"""Create a mock transcript in the database"""
|
||||||
|
from reflector.db.transcripts import Transcript
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Set the DATA_DIR to our tmpdir
|
||||||
|
original_data_dir = settings.DATA_DIR
|
||||||
|
settings.DATA_DIR = str(tmpdir)
|
||||||
|
|
||||||
|
transcript_id = str(uuid4())
|
||||||
|
data_path = Path(tmpdir) / transcript_id
|
||||||
|
data_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create mock transcript object
|
||||||
|
transcript = Transcript(
|
||||||
|
id=transcript_id,
|
||||||
|
name="Test Transcript",
|
||||||
|
status="processing",
|
||||||
|
source_kind="file",
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the controller to return our transcript
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = transcript
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get2:
|
||||||
|
mock_get2.return_value = transcript
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
||||||
|
) as mock_update:
|
||||||
|
mock_update.return_value = None
|
||||||
|
yield transcript
|
||||||
|
finally:
|
||||||
|
# Restore original DATA_DIR
|
||||||
|
settings.DATA_DIR = original_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_storage():
|
||||||
|
"""Mock storage for file uploads"""
|
||||||
|
from reflector.storage.base import Storage
|
||||||
|
|
||||||
|
class TestStorage(Storage):
|
||||||
|
async def _put_file(self, path, data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_file_url(self, path):
|
||||||
|
return f"http://test-storage/{path}"
|
||||||
|
|
||||||
|
async def _get_file(self, path):
|
||||||
|
return b"test_audio_data"
|
||||||
|
|
||||||
|
async def _delete_file(self, path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
storage = TestStorage()
|
||||||
|
# Add mock tracking for verification
|
||||||
|
storage._put_file = AsyncMock(side_effect=storage._put_file)
|
||||||
|
storage._get_file_url = AsyncMock(side_effect=storage._get_file_url)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = storage
|
||||||
|
yield storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_audio_file_writer():
|
||||||
|
"""Mock AudioFileWriterProcessor to avoid actual file writing"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.AudioFileWriterProcessor"
|
||||||
|
) as mock_writer_class:
|
||||||
|
mock_writer = AsyncMock()
|
||||||
|
mock_writer.push = AsyncMock()
|
||||||
|
mock_writer.flush = AsyncMock()
|
||||||
|
mock_writer_class.return_value = mock_writer
|
||||||
|
yield mock_writer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_waveform_processor():
|
||||||
|
"""Mock AudioWaveformProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.AudioWaveformProcessor"
|
||||||
|
) as mock_waveform_class:
|
||||||
|
mock_waveform = AsyncMock()
|
||||||
|
mock_waveform.set_pipeline = MagicMock()
|
||||||
|
mock_waveform.flush = AsyncMock()
|
||||||
|
mock_waveform_class.return_value = mock_waveform
|
||||||
|
yield mock_waveform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_topic_detector():
|
||||||
|
"""Mock TranscriptTopicDetectorProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor"
|
||||||
|
) as mock_topic_class:
|
||||||
|
mock_topic = AsyncMock()
|
||||||
|
mock_topic.set_pipeline = MagicMock()
|
||||||
|
mock_topic.push = AsyncMock()
|
||||||
|
mock_topic.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate topic detection by calling the callback
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_topic.flush_called = True
|
||||||
|
if hasattr(mock_topic, "_callback"):
|
||||||
|
# Create a minimal transcript for the TitleSummary
|
||||||
|
test_transcript = TranscriptType(words=[], text="test transcript")
|
||||||
|
await mock_topic._callback(
|
||||||
|
TitleSummary(
|
||||||
|
title="Test Topic",
|
||||||
|
summary="Test topic summary",
|
||||||
|
timestamp=0.0,
|
||||||
|
duration=10.0,
|
||||||
|
transcript=test_transcript,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_topic.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(callback=None):
|
||||||
|
mock_topic._callback = callback
|
||||||
|
return mock_topic
|
||||||
|
|
||||||
|
mock_topic_class.side_effect = init_with_callback
|
||||||
|
yield mock_topic
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_title_processor():
|
||||||
|
"""Mock TranscriptFinalTitleProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor"
|
||||||
|
) as mock_title_class:
|
||||||
|
mock_title = AsyncMock()
|
||||||
|
mock_title.set_pipeline = MagicMock()
|
||||||
|
mock_title.push = AsyncMock()
|
||||||
|
mock_title.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate title generation by calling the callback
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_title.flush_called = True
|
||||||
|
if hasattr(mock_title, "_callback"):
|
||||||
|
from reflector.processors.types import FinalTitle
|
||||||
|
|
||||||
|
await mock_title._callback(FinalTitle(title="Test Title"))
|
||||||
|
|
||||||
|
mock_title.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(callback=None):
|
||||||
|
mock_title._callback = callback
|
||||||
|
return mock_title
|
||||||
|
|
||||||
|
mock_title_class.side_effect = init_with_callback
|
||||||
|
yield mock_title
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_summary_processor():
|
||||||
|
"""Mock TranscriptFinalSummaryProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor"
|
||||||
|
) as mock_summary_class:
|
||||||
|
mock_summary = AsyncMock()
|
||||||
|
mock_summary.set_pipeline = MagicMock()
|
||||||
|
mock_summary.push = AsyncMock()
|
||||||
|
mock_summary.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate summary generation by calling the callbacks
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_summary.flush_called = True
|
||||||
|
from reflector.processors.types import FinalLongSummary, FinalShortSummary
|
||||||
|
|
||||||
|
if hasattr(mock_summary, "_callback"):
|
||||||
|
await mock_summary._callback(
|
||||||
|
FinalLongSummary(long_summary="Test long summary", duration=10.0)
|
||||||
|
)
|
||||||
|
if hasattr(mock_summary, "_on_short_summary"):
|
||||||
|
await mock_summary._on_short_summary(
|
||||||
|
FinalShortSummary(short_summary="Test short summary", duration=10.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_summary.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(transcript=None, callback=None, on_short_summary=None):
|
||||||
|
mock_summary._callback = callback
|
||||||
|
mock_summary._on_short_summary = on_short_summary
|
||||||
|
return mock_summary
|
||||||
|
|
||||||
|
mock_summary_class.side_effect = init_with_callback
|
||||||
|
yield mock_summary
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_process(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the complete PipelineMainFile processing pipeline.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
1. Audio extraction and writing
|
||||||
|
2. Audio upload to storage
|
||||||
|
3. Parallel processing of transcription, diarization, and waveform
|
||||||
|
4. Assembly of transcript with diarization
|
||||||
|
5. Topic detection
|
||||||
|
6. Title and summary generation
|
||||||
|
"""
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path as if it was uploaded
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||||
|
# Since we're mocking AudioFileWriterProcessor, we need to create this manually
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Track callback invocations
|
||||||
|
callback_marks = {
|
||||||
|
"on_status": [],
|
||||||
|
"on_duration": [],
|
||||||
|
"on_waveform": [],
|
||||||
|
"on_topic": [],
|
||||||
|
"on_title": [],
|
||||||
|
"on_long_summary": [],
|
||||||
|
"on_short_summary": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create pipeline with mocked callbacks
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Override callbacks to track invocations
|
||||||
|
async def track_callback(name, data):
|
||||||
|
callback_marks[name].append(data)
|
||||||
|
# Call the original callback
|
||||||
|
original = getattr(PipelineMainFile, name)
|
||||||
|
return await original(pipeline, data)
|
||||||
|
|
||||||
|
for callback_name in callback_marks.keys():
|
||||||
|
setattr(
|
||||||
|
pipeline,
|
||||||
|
callback_name,
|
||||||
|
lambda data, n=callback_name: track_callback(n, data),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams (audio only)
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter(
|
||||||
|
[MagicMock()]
|
||||||
|
) # One mock audio frame
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify audio extraction and writing
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
|
||||||
|
# Verify storage upload
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_storage._get_file_url.called
|
||||||
|
|
||||||
|
# Verify waveform generation
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_waveform_processor.set_pipeline.called
|
||||||
|
|
||||||
|
# Verify topic detection
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_topic_detector.flush_called
|
||||||
|
|
||||||
|
# Verify title generation
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_title_processor.flush_called
|
||||||
|
|
||||||
|
# Verify summary generation
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
assert mock_summary_processor.flush_called
|
||||||
|
|
||||||
|
# Verify callbacks were invoked
|
||||||
|
assert len(callback_marks["on_topic"]) > 0, "Topic callback should be invoked"
|
||||||
|
assert len(callback_marks["on_title"]) > 0, "Title callback should be invoked"
|
||||||
|
assert (
|
||||||
|
len(callback_marks["on_long_summary"]) > 0
|
||||||
|
), "Long summary callback should be invoked"
|
||||||
|
assert (
|
||||||
|
len(callback_marks["on_short_summary"]) > 0
|
||||||
|
), "Short summary callback should be invoked"
|
||||||
|
|
||||||
|
print(f"Callback marks: {callback_marks}")
|
||||||
|
|
||||||
|
# Verify the pipeline completed successfully
|
||||||
|
assert pipeline.logger is not None
|
||||||
|
print("PipelineMainFile test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_with_video(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test PipelineMainFile with video input (verifies audio extraction).
|
||||||
|
"""
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path as if it was a video upload
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.mp4"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Mock av.open for video processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [MagicMock()] # Has video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter(
|
||||||
|
[MagicMock()]
|
||||||
|
) # One mock audio frame
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify audio extraction from video
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
|
||||||
|
# Verify the rest of the pipeline completed
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("PipelineMainFile video test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_no_diarization(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test PipelineMainFile with diarization disabled.
|
||||||
|
"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Disable diarization
|
||||||
|
with patch.object(settings, "DIARIZATION_BACKEND", None):
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify the pipeline completed without diarization
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("PipelineMainFile no-diarization test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_task_pipeline_file_process(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the Celery task entry point for file pipeline processing.
|
||||||
|
"""
|
||||||
|
# Direct import of the underlying async function, bypassing the asynctask decorator
|
||||||
|
|
||||||
|
# Create a test audio file in the transcript's data path
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Get the original async function without the asynctask decorator
|
||||||
|
# The function is wrapped, so we need to call it differently
|
||||||
|
# For now, we test the pipeline directly since the task is just a thin wrapper
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify the pipeline was executed through the task
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("task_pipeline_file_process test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_file_process_no_transcript():
|
||||||
|
"""
|
||||||
|
Test the pipeline with a non-existent transcript.
|
||||||
|
"""
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
# Mock the controller to return None (transcript not found)
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = None
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=str(uuid4()))
|
||||||
|
|
||||||
|
# Should raise an exception for missing transcript when get_transcript is called
|
||||||
|
with pytest.raises(Exception, match="Transcript not found"):
|
||||||
|
await pipeline.get_transcript()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_file_process_no_audio_file(
|
||||||
|
mock_transcript_in_db,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the pipeline when no audio file is found.
|
||||||
|
"""
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
# Don't create any audio files in the data path
|
||||||
|
# The pipeline's process should handle missing files gracefully
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Try to process a non-existent file
|
||||||
|
non_existent_path = mock_transcript_in_db.data_path / "nonexistent.wav"
|
||||||
|
|
||||||
|
# This should fail when trying to open the file with av
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await pipeline.process(non_existent_path)
|
||||||
265
server/tests/test_processors_modal.py
Normal file
265
server/tests/test_processors_modal.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
Tests for Modal-based processors using pytest-recording for HTTP recording/playbook
|
||||||
|
|
||||||
|
Note: theses tests require full modal configuration to be able to record
|
||||||
|
vcr cassettes
|
||||||
|
|
||||||
|
Configuration required for the first recording:
|
||||||
|
- TRANSCRIPT_BACKEND=modal
|
||||||
|
- TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
- TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||||
|
- DIARIZATION_BACKEND=modal
|
||||||
|
- DIARIZATION_URL=https://xxxxx--reflector-diarizer-web.modal.run
|
||||||
|
- DIARIZATION_MODAL_API_KEY=xxxxx
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationInput
|
||||||
|
from reflector.processors.file_diarization_modal import FileDiarizationModalProcessor
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptInput
|
||||||
|
from reflector.processors.file_transcript_modal import FileTranscriptModalProcessor
|
||||||
|
from reflector.processors.transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import DiarizationSegment, Transcript, Word
|
||||||
|
|
||||||
|
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||||
|
TEST_AUDIO_URL = (
|
||||||
|
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_transcript_modal_processor_missing_url():
|
||||||
|
with patch("reflector.processors.file_transcript_modal.settings") as mock_settings:
|
||||||
|
mock_settings.TRANSCRIPT_URL = None
|
||||||
|
with pytest.raises(Exception, match="TRANSCRIPT_URL required"):
|
||||||
|
FileTranscriptModalProcessor(modal_api_key="test-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_diarization_modal_processor_missing_url():
|
||||||
|
with patch("reflector.processors.file_diarization_modal.settings") as mock_settings:
|
||||||
|
mock_settings.DIARIZATION_URL = None
|
||||||
|
with pytest.raises(Exception, match="DIARIZATION_URL required"):
|
||||||
|
FileDiarizationModalProcessor(modal_api_key="test-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_diarization_modal_processor(vcr):
|
||||||
|
"""Test FileDiarizationModalProcessor using public audio URL and Modal API"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
processor = FileDiarizationModalProcessor(
|
||||||
|
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
test_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||||
|
result = await processor._diarize(test_input)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, "diarization")
|
||||||
|
assert isinstance(result.diarization, list)
|
||||||
|
|
||||||
|
# Check structure of each diarization segment
|
||||||
|
for segment in result.diarization:
|
||||||
|
assert "start" in segment
|
||||||
|
assert "end" in segment
|
||||||
|
assert "speaker" in segment
|
||||||
|
assert isinstance(segment["start"], (int, float))
|
||||||
|
assert isinstance(segment["end"], (int, float))
|
||||||
|
assert isinstance(segment["speaker"], int)
|
||||||
|
# Basic sanity check - start should be before end
|
||||||
|
assert segment["start"] < segment["end"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_transcript_modal_processor():
|
||||||
|
"""Test FileTranscriptModalProcessor using public audio URL and Modal API"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
processor = FileTranscriptModalProcessor(
|
||||||
|
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
test_input = FileTranscriptInput(
|
||||||
|
audio_url=TEST_AUDIO_URL,
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will record the HTTP interaction on first run, replay on subsequent runs
|
||||||
|
result = await processor._transcript(test_input)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, "words")
|
||||||
|
assert isinstance(result.words, list)
|
||||||
|
|
||||||
|
# Check structure of each word if present
|
||||||
|
for word in result.words:
|
||||||
|
assert hasattr(word, "text")
|
||||||
|
assert hasattr(word, "start")
|
||||||
|
assert hasattr(word, "end")
|
||||||
|
assert isinstance(word.start, (int, float))
|
||||||
|
assert isinstance(word.end, (int, float))
|
||||||
|
assert isinstance(word.text, str)
|
||||||
|
# Basic sanity check - start should be before or equal to end
|
||||||
|
assert word.start <= word.end
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_diarization_assembler_processor():
|
||||||
|
"""Test TranscriptDiarizationAssemblerProcessor without VCR (no HTTP requests)"""
|
||||||
|
# Create test transcript with words
|
||||||
|
words = [
|
||||||
|
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||||
|
Word(text=" ", start=1.0, end=1.1, speaker=0),
|
||||||
|
Word(text="world", start=1.1, end=2.0, speaker=0),
|
||||||
|
Word(text=".", start=2.0, end=2.1, speaker=0),
|
||||||
|
Word(text=" ", start=2.1, end=2.2, speaker=0),
|
||||||
|
Word(text="How", start=2.2, end=2.8, speaker=0),
|
||||||
|
Word(text=" ", start=2.8, end=2.9, speaker=0),
|
||||||
|
Word(text="are", start=2.9, end=3.2, speaker=0),
|
||||||
|
Word(text=" ", start=3.2, end=3.3, speaker=0),
|
||||||
|
Word(text="you", start=3.3, end=3.8, speaker=0),
|
||||||
|
Word(text="?", start=3.8, end=3.9, speaker=0),
|
||||||
|
]
|
||||||
|
transcript = Transcript(words=words)
|
||||||
|
|
||||||
|
# Create test diarization segments
|
||||||
|
diarization = [
|
||||||
|
DiarizationSegment(start=0.0, end=2.1, speaker=0),
|
||||||
|
DiarizationSegment(start=2.1, end=3.9, speaker=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create processor and test input
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
test_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=diarization
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track emitted results
|
||||||
|
emitted_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
emitted_results.append(result)
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
|
||||||
|
# Process the input
|
||||||
|
await processor.push(test_input)
|
||||||
|
|
||||||
|
# Verify result was emitted
|
||||||
|
assert len(emitted_results) == 1
|
||||||
|
result = emitted_results[0]
|
||||||
|
|
||||||
|
# Verify result structure
|
||||||
|
assert isinstance(result, Transcript)
|
||||||
|
assert len(result.words) == len(words)
|
||||||
|
|
||||||
|
# Verify speaker assignments were applied
|
||||||
|
# Words 0-3 (indices) should be speaker 0 (time 0.0-2.0)
|
||||||
|
# Words 4-10 (indices) should be speaker 1 (time 2.1-3.9)
|
||||||
|
for i in range(4): # First 4 words (Hello, space, world, .)
|
||||||
|
assert (
|
||||||
|
result.words[i].speaker == 0
|
||||||
|
), f"Word {i} '{result.words[i].text}' should be speaker 0, got {result.words[i].speaker}"
|
||||||
|
|
||||||
|
for i in range(4, 11): # Remaining words (space, How, space, are, space, you, ?)
|
||||||
|
assert (
|
||||||
|
result.words[i].speaker == 1
|
||||||
|
), f"Word {i} '{result.words[i].text}' should be speaker 1, got {result.words[i].speaker}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_diarization_assembler_no_diarization():
|
||||||
|
"""Test TranscriptDiarizationAssemblerProcessor with no diarization data"""
|
||||||
|
# Create test transcript
|
||||||
|
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||||
|
transcript = Transcript(words=words)
|
||||||
|
|
||||||
|
# Create processor and test input with empty diarization
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
test_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track emitted results
|
||||||
|
emitted_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
emitted_results.append(result)
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
|
||||||
|
# Process the input
|
||||||
|
await processor.push(test_input)
|
||||||
|
|
||||||
|
# Verify original transcript was returned unchanged
|
||||||
|
assert len(emitted_results) == 1
|
||||||
|
result = emitted_results[0]
|
||||||
|
assert result is transcript # Should be the same object
|
||||||
|
assert result.words[0].speaker == 0 # Original speaker unchanged
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_modal_pipeline_integration(vcr):
|
||||||
|
"""Integration test: Transcription -> Diarization -> Assembly
|
||||||
|
|
||||||
|
This test demonstrates the full pipeline:
|
||||||
|
1. Run transcription via Modal
|
||||||
|
2. Run diarization via Modal
|
||||||
|
3. Assemble transcript with diarization
|
||||||
|
"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Step 1: Transcription
|
||||||
|
transcript_processor = FileTranscriptModalProcessor(
|
||||||
|
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
transcript_input = FileTranscriptInput(audio_url=TEST_AUDIO_URL, language="en")
|
||||||
|
transcript = await transcript_processor._transcript(transcript_input)
|
||||||
|
|
||||||
|
# Step 2: Diarization
|
||||||
|
diarization_processor = FileDiarizationModalProcessor(
|
||||||
|
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
diarization_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||||
|
diarization_result = await diarization_processor._diarize(diarization_input)
|
||||||
|
|
||||||
|
# Step 3: Assembly
|
||||||
|
assembler = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
assembly_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=diarization_result.diarization
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track assembled result
|
||||||
|
assembled_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
assembled_results.append(result)
|
||||||
|
|
||||||
|
assembler.on(capture_result)
|
||||||
|
|
||||||
|
await assembler.push(assembly_input)
|
||||||
|
|
||||||
|
# Verify the full pipeline worked
|
||||||
|
assert len(assembled_results) == 1
|
||||||
|
final_transcript = assembled_results[0]
|
||||||
|
|
||||||
|
# Verify the final transcript has the original words with updated speaker info
|
||||||
|
assert isinstance(final_transcript, Transcript)
|
||||||
|
assert len(final_transcript.words) == len(transcript.words)
|
||||||
|
assert len(final_transcript.words) > 0
|
||||||
|
|
||||||
|
# Verify some words have been assigned speakers from diarization
|
||||||
|
speakers_found = set(word.speaker for word in final_transcript.words)
|
||||||
|
assert len(speakers_found) > 0 # At least some speaker assignments
|
||||||
@@ -2,10 +2,13 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("enable_diarization", [False, True])
|
||||||
async def test_basic_process(
|
async def test_basic_process(
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
|
enable_diarization,
|
||||||
|
dummy_diarization,
|
||||||
):
|
):
|
||||||
# goal is to start the server, and send rtc audio to it
|
# goal is to start the server, and send rtc audio to it
|
||||||
# validate the events received
|
# validate the events received
|
||||||
@@ -28,12 +31,31 @@ async def test_basic_process(
|
|||||||
|
|
||||||
# invoke the process and capture events
|
# invoke the process and capture events
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
await process_audio_file(path.as_posix(), event_callback)
|
|
||||||
print(marks)
|
if enable_diarization:
|
||||||
|
# Test with diarization - may fail if pyannote.audio is not installed
|
||||||
|
try:
|
||||||
|
await process_audio_file(
|
||||||
|
path.as_posix(), event_callback, enable_diarization=True
|
||||||
|
)
|
||||||
|
except SystemExit:
|
||||||
|
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
||||||
|
else:
|
||||||
|
# Test without diarization - should always work
|
||||||
|
await process_audio_file(
|
||||||
|
path.as_posix(), event_callback, enable_diarization=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
||||||
|
|
||||||
# validate the events
|
# validate the events
|
||||||
assert marks["TranscriptLinerProcessor"] == 1
|
# Each processor should be called for each audio segment processed
|
||||||
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
|
# The final processors (Topic, Title, Summary) should be called once at the end
|
||||||
|
assert marks["TranscriptLinerProcessor"] > 0
|
||||||
|
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
||||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||||
|
|
||||||
|
if enable_diarization:
|
||||||
|
assert marks["TestAudioDiarizationProcessor"] == 1
|
||||||
|
|||||||
937
server/uv.lock
generated
937
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user