Files
reflector/server/tests/test_gpu_modal_transcript.py
Mathieu Virbel 3ea7f6b7b6 feat: pipeline improvement with file processing, parakeet, silero-vad (#540)
* feat: improve pipeline threading, and transcriber (parakeet and silero vad)

* refactor: remove whisperx, implement parakeet

* refactor: make audio_chunker more smart and wait for speech, instead of fixed frame

* refactor: make audio merge to always downscale the audio to 16k for transcription

* refactor: make the audio transcript modal accepting batches

* refactor: improve type safety and remove prometheus metrics

- Add DiarizationSegment TypedDict for proper diarization typing
- Replace List/Optional with modern Python list/| None syntax
- Remove all Prometheus metrics from TranscriptDiarizationAssemblerProcessor
- Add comprehensive file processing pipeline with parallel execution
- Update processor imports and type annotations throughout
- Implement optimized file pipeline as default in process.py tool

* refactor: convert FileDiarizationProcessor I/O types to BaseModel

Update FileDiarizationInput and FileDiarizationOutput to inherit from
BaseModel instead of plain classes, following the standard pattern
used by other processors in the codebase.

* test: add tests for file transcript and diarization with pytest-recording

* build: add pytest-recording

* feat: add local pyannote for testing

* fix: replace PyAV AudioResampler with torchaudio for reliable audio processing

- Replace problematic PyAV AudioResampler that was causing ValueError: [Errno 22] Invalid argument
- Use torchaudio.functional.resample for robust sample rate conversion
- Optimize processing: skip conversion for already 16kHz mono audio
- Add direct WAV writing with Python wave module for better performance
- Consolidate duplicate downsample checks for cleaner code
- Maintain list[av.AudioFrame] input interface
- Required for Silero VAD which needs 16kHz mono audio

* fix: replace PyAV AudioResampler with torchaudio solution

- Resolves ValueError: [Errno 22] Invalid argument in AudioMergeProcessor
- Replaces problematic PyAV AudioResampler with torchaudio.functional.resample
- Optimizes processing to skip unnecessary conversions when audio is already 16kHz mono
- Uses direct WAV writing with Python's wave module for better performance
- Fixes test_basic_process to disable diarization (pyannote dependency not installed)
- Updates test expectations to match actual processor behavior
- Removes unused pydub dependency from pyproject.toml
- Adds comprehensive TEST_ANALYSIS.md documenting test suite status

* feat: add parameterized test for both diarization modes

- Adds @pytest.mark.parametrize to test_basic_process with enable_diarization=[False, True]
- Test with diarization=False always passes (tests core AudioMergeProcessor functionality)
- Test with diarization=True gracefully skips when pyannote.audio is not installed
- Provides comprehensive test coverage for both pipeline configurations

* fix: resolve pipeline property naming conflict in AudioDiarizationPyannoteProcessor

- Renames 'pipeline' property to 'diarization_pipeline' to avoid conflict with base Processor.pipeline attribute
- Fixes AttributeError: 'property 'pipeline' object has no setter' when set_pipeline() is called
- Updates property usage in _diarize method to use new name
- Now correctly supports pipeline initialization for diarization processing

* fix: add local for pyannote

* test: add diarization test

* fix: resample on audio merge now working

* fix: correctly restore timestamp

* fix: display exception in a threaded processor if that happen

* Update pyproject.toml

* ci: remove option

* ci: update astral-sh/setup-uv

* test: add monadical url for pytest-recording

* refactor: remove previous version

* build: move faster whisper to local dep

* test: fix missing import

* refactor: improve main_file_pipeline organization and error handling

- Move all imports to the top of the file
- Create unified EmptyPipeline class to replace duplicate mock pipeline code
- Remove timeout and fallback logic - let processors handle their own retries
- Fix error handling to raise any exception from parallel tasks
- Add proper type hints and validation for captured results

* fix: wrong function

* fix: remove task_done

* feat: add configurable file processing timeouts for modal processors

- Add TRANSCRIPT_FILE_TIMEOUT setting (default: 600s) for file transcription
- Add DIARIZATION_FILE_TIMEOUT setting (default: 600s) for file diarization
- Replace hardcoded timeout=600 with configurable settings in modal processors
- Allows customization of timeout values via environment variables

* fix: use logger

* fix: worker process meetings now use file pipeline

* fix: topic not gathered

* refactor: remove prepare(), pipeline now work

* refactor: implement many review from Igor

* test: add test for test_pipeline_main_file

* refactor: remove doc

* doc: add doc

* ci: update build to use native arm64 builder

* fix: merge fixes

* refactor: changes from Igor review + add test (not by default) to test gpu modal part

* ci: update to our own runner linux-amd64

* ci: try using suggested mode=min

* fix: update diarizer for latest modal, and use volume

* fix: modal file extension detection

* fix: put the diarizer as A100
2025-08-20 20:07:19 -06:00

331 lines
12 KiB
Python

"""
Tests for GPU Modal transcription endpoints.
These tests are marked with the "gpu-modal" group and will not run by default.
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
Required environment variables:
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
Example with pytest (override default addopts to run ONLY gpu_modal tests):
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
TRANSCRIPT_MODAL_API_KEY=your-api-key \
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
# Or with completely clean options:
uv run -m pytest -m gpu_modal -o addopts="" tests/
Running Modal locally for testing:
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
"""
import os
import tempfile
from pathlib import Path
import httpx
import pytest
# Test audio file URL for testing
TEST_AUDIO_URL = (
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
)
def get_modal_transcript_url():
"""Get and validate the Modal transcript URL from environment."""
url = os.environ.get("TRANSCRIPT_URL")
if not url:
pytest.skip(
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
)
return url
def get_auth_headers():
"""Get authentication headers if API key is available."""
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
if api_key:
return {"Authorization": f"Bearer {api_key}"}
return {}
def get_model_name():
"""Get the model name from environment or use default."""
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
@pytest.mark.gpu_modal
class TestGPUModalTranscript:
"""Test suite for GPU Modal transcription endpoints."""
def test_transcriptions_from_url(self):
"""Test the /v1/audio/transcriptions-from-url endpoint."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "en",
"timestamp_offset": 0.0,
},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure
assert "text" in result
assert "words" in result
assert isinstance(result["text"], str)
assert isinstance(result["words"], list)
# Verify content is meaningful
assert len(result["text"]) > 0, "Transcript text should not be empty"
assert len(result["words"]) > 0, "Words list must not be empty"
# Verify word structure
for word in result["words"]:
assert "word" in word
assert "start" in word
assert "end" in word
assert isinstance(word["start"], (int, float))
assert isinstance(word["end"], (int, float))
assert word["start"] <= word["end"]
def test_transcriptions_single_file(self):
"""Test the /v1/audio/transcriptions endpoint with a single file."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Download test audio file to upload
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
tmp_file.write(audio_response.content)
tmp_file_path = tmp_file.name
try:
# Upload the file for transcription
with open(tmp_file_path, "rb") as f:
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
data = {
"model": get_model_name(),
"language": "en",
"batch": "false",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for single file
assert "text" in result
assert "words" in result
assert "filename" in result
assert isinstance(result["text"], str)
assert isinstance(result["words"], list)
# Verify content
assert len(result["text"]) > 0, "Transcript text should not be empty"
finally:
Path(tmp_file_path).unlink(missing_ok=True)
def test_transcriptions_multiple_files(self):
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Create multiple test files (we'll use the same audio content for simplicity)
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
audio_content = audio_response.content
temp_files = []
try:
# Create 3 temporary files
for i in range(3):
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tmp_file.write(audio_content)
tmp_file.close()
temp_files.append(tmp_file.name)
# Upload multiple files for transcription (non-batch)
files = [
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
for i, f in enumerate(temp_files)
]
data = {
"model": get_model_name(),
"language": "en",
"batch": "false",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
# Close file handles
for _, file_tuple in files:
file_tuple[1].close()
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for multiple files (non-batch)
assert "results" in result
assert isinstance(result["results"], list)
assert len(result["results"]) == 3
for idx, file_result in enumerate(result["results"]):
assert "text" in file_result
assert "words" in file_result
assert "filename" in file_result
assert isinstance(file_result["text"], str)
assert isinstance(file_result["words"], list)
assert len(file_result["text"]) > 0
finally:
for f in temp_files:
Path(f).unlink(missing_ok=True)
def test_transcriptions_multiple_files_batch(self):
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Create multiple test files
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
audio_content = audio_response.content
temp_files = []
try:
# Create 3 temporary files
for i in range(3):
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tmp_file.write(audio_content)
tmp_file.close()
temp_files.append(tmp_file.name)
# Upload multiple files for batch transcription
files = [
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
for i, f in enumerate(temp_files)
]
data = {
"model": get_model_name(),
"language": "en",
"batch": "true",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
# Close file handles
for _, file_tuple in files:
file_tuple[1].close()
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for batch mode
assert "results" in result
assert isinstance(result["results"], list)
assert len(result["results"]) == 3
for idx, batch_result in enumerate(result["results"]):
assert "text" in batch_result
assert "words" in batch_result
assert "filename" in batch_result
assert isinstance(batch_result["text"], str)
assert isinstance(batch_result["words"], list)
assert len(batch_result["text"]) > 0
finally:
for f in temp_files:
Path(f).unlink(missing_ok=True)
def test_transcriptions_error_handling(self):
"""Test error handling for invalid requests."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
# Test with unsupported language
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "fr", # Parakeet only supports English
"timestamp_offset": 0.0,
},
headers=headers,
)
assert response.status_code == 400
assert "only supports English" in response.text
def test_transcriptions_with_timestamp_offset(self):
"""Test transcription with timestamp offset parameter."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
# Test with timestamp offset
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "en",
"timestamp_offset": 10.0, # Add 10 second offset
},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure
assert "text" in result
assert "words" in result
assert len(result["words"]) > 0, "Words list must not be empty"
# Verify that timestamps have been offset
for word in result["words"]:
# All timestamps should be >= 10.0 due to offset
assert (
word["start"] >= 10.0
), f"Word start time {word['start']} should be >= 10.0"
assert (
word["end"] >= 10.0
), f"Word end time {word['end']} should be >= 10.0"