Compare commits

...

12 Commits

Author SHA1 Message Date
52f9f533d7 chore(main): release 0.7.2 (#559) 2025-08-21 21:00:05 -06:00
0c3878ac3c fix: docker image not loading libgomp.so.1 for torch (#560)
On ARM64, the docker iamge crash because torch cannot load libgomp.so.1
-- Look like pytorch does not install the same packages depending the
platform.

AMD64:

/app/.venv/lib/python3.12/site-packages/torch/lib/libgomp.so.1
/app/.venv/lib/python3.12/site-packages/ctranslate2.libs/libgomp-a34b3233.so.1.0.0
/app/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0

ARM64:

/app/.venv/lib/python3.12/site-packages/ctranslate2.libs/libgomp-d22c30c5.so.1.0.0
/app/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-947d5fa1.so.1.0.0
/app/.venv/lib/python3.12/site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0
2025-08-21 16:41:35 -06:00
Igor Loskutov
d70beee51b fix: include shared rooms to search (#558)
* include shared rooms to search

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* remove tests, thats too much
2025-08-21 14:52:29 -04:00
bc5b351d2b chore(main): release 0.7.1 (#557) 2025-08-20 23:23:27 -06:00
Igor Loskutov
07981e8090 fix: webvtt db null expectation mismatch (#556) 2025-08-20 23:22:41 -06:00
7e366f6338 chore(main): release 0.7.0 (#541) 2025-08-20 22:24:36 -06:00
7592679a35 build: separate silero-vad and force torch to be resolved without nvidia (#555)
* build: separate silero-vad and force torch to be resolved without nvidia

* build: also add torchaudio as cpu version
2025-08-20 22:23:48 -06:00
af16178f86 ci: use github-token to get around potential api throttling + rework dockerfile (#554)
* ci: use github-token to get around potential api throttling

* build: put pyannote-audio separate to the project

* fix: now that we have a readme, use it

* build: add UV_NO_CACHE
2025-08-20 21:59:29 -06:00
3ea7f6b7b6 feat: pipeline improvement with file processing, parakeet, silero-vad (#540)
* feat: improve pipeline threading, and transcriber (parakeet and silero vad)

* refactor: remove whisperx, implement parakeet

* refactor: make audio_chunker more smart and wait for speech, instead of fixed frame

* refactor: make audio merge to always downscale the audio to 16k for transcription

* refactor: make the audio transcript modal accepting batches

* refactor: improve type safety and remove prometheus metrics

- Add DiarizationSegment TypedDict for proper diarization typing
- Replace List/Optional with modern Python list/| None syntax
- Remove all Prometheus metrics from TranscriptDiarizationAssemblerProcessor
- Add comprehensive file processing pipeline with parallel execution
- Update processor imports and type annotations throughout
- Implement optimized file pipeline as default in process.py tool

* refactor: convert FileDiarizationProcessor I/O types to BaseModel

Update FileDiarizationInput and FileDiarizationOutput to inherit from
BaseModel instead of plain classes, following the standard pattern
used by other processors in the codebase.

* test: add tests for file transcript and diarization with pytest-recording

* build: add pytest-recording

* feat: add local pyannote for testing

* fix: replace PyAV AudioResampler with torchaudio for reliable audio processing

- Replace problematic PyAV AudioResampler that was causing ValueError: [Errno 22] Invalid argument
- Use torchaudio.functional.resample for robust sample rate conversion
- Optimize processing: skip conversion for already 16kHz mono audio
- Add direct WAV writing with Python wave module for better performance
- Consolidate duplicate downsample checks for cleaner code
- Maintain list[av.AudioFrame] input interface
- Required for Silero VAD which needs 16kHz mono audio

* fix: replace PyAV AudioResampler with torchaudio solution

- Resolves ValueError: [Errno 22] Invalid argument in AudioMergeProcessor
- Replaces problematic PyAV AudioResampler with torchaudio.functional.resample
- Optimizes processing to skip unnecessary conversions when audio is already 16kHz mono
- Uses direct WAV writing with Python's wave module for better performance
- Fixes test_basic_process to disable diarization (pyannote dependency not installed)
- Updates test expectations to match actual processor behavior
- Removes unused pydub dependency from pyproject.toml
- Adds comprehensive TEST_ANALYSIS.md documenting test suite status

* feat: add parameterized test for both diarization modes

- Adds @pytest.mark.parametrize to test_basic_process with enable_diarization=[False, True]
- Test with diarization=False always passes (tests core AudioMergeProcessor functionality)
- Test with diarization=True gracefully skips when pyannote.audio is not installed
- Provides comprehensive test coverage for both pipeline configurations

* fix: resolve pipeline property naming conflict in AudioDiarizationPyannoteProcessor

- Renames 'pipeline' property to 'diarization_pipeline' to avoid conflict with base Processor.pipeline attribute
- Fixes AttributeError: 'property 'pipeline' object has no setter' when set_pipeline() is called
- Updates property usage in _diarize method to use new name
- Now correctly supports pipeline initialization for diarization processing

* fix: add local for pyannote

* test: add diarization test

* fix: resample on audio merge now working

* fix: correctly restore timestamp

* fix: display exception in a threaded processor if that happen

* Update pyproject.toml

* ci: remove option

* ci: update astral-sh/setup-uv

* test: add monadical url for pytest-recording

* refactor: remove previous version

* build: move faster whisper to local dep

* test: fix missing import

* refactor: improve main_file_pipeline organization and error handling

- Move all imports to the top of the file
- Create unified EmptyPipeline class to replace duplicate mock pipeline code
- Remove timeout and fallback logic - let processors handle their own retries
- Fix error handling to raise any exception from parallel tasks
- Add proper type hints and validation for captured results

* fix: wrong function

* fix: remove task_done

* feat: add configurable file processing timeouts for modal processors

- Add TRANSCRIPT_FILE_TIMEOUT setting (default: 600s) for file transcription
- Add DIARIZATION_FILE_TIMEOUT setting (default: 600s) for file diarization
- Replace hardcoded timeout=600 with configurable settings in modal processors
- Allows customization of timeout values via environment variables

* fix: use logger

* fix: worker process meetings now use file pipeline

* fix: topic not gathered

* refactor: remove prepare(), pipeline now work

* refactor: implement many review from Igor

* test: add test for test_pipeline_main_file

* refactor: remove doc

* doc: add doc

* ci: update build to use native arm64 builder

* fix: merge fixes

* refactor: changes from Igor review + add test (not by default) to test gpu modal part

* ci: update to our own runner linux-amd64

* ci: try using suggested mode=min

* fix: update diarizer for latest modal, and use volume

* fix: modal file extension detection

* fix: put the diarizer as A100
2025-08-20 20:07:19 -06:00
Igor Loskutov
009590c080 feat: search frontend (#551)
* feat: better highlight

* feat(search): add long_summary to search vector for improved search results

- Update search vector to include long_summary with weight B (between title A and webvtt C)
- Modify SearchController to fetch long_summary and prioritize its snippets
- Generate snippets from long_summary first (max 2), then from webvtt for remaining slots
- Add comprehensive tests for long_summary search functionality
- Create migration to update search_vector_en column in PostgreSQL

This improves search quality by including summarized content which often contains
key topics and themes that may not be explicitly mentioned in the transcript.

* fix: address code review feedback for search enhancements

- Fix test file inconsistencies by removing references to non-existent model fields
  - Comment out tests for unimplemented features (room_ids, status filters, date ranges)
  - Update tests to only use currently available fields (room_id singular, no room_name/processing_status)
  - Mark future functionality tests with @pytest.mark.skip

- Make snippet counts configurable
  - Add LONG_SUMMARY_MAX_SNIPPETS constant (default: 2)
  - Replace hardcoded value with configurable constant

- Improve error handling consistency in WebVTT parsing
  - Use different log levels for different error types (debug for malformed, warning for decode, error for unexpected)
  - Add catch-all exception handler for unexpected errors
  - Include stack trace for critical errors

All existing tests pass with these changes.

* fix: correct datetime test to include required duration field

* feat: better highlight

* feat: search room names

* feat: acknowledge deleted room

* feat: search filters fix and rank removal

* chore: minor refactoring

* feat: better matches frontend

* chore: self-review (vibe)

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* remove swc (vibe)

* search url query sync (vibe)

* search url query sync (vibe)

* better casts and cap while

* PR review + simplify frontend hook

* pr: remove search db timeouts

* cleanup tests

* tests cleanup

* frontend cleanup

* index declarations

* refactor frontend (self-review)

* fix search pagination

* clear "x" for search input

* pagination max pages fix

* chore: cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* lockfile

* pr review
2025-08-20 20:56:45 -04:00
Igor Loskutov
fe5d344cff diarization cli: throw on modal errors (#553) 2025-08-20 10:21:52 -04:00
Igor Loskutov
86455ce573 chore: type fixes (#544)
* chore: type fixes

* chore: type fixes
2025-08-18 16:31:23 -04:00
72 changed files with 7441 additions and 834 deletions

View File

@@ -8,18 +8,30 @@ env:
ECR_REPOSITORY: reflector
jobs:
deploy:
runs-on: ubuntu-latest
build:
strategy:
matrix:
include:
- platform: linux/amd64
runner: linux-amd64
arch: amd64
- platform: linux/arm64
runner: linux-arm64
arch: arm64
runs-on: ${{ matrix.runner }}
permissions:
deployments: write
contents: read
outputs:
registry: ${{ steps.login-ecr.outputs.registry }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
@@ -27,21 +39,52 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: aws-actions/amazon-ecr-login@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build and push
id: docker_build
uses: docker/build-push-action@v4
- name: Build and push ${{ matrix.arch }}
uses: docker/build-push-action@v5
with:
context: server
platforms: linux/amd64,linux/arm64
platforms: ${{ matrix.platform }}
push: true
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
cache-from: type=gha
cache-to: type=gha,mode=max
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
cache-from: type=gha,scope=${{ matrix.arch }}
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
provenance: false
create-manifest:
runs-on: ubuntu-latest
needs: [build]
permissions:
deployments: write
contents: read
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Login to Amazon ECR
uses: aws-actions/amazon-ecr-login@v2
- name: Create and push multi-arch manifest
run: |
# Get the registry URL (since we can't easily access job outputs in matrix)
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
docker manifest create \
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"

View File

@@ -19,29 +19,41 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
working-directory: server
- name: Tests
run: |
cd server
uv run -m pytest -v tests
docker:
runs-on: ubuntu-latest
docker-amd64:
runs-on: linux-amd64
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Build and push
id: docker_build
uses: docker/build-push-action@v4
uses: docker/setup-buildx-action@v3
- name: Build AMD64
uses: docker/build-push-action@v6
with:
context: server
platforms: linux/amd64,linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64
cache-from: type=gha,scope=amd64
cache-to: type=gha,mode=max,scope=amd64
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
docker-arm64:
runs-on: linux-arm64
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build ARM64
uses: docker/build-push-action@v6
with:
context: server
platforms: linux/arm64
cache-from: type=gha,scope=arm64
cache-to: type=gha,mode=max,scope=arm64
github-token: ${{ secrets.GHA_CACHE_TOKEN }}

View File

@@ -1,5 +1,37 @@
# Changelog
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
### Bug Fixes
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
### Bug Fixes
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
### Features
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
### Bug Fixes
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)

3
server/.gitignore vendored
View File

@@ -176,7 +176,8 @@ artefacts/
audio_*.wav
# ignore local database
reflector.sqlite3
*.sqlite3
*.db
data/
dump.rdb

View File

@@ -1,7 +1,8 @@
FROM python:3.12-slim
ENV PYTHONUNBUFFERED=1 \
UV_LINK_MODE=copy
UV_LINK_MODE=copy \
UV_NO_CACHE=1
# builder install base dependencies
WORKDIR /tmp
@@ -13,8 +14,8 @@ ENV PATH="/root/.local/bin/:$PATH"
# install application dependencies
RUN mkdir -p /app
WORKDIR /app
COPY pyproject.toml uv.lock /app/
RUN touch README.md && env uv sync --compile-bytecode --locked
COPY pyproject.toml uv.lock README.md /app/
RUN uv sync --compile-bytecode --locked
# pre-download nltk packages
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
@@ -26,4 +27,15 @@ COPY migrations /app/migrations
COPY reflector /app/reflector
WORKDIR /app
# Create symlink for libgomp if it doesn't exist (for ARM64 compatibility)
RUN if [ "$(uname -m)" = "aarch64" ] && [ ! -f /usr/lib/libgomp.so.1 ]; then \
LIBGOMP_PATH=$(find /app/.venv/lib -path "*/torch.libs/libgomp*.so.*" 2>/dev/null | head -n1); \
if [ -n "$LIBGOMP_PATH" ]; then \
ln -sf "$LIBGOMP_PATH" /usr/lib/libgomp.so.1; \
fi \
fi
# Pre-check just to make sure the image will not fail
RUN uv run python -c "import silero_vad.model"
CMD ["./runserver.sh"]

View File

@@ -40,3 +40,5 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
```bash
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
```
.

View File

@@ -4,7 +4,8 @@ This repository hold an API for the GPU implementation of the Reflector API serv
and use [Modal.com](https://modal.com)
- `reflector_diarizer.py` - Diarization API
- `reflector_transcriber.py` - Transcription API
- `reflector_transcriber.py` - Transcription API (Whisper)
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
- `reflector_translator.py` - Translation API
## Modal.com deployment
@@ -19,6 +20,10 @@ $ modal deploy reflector_transcriber.py
...
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
$ modal deploy reflector_transcriber_parakeet.py
...
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
$ modal deploy reflector_llm.py
...
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
@@ -68,6 +73,86 @@ Authorization: bearer <REFLECTOR_APIKEY>
### Transcription
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
**GPU Configuration:**
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
- Higher concurrency (max_inputs=10)
- Optimized for multiple small audio files
- Supports batch processing for efficiency
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
- Lower concurrency but more powerful processing
- Optimized for single large audio files
- VAD-based chunking for long-form audio
##### `/v1/audio/transcriptions` - Small file transcription
**request** (multipart/form-data)
- `file` or `files[]` - audio file(s) to transcribe
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
- `language` - language code (default: `en`)
- `batch` - whether to use batch processing for multiple files (default: `true`)
**response**
```json
{
"text": "transcribed text",
"words": [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0}
],
"filename": "audio.mp3"
}
```
For multiple files with batch=true:
```json
{
"results": [
{
"filename": "audio1.mp3",
"text": "transcribed text",
"words": [...]
},
{
"filename": "audio2.mp3",
"text": "transcribed text",
"words": [...]
}
]
}
```
##### `/v1/audio/transcriptions-from-url` - Large file transcription
**request** (application/json)
```json
{
"audio_file_url": "https://example.com/audio.mp3",
"model": "nvidia/parakeet-tdt-0.6b-v2",
"language": "en",
"timestamp_offset": 0.0
}
```
**response**
```json
{
"text": "transcribed text from large file",
"words": [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0}
]
}
```
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
#### Whisper Transcriber (`reflector_transcriber.py`)
`POST /transcribe`
**request** (multipart/form-data)

View File

@@ -4,14 +4,80 @@ Reflector GPU backend - diarizer
"""
import os
import uuid
from typing import Mapping, NewType
from urllib.parse import urlparse
import modal.gpu
from modal import App, Image, Secret, asgi_app, enter, method
from pydantic import BaseModel
import modal
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
MODEL_DIR = "/root/diarization_models"
app = App(name="reflector-diarizer")
UPLOADS_PATH = "/uploads"
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
app = modal.App(name="reflector-diarizer")
# Volume for temporary file uploads
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
parsed_url = urlparse(url)
url_path = parsed_url.path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return AudioFileExtension(ext)
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return AudioFileExtension("mp3")
if "audio/wav" in content_type:
return AudioFileExtension("wav")
if "audio/mp4" in content_type:
return AudioFileExtension("mp4")
raise ValueError(
f"Unsupported audio format for URL: {url}. "
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
)
def download_audio_to_volume(
audio_file_url: str,
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
import requests
from fastapi import HTTPException
print(f"Checking audio file at: {audio_file_url}")
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
print(f"Downloading audio file from: {audio_file_url}")
response = requests.get(audio_file_url, allow_redirects=True)
if response.status_code != 200:
print(f"Download failed with status {response.status_code}: {response.text}")
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to download audio file: {response.status_code}",
)
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
print(f"File saved as: {unique_filename}")
return unique_filename, audio_suffix
def migrate_cache_llm():
@@ -39,7 +105,7 @@ def download_pyannote_audio():
diarizer_image = (
Image.debian_slim(python_version="3.10.8")
modal.Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio==3.1.0",
"requests",
@@ -55,7 +121,8 @@ diarizer_image = (
"hf-transfer",
)
.run_function(
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
download_pyannote_audio,
secrets=[modal.Secret.from_name("hf_token")],
)
.run_function(migrate_cache_llm)
.env(
@@ -70,53 +137,60 @@ diarizer_image = (
@app.cls(
gpu=modal.gpu.A100(size="40GB"),
gpu="A100",
timeout=60 * 30,
scaledown_window=60,
allow_concurrent_inputs=1,
image=diarizer_image,
volumes={UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
secrets=[
modal.Secret.from_name("hf_token"),
],
)
@modal.concurrent(max_inputs=1)
class Diarizer:
@enter()
@modal.enter(snap=True)
def enter(self):
import torch
from pyannote.audio import Pipeline
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
print(f"Using device: {self.device}")
self.diarization_pipeline = Pipeline.from_pretrained(
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
PYANNOTE_MODEL_NAME,
cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"],
)
self.diarization_pipeline.to(torch.device(self.device))
@method()
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
import tempfile
@modal.method()
def diarize(self, filename: str, timestamp: float = 0.0):
import torchaudio
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
upload_volume.reload()
print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
print(f"Diarizing audio from: {file_path}")
waveform, sample_rate = torchaudio.load(file_path)
diarization = self.diarization_pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
)
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:]),
}
)
words = []
for diarization_segment, _, speaker in diarization.itertracks(
yield_label=True
):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:]),
}
)
print("Diarization complete")
return {"diarization": words}
print("Diarization complete")
return {"diarization": words}
# -------------------------------------------------------------------
@@ -127,17 +201,18 @@ class Diarizer:
@app.function(
timeout=60 * 10,
scaledown_window=60 * 3,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
modal.Secret.from_name("reflector-gpu"),
],
volumes={UPLOADS_PATH: upload_volume},
image=diarizer_image,
)
@asgi_app()
@modal.concurrent(max_inputs=40)
@modal.asgi_app()
def web():
import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
diarizerstub = Diarizer()
@@ -153,35 +228,26 @@ def web():
headers={"WWW-Authenticate": "Bearer"},
)
def validate_audio_file(audio_file_url: str):
# Check if the audio file exists
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(
status_code=response.status_code,
detail="The audio file does not exist.",
)
class DiarizationResponse(BaseModel):
result: dict
@app.post(
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
)
def diarize(
audio_file_url: str, timestamp: float = 0.0
) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format
audio_suffix = "mp3"
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
print("Downloading audio file")
response = requests.get(audio_file_url, allow_redirects=True)
print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn(
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
)
result = func.get()
return result
try:
func = diarizerstub.diarize.spawn(
filename=unique_filename, timestamp=timestamp
)
result = func.get()
return result
finally:
try:
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
upload_volume.commit()
except Exception as e:
print(f"Error cleaning up {unique_filename}: {e}")
return app

View File

@@ -0,0 +1,622 @@
import logging
import os
import sys
import threading
import uuid
from typing import Mapping, NewType
from urllib.parse import urlparse
import modal
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
SAMPLERATE = 16000
UPLOADS_PATH = "/uploads"
CACHE_PATH = "/cache"
VAD_CONFIG = {
"max_segment_duration": 30.0,
"batch_max_files": 10,
"batch_max_duration": 5.0,
"min_segment_duration": 0.02,
"silence_padding": 0.5,
"window_size": 512,
}
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
app = modal.App("reflector-transcriber-parakeet")
# Volume for caching model weights
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
# Volume for temporary file uploads
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
image = (
modal.Image.from_registry(
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_HOME": "/cache",
"DEBIAN_FRONTEND": "noninteractive",
"CXX": "g++",
"CC": "g++",
}
)
.apt_install("ffmpeg")
.pip_install(
"hf_transfer==0.1.9",
"huggingface_hub[hf-xet]==0.31.2",
"nemo_toolkit[asr]==2.3.0",
"cuda-python==12.8.0",
"fastapi==0.115.12",
"numpy<2",
"librosa==0.10.1",
"requests",
"silero-vad==5.1.0",
"torch",
)
.entrypoint([]) # silence chatty logs by container on start
)
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
parsed_url = urlparse(url)
url_path = parsed_url.path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return AudioFileExtension(ext)
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return AudioFileExtension("mp3")
if "audio/wav" in content_type:
return AudioFileExtension("wav")
if "audio/mp4" in content_type:
return AudioFileExtension("mp4")
raise ValueError(
f"Unsupported audio format for URL: {url}. "
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
)
def download_audio_to_volume(
audio_file_url: str,
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
import requests
from fastapi import HTTPException
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
response = requests.get(audio_file_url, allow_redirects=True)
response.raise_for_status()
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
file_path = f"{UPLOADS_PATH}/{unique_filename}"
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
return unique_filename, audio_suffix
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
"""Add 0.5 seconds of silence if audio is less than 500ms.
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
have to be of the same length
See: https://github.com/NVIDIA/NeMo/issues/8451
"""
import numpy as np
audio_duration = len(audio_array) / sample_rate
if audio_duration < 0.5:
silence_samples = int(sample_rate * 0.5)
silence = np.zeros(silence_samples, dtype=np.float32)
return np.concatenate([audio_array, silence])
return audio_array
@app.cls(
gpu="A10G",
timeout=600,
scaledown_window=300,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
)
@modal.concurrent(max_inputs=10)
class TranscriberParakeetLive:
@modal.enter(snap=True)
def enter(self):
import nemo.collections.asr as nemo_asr
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
self.lock = threading.Lock()
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
device = next(self.model.parameters()).device
print(f"Model is on device: {device}")
@modal.method()
def transcribe_segment(
self,
filename: str,
):
import librosa
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
padded_audio = pad_audio(audio_array, sample_rate)
with self.lock:
with NoStdStreams():
(output,) = self.model.transcribe([padded_audio], timestamps=True)
text = output.text.strip()
words = [
{
"word": word_info["word"],
"start": round(word_info["start"], 2),
"end": round(word_info["end"], 2),
}
for word_info in output.timestamp["word"]
]
return {"text": text, "words": words}
@modal.method()
def transcribe_batch(
self,
filenames: list[str],
):
import librosa
upload_volume.reload()
results = []
audio_arrays = []
# Load all audio files with padding
for filename in filenames:
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"Batch file not found: {file_path}")
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
padded_audio = pad_audio(audio_array, sample_rate)
audio_arrays.append(padded_audio)
with self.lock:
with NoStdStreams():
outputs = self.model.transcribe(audio_arrays, timestamps=True)
# Process results for each file
for i, (filename, output) in enumerate(zip(filenames, outputs)):
text = output.text.strip()
words = [
{
"word": word_info["word"],
"start": round(word_info["start"], 2),
"end": round(word_info["end"], 2),
}
for word_info in output.timestamp["word"]
]
results.append(
{
"filename": filename,
"text": text,
"words": words,
}
)
return results
# L40S class for file transcription (bigger files)
@app.cls(
gpu="L40S",
timeout=900,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
)
class TranscriberParakeetFile:
@modal.enter(snap=True)
def enter(self):
import nemo.collections.asr as nemo_asr
import torch
from silero_vad import load_silero_vad
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
device = next(self.model.parameters()).device
print(f"Model is on device: {device}")
torch.set_num_threads(1)
self.vad_model = load_silero_vad(onnx=False)
print("Silero VAD initialized")
@modal.method()
def transcribe_segment(
self,
filename: str,
timestamp_offset: float = 0.0,
):
import librosa
import numpy as np
from silero_vad import VADIterator
def load_and_convert_audio(file_path):
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
return audio_array
def vad_segment_generator(audio_array):
"""Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"]
start = None
for i in range(0, len(audio_array), window_size):
chunk = audio_array[i : i + window_size]
if len(chunk) < window_size:
chunk = np.pad(
chunk, (0, window_size - len(chunk)), mode="constant"
)
speech_dict = vad_iterator(chunk)
if not speech_dict:
continue
if "start" in speech_dict:
start = speech_dict["start"]
continue
if "end" in speech_dict and start is not None:
end = speech_dict["end"]
start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE)
# Extract the actual audio segment
audio_segment = audio_array[start:end]
yield (start_time, end_time, audio_segment)
start = None
vad_iterator.reset_states()
def vad_segment_filter(segments):
"""Filter VAD segments by duration and chunk large segments"""
min_dur = VAD_CONFIG["min_segment_duration"]
max_dur = VAD_CONFIG["max_segment_duration"]
for start_time, end_time, audio_segment in segments:
segment_duration = end_time - start_time
# Skip very small segments
if segment_duration < min_dur:
continue
# If segment is within max duration, yield as-is
if segment_duration <= max_dur:
yield (start_time, end_time, audio_segment)
continue
# Chunk large segments into smaller pieces
chunk_samples = int(max_dur * SAMPLERATE)
current_start = start_time
for chunk_offset in range(0, len(audio_segment), chunk_samples):
chunk_audio = audio_segment[
chunk_offset : chunk_offset + chunk_samples
]
if len(chunk_audio) == 0:
break
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
chunk_end = current_start + chunk_duration
# Only yield chunks that meet minimum duration
if chunk_duration >= min_dur:
yield (current_start, chunk_end, chunk_audio)
current_start = chunk_end
def batch_segments(segments, max_files=10, max_duration=5.0):
batch = []
batch_duration = 0.0
for start_time, end_time, audio_segment in segments:
segment_duration = end_time - start_time
if segment_duration < VAD_CONFIG["silence_padding"]:
silence_samples = int(
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
)
padding = np.zeros(silence_samples, dtype=np.float32)
audio_segment = np.concatenate([audio_segment, padding])
segment_duration = VAD_CONFIG["silence_padding"]
batch.append((start_time, end_time, audio_segment))
batch_duration += segment_duration
if len(batch) >= max_files or batch_duration >= max_duration:
yield batch
batch = []
batch_duration = 0.0
if batch:
yield batch
def transcribe_batch(model, audio_segments):
with NoStdStreams():
outputs = model.transcribe(audio_segments, timestamps=True)
return outputs
def emit_results(
results,
segments_info,
batch_index,
total_batches,
):
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, (start_time, end_time, _)) in enumerate(
zip(results, segments_info)
):
text = output.text.strip()
words = [
{
"word": word_info["word"],
"start": round(
word_info["start"] + start_time + timestamp_offset, 2
),
"end": round(
word_info["end"] + start_time + timestamp_offset, 2
),
}
for word_info in output.timestamp["word"]
]
yield text, words
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
audio_array = load_and_convert_audio(file_path)
total_duration = len(audio_array) / float(SAMPLERATE)
processed_duration = 0.0
all_text_parts = []
all_words = []
raw_segments = vad_segment_generator(audio_array)
filtered_segments = vad_segment_filter(raw_segments)
batches = batch_segments(
filtered_segments,
VAD_CONFIG["batch_max_files"],
VAD_CONFIG["batch_max_duration"],
)
batch_index = 0
total_batches = max(
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
)
for batch in batches:
batch_index += 1
audio_segments = [seg[2] for seg in batch]
results = transcribe_batch(self.model, audio_segments)
for text, words in emit_results(
results,
batch,
batch_index,
total_batches,
):
if not text:
continue
all_text_parts.append(text)
all_words.extend(words)
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words}
@app.function(
scaledown_window=60,
timeout=600,
secrets=[
modal.Secret.from_name("reflector-gpu"),
],
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
image=image,
)
@modal.concurrent(max_inputs=40)
@modal.asgi_app()
def web():
import os
import uuid
from fastapi import (
Body,
Depends,
FastAPI,
Form,
HTTPException,
UploadFile,
status,
)
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
transcriber_live = TranscriberParakeetLive()
transcriber_file = TranscriberParakeetFile()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranscriptResponse(BaseModel):
result: dict
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
def transcribe(
file: UploadFile = None,
files: list[UploadFile] | None = None,
model: str = Form(MODEL_NAME),
language: str = Form("en"),
batch: bool = Form(False),
):
# Parakeet only supports English
if language != "en":
raise HTTPException(
status_code=400,
detail=f"Parakeet model only supports English. Got language='{language}'",
)
# Handle both single file and multiple files
if not file and not files:
raise HTTPException(
status_code=400, detail="Either 'file' or 'files' parameter is required"
)
if batch and not files:
raise HTTPException(
status_code=400, detail="Batch transcription requires 'files'"
)
upload_files = [file] if file else files
# Upload files to volume
uploaded_filenames = []
for upload_file in upload_files:
audio_suffix = upload_file.filename.split(".")[-1]
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
# Generate unique filename
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Writing file to: {file_path}")
with open(file_path, "wb") as f:
content = upload_file.file.read()
f.write(content)
uploaded_filenames.append(unique_filename)
upload_volume.commit()
try:
# Use A10G live transcriber for per-file transcription
if batch and len(upload_files) > 1:
# Use batch transcription
func = transcriber_live.transcribe_batch.spawn(
filenames=uploaded_filenames,
)
results = func.get()
return {"results": results}
# Per-file transcription
results = []
for filename in uploaded_filenames:
func = transcriber_live.transcribe_segment.spawn(
filename=filename,
)
result = func.get()
result["filename"] = filename
results.append(result)
return {"results": results} if len(results) > 1 else results[0]
finally:
for filename in uploaded_filenames:
try:
file_path = f"{UPLOADS_PATH}/{filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
except Exception as e:
print(f"Error deleting {filename}: {e}")
upload_volume.commit()
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
def transcribe_from_url(
audio_file_url: str = Body(
..., description="URL of the audio file to transcribe"
),
model: str = Body(MODEL_NAME),
language: str = Body("en", description="Language code (only 'en' supported)"),
timestamp_offset: float = Body(0.0),
):
# Parakeet only supports English
if language != "en":
raise HTTPException(
status_code=400,
detail=f"Parakeet model only supports English. Got language='{language}'",
)
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
try:
func = transcriber_file.transcribe_segment.spawn(
filename=unique_filename,
timestamp_offset=timestamp_offset,
)
result = func.get()
return result
finally:
try:
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
upload_volume.commit()
except Exception as e:
print(f"Error cleaning up {unique_filename}: {e}")
return app
class NoStdStreams:
def __init__(self):
self.devnull = open(os.devnull, "w")
def __enter__(self):
self._stdout, self._stderr = sys.stdout, sys.stderr
self._stdout.flush()
self._stderr.flush()
sys.stdout, sys.stderr = self.devnull, self.devnull
def __exit__(self, exc_type, exc_value, traceback):
sys.stdout, sys.stderr = self._stdout, self._stderr
self.devnull.close()

View File

@@ -0,0 +1,64 @@
"""add_long_summary_to_search_vector
Revision ID: 0ab2d7ffaa16
Revises: b1c33bd09963
Create Date: 2025-08-15 13:27:52.680211
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0ab2d7ffaa16"
down_revision: Union[str, None] = "b1c33bd09963"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the existing search vector column and index
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
op.drop_column("transcript", "search_vector_en")
# Recreate the search vector column with long_summary included
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
) STORED
""")
# Recreate the GIN index for the search vector
op.create_index(
"idx_transcript_search_vector_en",
"transcript",
["search_vector_en"],
postgresql_using="gin",
)
def downgrade() -> None:
# Drop the updated search vector column and index
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
op.drop_column("transcript", "search_vector_en")
# Recreate the original search vector column without long_summary
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
) STORED
""")
# Recreate the GIN index for the search vector
op.create_index(
"idx_transcript_search_vector_en",
"transcript",
["search_vector_en"],
postgresql_using="gin",
)

View File

@@ -0,0 +1,41 @@
"""add_search_optimization_indexes
Revision ID: b1c33bd09963
Revises: 9f5c78d352d6
Create Date: 2025-08-14 17:26:02.117408
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "b1c33bd09963"
down_revision: Union[str, None] = "9f5c78d352d6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add indexes for actual search filtering patterns used in frontend
# Based on /browse page filters: room_id and source_kind
# Index for room_id + created_at (for room-specific searches with date ordering)
op.create_index(
"idx_transcript_room_id_created_at",
"transcript",
["room_id", "created_at"],
if_not_exists=True,
)
# Index for source_kind alone (actively used filter in frontend)
op.create_index(
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
)
def downgrade() -> None:
# Remove the indexes in reverse order
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)

View File

@@ -32,7 +32,6 @@ dependencies = [
"redis>=5.0.1",
"python-jose[cryptography]>=3.3.0",
"python-multipart>=0.0.6",
"faster-whisper>=0.10.0",
"transformers>=4.36.2",
"jsonschema>=4.23.0",
"openai>=1.59.7",
@@ -57,6 +56,7 @@ tests = [
"httpx-ws>=0.4.1",
"pytest-httpx>=0.23.1",
"pytest-celery>=0.0.0",
"pytest-recording>=0.13.4",
"pytest-docker>=3.2.3",
"asgi-lifespan>=2.1.0",
]
@@ -67,6 +67,15 @@ evaluation = [
"tqdm>=4.66.0",
"pydantic>=2.1.1",
]
local = [
"pyannote-audio>=3.3.2",
"faster-whisper>=0.10.0",
]
silero-vad = [
"silero-vad>=5.1.2",
"torch>=2.8.0",
"torchaudio>=2.8.0",
]
[tool.uv]
default-groups = [
@@ -74,6 +83,21 @@ default-groups = [
"tests",
"aws",
"evaluation",
"local",
"silero-vad"
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu" },
]
torchaudio = [
{ index = "pytorch-cpu" },
]
[build-system]
@@ -94,6 +118,9 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"
markers = [
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
]
[tool.ruff.lint]
select = [

View File

@@ -1,24 +1,37 @@
"""Search functionality for transcripts and other entities."""
import itertools
from dataclasses import dataclass
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict
from typing import Annotated, Any, Dict, Iterator
import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from fastapi import HTTPException
from pydantic import (
BaseModel,
Field,
NonNegativeFloat,
NonNegativeInt,
ValidationError,
constr,
field_serializer,
)
from reflector.db import get_database
from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
from reflector.logger import logger
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = 150
DEFAULT_MAX_SNIPPETS = 3
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
LONG_SUMMARY_MAX_SNIPPETS = 2
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
@@ -32,6 +45,82 @@ SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
WEBVTT_SPEC_HEADER = "WEBVTT"
WebVTTContent = Annotated[
str,
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
]
class WebVTTProcessor:
"""Stateless processor for WebVTT content operations."""
@staticmethod
def parse(raw_content: str) -> WebVTTContent:
"""Parse WebVTT content and return it as a string."""
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
return raw_content
@staticmethod
def extract_text(webvtt_content: WebVTTContent) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except webvtt.errors.MalformedFileError as e:
logger.warning(f"Malformed WebVTT content: {e}")
return ""
except (UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to decode WebVTT content: {e}")
return ""
except AttributeError as e:
logger.error(
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
)
return ""
except Exception as e:
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
return ""
@staticmethod
def generate_snippets(
webvtt_content: WebVTTContent,
query: str,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate snippets from WebVTT content."""
return SnippetGenerator.generate(
WebVTTProcessor.extract_text(webvtt_content),
query,
max_snippets=max_snippets,
)
@dataclass(frozen=True)
class SnippetCandidate:
"""Represents a candidate snippet with its position."""
_text: str
start: NonNegativeInt
_original_text_length: int
@property
def end(self) -> NonNegativeInt:
"""Calculate end position from start and raw text length."""
return self.start + len(self._text)
def text(self) -> str:
"""Get display text with ellipses added if needed."""
result = self._text.strip()
if self.start > 0:
result = "..." + result
if self.end < self._original_text_length:
result = result + "..."
return result
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
@@ -41,6 +130,7 @@ class SearchParameters(BaseModel):
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
source_kind: SourceKind | None = None
class SearchResultDB(BaseModel):
@@ -64,13 +154,18 @@ class SearchResult(BaseModel):
title: str | None = None
user_id: str | None = None
room_id: str | None = None
room_name: str | None = None
source_kind: SourceKind
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: float | None = Field(..., ge=0, description="Duration in seconds")
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
total_match_count: NonNegativeInt = Field(
default=0, description="Total number of matches found in the transcript"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
@@ -79,84 +174,153 @@ class SearchResult(BaseModel):
return dt.isoformat()
class SearchController:
"""Controller for search operations across different entities."""
class SnippetGenerator:
"""Stateless generator for text snippets and match operations."""
@staticmethod
def _extract_webvtt_text(webvtt_content: str) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
if not webvtt_content:
return ""
def find_all_matches(text: str, query: str) -> Iterator[int]:
"""Generate all match positions for a query in text."""
if not text:
logger.warning("Empty text for search query in find_all_matches")
return
if not query:
logger.warning("Empty query for search text in find_all_matches")
return
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
return ""
except AttributeError as e:
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
return ""
text_lower = text.lower()
query_lower = query.lower()
start = 0
prev_start = start
while (pos := text_lower.find(query_lower, start)) != -1:
yield pos
start = pos + len(query_lower)
if start <= prev_start:
raise ValueError("panic! find_all_matches is not incremental")
prev_start = start
@staticmethod
def _generate_snippets(
def count_matches(text: str, query: str) -> NonNegativeInt:
"""Count total number of matches for a query in text."""
ZERO = NonNegativeInt(0)
if not text:
logger.warning("Empty text for search query in count_matches")
return ZERO
if not query:
logger.warning("Empty query for search text in count_matches")
return ZERO
return NonNegativeInt(
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
)
@staticmethod
def create_snippet(
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
) -> SnippetCandidate:
"""Create a snippet from a match position."""
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
snippet_text = text[snippet_start:snippet_end]
return SnippetCandidate(
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
)
@staticmethod
def filter_non_overlapping(
candidates: Iterator[SnippetCandidate],
) -> Iterator[str]:
"""Filter out overlapping snippets and return only display text."""
last_end = 0
for candidate in candidates:
display_text = candidate.text()
# it means that next overlapping snippets simply don't get included
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
if candidate.start >= last_end and display_text:
yield display_text
last_end = candidate.end
@staticmethod
def generate(
text: str,
q: SearchQuery,
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: int = DEFAULT_MAX_SNIPPETS,
query: str,
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate multiple snippets around all occurrences of search term."""
if not text or not q:
"""Generate snippets from text."""
if not text or not query:
logger.warning("Empty text or query for generate_snippets")
return []
snippets = []
lower_text = text.lower()
search_lower = q.lower()
candidates = (
SnippetGenerator.create_snippet(text, pos, max_length)
for pos in SnippetGenerator.find_all_matches(text, query)
)
filtered = SnippetGenerator.filter_non_overlapping(candidates)
snippets = list(itertools.islice(filtered, max_snippets))
last_snippet_end = 0
start_pos = 0
while len(snippets) < max_snippets:
match_pos = lower_text.find(search_lower, start_pos)
if match_pos == -1:
if not snippets and search_lower.split():
first_word = search_lower.split()[0]
match_pos = lower_text.find(first_word, start_pos)
if match_pos == -1:
break
else:
break
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
snippet_end = min(
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
)
if snippet_start < last_snippet_end:
start_pos = match_pos + len(search_lower)
continue
snippet = text[snippet_start:snippet_end]
if snippet_start > 0:
snippet = "..." + snippet
if snippet_end < len(text):
snippet = snippet + "..."
snippet = snippet.strip()
if snippet:
snippets.append(snippet)
last_snippet_end = snippet_end
start_pos = match_pos + len(search_lower)
if start_pos >= len(text):
break
# Fallback to first word search if no full matches
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
if not snippets and " " in query:
first_word = query.split()[0]
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
return snippets
@staticmethod
def from_summary(
summary: str,
query: str,
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
) -> list[str]:
"""Generate snippets from summary text."""
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
@staticmethod
def combine_sources(
summary: str | None,
webvtt: WebVTTContent | None,
query: str,
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> tuple[list[str], NonNegativeInt]:
"""Combine snippets from multiple sources and return total match count.
Returns (snippets, total_match_count) tuple.
snippets can be empty for real in case of e.g. title match
"""
webvtt_matches = 0
summary_matches = 0
if webvtt:
webvtt_text = WebVTTProcessor.extract_text(webvtt)
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
if summary:
summary_matches = SnippetGenerator.count_matches(summary, query)
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
summary_snippets = (
SnippetGenerator.from_summary(summary, query) if summary else []
)
if len(summary_snippets) >= max_total:
return summary_snippets[:max_total], total_matches
remaining = max_total - len(summary_snippets)
webvtt_snippets = (
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
if webvtt
else []
)
return summary_snippets + webvtt_snippets, total_matches
class SearchController:
"""Controller for search operations across different entities."""
@classmethod
async def search_transcripts(
cls, params: SearchParameters
@@ -172,39 +336,70 @@ class SearchController:
)
return [], 0
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
base_columns = [
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
transcripts.c.long_summary,
sqlalchemy.case(
(
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
"Deleted Room",
),
else_=rooms.c.name,
).label("room_name"),
]
if params.query_text:
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
rank_column = sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank")
else:
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
columns = base_columns + [rank_column]
base_query = sqlalchemy.select(columns).select_from(
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
)
base_query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank"),
]
).where(transcripts.c.search_vector_en.op("@@")(search_query))
if params.query_text:
base_query = base_query.where(
transcripts.c.search_vector_en.op("@@")(search_query)
)
if params.user_id:
base_query = base_query.where(transcripts.c.user_id == params.user_id)
base_query = base_query.where(
sqlalchemy.or_(
transcripts.c.user_id == params.user_id, rooms.c.is_shared
)
)
else:
base_query = base_query.where(rooms.c.is_shared)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
if params.source_kind:
base_query = base_query.where(
transcripts.c.source_kind == params.source_kind
)
if params.query_text:
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else:
order_by = sqlalchemy.desc(transcripts.c.created_at)
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
query = (
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
.limit(params.limit)
.offset(params.offset)
)
rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
@@ -214,18 +409,40 @@ class SearchController:
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt: str | None = r_dict.pop("webvtt", None)
webvtt_raw: str | None = r_dict.pop("webvtt", None)
if webvtt_raw:
webvtt = WebVTTProcessor.parse(webvtt_raw)
else:
webvtt = None
long_summary: str | None = r_dict.pop("long_summary", None)
room_name: str | None = r_dict.pop("room_name", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets = []
if webvtt:
plain_text = cls._extract_webvtt_text(webvtt)
snippets = cls._generate_snippets(plain_text, params.query_text)
snippets, total_match_count = SnippetGenerator.combine_sources(
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
)
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
return SearchResult(
**db_result.model_dump(),
room_name=room_name,
search_snippets=snippets,
total_match_count=total_match_count,
)
try:
results = [_process_result(r) for r in rs]
except ValidationError as e:
logger.error(f"Invalid search result data: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail="Internal search result data consistency error"
)
except Exception as e:
logger.error(f"Error processing search results: {e}", exc_info=True)
raise
results = [_process_result(r) for r in rs]
return results, total
search_controller = SearchController()
webvtt_processor = WebVTTProcessor()
snippet_generator = SnippetGenerator()

View File

@@ -88,6 +88,8 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
# Add PostgreSQL-specific full-text search column
@@ -99,7 +101,8 @@ if is_postgresql():
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)

View File

@@ -0,0 +1,375 @@
"""
File-based processing pipeline
==============================
Optimized pipeline for processing complete audio/video files.
Uses parallel processing for transcription, diarization, and waveform generation.
"""
import asyncio
from pathlib import Path
import av
import structlog
from celery import shared_task
from reflector.db.transcripts import (
Transcript,
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
from reflector.processors import (
AudioFileWriterProcessor,
TranscriptFinalSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.file_diarization import FileDiarizationInput
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
from reflector.processors.file_transcript import FileTranscriptInput
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
from reflector.processors.transcript_diarization_assembler import (
TranscriptDiarizationAssemblerInput,
TranscriptDiarizationAssemblerProcessor,
)
from reflector.processors.types import (
DiarizationSegment,
TitleSummary,
)
from reflector.processors.types import (
Transcript as TranscriptType,
)
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
class EmptyPipeline:
"""Empty pipeline for processors that need a pipeline reference"""
def __init__(self, logger: structlog.BoundLogger):
self.logger = logger
def get_pref(self, k, d=None):
return d
async def emit(self, event):
pass
class PipelineMainFile(PipelineMainBase):
"""
Optimized file processing pipeline.
Processes complete audio/video files with parallel execution.
"""
logger: structlog.BoundLogger = None
empty_pipeline = None
def __init__(self, transcript_id: str):
super().__init__(transcript_id=transcript_id)
self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger)
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
for i, result in enumerate(results):
if not isinstance(result, Exception):
continue
self.logger.error(
f"Error in {operation} (task {i}): {result}",
transcript_id=self.transcript_id,
exc_info=result,
)
async def process(self, file_path: Path):
"""Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript()
# Extract audio and write to transcript location
audio_path = await self.extract_and_write_audio(file_path, transcript)
# Upload for processing
audio_url = await self.upload_audio(audio_path, transcript)
# Run parallel processing
await self.run_parallel_processing(
audio_path,
audio_url,
transcript.source_language,
transcript.target_language,
)
self.logger.info("File pipeline complete")
async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript
) -> Path:
"""Extract audio from video if needed and write to transcript location as MP3"""
self.logger.info(f"Processing audio file: {file_path}")
# Check if it's already audio-only
container = av.open(str(file_path))
has_video = len(container.streams.video) > 0
container.close()
# Use AudioFileWriterProcessor to write MP3 to transcript location
mp3_writer = AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
on_duration=self.on_duration,
)
# Process audio frames and write to transcript location
input_container = av.open(str(file_path))
for frame in input_container.decode(audio=0):
await mp3_writer.push(frame)
await mp3_writer.flush()
input_container.close()
if has_video:
self.logger.info(
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
)
else:
self.logger.info(
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
)
return transcript.audio_mp3_filename
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
"""Upload audio to storage for processing"""
storage = get_transcripts_storage()
if not storage:
raise Exception(
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
)
self.logger.info("Uploading audio to storage")
with open(audio_path, "rb") as f:
audio_data = f.read()
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
await storage.put_file(storage_path, audio_data)
audio_url = await storage.get_file_url(storage_path)
self.logger.info(f"Audio uploaded to {audio_url}")
return audio_url
async def run_parallel_processing(
self,
audio_path: Path,
audio_url: str,
source_language: str,
target_language: str,
):
"""Coordinate parallel processing of transcription, diarization, and waveform"""
self.logger.info(
"Starting parallel processing", transcript_id=self.transcript_id
)
# Phase 1: Parallel processing of independent tasks
transcription_task = self.transcribe_file(audio_url, source_language)
diarization_task = self.diarize_file(audio_url)
waveform_task = self.generate_waveform(audio_path)
results = await asyncio.gather(
transcription_task, diarization_task, waveform_task, return_exceptions=True
)
transcript_result = results[0]
diarization_result = results[1]
# Handle errors - raise any exception that occurred
self._handle_gather_exceptions(results, "parallel processing")
for result in results:
if isinstance(result, Exception):
raise result
# Phase 2: Assemble transcript with diarization
self.logger.info(
"Assembling transcript with diarization", transcript_id=self.transcript_id
)
processor = TranscriptDiarizationAssemblerProcessor()
input_data = TranscriptDiarizationAssemblerInput(
transcript=transcript_result, diarization=diarization_result or []
)
# Store result for retrieval
diarized_transcript: Transcript | None = None
async def capture_result(transcript):
nonlocal diarized_transcript
diarized_transcript = transcript
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
if not diarized_transcript:
raise ValueError("No diarized transcript captured")
# Phase 3: Generate topics from diarized transcript
self.logger.info("Generating topics", transcript_id=self.transcript_id)
topics = await self.detect_topics(diarized_transcript, target_language)
# Phase 4: Generate title and summaries in parallel
self.logger.info(
"Generating title and summaries", transcript_id=self.transcript_id
)
results = await asyncio.gather(
self.generate_title(topics),
self.generate_summaries(topics),
return_exceptions=True,
)
self._handle_gather_exceptions(results, "title and summary generation")
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
"""Transcribe complete file"""
processor = FileTranscriptAutoProcessor()
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
# Store result for retrieval
result: TranscriptType | None = None
async def capture_result(transcript):
nonlocal result
result = transcript
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
if not result:
raise ValueError("No transcript captured")
return result
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
"""Get diarization for file"""
if not settings.DIARIZATION_BACKEND:
self.logger.info("Diarization disabled")
return None
processor = FileDiarizationAutoProcessor()
input_data = FileDiarizationInput(audio_url=audio_url)
# Store result for retrieval
result = None
async def capture_result(diarization_output):
nonlocal result
result = diarization_output.diarization
try:
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
return result
except Exception as e:
self.logger.error(f"Diarization failed: {e}")
return None
async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform"""
transcript = await self.get_transcript()
processor = AudioWaveformProcessor(
audio_path=audio_path,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
)
processor.set_pipeline(self.empty_pipeline)
await processor.flush()
async def detect_topics(
self, transcript: TranscriptType, target_language: str
) -> list[TitleSummary]:
"""Detect topics from complete transcript"""
chunk_size = 300
topics: list[TitleSummary] = []
async def on_topic(topic: TitleSummary):
topics.append(topic)
return await self.on_topic(topic)
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
topic_detector.set_pipeline(self.empty_pipeline)
for i in range(0, len(transcript.words), chunk_size):
chunk_words = transcript.words[i : i + chunk_size]
if not chunk_words:
continue
chunk_transcript = TranscriptType(
words=chunk_words, translation=transcript.translation
)
await topic_detector.push(chunk_transcript)
await topic_detector.flush()
return topics
async def generate_title(self, topics: list[TitleSummary]):
"""Generate title from topics"""
if not topics:
self.logger.warning("No topics for title generation")
return
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
processor.set_pipeline(self.empty_pipeline)
for topic in topics:
await processor.push(topic)
await processor.flush()
async def generate_summaries(self, topics: list[TitleSummary]):
"""Generate long and short summaries from topics"""
if not topics:
self.logger.warning("No topics for summary generation")
return
transcript = await self.get_transcript()
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=self.on_long_summary,
on_short_summary=self.on_short_summary,
)
processor.set_pipeline(self.empty_pipeline)
for topic in topics:
await processor.push(topic)
await processor.flush()
@shared_task
@asynctask
async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
# Find the file to process
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)
if not audio_file:
raise Exception("No audio file found to process")
# Run file pipeline
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process(audio_file)

View File

@@ -147,15 +147,18 @@ class StrValue(BaseModel):
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
def prepare(self):
# prepare websocket
def __init__(self, transcript_id: str):
super().__init__()
self._lock = asyncio.Lock()
self.transcript_id = transcript_id
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
self._ws_manager = None
@property
def ws_manager(self) -> WebsocketManager:
if self._ws_manager is None:
self._ws_manager = get_ws_manager()
return self._ws_manager
async def get_transcript(self) -> Transcript:
# fetch the transcript
@@ -355,7 +358,6 @@ class PipelineMainLive(PipelineMainBase):
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
transcript = await self.get_transcript()
processors = [
@@ -376,6 +378,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Pipeline main live created")
pipeline.describe()
return pipeline
@@ -394,7 +397,6 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
pipeline = Pipeline(
AudioDiarizationAutoProcessor(callback=self.on_topic),
)
@@ -435,8 +437,6 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
# get transcript
self._transcript = transcript = await self.get_transcript()

View File

@@ -18,22 +18,14 @@ During its lifecycle, it will emit the following status:
import asyncio
from typing import Generic, TypeVar
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
PipelineMessage = TypeVar("PipelineMessage")
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
class PipelineRunner(Generic[PipelineMessage]):
def __init__(self):
self._task = None
self._q_cmd = asyncio.Queue(maxsize=4096)
self._ev_done = asyncio.Event()
@@ -42,6 +34,8 @@ class PipelineRunner(BaseModel, Generic[PipelineMessage]):
runner=id(self),
runner_cls=self.__class__.__name__,
)
self.status = "idle"
self.pipeline: Pipeline | None = None
async def create(self) -> Pipeline:
"""

View File

@@ -11,6 +11,13 @@ from .base import ( # noqa: F401
Processor,
ThreadedProcessor,
)
from .file_diarization import FileDiarizationProcessor # noqa: F401
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
from .file_transcript import FileTranscriptProcessor # noqa: F401
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
from .transcript_diarization_assembler import (
TranscriptDiarizationAssemblerProcessor, # noqa: F401
)
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401

View File

@@ -1,28 +1,340 @@
from typing import Optional
import av
import numpy as np
import torch
from silero_vad import VADIterator, load_silero_vad
from reflector.processors.base import Processor
class AudioChunkerProcessor(Processor):
"""
Assemble audio frames into chunks
Assemble audio frames into chunks with VAD-based speech detection
"""
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = list[av.AudioFrame]
def __init__(self, max_frames=256):
def __init__(
self,
block_frames=256,
max_frames=1024,
vad_threshold=0.5,
use_onnx=False,
min_frames=2,
):
super().__init__()
self.frames: list[av.AudioFrame] = []
self.block_frames = block_frames
self.max_frames = max_frames
self.vad_threshold = vad_threshold
self.min_frames = min_frames
# Initialize Silero VAD
self._init_vad(use_onnx)
def _init_vad(self, use_onnx=False):
"""Initialize Silero VAD model"""
try:
torch.set_num_threads(1)
self.vad_model = load_silero_vad(onnx=use_onnx)
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
self.logger.info("Silero VAD initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize Silero VAD: {e}")
self.vad_model = None
self.vad_iterator = None
async def _push(self, data: av.AudioFrame):
self.frames.append(data)
if len(self.frames) >= self.max_frames:
await self.flush()
# print("timestamp", data.pts * data.time_base * 1000)
# Check for speech segments every 32 frames (~1 second)
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
await self._process_block()
# Safety fallback - emit if we hit max frames
elif len(self.frames) >= self.max_frames:
self.logger.warning(
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
f"emitting first {self.max_frames // 2} frames"
)
frames_to_emit = self.frames[: self.max_frames // 2]
self.frames = self.frames[self.max_frames // 2 :]
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
async def _process_block(self):
# Need at least 32 frames for VAD detection (~1 second)
if len(self.frames) < 32 or self.vad_iterator is None:
return
# Processing block with current buffer size
# print(f"Processing block: {len(self.frames)} frames in buffer")
try:
# Convert frames to numpy array for VAD
audio_array = self._frames_to_numpy(self.frames)
if audio_array is None:
# Fallback: emit all frames if conversion failed
frames_to_emit = self.frames[:]
self.frames = []
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
return
# Find complete speech segments in the buffer
speech_end_frame = self._find_speech_segment_end(audio_array)
if speech_end_frame is None or speech_end_frame <= 0:
# No speech found but buffer is getting large
if len(self.frames) > 512:
# Check if it's all silence and can be discarded
# No speech segment found, buffer at {len(self.frames)} frames
# Could emit silence or discard old frames here
# For now, keep first 256 frames and discard older silence
if len(self.frames) > 768:
self.logger.debug(
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
)
self.frames = self.frames[-256:]
return
# Calculate segment timing information
frames_to_emit = self.frames[:speech_end_frame]
# Get timing from av.AudioFrame
if frames_to_emit:
first_frame = frames_to_emit[0]
last_frame = frames_to_emit[-1]
sample_rate = first_frame.sample_rate
# Calculate duration
total_samples = sum(f.samples for f in frames_to_emit)
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
# Get timestamps if available
start_time = (
first_frame.pts * first_frame.time_base if first_frame.pts else 0
)
end_time = (
last_frame.pts * last_frame.time_base if last_frame.pts else 0
)
# Convert to HH:MM:SS format for logging
def format_time(seconds):
if not seconds:
return "00:00:00"
total_seconds = int(float(seconds))
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
secs = total_seconds % 60
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
start_formatted = format_time(start_time)
end_formatted = format_time(end_time)
# Keep remaining frames for next processing
remaining_after = len(self.frames) - speech_end_frame
# Single structured log line
self.logger.info(
"Speech segment found",
start=start_formatted,
end=end_formatted,
frames=speech_end_frame,
duration=round(duration_seconds, 2),
buffer_before=len(self.frames),
remaining=remaining_after,
)
# Keep remaining frames for next processing
self.frames = self.frames[speech_end_frame:]
# Filter out segments with too few frames
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
except Exception as e:
self.logger.error(f"Error in VAD processing: {e}")
# Fallback to simple chunking
if len(self.frames) >= self.block_frames:
frames_to_emit = self.frames[: self.block_frames]
self.frames = self.frames[self.block_frames :]
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
"""Convert av.AudioFrame list to numpy array for VAD processing"""
if not frames:
return None
try:
first_frame = frames[0]
original_sample_rate = first_frame.sample_rate
audio_data = []
for frame in frames:
frame_array = frame.to_ndarray()
# Handle stereo -> mono conversion
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
frame_array = np.mean(frame_array, axis=0)
elif len(frame_array.shape) == 2:
frame_array = frame_array.flatten()
audio_data.append(frame_array)
if not audio_data:
return None
combined_audio = np.concatenate(audio_data)
# Resample from 48kHz to 16kHz if needed
if original_sample_rate != 16000:
combined_audio = self._resample_audio(
combined_audio, original_sample_rate, 16000
)
# Ensure float32 format
if combined_audio.dtype == np.int16:
# Normalize int16 audio to float32 in range [-1.0, 1.0]
combined_audio = combined_audio.astype(np.float32) / 32768.0
elif combined_audio.dtype != np.float32:
combined_audio = combined_audio.astype(np.float32)
return combined_audio
except Exception as e:
self.logger.error(f"Error converting frames to numpy: {e}")
return None
def _resample_audio(
self, audio: np.ndarray, from_sr: int, to_sr: int
) -> np.ndarray:
"""Simple linear resampling from from_sr to to_sr"""
if from_sr == to_sr:
return audio
try:
# Simple linear interpolation resampling
ratio = to_sr / from_sr
new_length = int(len(audio) * ratio)
# Create indices for interpolation
old_indices = np.linspace(0, len(audio) - 1, new_length)
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
return resampled.astype(np.float32)
except Exception as e:
self.logger.error("Resampling error", exc_info=e)
# Fallback: simple decimation/repetition
if from_sr > to_sr:
# Downsample by taking every nth sample
step = from_sr // to_sr
return audio[::step]
else:
# Upsample by repeating samples
repeat = to_sr // from_sr
return np.repeat(audio, repeat)
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
"""Find complete speech segments and return frame index at segment end"""
if self.vad_iterator is None or len(audio_array) == 0:
return None
try:
# Process audio in 512-sample windows for VAD
window_size = 512
min_silence_windows = 3 # Require 3 windows of silence after speech
# Track speech state
in_speech = False
speech_start = None
speech_end = None
silence_count = 0
for i in range(0, len(audio_array), window_size):
chunk = audio_array[i : i + window_size]
if len(chunk) < window_size:
chunk = np.pad(chunk, (0, window_size - len(chunk)))
# Detect if this window has speech
speech_dict = self.vad_iterator(chunk, return_seconds=True)
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
if speech_dict:
if not in_speech:
# Speech started
speech_start = i
in_speech = True
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
silence_count = 0 # Reset silence counter
continue
if not in_speech:
continue
# We're in speech but found silence
silence_count += 1
if silence_count < min_silence_windows:
continue
# Found end of speech segment
speech_end = i - (min_silence_windows - 1) * window_size
# Debug: print(f"Speech END at sample {speech_end}")
# Convert sample position to frame index
samples_per_frame = self.frames[0].samples if self.frames else 1024
# Account for resampling: we process at 16kHz but frames might be 48kHz
resample_ratio = 48000 / 16000 # 3x
actual_sample_pos = int(speech_end * resample_ratio)
frame_index = actual_sample_pos // samples_per_frame
# Ensure we don't exceed buffer
frame_index = min(frame_index, len(self.frames))
return frame_index
return None
except Exception as e:
self.logger.error(f"Error finding speech segment: {e}")
return None
async def _flush(self):
frames = self.frames[:]
self.frames = []
if frames:
await self.emit(frames)
if len(frames) >= self.min_frames:
await self.emit(frames)
else:
self.logger.debug(
f"Ignoring flush segment with {len(frames)} frames "
f"(< {self.min_frames} minimum)"
)

View File

@@ -1,6 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import (
AudioDiarizationInput,
DiarizationSegment,
TitleSummary,
Word,
)
@@ -37,18 +38,21 @@ class AudioDiarizationProcessor(Processor):
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError
def assign_speaker(self, words: list[Word], diarization: list[dict]):
self._diarization_remove_overlap(diarization)
self._diarization_remove_segment_without_words(words, diarization)
self._diarization_merge_same_speaker(words, diarization)
self._diarization_assign_speaker(words, diarization)
@classmethod
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
cls._diarization_remove_overlap(diarization)
cls._diarization_remove_segment_without_words(words, diarization)
cls._diarization_merge_same_speaker(diarization)
cls._diarization_assign_speaker(words, diarization)
def iter_words_from_topics(self, topics: TitleSummary):
@staticmethod
def iter_words_from_topics(topics: list[TitleSummary]):
for topic in topics:
for word in topic.transcript.words:
yield word
def is_word_continuation(self, word_prev, word):
@staticmethod
def is_word_continuation(word_prev, word):
"""
Return True if the word is a continuation of the previous word
by checking if the previous word is ending with a punctuation
@@ -61,7 +65,8 @@ class AudioDiarizationProcessor(Processor):
return False
return True
def _diarization_remove_overlap(self, diarization: list[dict]):
@staticmethod
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
"""
Remove overlap in diarization results
@@ -86,8 +91,9 @@ class AudioDiarizationProcessor(Processor):
else:
diarization_idx += 1
@staticmethod
def _diarization_remove_segment_without_words(
self, words: list[Word], diarization: list[dict]
words: list[Word], diarization: list[DiarizationSegment]
):
"""
Remove diarization segments without words
@@ -116,9 +122,8 @@ class AudioDiarizationProcessor(Processor):
else:
diarization_idx += 1
def _diarization_merge_same_speaker(
self, words: list[Word], diarization: list[dict]
):
@staticmethod
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
"""
Merge diarization contigous segments with the same speaker
@@ -135,7 +140,10 @@ class AudioDiarizationProcessor(Processor):
else:
diarization_idx += 1
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
@classmethod
def _diarization_assign_speaker(
cls, words: list[Word], diarization: list[DiarizationSegment]
):
"""
Assign speaker to words based on diarization
@@ -143,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
"""
word_idx = 0
last_speaker = None
last_speaker = 0
for d in diarization:
start = d["start"]
end = d["end"]
@@ -158,7 +166,7 @@ class AudioDiarizationProcessor(Processor):
# If it's a continuation, assign with the last speaker
is_continuation = False
if word_idx > 0 and word_idx < len(words) - 1:
is_continuation = self.is_word_continuation(
is_continuation = cls.is_word_continuation(
*words[word_idx - 1 : word_idx + 1]
)
if is_continuation:

View File

@@ -0,0 +1,74 @@
import os
import torch
import torchaudio
from pyannote.audio import Pipeline
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
"""Local diarization processor using pyannote.audio library"""
def __init__(
self,
model_name: str = "pyannote/speaker-diarization-3.1",
pyannote_auth_token: str | None = None,
device: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.model_name = model_name
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
self.device = device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
self.diarization_pipeline = Pipeline.from_pretrained(
self.model_name, use_auth_token=self.auth_token
)
self.diarization_pipeline.to(torch.device(self.device))
self.logger.info(f"Diarization model loaded on device: {self.device}")
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
try:
# Load audio file (audio_url is assumed to be a local file path)
self.logger.info(f"Loading local audio file: {data.audio_url}")
waveform, sample_rate = torchaudio.load(data.audio_url)
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
self.logger.info("Running speaker diarization")
diarization = self.diarization_pipeline(audio_input)
# Convert pyannote diarization output to our format
segments = []
for segment, _, speaker in diarization.itertracks(yield_label=True):
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
speaker_id = 0
if speaker.startswith("SPEAKER_"):
try:
speaker_id = int(speaker.split("_")[-1])
except (ValueError, IndexError):
# Fallback to hash-based ID if parsing fails
speaker_id = hash(speaker) % 1000
segments.append(
{
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"speaker": speaker_id,
}
)
self.logger.info(f"Diarization completed with {len(segments)} segments")
return segments
except Exception as e:
self.logger.exception(f"Diarization failed: {e}")
raise
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)

View File

@@ -3,11 +3,24 @@ from time import monotonic_ns
from uuid import uuid4
import av
from av.audio.resampler import AudioResampler
from reflector.processors.base import Processor
from reflector.processors.types import AudioFile
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
frame_copy = frame.from_ndarray(
frame.to_ndarray(),
format=frame.format.name,
layout=frame.layout.name,
)
frame_copy.sample_rate = frame.sample_rate
frame_copy.pts = frame.pts
frame_copy.time_base = frame.time_base
return frame_copy
class AudioMergeProcessor(Processor):
"""
Merge audio frame into a single file
@@ -16,37 +29,92 @@ class AudioMergeProcessor(Processor):
INPUT_TYPE = list[av.AudioFrame]
OUTPUT_TYPE = AudioFile
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
super().__init__(**kwargs)
self.downsample_to_16k_mono = downsample_to_16k_mono
async def _push(self, data: list[av.AudioFrame]):
if not data:
return
# get audio information from first frame
frame = data[0]
channels = len(frame.layout.channels)
sample_rate = frame.sample_rate
sample_width = frame.format.bytes
original_channels = len(frame.layout.channels)
original_sample_rate = frame.sample_rate
original_sample_width = frame.format.bytes
# determine if we need processing
needs_processing = self.downsample_to_16k_mono and (
original_sample_rate != 16000 or original_channels != 1
)
# determine output parameters
if self.downsample_to_16k_mono:
output_sample_rate = 16000
output_channels = 1
output_sample_width = 2 # 16-bit = 2 bytes
else:
output_sample_rate = original_sample_rate
output_channels = original_channels
output_sample_width = original_sample_width
# create audio file
uu = uuid4().hex
fd = io.BytesIO()
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
for frame in data:
for packet in out_stream.encode(frame):
if needs_processing:
# Process with PyAV resampler
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
out_stream.layout = "mono"
# Create resampler if needed
resampler = None
if original_sample_rate != 16000 or original_channels != 1:
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
for frame in data:
if resampler:
# Resample and convert to mono
# XXX for an unknown reason, if we don't use a copy of the frame, we get
# Invalid Argumment from resample. Debugging indicate that when a previous processor
# already used the frame (like AudioFileWriter), it make it invalid argument here.
resampled_frames = resampler.resample(copy_frame(frame))
for resampled_frame in resampled_frames:
for packet in out_stream.encode(resampled_frame):
out_container.mux(packet)
else:
# Direct encoding without resampling
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Flush the encoder
for packet in out_stream.encode(None):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
out_container.close()
out_container.close()
else:
# Use PyAV for original frames (no processing needed)
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
out_stream.layout = "mono" if output_channels == 1 else frame.layout
for frame in data:
for packet in out_stream.encode(frame):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
out_container.close()
fd.seek(0)
# emit audio file
audiofile = AudioFile(
name=f"{monotonic_ns()}-{uu}.wav",
fd=fd,
sample_rate=sample_rate,
channels=channels,
sample_width=sample_width,
sample_rate=output_sample_rate,
channels=output_channels,
sample_width=output_sample_width,
timestamp=data[0].pts * data[0].time_base,
)

View File

@@ -12,6 +12,9 @@ API will be a POST request to TRANSCRIPT_URL:
"""
from typing import List
import aiohttp
from openai import AsyncOpenAI
from reflector.processors.audio_transcript import AudioTranscriptProcessor
@@ -21,7 +24,9 @@ from reflector.settings import settings
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
def __init__(
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
):
super().__init__()
if not settings.TRANSCRIPT_URL:
raise Exception(
@@ -30,6 +35,126 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.modal_api_key = modal_api_key
self.max_batch_duration = 10.0
self.max_batch_files = 15
self.batch_enabled = batch_enabled
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
@classmethod
def _calculate_duration(cls, audio_file: AudioFile) -> float:
"""Calculate audio duration in seconds from AudioFile metadata"""
# Duration = total_samples / sample_rate
# We need to estimate total samples from the file data
import wave
try:
# Try to read as WAV file to get duration
audio_file.fd.seek(0)
with wave.open(audio_file.fd, "rb") as wav_file:
frames = wav_file.getnframes()
sample_rate = wav_file.getframerate()
duration = frames / sample_rate
return duration
except Exception:
# Fallback: estimate from file size and audio parameters
audio_file.fd.seek(0, 2) # Seek to end
file_size = audio_file.fd.tell()
audio_file.fd.seek(0) # Reset to beginning
# Estimate: file_size / (sample_rate * channels * sample_width)
bytes_per_second = (
audio_file.sample_rate
* audio_file.channels
* (audio_file.sample_width // 8)
)
estimated_duration = (
file_size / bytes_per_second if bytes_per_second > 0 else 0
)
return max(0, estimated_duration)
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
"""Group audio files into batches with maximum 30s total duration"""
batches = []
current_batch = []
current_duration = 0.0
for audio_file in audio_files:
duration = self._calculate_duration(audio_file)
# If adding this file exceeds max duration, start a new batch
if current_duration + duration > self.max_batch_duration and current_batch:
batches.append(current_batch)
current_batch = [audio_file]
current_duration = duration
else:
current_batch.append(audio_file)
current_duration += duration
# Add the last batch if not empty
if current_batch:
batches.append(current_batch)
return batches
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
"""Transcribe a batch of audio files using the parakeet backend"""
if not audio_files:
return []
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
# Prepare form data for batch request
data = aiohttp.FormData()
data.add_field("language", self.get_pref("audio:source_language", "en"))
data.add_field("batch", "true")
for i, audio_file in enumerate(audio_files):
audio_file.fd.seek(0)
data.add_field(
"files",
audio_file.fd,
filename=f"{audio_file.name}",
content_type="audio/wav",
)
# Make batch request
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
f"{self.transcript_url}/audio/transcriptions",
data=data,
headers=headers,
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"Batch transcription failed: {response.status} {error_text}"
)
result = await response.json()
# Process batch results
transcripts = []
results = result.get("results", [])
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
transcript = Transcript(
words=[
Word(
text=word_info["word"],
start=word_info["start"],
end=word_info["end"],
)
for word_info in file_result.get("words", [])
]
)
transcript.add_offset(audio_file.timestamp)
transcripts.append(transcript)
return transcripts
async def _transcript(self, data: AudioFile):
async with AsyncOpenAI(
@@ -62,5 +187,96 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
return transcript
async def transcript_multiple(
self, audio_files: List[AudioFile]
) -> List[Transcript]:
"""Transcribe multiple audio files using batching"""
if len(audio_files) == 1:
# Single file, use existing method
return [await self._transcript(audio_files[0])]
# Create batches with max 30s duration each
batches = self._create_batches(audio_files)
self.logger.debug(
f"Processing {len(audio_files)} files in {len(batches)} batches"
)
# Process all batches concurrently
all_transcripts = []
for batch in batches:
batch_transcripts = await self._transcript_batch(batch)
all_transcripts.extend(batch_transcripts)
return all_transcripts
async def _push(self, data: AudioFile):
"""Override _push to support batching"""
if not self.batch_enabled:
# Use parent implementation for single file processing
return await super()._push(data)
# Add file to pending batch
self.pending_files.append(data)
self.logger.debug(
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
)
# Calculate total duration of pending files
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
# Process batch if it reaches max duration or has multiple files ready for optimization
should_process_batch = (
total_duration >= self.max_batch_duration
or len(self.pending_files) >= self.max_batch_files
)
if should_process_batch:
await self._process_pending_batch()
async def _process_pending_batch(self):
"""Process all pending files as batches"""
if not self.pending_files:
return
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
try:
# Create batches respecting duration limit
batches = self._create_batches(self.pending_files)
# Process each batch
for batch in batches:
self.m_transcript_call.inc()
try:
with self.m_transcript.time():
# Use batch transcription
transcripts = await self._transcript_batch(batch)
self.m_transcript_success.inc()
# Emit each transcript
for transcript in transcripts:
if transcript:
await self.emit(transcript)
except Exception:
self.m_transcript_failure.inc()
raise
finally:
# Release audio files
for audio_file in batch:
audio_file.release()
finally:
# Clear pending files
self.pending_files.clear()
async def _flush(self):
"""Process any remaining files when flushing"""
await self._process_pending_batch()
await super()._flush()
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)

View File

@@ -173,6 +173,7 @@ class Processor(Emitter):
except Exception:
self.m_processor_failure.inc()
self.logger.exception("Error in push")
raise
async def flush(self):
"""
@@ -240,33 +241,45 @@ class ThreadedProcessor(Processor):
self.INPUT_TYPE = processor.INPUT_TYPE
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.queue = asyncio.Queue()
self.task = asyncio.get_running_loop().create_task(self.loop())
self.queue = asyncio.Queue(maxsize=50)
self.task: asyncio.Task | None = None
def set_pipeline(self, pipeline: "Pipeline"):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
async def loop(self):
while True:
data = await self.queue.get()
self.m_processor_queue.set(self.queue.qsize())
with self.m_processor_queue_in_progress.track_inprogress():
try:
if data is None:
await self.processor.flush()
break
try:
while True:
data = await self.queue.get()
self.m_processor_queue.set(self.queue.qsize())
with self.m_processor_queue_in_progress.track_inprogress():
try:
await self.processor.push(data)
except Exception:
self.logger.error(
f"Error in push {self.processor.__class__.__name__}"
", continue"
)
finally:
self.queue.task_done()
if data is None:
await self.processor.flush()
break
try:
await self.processor.push(data)
except Exception:
self.logger.error(
f"Error in push {self.processor.__class__.__name__}"
", continue"
)
finally:
self.queue.task_done()
except Exception as e:
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
async def _ensure_task(self):
if self.task is None:
self.task = asyncio.get_running_loop().create_task(self.loop())
# XXX not doing a sleep here make the whole pipeline prior the thread
# to be running without having a chance to work on the task here.
await asyncio.sleep(0)
async def _push(self, data):
await self._ensure_task()
await self.queue.put(data)
async def _flush(self):

View File

@@ -0,0 +1,33 @@
from pydantic import BaseModel
from reflector.processors.base import Processor
from reflector.processors.types import DiarizationSegment
class FileDiarizationInput(BaseModel):
"""Input for file diarization containing audio URL"""
audio_url: str
class FileDiarizationOutput(BaseModel):
"""Output for file diarization containing speaker segments"""
diarization: list[DiarizationSegment]
class FileDiarizationProcessor(Processor):
"""
Diarize complete audio files from URL
"""
INPUT_TYPE = FileDiarizationInput
OUTPUT_TYPE = FileDiarizationOutput
async def _push(self, data: FileDiarizationInput):
result = await self._diarize(data)
if result:
await self.emit(result)
async def _diarize(self, data: FileDiarizationInput):
raise NotImplementedError

View File

@@ -0,0 +1,33 @@
import importlib
from reflector.processors.file_diarization import FileDiarizationProcessor
from reflector.settings import settings
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.file_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,57 @@
"""
File diarization implementation using the GPU service from modal.com
API will be a POST request to DIARIZATION_URL:
```
POST /diarize?audio_file_url=...&timestamp=0
Authorization: Bearer <modal_api_key>
```
"""
import httpx
from reflector.processors.file_diarization import (
FileDiarizationInput,
FileDiarizationOutput,
FileDiarizationProcessor,
)
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
from reflector.settings import settings
class FileDiarizationModalProcessor(FileDiarizationProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.DIARIZATION_URL:
raise Exception(
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
self.modal_api_key = modal_api_key
async def _diarize(self, data: FileDiarizationInput):
"""Get speaker diarization for file"""
self.logger.info(f"Starting diarization from {data.audio_url}")
headers = {}
if self.modal_api_key:
headers["Authorization"] = f"Bearer {self.modal_api_key}"
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
response = await client.post(
self.diarization_url,
headers=headers,
params={
"audio_file_url": data.audio_url,
"timestamp": 0,
},
)
response.raise_for_status()
diarization_data = response.json()["diarization"]
return FileDiarizationOutput(diarization=diarization_data)
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)

View File

@@ -0,0 +1,65 @@
from prometheus_client import Counter, Histogram
from reflector.processors.base import Processor
from reflector.processors.types import Transcript
class FileTranscriptInput:
"""Input for file transcription containing audio URL and language settings"""
def __init__(self, audio_url: str, language: str = "en"):
self.audio_url = audio_url
self.language = language
class FileTranscriptProcessor(Processor):
"""
Transcript complete audio files from URL
"""
INPUT_TYPE = FileTranscriptInput
OUTPUT_TYPE = Transcript
m_transcript = Histogram(
"file_transcript",
"Time spent in FileTranscript.transcript",
["backend"],
)
m_transcript_call = Counter(
"file_transcript_call",
"Number of calls to FileTranscript.transcript",
["backend"],
)
m_transcript_success = Counter(
"file_transcript_success",
"Number of successful calls to FileTranscript.transcript",
["backend"],
)
m_transcript_failure = Counter(
"file_transcript_failure",
"Number of failed calls to FileTranscript.transcript",
["backend"],
)
def __init__(self, *args, **kwargs):
name = self.__class__.__name__
self.m_transcript = self.m_transcript.labels(name)
self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name)
super().__init__(*args, **kwargs)
async def _push(self, data: FileTranscriptInput):
try:
self.m_transcript_call.inc()
with self.m_transcript.time():
result = await self._transcript(data)
self.m_transcript_success.inc()
if result:
await self.emit(result)
except Exception:
self.m_transcript_failure.inc()
raise
async def _transcript(self, data: FileTranscriptInput):
raise NotImplementedError

View File

@@ -0,0 +1,32 @@
import importlib
from reflector.processors.file_transcript import FileTranscriptProcessor
from reflector.settings import settings
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.file_transcript_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "TRANSCRIPT_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,74 @@
"""
File transcription implementation using the GPU service from modal.com
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_file_url": "https://...",
"language": "en",
"model": "parakeet-tdt-0.6b-v2",
"batch": true
}
```
"""
import httpx
from reflector.processors.file_transcript import (
FileTranscriptInput,
FileTranscriptProcessor,
)
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
from reflector.processors.types import Transcript, Word
from reflector.settings import settings
class FileTranscriptModalProcessor(FileTranscriptProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.TRANSCRIPT_URL:
raise Exception(
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
)
self.transcript_url = settings.TRANSCRIPT_URL
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
self.modal_api_key = modal_api_key
async def _transcript(self, data: FileTranscriptInput):
"""Send full file to Modal for transcription"""
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
self.logger.info(f"Starting file transcription from {data.audio_url}")
headers = {}
if self.modal_api_key:
headers["Authorization"] = f"Bearer {self.modal_api_key}"
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
response = await client.post(
url,
headers=headers,
json={
"audio_file_url": data.audio_url,
"language": data.language,
"batch": True,
},
)
response.raise_for_status()
result = response.json()
words = [
Word(
text=word_info["word"],
start=word_info["start"],
end=word_info["end"],
)
for word_info in result.get("words", [])
]
return Transcript(words=words)
# Register with the auto processor
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)

View File

@@ -0,0 +1,45 @@
"""
Processor to assemble transcript with diarization results
"""
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.base import Processor
from reflector.processors.types import DiarizationSegment, Transcript
class TranscriptDiarizationAssemblerInput:
"""Input containing transcript and diarization data"""
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
self.transcript = transcript
self.diarization = diarization
class TranscriptDiarizationAssemblerProcessor(Processor):
"""
Assemble transcript with diarization results by applying speaker assignments
"""
INPUT_TYPE = TranscriptDiarizationAssemblerInput
OUTPUT_TYPE = Transcript
async def _push(self, data: TranscriptDiarizationAssemblerInput):
result = await self._assemble(data)
if result:
await self.emit(result)
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
"""Apply diarization to transcript words"""
if not data.diarization:
self.logger.info(
"No diarization data provided, returning original transcript"
)
return data.transcript
# Reuse logic from AudioDiarizationProcessor
processor = AudioDiarizationProcessor()
words = data.transcript.words
processor.assign_speaker(words, data.diarization)
self.logger.info(f"Applied diarization to {len(words)} words")
return data.transcript

View File

@@ -2,13 +2,22 @@ import io
import re
import tempfile
from pathlib import Path
from typing import Annotated
from typing import Annotated, TypedDict
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, Field, PrivateAttr
from reflector.redis_cache import redis_cache
class DiarizationSegment(TypedDict):
"""Type definition for diarization segment containing speaker information"""
start: float
end: float
speaker: int
PUNC_RE = re.compile(r"[.;:?!…]")
profanity_filter = ProfanityFilter()

View File

@@ -26,6 +26,7 @@ class Settings(BaseSettings):
TRANSCRIPT_BACKEND: str = "whisper"
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
TRANSCRIPT_FILE_TIMEOUT: int = 600
# Audio Transcription: modal backend
TRANSCRIPT_MODAL_API_KEY: str | None = None
@@ -66,10 +67,14 @@ class Settings(BaseSettings):
DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
DIARIZATION_FILE_TIMEOUT: int = 600
# Diarization: modal backend
DIARIZATION_MODAL_API_KEY: str | None = None
# Diarization: local pyannote.audio
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
# Sentry
SENTRY_DSN: str | None = None

View File

@@ -1,10 +1,23 @@
"""
Process audio file with diarization support
===========================================
Extended version of process.py that includes speaker diarization.
This tool processes audio files locally without requiring the full server infrastructure.
"""
import asyncio
import tempfile
import uuid
from pathlib import Path
from typing import List
import av
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
Pipeline,
@@ -15,7 +28,43 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorAutoProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.base import BroadcastProcessor, Processor
from reflector.processors.types import (
AudioDiarizationInput,
TitleSummary,
TitleSummaryWithId,
)
class TopicCollectorProcessor(Processor):
"""Collect topics for diarization"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.topics: List[TitleSummaryWithId] = []
self._topic_id = 0
async def _push(self, data: TitleSummary):
# Convert to TitleSummaryWithId and collect
self._topic_id += 1
topic_with_id = TitleSummaryWithId(
id=str(self._topic_id),
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
duration=data.duration,
transcript=data.transcript,
)
self.topics.append(topic_with_id)
# Pass through the original topic
await self.emit(data)
def get_topics(self) -> List[TitleSummaryWithId]:
return self.topics
async def process_audio_file(
@@ -24,18 +73,40 @@ async def process_audio_file(
only_transcript=False,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="pyannote",
):
# build pipeline for audio processing
processors = [
# Create temp file for audio if diarization is enabled
audio_temp_path = None
if enable_diarization:
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_temp_path = audio_temp_file.name
audio_temp_file.close()
# Create processor for collecting topics
topic_collector = TopicCollectorProcessor()
# Build pipeline for audio processing
processors = []
# Add audio file writer at the beginning if diarization is enabled
if enable_diarization:
processors.append(AudioFileWriterProcessor(audio_temp_path))
# Add the rest of the processors
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorAutoProcessor.as_threaded(),
]
if not only_transcript:
processors += [
TranscriptTopicDetectorProcessor.as_threaded(),
# Collect topics for diarization
topic_collector,
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(),
@@ -44,14 +115,14 @@ async def process_audio_file(
),
]
# transcription output
# Create main pipeline
pipeline = Pipeline(*processors)
pipeline.set_pref("audio:source_language", source_language)
pipeline.set_pref("audio:target_language", target_language)
pipeline.describe()
pipeline.on(event_callback)
# start processing audio
# Start processing audio
logger.info(f"Opening {filename}")
container = av.open(filename)
try:
@@ -62,43 +133,242 @@ async def process_audio_file(
logger.info("Flushing the pipeline")
await pipeline.flush()
logger.info("All done !")
# Run diarization if enabled and we have topics
if enable_diarization and not only_transcript and audio_temp_path:
topics = topic_collector.get_topics()
if topics:
logger.info(f"Starting diarization with {len(topics)} topics")
try:
from reflector.processors import AudioDiarizationAutoProcessor
diarization_processor = AudioDiarizationAutoProcessor(
name=diarization_backend
)
diarization_processor.set_pipeline(pipeline)
# For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal":
from datetime import datetime
from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile
storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup
async with S3TemporaryFile(storage, audio_filename) as s3_file:
# Read and upload the audio file
with open(audio_temp_path, "rb") as f:
audio_data = f.read()
audio_url = await s3_file.upload(audio_data)
logger.info(f"Uploaded audio to S3: {audio_filename}")
# Create diarization input with S3 URL
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
# File will be automatically cleaned up when exiting the context
else:
# For local backend, use local file path
audio_url = audio_temp_path
# Create diarization input
diarization_input = AudioDiarizationInput(
audio_url=audio_url, topics=topics
)
# Run diarization
await diarization_processor.push(diarization_input)
await diarization_processor.flush()
logger.info("Diarization complete")
except ImportError as e:
logger.error(f"Failed to import diarization dependencies: {e}")
logger.error(
"Install with: uv pip install pyannote.audio torch torchaudio"
)
logger.error(
"And set HF_TOKEN environment variable for pyannote models"
)
raise SystemExit(1)
except Exception as e:
logger.error(f"Diarization failed: {e}")
raise SystemExit(1)
else:
logger.warning("Skipping diarization: no topics available")
# Clean up temp file
if audio_temp_path:
try:
Path(audio_temp_path).unlink()
except Exception as e:
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
logger.info("All done!")
async def process_file_pipeline(
filename: str,
event_callback,
source_language="en",
target_language="en",
enable_diarization=True,
diarization_backend="modal",
):
"""Process audio/video file using the optimized file pipeline"""
try:
from reflector.db import database
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_file_pipeline import PipelineMainFile
await database.connect()
try:
# Create a temporary transcript for processing
transcript = await transcripts_controller.add(
"",
source_kind=SourceKind.FILE,
source_language=source_language,
target_language=target_language,
)
# Process the file
pipeline = PipelineMainFile(transcript_id=transcript.id)
await pipeline.process(Path(filename))
logger.info("File pipeline processing complete")
finally:
await database.disconnect()
except ImportError as e:
logger.error(f"File pipeline not available: {e}")
logger.info("Falling back to stream pipeline")
# Fall back to stream pipeline
await process_audio_file(
filename,
event_callback,
only_transcript=False,
source_language=source_language,
target_language=target_language,
enable_diarization=enable_diarization,
diarization_backend=diarization_backend,
)
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Process audio files with optional speaker diarization"
)
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument("--only-transcript", "-t", action="store_true")
parser.add_argument("--source-language", default="en")
parser.add_argument("--target-language", default="en")
parser.add_argument(
"--stream",
action="store_true",
help="Use streaming pipeline (original frame-based processing)",
)
parser.add_argument(
"--only-transcript",
"-t",
action="store_true",
help="Only generate transcript without topics/summaries",
)
parser.add_argument(
"--source-language", default="en", help="Source language code (default: en)"
)
parser.add_argument(
"--target-language", default="en", help="Target language code (default: en)"
)
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
parser.add_argument(
"--enable-diarization",
"-d",
action="store_true",
help="Enable speaker diarization",
)
parser.add_argument(
"--diarization-backend",
default="pyannote",
choices=["pyannote", "modal"],
help="Diarization backend to use (default: pyannote)",
)
args = parser.parse_args()
if "REDIS_HOST" not in os.environ:
os.environ["REDIS_HOST"] = "localhost"
output_fd = None
if args.output:
output_fd = open(args.output, "w")
async def event_callback(event: PipelineEvent):
processor = event.processor
# ignore some processor
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
data = event.data
# Ignore internal processors
if processor in (
"AudioChunkerProcessor",
"AudioMergeProcessor",
"AudioFileWriterProcessor",
"TopicCollectorProcessor",
"BroadcastProcessor",
):
return
logger.info(f"Event: {event}")
# If diarization is enabled, skip the original topic events from the pipeline
# The diarization processor will emit the same topics but with speaker info
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
return
# Log all events
logger.info(f"Event: {processor} - {type(data).__name__}")
# Write to output
if output_fd:
output_fd.write(event.model_dump_json())
output_fd.write("\n")
output_fd.flush()
asyncio.run(
process_audio_file(
args.source,
event_callback,
only_transcript=args.only_transcript,
source_language=args.source_language,
target_language=args.target_language,
if args.stream:
# Use original streaming pipeline
asyncio.run(
process_audio_file(
args.source,
event_callback,
only_transcript=args.only_transcript,
source_language=args.source_language,
target_language=args.target_language,
enable_diarization=args.enable_diarization,
diarization_backend=args.diarization_backend,
)
)
else:
# Use optimized file pipeline (default)
asyncio.run(
process_file_pipeline(
args.source,
event_callback,
source_language=args.source_language,
target_language=args.target_language,
enable_diarization=args.enable_diarization,
diarization_backend=args.diarization_backend,
)
)
)
if output_fd:
output_fd.close()

View File

@@ -160,6 +160,7 @@ async def transcripts_search(
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
offset: SearchOffsetParam = 0,
room_id: Optional[str] = None,
source_kind: Optional[SourceKind] = None,
user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None,
@@ -173,7 +174,12 @@ async def transcripts_search(
user_id = user["sub"] if user else None
search_params = SearchParameters(
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
query_text=q,
limit=limit,
offset=offset,
user_id=user_id,
room_id=room_id,
source_kind=source_kind,
)
results, total = await search_controller.search_transcripts(search_params)

View File

@@ -14,7 +14,8 @@ from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_live_pipeline import asynctask
from reflector.settings import settings
from reflector.whereby import get_room_sessions
@@ -140,7 +141,7 @@ async def process_recording(bucket_name: str, object_key: str):
await transcripts_controller.update(transcript, {"status": "uploaded"})
task_pipeline_process.delay(transcript_id=transcript.id)
task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task

View File

@@ -0,0 +1,40 @@
interactions:
- request:
body: ''
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- DUMMY_API_KEY
connection:
- keep-alive
content-length:
- '0'
host:
- monadical-sas--reflector-diarizer-web.modal.run
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3&timestamp=0
response:
body:
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
headers:
Alt-Svc:
- h3=":443"; ma=2592000
Content-Length:
- '220'
Content-Type:
- application/json
Date:
- Wed, 13 Aug 2025 18:25:34 GMT
Modal-Function-Call-Id:
- fc-01K2JAVNEP6N7Y1Y7W3T98BCXK
Vary:
- accept-encoding
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,46 @@
interactions:
- request:
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
"language": "en", "batch": true}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- DUMMY_API_KEY
connection:
- keep-alive
content-length:
- '136'
content-type:
- application/json
host:
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
response:
body:
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
with Reflector. a Q teenage product that revolutionizes audio processing.
With reflector, I can easily convert any audio into accurate transcription.
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
headers:
Alt-Svc:
- h3=":443"; ma=2592000
Content-Length:
- '1933'
Content-Type:
- application/json
Date:
- Wed, 13 Aug 2025 18:26:59 GMT
Modal-Function-Call-Id:
- fc-01K2JAWC7GAMKX4DSJ21WV31NG
Vary:
- accept-encoding
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,84 @@
interactions:
- request:
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
"language": "en", "batch": true}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- DUMMY_API_KEY
connection:
- keep-alive
content-length:
- '136'
content-type:
- application/json
host:
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
response:
body:
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
with Reflector. a Q teenage product that revolutionizes audio processing.
With reflector, I can easily convert any audio into accurate transcription.
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
headers:
Alt-Svc:
- h3=":443"; ma=2592000
Content-Length:
- '1933'
Content-Type:
- application/json
Date:
- Wed, 13 Aug 2025 18:27:02 GMT
Modal-Function-Call-Id:
- fc-01K2JAYZ1AR2HE422VJVKBWX9Z
Vary:
- accept-encoding
status:
code: 200
message: OK
- request:
body: ''
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- DUMMY_API_KEY
connection:
- keep-alive
content-length:
- '0'
host:
- monadical-sas--reflector-diarizer-web.modal.run
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3&timestamp=0
response:
body:
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
headers:
Alt-Svc:
- h3=":443"; ma=2592000
Content-Length:
- '220'
Content-Type:
- application/json
Date:
- Wed, 13 Aug 2025 18:27:18 GMT
Modal-Function-Call-Id:
- fc-01K2JAZ1M34NQRJK03CCFK95D6
Vary:
- accept-encoding
status:
code: 200
message: OK
version: 1

View File

@@ -5,7 +5,29 @@ from unittest.mock import patch
import pytest
# Pytest-docker configuration
@pytest.fixture(scope="session", autouse=True)
def settings_configuration():
# theses settings are linked to monadical for pytest-recording
# if a fork is done, they have to provide their own url when cassettes needs to be updated
# modal api keys has to be defined by the user
from reflector.settings import settings
settings.TRANSCRIPT_BACKEND = "modal"
settings.TRANSCRIPT_URL = (
"https://monadical-sas--reflector-transcriber-parakeet-web.modal.run"
)
settings.DIARIZATION_BACKEND = "modal"
settings.DIARIZATION_URL = "https://monadical-sas--reflector-diarizer-web.modal.run"
@pytest.fixture(scope="module")
def vcr_config():
"""VCR configuration to filter sensitive headers"""
return {
"filter_headers": [("authorization", "DUMMY_API_KEY")],
}
@pytest.fixture(scope="session")
def docker_compose_file(pytestconfig):
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")

View File

@@ -1,7 +1,7 @@
version: '3.8'
version: "3.8"
services:
postgres_test:
image: postgres:15
image: postgres:17
environment:
POSTGRES_DB: reflector_test
POSTGRES_USER: test_user
@@ -10,4 +10,4 @@ services:
- "15432:5432"
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
tmpfs:
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g

View File

@@ -0,0 +1,330 @@
"""
Tests for GPU Modal transcription endpoints.
These tests are marked with the "gpu-modal" group and will not run by default.
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
Required environment variables:
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
Example with pytest (override default addopts to run ONLY gpu_modal tests):
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
TRANSCRIPT_MODAL_API_KEY=your-api-key \
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
# Or with completely clean options:
uv run -m pytest -m gpu_modal -o addopts="" tests/
Running Modal locally for testing:
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
"""
import os
import tempfile
from pathlib import Path
import httpx
import pytest
# Test audio file URL for testing
TEST_AUDIO_URL = (
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
)
def get_modal_transcript_url():
"""Get and validate the Modal transcript URL from environment."""
url = os.environ.get("TRANSCRIPT_URL")
if not url:
pytest.skip(
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
)
return url
def get_auth_headers():
"""Get authentication headers if API key is available."""
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
if api_key:
return {"Authorization": f"Bearer {api_key}"}
return {}
def get_model_name():
"""Get the model name from environment or use default."""
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
@pytest.mark.gpu_modal
class TestGPUModalTranscript:
"""Test suite for GPU Modal transcription endpoints."""
def test_transcriptions_from_url(self):
"""Test the /v1/audio/transcriptions-from-url endpoint."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "en",
"timestamp_offset": 0.0,
},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure
assert "text" in result
assert "words" in result
assert isinstance(result["text"], str)
assert isinstance(result["words"], list)
# Verify content is meaningful
assert len(result["text"]) > 0, "Transcript text should not be empty"
assert len(result["words"]) > 0, "Words list must not be empty"
# Verify word structure
for word in result["words"]:
assert "word" in word
assert "start" in word
assert "end" in word
assert isinstance(word["start"], (int, float))
assert isinstance(word["end"], (int, float))
assert word["start"] <= word["end"]
def test_transcriptions_single_file(self):
"""Test the /v1/audio/transcriptions endpoint with a single file."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Download test audio file to upload
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
tmp_file.write(audio_response.content)
tmp_file_path = tmp_file.name
try:
# Upload the file for transcription
with open(tmp_file_path, "rb") as f:
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
data = {
"model": get_model_name(),
"language": "en",
"batch": "false",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for single file
assert "text" in result
assert "words" in result
assert "filename" in result
assert isinstance(result["text"], str)
assert isinstance(result["words"], list)
# Verify content
assert len(result["text"]) > 0, "Transcript text should not be empty"
finally:
Path(tmp_file_path).unlink(missing_ok=True)
def test_transcriptions_multiple_files(self):
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Create multiple test files (we'll use the same audio content for simplicity)
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
audio_content = audio_response.content
temp_files = []
try:
# Create 3 temporary files
for i in range(3):
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tmp_file.write(audio_content)
tmp_file.close()
temp_files.append(tmp_file.name)
# Upload multiple files for transcription (non-batch)
files = [
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
for i, f in enumerate(temp_files)
]
data = {
"model": get_model_name(),
"language": "en",
"batch": "false",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
# Close file handles
for _, file_tuple in files:
file_tuple[1].close()
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for multiple files (non-batch)
assert "results" in result
assert isinstance(result["results"], list)
assert len(result["results"]) == 3
for idx, file_result in enumerate(result["results"]):
assert "text" in file_result
assert "words" in file_result
assert "filename" in file_result
assert isinstance(file_result["text"], str)
assert isinstance(file_result["words"], list)
assert len(file_result["text"]) > 0
finally:
for f in temp_files:
Path(f).unlink(missing_ok=True)
def test_transcriptions_multiple_files_batch(self):
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
url = get_modal_transcript_url()
headers = get_auth_headers()
# Create multiple test files
with httpx.Client(timeout=60.0) as client:
audio_response = client.get(TEST_AUDIO_URL)
audio_response.raise_for_status()
audio_content = audio_response.content
temp_files = []
try:
# Create 3 temporary files
for i in range(3):
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tmp_file.write(audio_content)
tmp_file.close()
temp_files.append(tmp_file.name)
# Upload multiple files for batch transcription
files = [
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
for i, f in enumerate(temp_files)
]
data = {
"model": get_model_name(),
"language": "en",
"batch": "true",
}
response = client.post(
f"{url}/v1/audio/transcriptions",
files=files,
data=data,
headers=headers,
)
# Close file handles
for _, file_tuple in files:
file_tuple[1].close()
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure for batch mode
assert "results" in result
assert isinstance(result["results"], list)
assert len(result["results"]) == 3
for idx, batch_result in enumerate(result["results"]):
assert "text" in batch_result
assert "words" in batch_result
assert "filename" in batch_result
assert isinstance(batch_result["text"], str)
assert isinstance(batch_result["words"], list)
assert len(batch_result["text"]) > 0
finally:
for f in temp_files:
Path(f).unlink(missing_ok=True)
def test_transcriptions_error_handling(self):
"""Test error handling for invalid requests."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
# Test with unsupported language
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "fr", # Parakeet only supports English
"timestamp_offset": 0.0,
},
headers=headers,
)
assert response.status_code == 400
assert "only supports English" in response.text
def test_transcriptions_with_timestamp_offset(self):
"""Test transcription with timestamp offset parameter."""
url = get_modal_transcript_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
# Test with timestamp offset
response = client.post(
f"{url}/v1/audio/transcriptions-from-url",
json={
"audio_file_url": TEST_AUDIO_URL,
"model": get_model_name(),
"language": "en",
"timestamp_offset": 10.0, # Add 10 second offset
},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
# Verify response structure
assert "text" in result
assert "words" in result
assert len(result["words"]) > 0, "Words list must not be empty"
# Verify that timestamps have been offset
for word in result["words"]:
# All timestamps should be >= 10.0 due to offset
assert (
word["start"] >= 10.0
), f"Word start time {word['start']} should be >= 10.0"
assert (
word["end"] >= 10.0
), f"Word end time {word['end']} should be >= 10.0"

View File

@@ -0,0 +1,633 @@
"""
Tests for PipelineMainFile - file-based processing pipeline
This test verifies the complete file processing pipeline without mocking much,
ensuring all processors are correctly invoked and the happy path works correctly.
"""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from reflector.pipelines.main_file_pipeline import PipelineMainFile
from reflector.processors.file_diarization import FileDiarizationOutput
from reflector.processors.types import (
DiarizationSegment,
TitleSummary,
Word,
)
from reflector.processors.types import (
Transcript as TranscriptType,
)
@pytest.fixture
async def dummy_file_transcript():
"""Mock FileTranscriptAutoProcessor for file processing"""
from reflector.processors.file_transcript import FileTranscriptProcessor
class TestFileTranscriptProcessor(FileTranscriptProcessor):
async def _transcript(self, data):
return TranscriptType(
text="Hello world. How are you today?",
words=[
Word(start=0.0, end=0.5, text="Hello", speaker=0),
Word(start=0.5, end=0.6, text=" ", speaker=0),
Word(start=0.6, end=1.0, text="world", speaker=0),
Word(start=1.0, end=1.1, text=".", speaker=0),
Word(start=1.1, end=1.2, text=" ", speaker=0),
Word(start=1.2, end=1.5, text="How", speaker=0),
Word(start=1.5, end=1.6, text=" ", speaker=0),
Word(start=1.6, end=1.8, text="are", speaker=0),
Word(start=1.8, end=1.9, text=" ", speaker=0),
Word(start=1.9, end=2.1, text="you", speaker=0),
Word(start=2.1, end=2.2, text=" ", speaker=0),
Word(start=2.2, end=2.5, text="today", speaker=0),
Word(start=2.5, end=2.6, text="?", speaker=0),
],
)
with patch(
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileTranscriptProcessor()
yield
@pytest.fixture
async def dummy_file_diarization():
"""Mock FileDiarizationAutoProcessor for file processing"""
from reflector.processors.file_diarization import FileDiarizationProcessor
class TestFileDiarizationProcessor(FileDiarizationProcessor):
async def _diarize(self, data):
return FileDiarizationOutput(
diarization=[
DiarizationSegment(start=0.0, end=1.1, speaker=0),
DiarizationSegment(start=1.2, end=2.6, speaker=1),
]
)
with patch(
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileDiarizationProcessor()
yield
@pytest.fixture
async def mock_transcript_in_db(tmpdir):
"""Create a mock transcript in the database"""
from reflector.db.transcripts import Transcript
from reflector.settings import settings
# Set the DATA_DIR to our tmpdir
original_data_dir = settings.DATA_DIR
settings.DATA_DIR = str(tmpdir)
transcript_id = str(uuid4())
data_path = Path(tmpdir) / transcript_id
data_path.mkdir(parents=True, exist_ok=True)
# Create mock transcript object
transcript = Transcript(
id=transcript_id,
name="Test Transcript",
status="processing",
source_kind="file",
source_language="en",
target_language="en",
)
# Mock the controller to return our transcript
try:
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
) as mock_get:
mock_get.return_value = transcript
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
) as mock_get2:
mock_get2.return_value = transcript
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
) as mock_update:
mock_update.return_value = None
yield transcript
finally:
# Restore original DATA_DIR
settings.DATA_DIR = original_data_dir
@pytest.fixture
async def mock_storage():
"""Mock storage for file uploads"""
from reflector.storage.base import Storage
class TestStorage(Storage):
async def _put_file(self, path, data):
return None
async def _get_file_url(self, path):
return f"http://test-storage/{path}"
async def _get_file(self, path):
return b"test_audio_data"
async def _delete_file(self, path):
return None
storage = TestStorage()
# Add mock tracking for verification
storage._put_file = AsyncMock(side_effect=storage._put_file)
storage._get_file_url = AsyncMock(side_effect=storage._get_file_url)
with patch(
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
) as mock_get:
mock_get.return_value = storage
yield storage
@pytest.fixture
async def mock_audio_file_writer():
"""Mock AudioFileWriterProcessor to avoid actual file writing"""
with patch(
"reflector.pipelines.main_file_pipeline.AudioFileWriterProcessor"
) as mock_writer_class:
mock_writer = AsyncMock()
mock_writer.push = AsyncMock()
mock_writer.flush = AsyncMock()
mock_writer_class.return_value = mock_writer
yield mock_writer
@pytest.fixture
async def mock_waveform_processor():
"""Mock AudioWaveformProcessor"""
with patch(
"reflector.pipelines.main_file_pipeline.AudioWaveformProcessor"
) as mock_waveform_class:
mock_waveform = AsyncMock()
mock_waveform.set_pipeline = MagicMock()
mock_waveform.flush = AsyncMock()
mock_waveform_class.return_value = mock_waveform
yield mock_waveform
@pytest.fixture
async def mock_topic_detector():
"""Mock TranscriptTopicDetectorProcessor"""
with patch(
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor"
) as mock_topic_class:
mock_topic = AsyncMock()
mock_topic.set_pipeline = MagicMock()
mock_topic.push = AsyncMock()
mock_topic.flush_called = False
# When flush is called, simulate topic detection by calling the callback
async def flush_with_callback():
mock_topic.flush_called = True
if hasattr(mock_topic, "_callback"):
# Create a minimal transcript for the TitleSummary
test_transcript = TranscriptType(words=[], text="test transcript")
await mock_topic._callback(
TitleSummary(
title="Test Topic",
summary="Test topic summary",
timestamp=0.0,
duration=10.0,
transcript=test_transcript,
)
)
mock_topic.flush = flush_with_callback
def init_with_callback(callback=None):
mock_topic._callback = callback
return mock_topic
mock_topic_class.side_effect = init_with_callback
yield mock_topic
@pytest.fixture
async def mock_title_processor():
"""Mock TranscriptFinalTitleProcessor"""
with patch(
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor"
) as mock_title_class:
mock_title = AsyncMock()
mock_title.set_pipeline = MagicMock()
mock_title.push = AsyncMock()
mock_title.flush_called = False
# When flush is called, simulate title generation by calling the callback
async def flush_with_callback():
mock_title.flush_called = True
if hasattr(mock_title, "_callback"):
from reflector.processors.types import FinalTitle
await mock_title._callback(FinalTitle(title="Test Title"))
mock_title.flush = flush_with_callback
def init_with_callback(callback=None):
mock_title._callback = callback
return mock_title
mock_title_class.side_effect = init_with_callback
yield mock_title
@pytest.fixture
async def mock_summary_processor():
"""Mock TranscriptFinalSummaryProcessor"""
with patch(
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor"
) as mock_summary_class:
mock_summary = AsyncMock()
mock_summary.set_pipeline = MagicMock()
mock_summary.push = AsyncMock()
mock_summary.flush_called = False
# When flush is called, simulate summary generation by calling the callbacks
async def flush_with_callback():
mock_summary.flush_called = True
from reflector.processors.types import FinalLongSummary, FinalShortSummary
if hasattr(mock_summary, "_callback"):
await mock_summary._callback(
FinalLongSummary(long_summary="Test long summary", duration=10.0)
)
if hasattr(mock_summary, "_on_short_summary"):
await mock_summary._on_short_summary(
FinalShortSummary(short_summary="Test short summary", duration=10.0)
)
mock_summary.flush = flush_with_callback
def init_with_callback(transcript=None, callback=None, on_short_summary=None):
mock_summary._callback = callback
mock_summary._on_short_summary = on_short_summary
return mock_summary
mock_summary_class.side_effect = init_with_callback
yield mock_summary
@pytest.mark.asyncio
async def test_pipeline_main_file_process(
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
dummy_file_diarization,
mock_storage,
mock_audio_file_writer,
mock_waveform_processor,
mock_topic_detector,
mock_title_processor,
mock_summary_processor,
):
"""
Test the complete PipelineMainFile processing pipeline.
This test verifies:
1. Audio extraction and writing
2. Audio upload to storage
3. Parallel processing of transcription, diarization, and waveform
4. Assembly of transcript with diarization
5. Topic detection
6. Title and summary generation
"""
# Create a test audio file
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
# Copy test audio to the transcript's data path as if it was uploaded
upload_path = mock_transcript_in_db.data_path / "upload.wav"
upload_path.write_bytes(test_audio_path.read_bytes())
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
# Since we're mocking AudioFileWriterProcessor, we need to create this manually
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
mp3_path.write_bytes(b"mock_mp3_data")
# Track callback invocations
callback_marks = {
"on_status": [],
"on_duration": [],
"on_waveform": [],
"on_topic": [],
"on_title": [],
"on_long_summary": [],
"on_short_summary": [],
}
# Create pipeline with mocked callbacks
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
# Override callbacks to track invocations
async def track_callback(name, data):
callback_marks[name].append(data)
# Call the original callback
original = getattr(PipelineMainFile, name)
return await original(pipeline, data)
for callback_name in callback_marks.keys():
setattr(
pipeline,
callback_name,
lambda data, n=callback_name: track_callback(n, data),
)
# Mock av.open for audio processing
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
# Mock container for checking video streams
mock_container = MagicMock()
mock_container.streams.video = [] # No video streams (audio only)
mock_container.close = MagicMock()
# Mock container for decoding audio frames
mock_decode_container = MagicMock()
mock_decode_container.decode.return_value = iter(
[MagicMock()]
) # One mock audio frame
mock_decode_container.close = MagicMock()
# Return different containers for different calls
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(upload_path)
# Verify audio extraction and writing
assert mock_audio_file_writer.push.called
assert mock_audio_file_writer.flush.called
# Verify storage upload
assert mock_storage._put_file.called
assert mock_storage._get_file_url.called
# Verify waveform generation
assert mock_waveform_processor.flush.called
assert mock_waveform_processor.set_pipeline.called
# Verify topic detection
assert mock_topic_detector.push.called
assert mock_topic_detector.flush_called
# Verify title generation
assert mock_title_processor.push.called
assert mock_title_processor.flush_called
# Verify summary generation
assert mock_summary_processor.push.called
assert mock_summary_processor.flush_called
# Verify callbacks were invoked
assert len(callback_marks["on_topic"]) > 0, "Topic callback should be invoked"
assert len(callback_marks["on_title"]) > 0, "Title callback should be invoked"
assert (
len(callback_marks["on_long_summary"]) > 0
), "Long summary callback should be invoked"
assert (
len(callback_marks["on_short_summary"]) > 0
), "Short summary callback should be invoked"
print(f"Callback marks: {callback_marks}")
# Verify the pipeline completed successfully
assert pipeline.logger is not None
print("PipelineMainFile test completed successfully!")
@pytest.mark.asyncio
async def test_pipeline_main_file_with_video(
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
dummy_file_diarization,
mock_storage,
mock_audio_file_writer,
mock_waveform_processor,
mock_topic_detector,
mock_title_processor,
mock_summary_processor,
):
"""
Test PipelineMainFile with video input (verifies audio extraction).
"""
# Create a test audio file
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
# Copy test audio to the transcript's data path as if it was a video upload
upload_path = mock_transcript_in_db.data_path / "upload.mp4"
upload_path.write_bytes(test_audio_path.read_bytes())
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
mp3_path.write_bytes(b"mock_mp3_data")
# Create pipeline
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
# Mock av.open for video processing
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
# Mock container for checking video streams
mock_container = MagicMock()
mock_container.streams.video = [MagicMock()] # Has video streams
mock_container.close = MagicMock()
# Mock container for decoding audio frames
mock_decode_container = MagicMock()
mock_decode_container.decode.return_value = iter(
[MagicMock()]
) # One mock audio frame
mock_decode_container.close = MagicMock()
# Return different containers for different calls
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(upload_path)
# Verify audio extraction from video
assert mock_audio_file_writer.push.called
assert mock_audio_file_writer.flush.called
# Verify the rest of the pipeline completed
assert mock_storage._put_file.called
assert mock_waveform_processor.flush.called
assert mock_topic_detector.push.called
assert mock_title_processor.push.called
assert mock_summary_processor.push.called
print("PipelineMainFile video test completed successfully!")
@pytest.mark.asyncio
async def test_pipeline_main_file_no_diarization(
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
mock_storage,
mock_audio_file_writer,
mock_waveform_processor,
mock_topic_detector,
mock_title_processor,
mock_summary_processor,
):
"""
Test PipelineMainFile with diarization disabled.
"""
from reflector.settings import settings
# Disable diarization
with patch.object(settings, "DIARIZATION_BACKEND", None):
# Create a test audio file
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
# Copy test audio to the transcript's data path
upload_path = mock_transcript_in_db.data_path / "upload.wav"
upload_path.write_bytes(test_audio_path.read_bytes())
# Also create the audio.mp3 file
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
mp3_path.write_bytes(b"mock_mp3_data")
# Create pipeline
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
# Mock av.open for audio processing
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
# Mock container for checking video streams
mock_container = MagicMock()
mock_container.streams.video = [] # No video streams
mock_container.close = MagicMock()
# Mock container for decoding audio frames
mock_decode_container = MagicMock()
mock_decode_container.decode.return_value = iter([MagicMock()])
mock_decode_container.close = MagicMock()
# Return different containers for different calls
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(upload_path)
# Verify the pipeline completed without diarization
assert mock_storage._put_file.called
assert mock_waveform_processor.flush.called
assert mock_topic_detector.push.called
assert mock_title_processor.push.called
assert mock_summary_processor.push.called
print("PipelineMainFile no-diarization test completed successfully!")
@pytest.mark.asyncio
async def test_task_pipeline_file_process(
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
dummy_file_diarization,
mock_storage,
mock_audio_file_writer,
mock_waveform_processor,
mock_topic_detector,
mock_title_processor,
mock_summary_processor,
):
"""
Test the Celery task entry point for file pipeline processing.
"""
# Direct import of the underlying async function, bypassing the asynctask decorator
# Create a test audio file in the transcript's data path
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
upload_path = mock_transcript_in_db.data_path / "upload.wav"
upload_path.write_bytes(test_audio_path.read_bytes())
# Also create the audio.mp3 file
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
mp3_path.write_bytes(b"mock_mp3_data")
# Mock av.open for audio processing
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
# Mock container for checking video streams
mock_container = MagicMock()
mock_container.streams.video = [] # No video streams
mock_container.close = MagicMock()
# Mock container for decoding audio frames
mock_decode_container = MagicMock()
mock_decode_container.decode.return_value = iter([MagicMock()])
mock_decode_container.close = MagicMock()
# Return different containers for different calls
mock_av.side_effect = [mock_container, mock_decode_container]
# Get the original async function without the asynctask decorator
# The function is wrapped, so we need to call it differently
# For now, we test the pipeline directly since the task is just a thin wrapper
from reflector.pipelines.main_file_pipeline import PipelineMainFile
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
await pipeline.process(upload_path)
# Verify the pipeline was executed through the task
assert mock_audio_file_writer.push.called
assert mock_audio_file_writer.flush.called
assert mock_storage._put_file.called
assert mock_waveform_processor.flush.called
assert mock_topic_detector.push.called
assert mock_title_processor.push.called
assert mock_summary_processor.push.called
print("task_pipeline_file_process test completed successfully!")
@pytest.mark.asyncio
async def test_pipeline_file_process_no_transcript():
"""
Test the pipeline with a non-existent transcript.
"""
from reflector.pipelines.main_file_pipeline import PipelineMainFile
# Mock the controller to return None (transcript not found)
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
) as mock_get:
mock_get.return_value = None
pipeline = PipelineMainFile(transcript_id=str(uuid4()))
# Should raise an exception for missing transcript when get_transcript is called
with pytest.raises(Exception, match="Transcript not found"):
await pipeline.get_transcript()
@pytest.mark.asyncio
async def test_pipeline_file_process_no_audio_file(
mock_transcript_in_db,
):
"""
Test the pipeline when no audio file is found.
"""
from reflector.pipelines.main_file_pipeline import PipelineMainFile
# Don't create any audio files in the data path
# The pipeline's process should handle missing files gracefully
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
# Try to process a non-existent file
non_existent_path = mock_transcript_in_db.data_path / "nonexistent.wav"
# This should fail when trying to open the file with av
with pytest.raises(Exception):
await pipeline.process(non_existent_path)

View File

@@ -0,0 +1,265 @@
"""
Tests for Modal-based processors using pytest-recording for HTTP recording/playbook
Note: theses tests require full modal configuration to be able to record
vcr cassettes
Configuration required for the first recording:
- TRANSCRIPT_BACKEND=modal
- TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-parakeet-web.modal.run
- TRANSCRIPT_MODAL_API_KEY=xxxxx
- DIARIZATION_BACKEND=modal
- DIARIZATION_URL=https://xxxxx--reflector-diarizer-web.modal.run
- DIARIZATION_MODAL_API_KEY=xxxxx
"""
from unittest.mock import patch
import pytest
from reflector.processors.file_diarization import FileDiarizationInput
from reflector.processors.file_diarization_modal import FileDiarizationModalProcessor
from reflector.processors.file_transcript import FileTranscriptInput
from reflector.processors.file_transcript_modal import FileTranscriptModalProcessor
from reflector.processors.transcript_diarization_assembler import (
TranscriptDiarizationAssemblerInput,
TranscriptDiarizationAssemblerProcessor,
)
from reflector.processors.types import DiarizationSegment, Transcript, Word
# Public test audio file hosted on S3 specifically for reflector pytests
TEST_AUDIO_URL = (
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
)
@pytest.mark.asyncio
async def test_file_transcript_modal_processor_missing_url():
with patch("reflector.processors.file_transcript_modal.settings") as mock_settings:
mock_settings.TRANSCRIPT_URL = None
with pytest.raises(Exception, match="TRANSCRIPT_URL required"):
FileTranscriptModalProcessor(modal_api_key="test-api-key")
@pytest.mark.asyncio
async def test_file_diarization_modal_processor_missing_url():
with patch("reflector.processors.file_diarization_modal.settings") as mock_settings:
mock_settings.DIARIZATION_URL = None
with pytest.raises(Exception, match="DIARIZATION_URL required"):
FileDiarizationModalProcessor(modal_api_key="test-api-key")
@pytest.mark.vcr()
@pytest.mark.asyncio
async def test_file_diarization_modal_processor(vcr):
"""Test FileDiarizationModalProcessor using public audio URL and Modal API"""
from reflector.settings import settings
processor = FileDiarizationModalProcessor(
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
)
test_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
result = await processor._diarize(test_input)
# Verify the result structure
assert result is not None
assert hasattr(result, "diarization")
assert isinstance(result.diarization, list)
# Check structure of each diarization segment
for segment in result.diarization:
assert "start" in segment
assert "end" in segment
assert "speaker" in segment
assert isinstance(segment["start"], (int, float))
assert isinstance(segment["end"], (int, float))
assert isinstance(segment["speaker"], int)
# Basic sanity check - start should be before end
assert segment["start"] < segment["end"]
@pytest.mark.vcr()
@pytest.mark.asyncio
async def test_file_transcript_modal_processor():
"""Test FileTranscriptModalProcessor using public audio URL and Modal API"""
from reflector.settings import settings
processor = FileTranscriptModalProcessor(
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
)
test_input = FileTranscriptInput(
audio_url=TEST_AUDIO_URL,
language="en",
)
# This will record the HTTP interaction on first run, replay on subsequent runs
result = await processor._transcript(test_input)
# Verify the result structure
assert result is not None
assert hasattr(result, "words")
assert isinstance(result.words, list)
# Check structure of each word if present
for word in result.words:
assert hasattr(word, "text")
assert hasattr(word, "start")
assert hasattr(word, "end")
assert isinstance(word.start, (int, float))
assert isinstance(word.end, (int, float))
assert isinstance(word.text, str)
# Basic sanity check - start should be before or equal to end
assert word.start <= word.end
@pytest.mark.asyncio
async def test_transcript_diarization_assembler_processor():
"""Test TranscriptDiarizationAssemblerProcessor without VCR (no HTTP requests)"""
# Create test transcript with words
words = [
Word(text="Hello", start=0.0, end=1.0, speaker=0),
Word(text=" ", start=1.0, end=1.1, speaker=0),
Word(text="world", start=1.1, end=2.0, speaker=0),
Word(text=".", start=2.0, end=2.1, speaker=0),
Word(text=" ", start=2.1, end=2.2, speaker=0),
Word(text="How", start=2.2, end=2.8, speaker=0),
Word(text=" ", start=2.8, end=2.9, speaker=0),
Word(text="are", start=2.9, end=3.2, speaker=0),
Word(text=" ", start=3.2, end=3.3, speaker=0),
Word(text="you", start=3.3, end=3.8, speaker=0),
Word(text="?", start=3.8, end=3.9, speaker=0),
]
transcript = Transcript(words=words)
# Create test diarization segments
diarization = [
DiarizationSegment(start=0.0, end=2.1, speaker=0),
DiarizationSegment(start=2.1, end=3.9, speaker=1),
]
# Create processor and test input
processor = TranscriptDiarizationAssemblerProcessor()
test_input = TranscriptDiarizationAssemblerInput(
transcript=transcript, diarization=diarization
)
# Track emitted results
emitted_results = []
async def capture_result(result):
emitted_results.append(result)
processor.on(capture_result)
# Process the input
await processor.push(test_input)
# Verify result was emitted
assert len(emitted_results) == 1
result = emitted_results[0]
# Verify result structure
assert isinstance(result, Transcript)
assert len(result.words) == len(words)
# Verify speaker assignments were applied
# Words 0-3 (indices) should be speaker 0 (time 0.0-2.0)
# Words 4-10 (indices) should be speaker 1 (time 2.1-3.9)
for i in range(4): # First 4 words (Hello, space, world, .)
assert (
result.words[i].speaker == 0
), f"Word {i} '{result.words[i].text}' should be speaker 0, got {result.words[i].speaker}"
for i in range(4, 11): # Remaining words (space, How, space, are, space, you, ?)
assert (
result.words[i].speaker == 1
), f"Word {i} '{result.words[i].text}' should be speaker 1, got {result.words[i].speaker}"
@pytest.mark.asyncio
async def test_transcript_diarization_assembler_no_diarization():
"""Test TranscriptDiarizationAssemblerProcessor with no diarization data"""
# Create test transcript
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
transcript = Transcript(words=words)
# Create processor and test input with empty diarization
processor = TranscriptDiarizationAssemblerProcessor()
test_input = TranscriptDiarizationAssemblerInput(
transcript=transcript, diarization=[]
)
# Track emitted results
emitted_results = []
async def capture_result(result):
emitted_results.append(result)
processor.on(capture_result)
# Process the input
await processor.push(test_input)
# Verify original transcript was returned unchanged
assert len(emitted_results) == 1
result = emitted_results[0]
assert result is transcript # Should be the same object
assert result.words[0].speaker == 0 # Original speaker unchanged
@pytest.mark.vcr()
@pytest.mark.asyncio
async def test_full_modal_pipeline_integration(vcr):
"""Integration test: Transcription -> Diarization -> Assembly
This test demonstrates the full pipeline:
1. Run transcription via Modal
2. Run diarization via Modal
3. Assemble transcript with diarization
"""
from reflector.settings import settings
# Step 1: Transcription
transcript_processor = FileTranscriptModalProcessor(
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
)
transcript_input = FileTranscriptInput(audio_url=TEST_AUDIO_URL, language="en")
transcript = await transcript_processor._transcript(transcript_input)
# Step 2: Diarization
diarization_processor = FileDiarizationModalProcessor(
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
)
diarization_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
diarization_result = await diarization_processor._diarize(diarization_input)
# Step 3: Assembly
assembler = TranscriptDiarizationAssemblerProcessor()
assembly_input = TranscriptDiarizationAssemblerInput(
transcript=transcript, diarization=diarization_result.diarization
)
# Track assembled result
assembled_results = []
async def capture_result(result):
assembled_results.append(result)
assembler.on(capture_result)
await assembler.push(assembly_input)
# Verify the full pipeline worked
assert len(assembled_results) == 1
final_transcript = assembled_results[0]
# Verify the final transcript has the original words with updated speaker info
assert isinstance(final_transcript, Transcript)
assert len(final_transcript.words) == len(transcript.words)
assert len(final_transcript.words) > 0
# Verify some words have been assigned speakers from diarization
speakers_found = set(word.speaker for word in final_transcript.words)
assert len(speakers_found) > 0 # At least some speaker assignments

View File

@@ -2,10 +2,13 @@ import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_diarization", [False, True])
async def test_basic_process(
dummy_transcript,
dummy_llm,
dummy_processors,
enable_diarization,
dummy_diarization,
):
# goal is to start the server, and send rtc audio to it
# validate the events received
@@ -28,12 +31,31 @@ async def test_basic_process(
# invoke the process and capture events
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
await process_audio_file(path.as_posix(), event_callback)
print(marks)
if enable_diarization:
# Test with diarization - may fail if pyannote.audio is not installed
try:
await process_audio_file(
path.as_posix(), event_callback, enable_diarization=True
)
except SystemExit:
pytest.skip("pyannote.audio not installed - skipping diarization test")
else:
# Test without diarization - should always work
await process_audio_file(
path.as_posix(), event_callback, enable_diarization=False
)
print(f"Diarization: {enable_diarization}, Marks: {marks}")
# validate the events
assert marks["TranscriptLinerProcessor"] == 1
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
# Each processor should be called for each audio segment processed
# The final processors (Topic, Title, Summary) should be called once at the end
assert marks["TranscriptLinerProcessor"] > 0
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalSummaryProcessor"] == 1
assert marks["TranscriptFinalTitleProcessor"] == 1
if enable_diarization:
assert marks["TestAudioDiarizationProcessor"] == 1

View File

@@ -2,13 +2,18 @@
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import ValidationError
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
from reflector.db.search import (
SearchController,
SearchParameters,
SearchResult,
search_controller,
)
from reflector.db.transcripts import SourceKind, transcripts
@pytest.mark.asyncio
@@ -18,39 +23,137 @@ async def test_search_postgresql_only():
assert results == []
assert total == 0
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
params_empty = SearchParameters(query_text="")
results_empty, total_empty = await search_controller.search_transcripts(
params_empty
)
assert isinstance(results_empty, list)
assert isinstance(total_empty, int)
@pytest.mark.asyncio
async def test_search_input_validation():
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
async def test_search_with_empty_query():
"""Test that empty query returns all transcripts."""
params = SearchParameters(query_text="")
results, total = await search_controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
if len(results) > 1:
for i in range(len(results) - 1):
assert results[i].created_at >= results[i + 1].created_at
@pytest.mark.asyncio
async def test_empty_transcript_title_only_match():
"""Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d"
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Empty Transcript",
"title": "Empty Meeting",
"status": "completed",
"locked": False,
"duration": 0.0,
"created_at": datetime.now(timezone.utc),
"short_summary": None,
"long_summary": None,
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": None,
"user_id": "test-user-1",
}
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = next((r for r in results if r.id == test_id), None)
assert found is not None, "Should find transcript by title match"
assert found.search_snippets == []
assert found.total_match_count == 0
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_search_with_long_summary():
"""Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Long Summary",
"title": "Regular Meeting",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Brief overview",
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Basic meeting content without special keywords.""",
"user_id": "test-user-2",
}
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary content"
test_result = next((r for r in results if r.id == test_id), None)
assert test_result
assert len(test_result.search_snippets) > 0
assert "quantum computing" in test_result.search_snippets[0].lower()
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
@@ -90,32 +193,31 @@ The search feature should support complex queries with ranking.
00:00:30.000 --> 00:00:40.000
We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
# Test 2: Search for a word in webvtt content
params = SearchParameters(query_text="tsvector")
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
# Test 3: Search with multiple words
params = SearchParameters(query_text="engineering planning")
params = SearchParameters(
query_text="engineering planning", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
# Test 4: Verify SearchResult structure
test_result = next((r for r in results if r.id == test_id), None)
if test_result:
assert test_result.title == "Engineering Planning Meeting Q4 2024"
@@ -123,15 +225,17 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert test_result.duration == 1800.0
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
params = SearchParameters(query_text="tsvector OR nosuchword")
params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
# Test 6: Quoted phrase search
params = SearchParameters(query_text='"full-text search"')
params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
@@ -142,3 +246,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.fixture
def sample_search_params():
"""Create sample search parameters for testing."""
return SearchParameters(
query_text="test query",
limit=20,
offset=0,
user_id="test-user",
room_id="room1",
)
@pytest.fixture
def mock_db_result():
"""Create a mock database result."""
return {
"id": "test-transcript-id",
"title": "Test Transcript",
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
"duration": 3600.0,
"status": "completed",
"user_id": "test-user",
"room_id": "room1",
"source_kind": SourceKind.LIVE,
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
"rank": 0.95,
}
class TestSearchParameters:
"""Test SearchParameters model validation and functionality."""
def test_search_parameters_with_available_filters(self):
"""Test creating SearchParameters with currently available filter options."""
params = SearchParameters(
query_text="search term",
limit=50,
offset=10,
user_id="user123",
room_id="room1",
)
assert params.query_text == "search term"
assert params.limit == 50
assert params.offset == 10
assert params.user_id == "user123"
assert params.room_id == "room1"
def test_search_parameters_defaults(self):
"""Test SearchParameters with default values."""
params = SearchParameters(query_text="test")
assert params.query_text == "test"
assert params.limit == 20
assert params.offset == 0
assert params.user_id is None
assert params.room_id is None
class TestSearchControllerFilters:
"""Test SearchController functionality with various filters."""
@pytest.mark.asyncio
async def test_search_with_source_kind_filter(self):
"""Test search filtering by source_kind."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
results, total = await controller.search_transcripts(params)
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_with_single_room_id(self):
"""Test search filtering by single room ID (currently supported)."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(
query_text="test",
room_id="room1",
)
results, total = await controller.search_transcripts(params)
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_result_includes_available_fields(self, mock_db_result):
"""Test that search results include available fields like source_kind."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
class MockRow:
def __init__(self, data):
self._data = data
self._mapping = data
def __iter__(self):
return iter(self._data.items())
def __getitem__(self, key):
return self._data[key]
def keys(self):
return self._data.keys()
mock_row = MockRow(mock_db_result)
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
params = SearchParameters(query_text="test")
results, total = await controller.search_transcripts(params)
assert total == 1
assert len(results) == 1
result = results[0]
assert isinstance(result, SearchResult)
assert result.id == "test-transcript-id"
assert result.title == "Test Transcript"
assert result.rank == 0.95
class TestSearchEndpointParsing:
"""Test parameter parsing in the search endpoint."""
def test_parse_comma_separated_room_ids(self):
"""Test parsing comma-separated room IDs."""
room_ids_str = "room1,room2,room3"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room2", "room3"]
room_ids_str = "room1, room2 , room3"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room2", "room3"]
room_ids_str = "room1,,room3,"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room3"]
def test_parse_source_kind(self):
"""Test parsing source_kind values."""
for kind_str in ["live", "file", "room"]:
parsed = SourceKind(kind_str)
assert parsed == SourceKind(kind_str)
with pytest.raises(ValueError):
SourceKind("invalid_kind")
class TestSearchResultModel:
"""Test SearchResult model and serialization."""
def test_search_result_with_available_fields(self):
"""Test SearchResult model with currently available fields populated."""
result = SearchResult(
id="test-id",
title="Test Title",
user_id="user-123",
room_id="room-456",
source_kind=SourceKind.ROOM,
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
status="completed",
rank=0.85,
duration=1800.5,
search_snippets=["snippet 1", "snippet 2"],
)
assert result.id == "test-id"
assert result.title == "Test Title"
assert result.user_id == "user-123"
assert result.room_id == "room-456"
assert result.status == "completed"
assert result.rank == 0.85
assert result.duration == 1800.5
assert len(result.search_snippets) == 2
def test_search_result_with_optional_fields_none(self):
"""Test SearchResult model with optional fields as None."""
result = SearchResult(
id="test-id",
source_kind=SourceKind.FILE,
created_at=datetime.now(timezone.utc),
status="processing",
rank=0.5,
search_snippets=[],
title=None,
user_id=None,
room_id=None,
duration=None,
)
assert result.title is None
assert result.user_id is None
assert result.room_id is None
assert result.duration is None
def test_search_result_datetime_field(self):
"""Test that SearchResult accepts datetime field."""
result = SearchResult(
id="test-id",
source_kind=SourceKind.LIVE,
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
status="completed",
rank=0.9,
duration=None,
search_snippets=[],
)
assert result.created_at == datetime(
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
)

View File

@@ -0,0 +1,166 @@
"""Tests for long_summary in search functionality."""
import json
from datetime import datetime, timezone
import pytest
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
@pytest.mark.asyncio
async def test_long_summary_snippet_prioritization():
"""Test that snippets from long_summary are prioritized over webvtt content."""
test_id = "test-snippet-priority-3f9a2b8c"
try:
# Clean up any existing test data
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Snippet Priority",
"title": "Meeting About Projects",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Project discussion",
"long_summary": (
"The team discussed advanced robotics applications including "
"autonomous navigation systems and sensor fusion techniques. "
"Robotics development will focus on real-time processing."
),
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
We talked about many different topics today.
00:00:10.000 --> 00:00:20.000
The robotics project is making good progress.
00:00:20.000 --> 00:00:30.000
We need to consider various implementation approaches.""",
"user_id": "test-user-priority",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Search for "robotics" which appears in both long_summary and webvtt
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
test_result = next((r for r in results if r.id == test_id), None)
assert test_result, "Should find the test transcript"
snippets = test_result.search_snippets
assert len(snippets) > 0, "Should have at least one snippet"
# The first snippets should be from long_summary (more detailed content)
first_snippet = snippets[0].lower()
assert (
"advanced robotics" in first_snippet or "autonomous" in first_snippet
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
# With max 3 snippets, we should get both from long_summary and webvtt
assert len(snippets) <= 3, "Should respect max snippets limit"
# All snippets should contain the search term
for snippet in snippets:
assert (
"robotics" in snippet.lower()
), f"Snippet should contain search term: {snippet}"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_long_summary_only_search():
"""Test searching for content that only exists in long_summary."""
test_id = "test-long-only-8b3c9f2a"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Long Only",
"title": "Standard Meeting",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Team sync",
"long_summary": (
"Detailed analysis of cryptocurrency market trends and "
"decentralized finance protocols. Discussion included "
"yield farming strategies and liquidity pool mechanics."
),
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Team meeting about general project updates.
00:00:10.000 --> 00:00:20.000
Discussion of timeline and deliverables.""",
"user_id": "test-user-long",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Search for terms only in long_summary
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
results, total = await search_controller.search_transcripts(params)
found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary-only content"
test_result = next((r for r in results if r.id == test_id), None)
assert test_result
assert len(test_result.search_snippets) > 0
# Verify the snippet is about cryptocurrency
snippet = test_result.search_snippets[0].lower()
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
# Search for "yield farming" - a more specific term
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
results2, total2 = await search_controller.search_transcripts(params2)
found2 = any(r.id == test_id for r in results2)
assert found2, "Should find transcript by specific long_summary phrase"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()

View File

@@ -1,6 +1,10 @@
"""Unit tests for search snippet generation."""
from reflector.db.search import SearchController
from reflector.db.search import (
SnippetCandidate,
SnippetGenerator,
WebVTTProcessor,
)
class TestExtractWebVTT:
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
00:00:10.000 --> 00:00:20.000
<v Speaker1>Indeed it is a test of WebVTT parsing.
"""
result = SearchController._extract_webvtt_text(webvtt)
result = WebVTTProcessor.extract_text(webvtt)
assert "Hello world, this is a test" in result
assert "Indeed it is a test" in result
assert "<v Speaker" not in result
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
def test_extract_empty_webvtt(self):
"""Test empty WebVTT returns empty string."""
assert SearchController._extract_webvtt_text("") == ""
assert SearchController._extract_webvtt_text(None) == ""
assert WebVTTProcessor.extract_text("") == ""
def test_extract_malformed_webvtt(self):
"""Test malformed WebVTT returns empty string."""
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
assert result == ""
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
def test_multiple_matches(self):
"""Test finding multiple occurrences of search term in long text."""
# Create text with Python mentions far apart to get separate snippets
separator = " This is filler text. " * 20 # ~400 chars of padding
separator = " This is filler text. " * 20
text = (
"Python is great for machine learning."
+ separator
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
+ "The Python community is very supportive."
)
snippets = SearchController._generate_snippets(text, "Python")
# With enough separation, we should get multiple snippets
assert len(snippets) >= 2 # At least 2 distinct snippets
snippets = SnippetGenerator.generate(text, "Python")
assert len(snippets) >= 2
# Each snippet should contain "Python"
for snippet in snippets:
assert "python" in snippet.lower()
def test_single_match(self):
"""Test single occurrence returns one snippet."""
text = "This document discusses artificial intelligence and its applications."
snippets = SearchController._generate_snippets(text, "artificial intelligence")
snippets = SnippetGenerator.generate(text, "artificial intelligence")
assert len(snippets) == 1
assert "artificial intelligence" in snippets[0].lower()
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
def test_no_matches(self):
"""Test no matches returns empty list."""
text = "This is some random text without the search term."
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
assert snippets == []
def test_case_insensitive_search(self):
"""Test search is case insensitive."""
# Add enough text between matches to get separate snippets
text = (
"MACHINE LEARNING is important for modern applications. "
+ "It requires lots of data and computational resources. " * 5 # Padding
+ "It requires lots of data and computational resources. " * 5
+ "Machine Learning rocks and transforms industries. "
+ "Deep learning is a subset of it. " * 5 # More padding
+ "Deep learning is a subset of it. " * 5
+ "Finally, machine learning will shape our future."
)
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
# Should find at least 2 (might be 3 if text is long enough)
assert len(snippets) >= 2
for snippet in snippets:
assert "machine learning" in snippet.lower()
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
def test_partial_match_fallback(self):
"""Test fallback to first word when exact phrase not found."""
text = "We use machine intelligence for processing."
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
# Should fall back to finding "machine"
assert len(snippets) == 1
assert "machine" in snippets[0].lower()
def test_snippet_ellipsis(self):
"""Test ellipsis added for truncated snippets."""
# Long text where match is in the middle
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
assert len(snippets) == 1
assert "..." in snippets[0] # Should have ellipsis
assert "..." in snippets[0]
assert "TARGET_WORD" in snippets[0]
def test_overlapping_snippets_deduplicated(self):
"""Test overlapping matches don't create duplicate snippets."""
text = "test test test word" * 10 # Repeated pattern
snippets = SearchController._generate_snippets(text, "test")
text = "test test test word" * 10
snippets = SnippetGenerator.generate(text, "test")
# Should get unique snippets, not duplicates
assert len(snippets) <= 3
assert len(snippets) == len(set(snippets)) # All unique
assert len(snippets) == len(set(snippets))
def test_empty_inputs(self):
"""Test empty text or search term returns empty list."""
assert SearchController._generate_snippets("", "search") == []
assert SearchController._generate_snippets("text", "") == []
assert SearchController._generate_snippets("", "") == []
assert SnippetGenerator.generate("", "search") == []
assert SnippetGenerator.generate("text", "") == []
assert SnippetGenerator.generate("", "") == []
def test_max_snippets_limit(self):
"""Test respects max_snippets parameter."""
# Create text with well-separated occurrences
separator = " filler " * 50 # Ensure snippets don't overlap
text = ("Python is amazing" + separator) * 10 # 10 occurrences
separator = " filler " * 50
text = ("Python is amazing" + separator) * 10
# Test with different limits
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
assert len(snippets_1) == 1
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
assert len(snippets_2) == 2
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
assert len(snippets_5) == 5
def test_snippet_length(self):
"""Test snippet length is reasonable."""
text = "word " * 200 # Long text
snippets = SearchController._generate_snippets(text, "word")
text = "word " * 200
snippets = SnippetGenerator.generate(text, "word")
for snippet in snippets:
# Default max_length is 150 + some context
assert len(snippet) <= 200 # Some buffer for ellipsis
assert len(snippet) <= 200
class TestFullPipeline:
@@ -157,7 +149,6 @@ class TestFullPipeline:
def test_webvtt_to_snippets_integration(self):
"""Test full pipeline from WebVTT to search snippets."""
# Create WebVTT with well-separated content for multiple snippets
webvtt = (
"""WEBVTT
@@ -182,17 +173,362 @@ class TestFullPipeline:
"""
)
# Extract and generate snippets
plain_text = SearchController._extract_webvtt_text(webvtt)
snippets = SearchController._generate_snippets(plain_text, "machine learning")
plain_text = WebVTTProcessor.extract_text(webvtt)
snippets = SnippetGenerator.generate(plain_text, "machine learning")
# Should find at least 2 snippets (text might still be close together)
assert len(snippets) >= 1 # At minimum one snippet containing matches
assert len(snippets) <= 3 # At most 3 by default
assert len(snippets) >= 1
assert len(snippets) <= 3
# No WebVTT artifacts in snippets
for snippet in snippets:
assert "machine learning" in snippet.lower()
assert "<v Speaker" not in snippet
assert "00:00" not in snippet
assert "-->" not in snippet
class TestMultiWordQueryBehavior:
"""Tests for multi-word query behavior and exact phrase matching."""
def test_multi_word_query_snippet_behavior(self):
"""Test that multi-word queries generate snippets based on exact phrase matching."""
sample_text = """This is a sample transcript where user Alice is talking.
Later in the conversation, jordan mentions something important.
The user jordan collaboration was successful.
Another user named Bob joins the discussion."""
user_snippets = SnippetGenerator.generate(sample_text, "user")
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
assert len(multi_word_snippets) == 1, (
"Should return exactly 1 snippet for 'user jordan' "
"(only the exact phrase match, not individual word occurrences)"
)
snippet = multi_word_snippets[0]
assert (
"user jordan" in snippet.lower()
), "The snippet should contain the exact phrase 'user jordan'"
assert (
"alice" not in snippet.lower()
), "The snippet should not include the first standalone 'user' with Alice"
def test_multi_word_query_without_exact_match(self):
"""Test snippet generation when exact phrase is not found."""
sample_text = """User Alice is here. Bob and jordan are talking.
Later jordan mentions something. The user is happy."""
snippets = SnippetGenerator.generate(sample_text, "user jordan")
assert (
len(snippets) >= 1
), "Should find at least 1 snippet when falling back to first word"
all_snippets_text = " ".join(snippets).lower()
assert (
"user" in all_snippets_text
), "Snippets should contain 'user' (the first word)"
def test_exact_phrase_at_text_boundaries(self):
"""Test snippet generation when exact phrase appears at text boundaries."""
text_start = "user jordan started the meeting. Other content here."
snippets = SnippetGenerator.generate(text_start, "user jordan")
assert len(snippets) == 1
assert "user jordan" in snippets[0].lower()
text_end = "Other content here. The meeting ended with user jordan"
snippets = SnippetGenerator.generate(text_end, "user jordan")
assert len(snippets) == 1
assert "user jordan" in snippets[0].lower()
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
sample_text = """This is a sample transcript where user Alice is talking.
Later in the conversation, jordan mentions something important.
The user jordan collaboration was successful.
Another user named Bob joins the discussion."""
search_query = "user jordan"
snippets = SnippetGenerator.generate(sample_text, search_query)
assert len(snippets) == 1, (
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
f"got {len(snippets)}. Should ignore individual word occurrences."
)
snippet = snippets[0]
assert (
search_query in snippet.lower()
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
assert (
"jordan mentions" in snippet.lower()
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
assert (
"alice" not in snippet.lower()
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
text_2 = """The alpha version was released.
Beta testing started yesterday.
The alpha beta integration is complete."""
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
assert (
"version" not in snippets_2[0].lower()
), "Should not include first separate occurrence"
class TestSnippetGenerationEnhanced:
"""Additional snippet generation tests from test_search_enhancements.py."""
def test_snippet_generation_from_webvtt(self):
"""Test snippet generation from WebVTT content."""
webvtt_content = """WEBVTT
00:00:00.000 --> 00:00:05.000
This is the beginning of the transcript
00:00:05.000 --> 00:00:10.000
The search term appears here in the middle
00:00:10.000 --> 00:00:15.000
And this is the end of the content"""
plain_text = WebVTTProcessor.extract_text(webvtt_content)
snippets = SnippetGenerator.generate(plain_text, "search term")
assert len(snippets) > 0
assert any("search term" in snippet.lower() for snippet in snippets)
def test_extract_webvtt_text_with_malformed_variations(self):
"""Test WebVTT extraction with various malformed content."""
malformed_vtt = "This is not valid WebVTT content"
result = WebVTTProcessor.extract_text(malformed_vtt)
assert result == ""
partial_vtt = "WEBVTT\nNo timestamps here"
result = WebVTTProcessor.extract_text(partial_vtt)
assert result == "" or "No timestamps" not in result
class TestPureFunctions:
"""Test the pure functions extracted for functional programming."""
def test_find_all_matches(self):
"""Test finding all match positions in text."""
text = "Python is great. Python is powerful. I love Python."
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
assert matches == [0, 17, 44]
matches = list(SnippetGenerator.find_all_matches(text, "python"))
assert matches == [0, 17, 44]
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
assert matches == []
matches = list(SnippetGenerator.find_all_matches("", "test"))
assert matches == []
matches = list(SnippetGenerator.find_all_matches("test", ""))
assert matches == []
def test_create_snippet(self):
"""Test creating a snippet from a match position."""
text = "This is a long text with the word Python in the middle and more text after."
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
assert "Python" in snippet.text()
assert snippet.start >= 0
assert snippet.end <= len(text)
assert isinstance(snippet, SnippetCandidate)
assert len(snippet.text()) > 0
assert snippet.start <= snippet.end
long_text = "A" * 200
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
assert snippet.text().startswith("...")
assert snippet.text().endswith("...")
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
assert snippet.start == 0
assert "short text" in snippet.text()
def test_filter_non_overlapping(self):
"""Test filtering overlapping snippets."""
candidates = [
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
SnippetCandidate(
_text="Third snippet", start=40, _original_text_length=100
),
SnippetCandidate(
_text="Fourth snippet", start=65, _original_text_length=100
),
]
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
assert filtered == [
"First snippet...",
"...Third snippet...",
"...Fourth snippet...",
]
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
assert filtered == []
def test_generate_integration(self):
"""Test the main SnippetGenerator.generate function."""
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
snippets = SnippetGenerator.generate(text, "machine learning")
assert len(snippets) <= 3
assert all("machine learning" in s.lower() for s in snippets)
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
assert len(snippets) <= 2
snippets = SnippetGenerator.generate(text, "machine vision")
assert len(snippets) > 0
assert any("machine" in s.lower() for s in snippets)
def test_extract_webvtt_text_basic(self):
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Hello world
00:00:02.000 --> 00:00:04.000
This is a test"""
result = WebVTTProcessor.extract_text(webvtt)
assert "Hello world" in result
assert "This is a test" in result
# Test empty input
assert WebVTTProcessor.extract_text("") == ""
assert WebVTTProcessor.extract_text(None) == ""
def test_generate_webvtt_snippets(self):
"""Test generating snippets from WebVTT content."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Python programming is great
00:00:02.000 --> 00:00:04.000
Learn Python today"""
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
assert len(snippets) > 0
assert any("Python" in s for s in snippets)
snippets = WebVTTProcessor.generate_snippets("", "Python")
assert snippets == []
def test_from_summary(self):
"""Test generating snippets from summary text."""
summary = "This meeting discussed Python development and machine learning applications."
snippets = SnippetGenerator.from_summary(summary, "Python")
assert len(snippets) > 0
assert any("Python" in s for s in snippets)
long_summary = "Python " * 20
snippets = SnippetGenerator.from_summary(long_summary, "Python")
assert len(snippets) <= 2
def test_combine_sources(self):
"""Test combining snippets from multiple sources."""
summary = "Python is a great programming language."
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Learn Python programming
00:00:02.000 --> 00:00:04.000
Python is powerful"""
snippets, total_count = SnippetGenerator.combine_sources(
summary, webvtt, "Python", max_total=3
)
assert len(snippets) <= 3
assert len(snippets) > 0
assert total_count > 0
snippets, total_count = SnippetGenerator.combine_sources(
summary, None, "Python", max_total=3
)
assert len(snippets) > 0
assert all("Python" in s for s in snippets)
assert total_count == 1
snippets, total_count = SnippetGenerator.combine_sources(
None, webvtt, "Python", max_total=3
)
assert len(snippets) > 0
assert total_count == 2
long_summary = "Python " * 10
snippets, total_count = SnippetGenerator.combine_sources(
long_summary, webvtt, "Python", max_total=2
)
assert len(snippets) == 2
assert total_count >= 10
def test_match_counting_sum_logic(self):
"""Test that match counting correctly sums matches from both sources."""
summary = "data science uses data analysis and data mining techniques"
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Big data processing
00:00:02.000 --> 00:00:04.000
data visualization and data storage"""
snippets, total_count = SnippetGenerator.combine_sources(
summary, webvtt, "data", max_total=3
)
assert total_count == 6
assert len(snippets) <= 3
summary_snippets, summary_count = SnippetGenerator.combine_sources(
summary, None, "data", max_total=3
)
assert summary_count == 3
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
None, webvtt, "data", max_total=3
)
assert webvtt_count == 3
snippets_empty, count_empty = SnippetGenerator.combine_sources(
None, None, "data", max_total=3
)
assert snippets_empty == []
assert count_empty == 0
def test_edge_cases(self):
"""Test edge cases for the pure functions."""
text = "Test with special: @#$%^&*() characters"
snippets = SnippetGenerator.generate(text, "@#$%")
assert len(snippets) > 0
long_query = "a" * 100
snippets = SnippetGenerator.generate("Some text", long_query)
assert snippets == []
text = "Unicode test: café, naïve, 日本語"
snippets = SnippetGenerator.generate(text, "café")
assert len(snippets) > 0
assert "café" in snippets[0]

873
server/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,26 +1,67 @@
import React from "react";
import React, { useEffect } from "react";
import { Pagination, IconButton, ButtonGroup } from "@chakra-ui/react";
import { LuChevronLeft, LuChevronRight } from "react-icons/lu";
// explicitly 1-based to prevent +/-1-confusion errors
export const FIRST_PAGE = 1 as PaginationPage;
export const parsePaginationPage = (
page: number,
):
| {
value: PaginationPage;
}
| {
error: string;
} => {
if (page < FIRST_PAGE)
return {
error: "Page must be greater than 0",
};
if (!Number.isInteger(page))
return {
error: "Page must be an integer",
};
return {
value: page as PaginationPage,
};
};
export type PaginationPage = number & { __brand: "PaginationPage" };
export const PaginationPage = (page: number): PaginationPage => {
const v = parsePaginationPage(page);
if ("error" in v) throw new Error(v.error);
return v.value;
};
export const paginationPageTo0Based = (page: PaginationPage): number =>
page - FIRST_PAGE;
type PaginationProps = {
page: number;
setPage: (page: number) => void;
page: PaginationPage;
setPage: (page: PaginationPage) => void;
total: number;
size: number;
};
export const totalPages = (total: number, size: number) => {
return Math.ceil(total / size);
};
export default function PaginationComponent(props: PaginationProps) {
const { page, setPage, total, size } = props;
const totalPages = Math.ceil(total / size);
if (totalPages <= 1) return null;
useEffect(() => {
if (page > totalPages(total, size)) {
console.error(
`Page number (${page}) is greater than total pages (${totalPages}) in pagination`,
);
}
}, [page, totalPages(total, size)]);
return (
<Pagination.Root
count={total}
pageSize={size}
page={page}
onPageChange={(details) => setPage(details.page)}
onPageChange={(details) => setPage(PaginationPage(details.page))}
style={{ display: "flex", justifyContent: "center" }}
>
<ButtonGroup variant="ghost" size="xs">

View File

@@ -1,34 +0,0 @@
import React, { useState } from "react";
import { Flex, Input, Button } from "@chakra-ui/react";
interface SearchBarProps {
onSearch: (searchTerm: string) => void;
}
export default function SearchBar({ onSearch }: SearchBarProps) {
const [searchInputValue, setSearchInputValue] = useState("");
const handleSearch = () => {
onSearch(searchInputValue);
};
const handleKeyDown = (event: React.KeyboardEvent) => {
if (event.key === "Enter") {
handleSearch();
}
};
return (
<Flex alignItems="center">
<Input
placeholder="Search transcriptions..."
value={searchInputValue}
onChange={(e) => setSearchInputValue(e.target.value)}
onKeyDown={handleKeyDown}
/>
<Button ml={2} onClick={handleSearch}>
Search
</Button>
</Flex>
);
}

View File

@@ -4,8 +4,8 @@ import { LuMenu, LuTrash, LuRotateCw } from "react-icons/lu";
interface TranscriptActionsMenuProps {
transcriptId: string;
onDelete: (transcriptId: string) => (e: any) => void;
onReprocess: (transcriptId: string) => (e: any) => void;
onDelete: (transcriptId: string) => void;
onReprocess: (transcriptId: string) => void;
}
export default function TranscriptActionsMenu({
@@ -24,11 +24,17 @@ export default function TranscriptActionsMenu({
<Menu.Content>
<Menu.Item
value="reprocess"
onClick={(e) => onReprocess(transcriptId)(e)}
onClick={() => onReprocess(transcriptId)}
>
<LuRotateCw /> Reprocess
</Menu.Item>
<Menu.Item value="delete" onClick={(e) => onDelete(transcriptId)(e)}>
<Menu.Item
value="delete"
onClick={(e) => {
e.stopPropagation();
onDelete(transcriptId);
}}
>
<LuTrash /> Delete
</Menu.Item>
</Menu.Content>

View File

@@ -1,27 +1,290 @@
import React from "react";
import { Box, Stack, Text, Flex, Link, Spinner } from "@chakra-ui/react";
import React, { useState } from "react";
import {
Box,
Stack,
Text,
Flex,
Link,
Spinner,
Badge,
HStack,
VStack,
} from "@chakra-ui/react";
import NextLink from "next/link";
import { GetTranscriptMinimal } from "../../../api";
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
import TranscriptStatusIcon from "./TranscriptStatusIcon";
import TranscriptActionsMenu from "./TranscriptActionsMenu";
import {
highlightMatches,
generateTextFragment,
} from "../../../lib/textHighlight";
import { SearchResult } from "../../../api";
interface TranscriptCardsProps {
transcripts: GetTranscriptMinimal[];
onDelete: (transcriptId: string) => (e: any) => void;
onReprocess: (transcriptId: string) => (e: any) => void;
loading?: boolean;
results: SearchResult[];
query: string;
isLoading?: boolean;
onDelete: (transcriptId: string) => void;
onReprocess: (transcriptId: string) => void;
}
function highlightText(text: string, query: string): React.ReactNode {
if (!query) return text;
const matches = highlightMatches(text, query);
if (matches.length === 0) return text;
// Sort matches by index to process them in order
const sortedMatches = [...matches].sort((a, b) => a.index - b.index);
const parts: React.ReactNode[] = [];
let lastIndex = 0;
sortedMatches.forEach((match, i) => {
// Add text before the match
if (match.index > lastIndex) {
parts.push(
<Text as="span" key={`text-${i}`} display="inline">
{text.slice(lastIndex, match.index)}
</Text>,
);
}
// Add the highlighted match
parts.push(
<Text
as="mark"
key={`match-${i}`}
bg="yellow.200"
px={0.5}
display="inline"
>
{match.match}
</Text>,
);
lastIndex = match.index + match.match.length;
});
// Add remaining text after last match
if (lastIndex < text.length) {
parts.push(
<Text as="span" key={`text-end`} display="inline">
{text.slice(lastIndex)}
</Text>,
);
}
return parts;
}
const transcriptHref = (
transcriptId: string,
mainSnippet: string,
query: string,
): `/transcripts/${string}` => {
const urlTextFragment = mainSnippet
? generateTextFragment(mainSnippet, query)
: null;
const urlTextFragmentWithHash = urlTextFragment
? `#${urlTextFragment.k}=${encodeURIComponent(urlTextFragment.v)}`
: "";
return `/transcripts/${transcriptId}${urlTextFragmentWithHash}`;
};
// note that it's strongly tied to search logic - in case you want to use it independently, refactor
function TranscriptCard({
result,
query,
onDelete,
onReprocess,
}: {
result: SearchResult;
query: string;
onDelete: (transcriptId: string) => void;
onReprocess: (transcriptId: string) => void;
}) {
const [isExpanded, setIsExpanded] = useState(false);
const mainSnippet = result.search_snippets[0];
const additionalSnippets = result.search_snippets.slice(1);
const totalMatches = result.total_match_count || 0;
const snippetsShown = result.search_snippets.length;
const remainingMatches = totalMatches - snippetsShown;
const hasAdditionalSnippets = additionalSnippets.length > 0;
const resultTitle = result.title || "Unnamed Transcript";
const formattedDuration = result.duration
? formatTimeMs(result.duration)
: "N/A";
const formattedDate = formatLocalDate(result.created_at);
const source =
result.source_kind === "room"
? result.room_name || result.room_id
: result.source_kind;
const handleExpandClick = (e: React.MouseEvent) => {
e.preventDefault();
e.stopPropagation();
setIsExpanded(!isExpanded);
};
return (
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
<Flex justify="space-between" alignItems="flex-start" gap="2">
<Box>
<TranscriptStatusIcon status={result.status} />
</Box>
<Box flex="1">
{/* Title with highlighting and text fragment for deep linking */}
<Link
as={NextLink}
href={transcriptHref(result.id, mainSnippet, query)}
fontWeight="600"
display="block"
mb={2}
>
{highlightText(resultTitle, query)}
</Link>
{/* Metadata - Horizontal on desktop, vertical on mobile */}
<Flex
direction={{ base: "column", md: "row" }}
gap={{ base: 1, md: 2 }}
fontSize="xs"
color="gray.600"
flexWrap="wrap"
align={{ base: "flex-start", md: "center" }}
>
<Flex align="center" gap={1}>
<Text fontWeight="medium" color="gray.500">
Source:
</Text>
<Text>{source}</Text>
</Flex>
<Text display={{ base: "none", md: "block" }} color="gray.400">
</Text>
<Flex align="center" gap={1}>
<Text fontWeight="medium" color="gray.500">
Date:
</Text>
<Text>{formattedDate}</Text>
</Flex>
<Text display={{ base: "none", md: "block" }} color="gray.400">
</Text>
<Flex align="center" gap={1}>
<Text fontWeight="medium" color="gray.500">
Duration:
</Text>
<Text>{formattedDuration}</Text>
</Flex>
</Flex>
{/* Search Results Section - only show when searching */}
{mainSnippet && (
<>
{/* Main Snippet */}
<Box
mt={3}
p={2}
bg="gray.50"
borderLeft="2px solid"
borderLeftColor="blue.400"
borderRadius="sm"
fontSize="xs"
>
<Text color="gray.700">
{highlightText(mainSnippet, query)}
</Text>
</Box>
{hasAdditionalSnippets && (
<>
<Flex
mt={2}
p={2}
bg="blue.50"
borderRadius="sm"
cursor="pointer"
onClick={handleExpandClick}
_hover={{ bg: "blue.100" }}
align="center"
justify="space-between"
>
<HStack gap={2}>
<Badge
bg="blue.500"
color="white"
fontSize="xs"
px={2}
borderRadius="full"
>
{remainingMatches > 0
? `${additionalSnippets.length + remainingMatches}+`
: additionalSnippets.length}
</Badge>
<Text fontSize="xs" color="blue.600" fontWeight="medium">
more{" "}
{additionalSnippets.length + remainingMatches === 1
? "match"
: "matches"}
{remainingMatches > 0 &&
` (${additionalSnippets.length} shown)`}
</Text>
</HStack>
<Text fontSize="xs" color="blue.600">
{isExpanded ? "▲" : "▼"}
</Text>
</Flex>
{/* Additional Snippets */}
{isExpanded && (
<VStack align="stretch" gap={2} mt={2}>
{additionalSnippets.map((snippet, index) => (
<Box
key={index}
p={2}
bg="gray.50"
borderLeft="2px solid"
borderLeftColor="gray.300"
borderRadius="sm"
fontSize="xs"
>
<Text color="gray.700">
{highlightText(snippet, query)}
</Text>
</Box>
))}
</VStack>
)}
</>
)}
</>
)}
</Box>
<TranscriptActionsMenu
transcriptId={result.id}
onDelete={onDelete}
onReprocess={onReprocess}
/>
</Flex>
</Box>
);
}
export default function TranscriptCards({
transcripts,
results,
query,
isLoading,
onDelete,
onReprocess,
loading,
}: TranscriptCardsProps) {
return (
<Box display={{ base: "block", lg: "none" }} position="relative">
{loading && (
<Box position="relative">
{isLoading && (
<Flex
position="absolute"
top={0}
@@ -37,48 +300,19 @@ export default function TranscriptCards({
</Flex>
)}
<Box
opacity={loading ? 0.9 : 1}
pointerEvents={loading ? "none" : "auto"}
opacity={isLoading ? 0.9 : 1}
pointerEvents={isLoading ? "none" : "auto"}
transition="opacity 0.2s ease-in-out"
>
<Stack gap={2}>
{transcripts.map((item) => (
<Box
key={item.id}
borderWidth={1}
p={4}
borderRadius="md"
fontSize="sm"
>
<Flex justify="space-between" alignItems="flex-start" gap="2">
<Box>
<TranscriptStatusIcon status={item.status} />
</Box>
<Box flex="1">
<Link
as={NextLink}
href={`/transcripts/${item.id}`}
fontWeight="600"
display="block"
>
{item.title || "Unnamed Transcript"}
</Link>
<Text>
Source:{" "}
{item.source_kind === "room"
? item.room_name
: item.source_kind}
</Text>
<Text>Date: {formatLocalDate(item.created_at)}</Text>
<Text>Duration: {formatTimeMs(item.duration)}</Text>
</Box>
<TranscriptActionsMenu
transcriptId={item.id}
onDelete={onDelete}
onReprocess={onReprocess}
/>
</Flex>
</Box>
<Stack gap={3}>
{results.map((result) => (
<TranscriptCard
key={result.id}
result={result}
query={query}
onDelete={onDelete}
onReprocess={onReprocess}
/>
))}
</Stack>
</Box>

View File

@@ -1,99 +0,0 @@
import React from "react";
import { Box, Table, Link, Flex, Spinner } from "@chakra-ui/react";
import NextLink from "next/link";
import { GetTranscriptMinimal } from "../../../api";
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
import TranscriptStatusIcon from "./TranscriptStatusIcon";
import TranscriptActionsMenu from "./TranscriptActionsMenu";
interface TranscriptTableProps {
transcripts: GetTranscriptMinimal[];
onDelete: (transcriptId: string) => (e: any) => void;
onReprocess: (transcriptId: string) => (e: any) => void;
loading?: boolean;
}
export default function TranscriptTable({
transcripts,
onDelete,
onReprocess,
loading,
}: TranscriptTableProps) {
return (
<Box display={{ base: "none", lg: "block" }} position="relative">
{loading && (
<Flex
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
align="center"
justify="center"
>
<Spinner size="xl" color="gray.700" />
</Flex>
)}
<Box
opacity={loading ? 0.9 : 1}
pointerEvents={loading ? "none" : "auto"}
transition="opacity 0.2s ease-in-out"
>
<Table.Root>
<Table.Header>
<Table.Row>
<Table.ColumnHeader
width="16px"
fontWeight="600"
></Table.ColumnHeader>
<Table.ColumnHeader width="400px" fontWeight="600">
Transcription Title
</Table.ColumnHeader>
<Table.ColumnHeader width="150px" fontWeight="600">
Source
</Table.ColumnHeader>
<Table.ColumnHeader width="200px" fontWeight="600">
Date
</Table.ColumnHeader>
<Table.ColumnHeader width="100px" fontWeight="600">
Duration
</Table.ColumnHeader>
<Table.ColumnHeader
width="50px"
fontWeight="600"
></Table.ColumnHeader>
</Table.Row>
</Table.Header>
<Table.Body>
{transcripts.map((item) => (
<Table.Row key={item.id}>
<Table.Cell>
<TranscriptStatusIcon status={item.status} />
</Table.Cell>
<Table.Cell>
<Link as={NextLink} href={`/transcripts/${item.id}`}>
{item.title || "Unnamed Transcript"}
</Link>
</Table.Cell>
<Table.Cell>
{item.source_kind === "room"
? item.room_name
: item.source_kind}
</Table.Cell>
<Table.Cell>{formatLocalDate(item.created_at)}</Table.Cell>
<Table.Cell>{formatTimeMs(item.duration)}</Table.Cell>
<Table.Cell>
<TranscriptActionsMenu
transcriptId={item.id}
onDelete={onDelete}
onReprocess={onReprocess}
/>
</Table.Cell>
</Table.Row>
))}
</Table.Body>
</Table.Root>
</Box>
</Box>
);
}

View File

@@ -1,33 +1,264 @@
"use client";
import React, { useState, useEffect } from "react";
import { Flex, Spinner, Heading, Text, Link } from "@chakra-ui/react";
import useTranscriptList from "../transcripts/useTranscriptList";
import {
Flex,
Spinner,
Heading,
Text,
Link,
Box,
Stack,
Input,
Button,
IconButton,
} from "@chakra-ui/react";
import {
useQueryState,
parseAsString,
parseAsInteger,
parseAsStringLiteral,
} from "nuqs";
import { LuX } from "react-icons/lu";
import { useSearchTranscripts } from "../transcripts/useSearchTranscripts";
import useSessionUser from "../../lib/useSessionUser";
import { Room } from "../../api";
import Pagination from "./_components/Pagination";
import { Room, SourceKind, SearchResult, $SourceKind } from "../../api";
import useApi from "../../lib/useApi";
import { useError } from "../../(errors)/errorContext";
import { SourceKind } from "../../api";
import FilterSidebar from "./_components/FilterSidebar";
import SearchBar from "./_components/SearchBar";
import TranscriptTable from "./_components/TranscriptTable";
import Pagination, {
FIRST_PAGE,
PaginationPage,
parsePaginationPage,
totalPages as getTotalPages,
} from "./_components/Pagination";
import TranscriptCards from "./_components/TranscriptCards";
import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
import { formatLocalDate } from "../../lib/time";
import { RECORD_A_MEETING_URL } from "../../api/urls";
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
const usePrefetchRooms = (setRooms: (rooms: Room[]) => void): void => {
const { setError } = useError();
const api = useApi();
useEffect(() => {
if (!api) return;
api
.v1RoomsList({ page: 1 })
.then((rooms) => setRooms(rooms.items))
.catch((err) => setError(err, "There was an error fetching the rooms"));
}, [api, setError]);
};
const SearchForm: React.FC<{
setPage: (page: PaginationPage) => void;
sourceKind: SourceKind | null;
roomId: string | null;
setSourceKind: (sourceKind: SourceKind | null) => void;
setRoomId: (roomId: string | null) => void;
rooms: Room[];
searchQuery: string | null;
setSearchQuery: (query: string | null) => void;
}> = ({
setPage,
sourceKind,
roomId,
setRoomId,
setSourceKind,
rooms,
searchQuery,
setSearchQuery,
}) => {
// to keep the search input controllable + more fine grained control (urlSearchQuery is updated on submits)
const [searchInputValue, setSearchInputValue] = useState(searchQuery || "");
const handleSearchQuerySubmit = async (d: FormData) => {
await setSearchQuery((d.get(SEARCH_FORM_QUERY_INPUT_NAME) as string) || "");
};
const handleClearSearch = () => {
setSearchInputValue("");
setSearchQuery(null);
setPage(FIRST_PAGE);
};
return (
<Stack gap={2}>
<form action={handleSearchQuerySubmit}>
<Flex alignItems="center">
<Box position="relative" flex="1">
<Input
placeholder="Search transcriptions..."
value={searchInputValue}
onChange={(e) => setSearchInputValue(e.target.value)}
name={SEARCH_FORM_QUERY_INPUT_NAME}
pr={searchQuery ? "2.5rem" : undefined}
/>
{searchQuery && (
<IconButton
aria-label="Clear search"
size="sm"
variant="ghost"
onClick={handleClearSearch}
position="absolute"
right="0.25rem"
top="50%"
transform="translateY(-50%)"
_hover={{ bg: "gray.100" }}
>
<LuX />
</IconButton>
)}
</Box>
<Button ml={2} type="submit">
Search
</Button>
</Flex>
</form>
<UnderSearchFormFilterIndicators
sourceKind={sourceKind}
roomId={roomId}
setSourceKind={setSourceKind}
setRoomId={setRoomId}
rooms={rooms}
/>
</Stack>
);
};
const UnderSearchFormFilterIndicators: React.FC<{
sourceKind: SourceKind | null;
roomId: string | null;
setSourceKind: (sourceKind: SourceKind | null) => void;
setRoomId: (roomId: string | null) => void;
rooms: Room[];
}> = ({ sourceKind, roomId, setRoomId, setSourceKind, rooms }) => {
return (
<>
{(sourceKind || roomId) && (
<Flex gap={2} flexWrap="wrap" align="center">
<Text fontSize="sm" color="gray.600">
Active filters:
</Text>
{sourceKind && (
<Flex
align="center"
px={2}
py={1}
bg="blue.100"
borderRadius="md"
fontSize="xs"
gap={1}
>
<Text>
{roomId
? `Room: ${
rooms.find((r) => r.id === roomId)?.name || roomId
}`
: `Source: ${sourceKind}`}
</Text>
<Button
size="xs"
variant="ghost"
minW="auto"
h="auto"
p="1px"
onClick={() => {
setSourceKind(null);
// TODO questionable
setRoomId(null);
}}
_hover={{ bg: "blue.200" }}
aria-label="Clear filter"
>
<LuX size={14} />
</Button>
</Flex>
)}
</Flex>
)}
</>
);
};
const EmptyResult: React.FC<{
searchQuery: string;
}> = ({ searchQuery }) => {
return (
<Flex flexDir="column" alignItems="center" justifyContent="center" py={8}>
<Text textAlign="center">
{searchQuery
? `No results found for "${searchQuery}". Try adjusting your search terms.`
: "No transcripts found, but you can "}
{!searchQuery && (
<>
<Link href={RECORD_A_MEETING_URL} color="blue.500">
record a meeting
</Link>
{" to get started."}
</>
)}
</Text>
</Flex>
);
};
export default function TranscriptBrowser() {
const [selectedSourceKind, setSelectedSourceKind] =
useState<SourceKind | null>(null);
const [selectedRoomId, setSelectedRoomId] = useState("");
const [rooms, setRooms] = useState<Room[]>([]);
const [page, setPage] = useState(1);
const [searchTerm, setSearchTerm] = useState("");
const { loading, response, refetch } = useTranscriptList(
page,
selectedSourceKind,
selectedRoomId,
searchTerm,
const [urlSearchQuery, setUrlSearchQuery] = useQueryState(
"q",
parseAsString.withDefault("").withOptions({ shallow: false }),
);
const [urlSourceKind, setUrlSourceKind] = useQueryState(
"source",
parseAsStringLiteral($SourceKind.enum).withOptions({
shallow: false,
}),
);
const [urlRoomId, setUrlRoomId] = useQueryState(
"room",
parseAsString.withDefault("").withOptions({ shallow: false }),
);
const [urlPage, setPage] = useQueryState(
"page",
parseAsInteger.withDefault(1).withOptions({ shallow: false }),
);
const [page, _setSafePage] = useState(FIRST_PAGE);
// safety net
useEffect(() => {
const maybePage = parsePaginationPage(urlPage);
if ("error" in maybePage) {
setPage(FIRST_PAGE).then(() => {
/*may be called n times we dont care*/
});
return;
}
_setSafePage(maybePage.value);
}, [urlPage]);
const [rooms, setRooms] = useState<Room[]>([]);
const pageSize = 20;
const {
results,
totalCount: totalResults,
isLoading,
reload,
} = useSearchTranscripts(
urlSearchQuery,
{
roomIds: urlRoomId ? [urlRoomId] : null,
sourceKind: urlSourceKind,
},
{
pageSize,
page,
},
);
const totalPages = getTotalPages(totalResults, pageSize);
const userName = useSessionUser().name;
const [deletionLoading, setDeletionLoading] = useState(false);
const api = useApi();
@@ -35,37 +266,73 @@ export default function TranscriptBrowser() {
const cancelRef = React.useRef(null);
const [transcriptToDeleteId, setTranscriptToDeleteId] =
React.useState<string>();
const [deletedItemIds, setDeletedItemIds] = React.useState<string[]>();
useEffect(() => {
setDeletedItemIds([]);
}, [page, response]);
useEffect(() => {
if (!api) return;
api
.v1RoomsList({ page: 1 })
.then((rooms) => setRooms(rooms.items))
.catch((err) => setError(err, "There was an error fetching the rooms"));
}, [api]);
usePrefetchRooms(setRooms);
const handleFilterTranscripts = (
sourceKind: SourceKind | null,
roomId: string,
) => {
setSelectedSourceKind(sourceKind);
setSelectedRoomId(roomId);
setUrlSourceKind(sourceKind);
setUrlRoomId(roomId);
setPage(1);
};
const handleSearch = (searchTerm: string) => {
setPage(1);
setSearchTerm(searchTerm);
setSelectedSourceKind(null);
setSelectedRoomId("");
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
const confirmDeleteTranscript = (transcriptId: string) => {
if (!api || deletionLoading) return;
setDeletionLoading(true);
api
.v1TranscriptDelete({ transcriptId })
.then(() => {
setDeletionLoading(false);
onCloseDeletion();
reload();
})
.catch((err) => {
setDeletionLoading(false);
setError(err, "There was an error deleting the transcript");
});
};
if (loading && !response)
const handleProcessTranscript = (transcriptId: string) => {
if (!api) {
console.error("API not available on handleProcessTranscript");
return;
}
api
.v1TranscriptProcess({ transcriptId })
.then((result) => {
const status =
result && typeof result === "object" && "status" in result
? (result as { status: string }).status
: undefined;
if (status === "already running") {
setError(
new Error("Processing is already running, please wait"),
"Processing is already running, please wait",
);
}
})
.catch((err) => {
setError(err, "There was an error processing the transcript");
});
};
const transcriptToDelete = results?.find(
(i) => i.id === transcriptToDeleteId,
);
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
const dialogDate = transcriptToDelete?.created_at
? formatLocalDate(transcriptToDelete.created_at)
: undefined;
const dialogSource =
transcriptToDelete?.source_kind === "room" && transcriptToDelete?.room_id
? transcriptToDelete.room_name || transcriptToDelete.room_id
: transcriptToDelete?.source_kind;
if (isLoading && results.length === 0) {
return (
<Flex
flexDir="column"
@@ -76,82 +343,7 @@ export default function TranscriptBrowser() {
<Spinner size="xl" />
</Flex>
);
if (!loading && !response)
return (
<Flex
flexDir="column"
alignItems="center"
justifyContent="center"
h="100%"
>
<Text>
No transcripts found, but you can&nbsp;
<Link href="/transcripts/new" className="underline">
record a meeting
</Link>
&nbsp;to get started.
</Text>
</Flex>
);
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
const confirmDeleteTranscript = (transcriptId: string) => {
if (!api || deletionLoading) return;
setDeletionLoading(true);
api
.v1TranscriptDelete({ transcriptId })
.then(() => {
refetch();
setDeletionLoading(false);
onCloseDeletion();
setDeletedItemIds((prev) =>
prev ? [...prev, transcriptId] : [transcriptId],
);
})
.catch((err) => {
setDeletionLoading(false);
setError(err, "There was an error deleting the transcript");
});
};
const handleDeleteTranscript = (transcriptId: string) => (e: any) => {
e?.stopPropagation?.();
setTranscriptToDeleteId(transcriptId);
};
const handleProcessTranscript = (transcriptId) => (e) => {
if (api) {
api
.v1TranscriptProcess({ transcriptId })
.then((result) => {
const status = (result as any).status;
if (status === "already running") {
setError(
new Error("Processing is already running, please wait"),
"Processing is already running, please wait",
);
}
})
.catch((err) => {
setError(err, "There was an error processing the transcript");
});
}
};
const transcriptToDelete = response?.items?.find(
(i) => i.id === transcriptToDeleteId,
);
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
const dialogDate = transcriptToDelete?.created_at
? formatLocalDate(transcriptToDelete.created_at)
: undefined;
const dialogSource = transcriptToDelete
? transcriptToDelete.source_kind === "room"
? transcriptToDelete.room_name || undefined
: transcriptToDelete.source_kind
: undefined;
}
return (
<Flex
@@ -168,15 +360,15 @@ export default function TranscriptBrowser() {
>
<Heading size="lg">
{userName ? `${userName}'s Transcriptions` : "Your Transcriptions"}{" "}
{loading || (deletionLoading && <Spinner size="sm" />)}
{(isLoading || deletionLoading) && <Spinner size="sm" />}
</Heading>
</Flex>
<Flex flexDir={{ base: "column", md: "row" }}>
<FilterSidebar
rooms={rooms}
selectedSourceKind={selectedSourceKind}
selectedRoomId={selectedRoomId}
selectedSourceKind={urlSourceKind}
selectedRoomId={urlRoomId}
onFilterChange={handleFilterTranscripts}
/>
@@ -188,25 +380,37 @@ export default function TranscriptBrowser() {
gap={4}
px={{ base: 0, md: 4 }}
>
<SearchBar onSearch={handleSearch} />
<Pagination
page={page}
<SearchForm
setPage={setPage}
total={response?.total || 0}
size={response?.size || 0}
/>
<TranscriptTable
transcripts={response?.items || []}
onDelete={handleDeleteTranscript}
onReprocess={handleProcessTranscript}
loading={loading}
sourceKind={urlSourceKind}
roomId={urlRoomId}
searchQuery={urlSearchQuery}
setSearchQuery={setUrlSearchQuery}
setSourceKind={setUrlSourceKind}
setRoomId={setUrlRoomId}
rooms={rooms}
/>
{totalPages > 1 ? (
<Pagination
page={page}
setPage={setPage}
total={totalResults}
size={pageSize}
/>
) : null}
<TranscriptCards
transcripts={response?.items || []}
onDelete={handleDeleteTranscript}
results={results}
query={urlSearchQuery}
isLoading={isLoading}
onDelete={setTranscriptToDeleteId}
onReprocess={handleProcessTranscript}
loading={loading}
/>
{!isLoading && results.length === 0 && (
<EmptyResult searchQuery={urlSearchQuery} />
)}
</Flex>
</Flex>

View File

@@ -5,6 +5,7 @@ import Image from "next/image";
import About from "../(aboutAndPrivacy)/about";
import Privacy from "../(aboutAndPrivacy)/privacy";
import UserInfo from "../(auth)/userInfo";
import { RECORD_A_MEETING_URL } from "../api/urls";
export default async function AppLayout({
children,
@@ -53,7 +54,7 @@ export default async function AppLayout({
{/* Text link on the right */}
<Link
as={NextLink}
href="/transcripts/new"
href={RECORD_A_MEETING_URL}
className="font-light px-2"
>
Create

View File

@@ -19,6 +19,7 @@ import useApi from "../../lib/useApi";
import useRoomList from "./useRoomList";
import { ApiError, Room } from "../../api";
import { RoomList } from "./_components/RoomList";
import { PaginationPage } from "../browse/_components/Pagination";
interface SelectOption {
label: string;
@@ -75,8 +76,9 @@ export default function RoomsList() {
const [isEditing, setIsEditing] = useState(false);
const [editRoomId, setEditRoomId] = useState("");
const api = useApi();
// TODO seems to be no setPage calls
const [page, setPage] = useState<number>(1);
const { loading, response, refetch } = useRoomList(page);
const { loading, response, refetch } = useRoomList(PaginationPage(page));
const [streams, setStreams] = useState<Stream[]>([]);
const [topics, setTopics] = useState<Topic[]>([]);
const [nameError, setNameError] = useState("");

View File

@@ -2,6 +2,7 @@ import { useEffect, useState } from "react";
import { useError } from "../../(errors)/errorContext";
import useApi from "../../lib/useApi";
import { Page_Room_ } from "../../api";
import { PaginationPage } from "../browse/_components/Pagination";
type RoomList = {
response: Page_Room_ | null;
@@ -11,7 +12,7 @@ type RoomList = {
};
//always protected
const useRoomList = (page: number): RoomList => {
const useRoomList = (page: PaginationPage): RoomList => {
const [response, setResponse] = useState<Page_Room_ | null>(null);
const [loading, setLoading] = useState<boolean>(true);
const [error, setErrorState] = useState<Error | null>(null);

View File

@@ -1,32 +0,0 @@
[
{
"id": "27c07e49-d7a3-4b86-905c-f1a047366f91",
"title": "Issue one",
"summary": "The team discusses the first issue in the list",
"timestamp": 0.0,
"transcript": "",
"duration": 33,
"segments": [
{
"text": "Let's start with issue one, Alice you've been working on that, can you give an update ?",
"start": 0.0,
"speaker": 0
},
{
"text": "Yes, I've run into an issue with the task system but Bob helped me out and I have a POC ready, should I present it now ?",
"start": 0.38,
"speaker": 1
},
{
"text": "Yeah, I had to modify the task system because it didn't account for incoming blobs",
"start": 4.5,
"speaker": 2
},
{
"text": "Cool, yeah lets see it",
"start": 5.96,
"speaker": 0
}
]
}
]

View File

@@ -11,6 +11,7 @@ import useWebRTC from "./useWebRTC";
import useAudioDevice from "./useAudioDevice";
import { Box, Flex, IconButton, Menu, RadioGroup } from "@chakra-ui/react";
import { LuScreenShare, LuMic, LuPlay, LuCircleStop } from "react-icons/lu";
import { RECORD_A_MEETING_URL } from "../../api/urls";
type RecorderProps = {
transcriptId: string;
@@ -46,7 +47,7 @@ export default function Recorder(props: RecorderProps) {
location.href = "";
break;
case ",":
location.href = "/transcripts/new";
location.href = RECORD_A_MEETING_URL;
break;
case "!":
if (record.isRecording()) return;

View File

@@ -0,0 +1,123 @@
// this hook is not great, we want to substitute it with a proper state management solution that is also not re-invention
import { useEffect, useRef, useState } from "react";
import { SearchResult, SourceKind } from "../../api";
import useApi from "../../lib/useApi";
import {
PaginationPage,
paginationPageTo0Based,
} from "../browse/_components/Pagination";
interface SearchFilters {
roomIds: readonly string[] | null;
sourceKind: SourceKind | null;
}
const EMPTY_SEARCH_FILTERS: SearchFilters = {
roomIds: null,
sourceKind: null,
};
type UseSearchTranscriptsOptions = {
pageSize: number;
page: PaginationPage;
};
interface UseSearchTranscriptsReturn {
results: SearchResult[];
totalCount: number;
isLoading: boolean;
error: unknown;
reload: () => void;
}
function hashEffectFilters(filters: SearchFilters): string {
return JSON.stringify(filters);
}
export function useSearchTranscripts(
query: string = "",
filters: SearchFilters = EMPTY_SEARCH_FILTERS,
options: UseSearchTranscriptsOptions = {
pageSize: 20,
page: PaginationPage(1),
},
): UseSearchTranscriptsReturn {
const { pageSize, page } = options;
const [reloadCount, setReloadCount] = useState(0);
const api = useApi();
const abortControllerRef = useRef<AbortController>();
const [data, setData] = useState<{ results: SearchResult[]; total: number }>({
results: [],
total: 0,
});
const [error, setError] = useState<any>();
const [isLoading, setIsLoading] = useState(false);
const filterHash = hashEffectFilters(filters);
useEffect(() => {
if (!api) {
setData({ results: [], total: 0 });
setError(undefined);
setIsLoading(false);
return;
}
if (abortControllerRef.current) {
abortControllerRef.current.abort();
}
const abortController = new AbortController();
abortControllerRef.current = abortController;
const performSearch = async () => {
setIsLoading(true);
try {
const response = await api.v1TranscriptsSearch({
q: query || "",
limit: pageSize,
offset: paginationPageTo0Based(page) * pageSize,
roomId: filters.roomIds?.[0],
sourceKind: filters.sourceKind || undefined,
});
if (abortController.signal.aborted) return;
setData(response);
setError(undefined);
} catch (err: unknown) {
if ((err as Error).name === "AbortError") {
return;
}
if (abortController.signal.aborted) {
console.error("Aborted search but error", err);
return;
}
setError(err);
} finally {
if (!abortController.signal.aborted) {
setIsLoading(false);
}
}
};
performSearch().then(() => {});
return () => {
abortController.abort();
};
}, [api, query, page, filterHash, pageSize, reloadCount]);
return {
results: data.results,
totalCount: data.total,
isLoading,
error,
reload: () => setReloadCount(reloadCount + 1),
};
}

View File

@@ -1,59 +0,0 @@
import { useEffect, useState } from "react";
import { useError } from "../../(errors)/errorContext";
import useApi from "../../lib/useApi";
import { Page_GetTranscriptMinimal_, SourceKind } from "../../api";
type TranscriptList = {
response: Page_GetTranscriptMinimal_ | null;
loading: boolean;
error: Error | null;
refetch: () => void;
};
const useTranscriptList = (
page: number,
sourceKind: SourceKind | null,
roomId: string | null,
searchTerm: string | null,
): TranscriptList => {
const [response, setResponse] = useState<Page_GetTranscriptMinimal_ | null>(
null,
);
const [loading, setLoading] = useState<boolean>(true);
const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError();
const api = useApi();
const [refetchCount, setRefetchCount] = useState(0);
const refetch = () => {
setLoading(true);
setRefetchCount(refetchCount + 1);
};
useEffect(() => {
if (!api) return;
setLoading(true);
api
.v1TranscriptsList({
page,
sourceKind,
roomId,
searchTerm,
size: 10,
})
.then((response) => {
setResponse(response);
setLoading(false);
})
.catch((err) => {
setResponse(null);
setLoading(false);
setError(err);
setErrorState(err);
});
}, [api, page, refetchCount, roomId, searchTerm, sourceKind]);
return { response, loading, error, refetch };
};
export default useTranscriptList;

View File

@@ -1002,7 +1002,7 @@ export const $SearchResponse = {
},
query: {
type: "string",
minLength: 1,
minLength: 0,
title: "Query",
description: "Search query text",
},
@@ -1065,6 +1065,20 @@ export const $SearchResult = {
],
title: "Room Id",
},
room_name: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Room Name",
},
source_kind: {
$ref: "#/components/schemas/SourceKind",
},
created_at: {
type: "string",
title: "Created At",
@@ -1101,10 +1115,18 @@ export const $SearchResult = {
title: "Search Snippets",
description: "Text snippets around search matches",
},
total_match_count: {
type: "integer",
minimum: 0,
title: "Total Match Count",
description: "Total number of matches found in the transcript",
default: 0,
},
},
type: "object",
required: [
"id",
"source_kind",
"created_at",
"status",
"rank",

View File

@@ -286,6 +286,7 @@ export class DefaultService {
* @param data.limit Results per page
* @param data.offset Number of results to skip
* @param data.roomId
* @param data.sourceKind
* @returns SearchResponse Successful Response
* @throws ApiError
*/
@@ -300,6 +301,7 @@ export class DefaultService {
limit: data.limit,
offset: data.offset,
room_id: data.roomId,
source_kind: data.sourceKind,
},
errors: {
422: "Validation Error",

View File

@@ -209,6 +209,8 @@ export type SearchResult = {
title?: string | null;
user_id?: string | null;
room_id?: string | null;
room_name?: string | null;
source_kind: SourceKind;
created_at: string;
status: string;
rank: number;
@@ -220,6 +222,10 @@ export type SearchResult = {
* Text snippets around search matches
*/
search_snippets: Array<string>;
/**
* Total number of matches found in the transcript
*/
total_match_count?: number;
};
export type SourceKind = "room" | "live" | "file";
@@ -407,6 +413,7 @@ export type V1TranscriptsSearchData = {
*/
q: string;
roomId?: string | null;
sourceKind?: SourceKind | null;
};
export type V1TranscriptsSearchResponse = SearchResponse;

2
www/app/api/urls.ts Normal file
View File

@@ -0,0 +1,2 @@
// TODO better connection with generated schema; it's duplication
export const RECORD_A_MEETING_URL = "/transcripts/new" as const;

View File

@@ -0,0 +1,62 @@
/**
* Text highlighting and text fragment generation utilities
* Used for search result highlighting and deep linking with Chrome Text Fragments
*/
import React from "react";
export interface HighlightResult {
text: string;
matches: string[];
}
/**
* Escapes special regex characters in a string
*/
function escapeRegex(str: string): string {
return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
}
export const highlightMatches = (
text: string,
query: string,
): { match: string; index: number }[] => {
if (!query || !text) {
return [];
}
const queryWords = query.trim().split(/\s+/);
const regex = new RegExp(
`(${queryWords.map((word) => escapeRegex(word)).join("|")})`,
"gi",
);
return Array.from(text.matchAll(regex)).map((result) => ({
match: result[0],
index: result.index!,
}));
};
export function findFirstHighlight(text: string, query: string): string | null {
const matches = highlightMatches(text, query);
if (matches.length === 0) {
return null;
}
return matches[0].match;
}
export function generateTextFragment(
text: string,
query: string,
): {
k: ":~:text";
v: string;
} | null {
const firstMatch = findFirstHighlight(text, query);
if (!firstMatch) return null;
return {
k: ":~:text",
v: firstMatch,
};
}

View File

@@ -136,3 +136,10 @@ export function extractDomain(url) {
return null;
}
}
export function assertExists<T>(value: T | null | undefined, err?: string): T {
if (value === null || value === undefined) {
throw new Error(`Assertion failed: ${err ?? "value is null or undefined"}`);
}
return value;
}

View File

@@ -1,6 +1,7 @@
"use client";
import { redirect } from "next/navigation";
import { RECORD_A_MEETING_URL } from "./api/urls";
export default function Index() {
redirect("/transcripts/new");
redirect(RECORD_A_MEETING_URL);
}

View File

@@ -5,14 +5,17 @@ import system from "./styles/theme";
import { WherebyProvider } from "@whereby.com/browser-sdk/react";
import { Toaster } from "./components/ui/toaster";
import { NuqsAdapter } from "nuqs/adapters/next/app";
export function Providers({ children }: { children: React.ReactNode }) {
return (
<ChakraProvider value={system}>
<WherebyProvider>
{children}
<Toaster />
</WherebyProvider>
</ChakraProvider>
<NuqsAdapter>
<ChakraProvider value={system}>
<WherebyProvider>
{children}
<Toaster />
</WherebyProvider>
</ChakraProvider>
</NuqsAdapter>
);
}

View File

@@ -31,6 +31,7 @@
"next": "^14.2.30",
"next-auth": "^4.24.7",
"next-themes": "^0.4.6",
"nuqs": "^2.4.3",
"postcss": "8.4.31",
"prop-types": "^15.8.1",
"react": "^18.2.0",

39
www/pnpm-lock.yaml generated
View File

@@ -67,6 +67,9 @@ importers:
next-themes:
specifier: ^0.4.6
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
nuqs:
specifier: ^2.4.3
version: 2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1)
postcss:
specifier: 8.4.31
version: 8.4.31
@@ -5436,6 +5439,12 @@ packages:
}
engines: { node: ">= 8" }
mitt@3.0.1:
resolution:
{
integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==,
}
mkdirp@0.5.6:
resolution:
{
@@ -5660,6 +5669,27 @@ packages:
}
deprecated: This package is no longer supported.
nuqs@2.4.3:
resolution:
{
integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==,
}
peerDependencies:
"@remix-run/react": ">=2"
next: ">=14.2.0"
react: ">=18.2.0 || ^19.0.0-0"
react-router: ^6 || ^7
react-router-dom: ^6 || ^7
peerDependenciesMeta:
"@remix-run/react":
optional: true
next:
optional: true
react-router:
optional: true
react-router-dom:
optional: true
nypm@0.5.4:
resolution:
{
@@ -11553,6 +11583,8 @@ snapshots:
minipass: 3.3.6
yallist: 4.0.0
mitt@3.0.1: {}
mkdirp@0.5.6:
dependencies:
minimist: 1.2.8
@@ -11674,6 +11706,13 @@ snapshots:
gauge: 3.0.2
set-blocking: 2.0.0
nuqs@2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1):
dependencies:
mitt: 3.0.1
react: 18.3.1
optionalDependencies:
next: 14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0)
nypm@0.5.4:
dependencies:
citty: 0.1.6