mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
9 Commits
mathieu/ca
...
v0.7.1
| Author | SHA1 | Date | |
|---|---|---|---|
| bc5b351d2b | |||
|
|
07981e8090 | ||
| 7e366f6338 | |||
| 7592679a35 | |||
| af16178f86 | |||
| 3ea7f6b7b6 | |||
|
|
009590c080 | ||
|
|
fe5d344cff | ||
|
|
86455ce573 |
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,18 +8,30 @@ env:
|
||||
ECR_REPOSITORY: reflector
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
runner: linux-amd64
|
||||
arch: amd64
|
||||
- platform: linux/arm64
|
||||
runner: linux-arm64
|
||||
arch: arm64
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
outputs:
|
||||
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
@@ -27,21 +39,52 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
- name: Build and push ${{ matrix.arch }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
provenance: false
|
||||
|
||||
create-manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build]
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Create and push multi-arch manifest
|
||||
run: |
|
||||
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||
|
||||
docker manifest create \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||
|
||||
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||
|
||||
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||
|
||||
38
.github/workflows/test_server.yml
vendored
38
.github/workflows/test_server.yml
vendored
@@ -19,29 +19,41 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: server
|
||||
|
||||
- name: Tests
|
||||
run: |
|
||||
cd server
|
||||
uv run -m pytest -v tests
|
||||
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
docker-amd64:
|
||||
runs-on: linux-amd64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build AMD64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64
|
||||
cache-from: type=gha,scope=amd64
|
||||
cache-to: type=gha,mode=max,scope=amd64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
docker-arm64:
|
||||
runs-on: linux-arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build ARM64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/arm64
|
||||
cache-from: type=gha,scope=arm64
|
||||
cache-to: type=gha,mode=max,scope=arm64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
24
CHANGELOG.md
24
CHANGELOG.md
@@ -1,5 +1,29 @@
|
||||
# Changelog
|
||||
|
||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
||||
|
||||
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
||||
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
||||
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
||||
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
||||
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
||||
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
||||
|
||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||
|
||||
|
||||
|
||||
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,7 +176,8 @@ artefacts/
|
||||
audio_*.wav
|
||||
|
||||
# ignore local database
|
||||
reflector.sqlite3
|
||||
*.sqlite3
|
||||
*.db
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
@@ -13,8 +14,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
||||
# install application dependencies
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN touch README.md && env uv sync --compile-bytecode --locked
|
||||
COPY pyproject.toml uv.lock README.md /app/
|
||||
RUN uv sync --compile-bytecode --locked
|
||||
|
||||
# pre-download nltk packages
|
||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
|
||||
@@ -40,3 +40,5 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||
```
|
||||
|
||||
.
|
||||
|
||||
@@ -4,7 +4,8 @@ This repository hold an API for the GPU implementation of the Reflector API serv
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
@@ -19,6 +20,10 @@ $ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_transcriber_parakeet.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
@@ -68,6 +73,86 @@ Authorization: bearer <REFLECTOR_APIKEY>
|
||||
|
||||
### Transcription
|
||||
|
||||
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||
|
||||
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||
|
||||
**GPU Configuration:**
|
||||
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||
- Higher concurrency (max_inputs=10)
|
||||
- Optimized for multiple small audio files
|
||||
- Supports batch processing for efficiency
|
||||
|
||||
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||
- Lower concurrency but more powerful processing
|
||||
- Optimized for single large audio files
|
||||
- VAD-based chunking for long-form audio
|
||||
|
||||
##### `/v1/audio/transcriptions` - Small file transcription
|
||||
|
||||
**request** (multipart/form-data)
|
||||
- `file` or `files[]` - audio file(s) to transcribe
|
||||
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||
- `language` - language code (default: `en`)
|
||||
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
For multiple files with batch=true:
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"filename": "audio1.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
},
|
||||
{
|
||||
"filename": "audio2.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||
|
||||
**request** (application/json)
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text from large file",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||
|
||||
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
@@ -4,14 +4,80 @@ Reflector GPU backend - diarizer
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal.gpu
|
||||
from modal import App, Image, Secret, asgi_app, enter, method
|
||||
from pydantic import BaseModel
|
||||
import modal
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
app = App(name="reflector-diarizer")
|
||||
UPLOADS_PATH = "/uploads"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App(name="reflector-diarizer")
|
||||
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
print(f"Checking audio file at: {audio_file_url}")
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
print(f"Downloading audio file from: {audio_file_url}")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to download audio file: {response.status_code}",
|
||||
)
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
print(f"File saved as: {unique_filename}")
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
@@ -39,7 +105,7 @@ def download_pyannote_audio():
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
modal.Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
@@ -55,7 +121,8 @@ diarizer_image = (
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
||||
download_pyannote_audio,
|
||||
secrets=[modal.Secret.from_name("hf_token")],
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
@@ -70,53 +137,60 @@ diarizer_image = (
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu=modal.gpu.A100(size="40GB"),
|
||||
gpu="A100",
|
||||
timeout=60 * 30,
|
||||
scaledown_window=60,
|
||||
allow_concurrent_inputs=1,
|
||||
image=diarizer_image,
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
secrets=[
|
||||
modal.Secret.from_name("hf_token"),
|
||||
],
|
||||
)
|
||||
@modal.concurrent(max_inputs=1)
|
||||
class Diarizer:
|
||||
@enter()
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@method()
|
||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
||||
import tempfile
|
||||
|
||||
@modal.method()
|
||||
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||
import torchaudio
|
||||
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
upload_volume.reload()
|
||||
|
||||
print("Diarizing audio")
|
||||
waveform, sample_rate = torchaudio.load(fp.name)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
print(f"Diarizing audio from: {file_path}")
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(
|
||||
yield_label=True
|
||||
):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
@@ -127,17 +201,18 @@ class Diarizer:
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
image=diarizer_image,
|
||||
)
|
||||
@asgi_app()
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
@@ -153,35 +228,26 @@ def web():
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def validate_audio_file(audio_file_url: str):
|
||||
# Check if the audio file exists
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="The audio file does not exist.",
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
||||
)
|
||||
def diarize(
|
||||
audio_file_url: str, timestamp: float = 0.0
|
||||
) -> HTTPException | DiarizationResponse:
|
||||
# Currently the uploaded files are in mp3 format
|
||||
audio_suffix = "mp3"
|
||||
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
print("Downloading audio file")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
print("Audio file downloaded successfully")
|
||||
|
||||
func = diarizerstub.diarize.spawn(
|
||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
try:
|
||||
func = diarizerstub.diarize.spawn(
|
||||
filename=unique_filename, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
|
||||
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
@@ -0,0 +1,622 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/cache"
|
||||
VAD_CONFIG = {
|
||||
"max_segment_duration": 30.0,
|
||||
"batch_max_files": 10,
|
||||
"batch_max_duration": 5.0,
|
||||
"min_segment_duration": 0.02,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App("reflector-transcriber-parakeet")
|
||||
|
||||
# Volume for caching model weights
|
||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry(
|
||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"HF_HOME": "/cache",
|
||||
"DEBIAN_FRONTEND": "noninteractive",
|
||||
"CXX": "g++",
|
||||
"CC": "g++",
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"hf_transfer==0.1.9",
|
||||
"huggingface_hub[hf-xet]==0.31.2",
|
||||
"nemo_toolkit[asr]==2.3.0",
|
||||
"cuda-python==12.8.0",
|
||||
"fastapi==0.115.12",
|
||||
"numpy<2",
|
||||
"librosa==0.10.1",
|
||||
"requests",
|
||||
"silero-vad==5.1.0",
|
||||
"torch",
|
||||
)
|
||||
.entrypoint([]) # silence chatty logs by container on start
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||
|
||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||
have to be of the same length
|
||||
|
||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < 0.5:
|
||||
silence_samples = int(sample_rate * 0.5)
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=600,
|
||||
scaledown_window=300,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberParakeetLive:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
audio_arrays = []
|
||||
|
||||
# Load all audio files with padding
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
audio_arrays.append(padded_audio)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||
|
||||
# Process results for each file
|
||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||
text = output.text.strip()
|
||||
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# L40S class for file transcription (bigger files)
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=900,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
class TranscriberParakeetFile:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
print("Silero VAD initialized")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
timestamp_offset: float = 0.0,
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def load_and_convert_audio(file_path):
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
return audio_array
|
||||
|
||||
def vad_segment_generator(audio_array):
|
||||
"""Generate speech segments using VAD with start/end sample indices"""
|
||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||
window_size = VAD_CONFIG["window_size"]
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
|
||||
speech_dict = vad_iterator(chunk)
|
||||
if not speech_dict:
|
||||
continue
|
||||
|
||||
if "start" in speech_dict:
|
||||
start = speech_dict["start"]
|
||||
continue
|
||||
|
||||
if "end" in speech_dict and start is not None:
|
||||
end = speech_dict["end"]
|
||||
start_time = start / float(SAMPLERATE)
|
||||
end_time = end / float(SAMPLERATE)
|
||||
|
||||
# Extract the actual audio segment
|
||||
audio_segment = audio_array[start:end]
|
||||
|
||||
yield (start_time, end_time, audio_segment)
|
||||
start = None
|
||||
|
||||
vad_iterator.reset_states()
|
||||
|
||||
def vad_segment_filter(segments):
|
||||
"""Filter VAD segments by duration and chunk large segments"""
|
||||
min_dur = VAD_CONFIG["min_segment_duration"]
|
||||
max_dur = VAD_CONFIG["max_segment_duration"]
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
|
||||
# Skip very small segments
|
||||
if segment_duration < min_dur:
|
||||
continue
|
||||
|
||||
# If segment is within max duration, yield as-is
|
||||
if segment_duration <= max_dur:
|
||||
yield (start_time, end_time, audio_segment)
|
||||
continue
|
||||
|
||||
# Chunk large segments into smaller pieces
|
||||
chunk_samples = int(max_dur * SAMPLERATE)
|
||||
current_start = start_time
|
||||
|
||||
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
||||
chunk_audio = audio_segment[
|
||||
chunk_offset : chunk_offset + chunk_samples
|
||||
]
|
||||
if len(chunk_audio) == 0:
|
||||
break
|
||||
|
||||
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
||||
chunk_end = current_start + chunk_duration
|
||||
|
||||
# Only yield chunks that meet minimum duration
|
||||
if chunk_duration >= min_dur:
|
||||
yield (current_start, chunk_end, chunk_audio)
|
||||
|
||||
current_start = chunk_end
|
||||
|
||||
def batch_segments(segments, max_files=10, max_duration=5.0):
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
|
||||
if segment_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(
|
||||
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
||||
)
|
||||
padding = np.zeros(silence_samples, dtype=np.float32)
|
||||
audio_segment = np.concatenate([audio_segment, padding])
|
||||
segment_duration = VAD_CONFIG["silence_padding"]
|
||||
|
||||
batch.append((start_time, end_time, audio_segment))
|
||||
batch_duration += segment_duration
|
||||
|
||||
if len(batch) >= max_files or batch_duration >= max_duration:
|
||||
yield batch
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def transcribe_batch(model, audio_segments):
|
||||
with NoStdStreams():
|
||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||
return outputs
|
||||
|
||||
def emit_results(
|
||||
results,
|
||||
segments_info,
|
||||
batch_index,
|
||||
total_batches,
|
||||
):
|
||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||
for i, (output, (start_time, end_time, _)) in enumerate(
|
||||
zip(results, segments_info)
|
||||
):
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(
|
||||
word_info["start"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
"end": round(
|
||||
word_info["end"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
yield text, words
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array = load_and_convert_audio(file_path)
|
||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||
processed_duration = 0.0
|
||||
|
||||
all_text_parts = []
|
||||
all_words = []
|
||||
|
||||
raw_segments = vad_segment_generator(audio_array)
|
||||
filtered_segments = vad_segment_filter(raw_segments)
|
||||
batches = batch_segments(
|
||||
filtered_segments,
|
||||
VAD_CONFIG["batch_max_files"],
|
||||
VAD_CONFIG["batch_max_duration"],
|
||||
)
|
||||
|
||||
batch_index = 0
|
||||
total_batches = max(
|
||||
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
batch_index += 1
|
||||
audio_segments = [seg[2] for seg in batch]
|
||||
results = transcribe_batch(self.model, audio_segments)
|
||||
|
||||
for text, words in emit_results(
|
||||
results,
|
||||
batch,
|
||||
batch_index,
|
||||
total_batches,
|
||||
):
|
||||
if not text:
|
||||
continue
|
||||
all_text_parts.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
||||
|
||||
combined_text = " ".join(all_text_parts)
|
||||
return {"text": combined_text, "words": all_words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
transcriber_live = TranscriberParakeetLive()
|
||||
transcriber_file = TranscriberParakeetFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
# Handle both single file and multiple files
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
# Upload files to volume
|
||||
uploaded_filenames = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||
|
||||
# Generate unique filename
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
# Use A10G live transcriber for per-file transcription
|
||||
if batch and len(upload_files) > 1:
|
||||
# Use batch transcription
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
# Per-file transcription
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {filename}: {e}")
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
@@ -0,0 +1,64 @@
|
||||
"""add_long_summary_to_search_vector
|
||||
|
||||
Revision ID: 0ab2d7ffaa16
|
||||
Revises: b1c33bd09963
|
||||
Create Date: 2025-08-15 13:27:52.680211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ab2d7ffaa16"
|
||||
down_revision: Union[str, None] = "b1c33bd09963"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the search vector column with long_summary included
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the updated search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the original search vector column without long_summary
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_search_optimization_indexes
|
||||
|
||||
Revision ID: b1c33bd09963
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-14 17:26:02.117408
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b1c33bd09963"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add indexes for actual search filtering patterns used in frontend
|
||||
# Based on /browse page filters: room_id and source_kind
|
||||
|
||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||
op.create_index(
|
||||
"idx_transcript_room_id_created_at",
|
||||
"transcript",
|
||||
["room_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
# Index for source_kind alone (actively used filter in frontend)
|
||||
op.create_index(
|
||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||
@@ -32,7 +32,6 @@ dependencies = [
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"faster-whisper>=0.10.0",
|
||||
"transformers>=4.36.2",
|
||||
"jsonschema>=4.23.0",
|
||||
"openai>=1.59.7",
|
||||
@@ -57,6 +56,7 @@ tests = [
|
||||
"httpx-ws>=0.4.1",
|
||||
"pytest-httpx>=0.23.1",
|
||||
"pytest-celery>=0.0.0",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-docker>=3.2.3",
|
||||
"asgi-lifespan>=2.1.0",
|
||||
]
|
||||
@@ -67,6 +67,15 @@ evaluation = [
|
||||
"tqdm>=4.66.0",
|
||||
"pydantic>=2.1.1",
|
||||
]
|
||||
local = [
|
||||
"pyannote-audio>=3.3.2",
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = [
|
||||
@@ -74,6 +83,21 @@ default-groups = [
|
||||
"tests",
|
||||
"aws",
|
||||
"evaluation",
|
||||
"local",
|
||||
"silero-vad"
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -94,6 +118,9 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
|
||||
@@ -1,24 +1,37 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict
|
||||
from typing import Annotated, Any, Dict, Iterator
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
from fastapi import HTTPException
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
@@ -32,6 +45,82 @@ SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
WEBVTT_SPEC_HEADER = "WEBVTT"
|
||||
|
||||
WebVTTContent = Annotated[
|
||||
str,
|
||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||
]
|
||||
|
||||
|
||||
class WebVTTProcessor:
|
||||
"""Stateless processor for WebVTT content operations."""
|
||||
|
||||
@staticmethod
|
||||
def parse(raw_content: str) -> WebVTTContent:
|
||||
"""Parse WebVTT content and return it as a string."""
|
||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||
return raw_content
|
||||
|
||||
@staticmethod
|
||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except webvtt.errors.MalformedFileError as e:
|
||||
logger.warning(f"Malformed WebVTT content: {e}")
|
||||
return ""
|
||||
except (UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
return SnippetGenerator.generate(
|
||||
WebVTTProcessor.extract_text(webvtt_content),
|
||||
query,
|
||||
max_snippets=max_snippets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnippetCandidate:
|
||||
"""Represents a candidate snippet with its position."""
|
||||
|
||||
_text: str
|
||||
start: NonNegativeInt
|
||||
_original_text_length: int
|
||||
|
||||
@property
|
||||
def end(self) -> NonNegativeInt:
|
||||
"""Calculate end position from start and raw text length."""
|
||||
return self.start + len(self._text)
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get display text with ellipses added if needed."""
|
||||
result = self._text.strip()
|
||||
if self.start > 0:
|
||||
result = "..." + result
|
||||
if self.end < self._original_text_length:
|
||||
result = result + "..."
|
||||
return result
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
@@ -41,6 +130,7 @@ class SearchParameters(BaseModel):
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
@@ -64,13 +154,18 @@ class SearchResult(BaseModel):
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -79,84 +174,153 @@ class SearchResult(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
class SnippetGenerator:
|
||||
"""Stateless generator for text snippets and match operations."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||
"""Generate all match positions for a query in text."""
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in find_all_matches")
|
||||
return
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in find_all_matches")
|
||||
return
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
text_lower = text.lower()
|
||||
query_lower = query.lower()
|
||||
start = 0
|
||||
prev_start = start
|
||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||
yield pos
|
||||
start = pos + len(query_lower)
|
||||
if start <= prev_start:
|
||||
raise ValueError("panic! find_all_matches is not incremental")
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
def count_matches(text: str, query: str) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in count_matches")
|
||||
return ZERO
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||
) -> SnippetCandidate:
|
||||
"""Create a snippet from a match position."""
|
||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||
|
||||
snippet_text = text[snippet_start:snippet_end]
|
||||
|
||||
return SnippetCandidate(
|
||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_non_overlapping(
|
||||
candidates: Iterator[SnippetCandidate],
|
||||
) -> Iterator[str]:
|
||||
"""Filter out overlapping snippets and return only display text."""
|
||||
last_end = 0
|
||||
for candidate in candidates:
|
||||
display_text = candidate.text()
|
||||
# it means that next overlapping snippets simply don't get included
|
||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||
if candidate.start >= last_end and display_text:
|
||||
yield display_text
|
||||
last_end = candidate.end
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
query: str,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
"""Generate snippets from text."""
|
||||
if not text or not query:
|
||||
logger.warning("Empty text or query for generate_snippets")
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
candidates = (
|
||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||
)
|
||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||
snippets = list(itertools.islice(filtered, max_snippets))
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
# Fallback to first word search if no full matches
|
||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||
if not snippets and " " in query:
|
||||
first_word = query.split()[0]
|
||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: str | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: str,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
|
||||
Returns (snippets, total_match_count) tuple.
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
if webvtt:
|
||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||
|
||||
if summary:
|
||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||
|
||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||
|
||||
summary_snippets = (
|
||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||
)
|
||||
|
||||
if len(summary_snippets) >= max_total:
|
||||
return summary_snippets[:max_total], total_matches
|
||||
|
||||
remaining = max_total - len(summary_snippets)
|
||||
webvtt_snippets = (
|
||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||
if webvtt
|
||||
else []
|
||||
)
|
||||
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
@@ -172,39 +336,64 @@ class SearchController:
|
||||
)
|
||||
return [], 0
|
||||
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
base_columns = [
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
transcripts.c.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=rooms.c.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
|
||||
if params.query_text:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
else:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(columns).select_from(
|
||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||
)
|
||||
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
if params.query_text:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await get_database().fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
@@ -214,18 +403,40 @@ class SearchController:
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
long_summary: str | None = r_dict.pop("long_summary", None)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
snippets, total_match_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
room_name=room_name,
|
||||
search_snippets=snippets,
|
||||
total_match_count=total_match_count,
|
||||
)
|
||||
|
||||
try:
|
||||
results = [_process_result(r) for r in rs]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal search result data consistency error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
webvtt_processor = WebVTTProcessor()
|
||||
snippet_generator = SnippetGenerator()
|
||||
|
||||
@@ -88,6 +88,8 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
@@ -99,7 +101,8 @@ if is_postgresql():
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
|
||||
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
File-based processing pipeline
|
||||
==============================
|
||||
|
||||
Optimized pipeline for processing complete audio/video files.
|
||||
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
|
||||
from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
|
||||
|
||||
class EmptyPipeline:
|
||||
"""Empty pipeline for processors that need a pipeline reference"""
|
||||
|
||||
def __init__(self, logger: structlog.BoundLogger):
|
||||
self.logger = logger
|
||||
|
||||
def get_pref(self, k, d=None):
|
||||
return d
|
||||
|
||||
async def emit(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineMainFile(PipelineMainBase):
|
||||
"""
|
||||
Optimized file processing pipeline.
|
||||
Processes complete audio/video files with parallel execution.
|
||||
"""
|
||||
|
||||
logger: structlog.BoundLogger = None
|
||||
empty_pipeline = None
|
||||
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__(transcript_id=transcript_id)
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
self.logger.error(
|
||||
f"Error in {operation} (task {i}): {result}",
|
||||
transcript_id=self.transcript_id,
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
|
||||
# Upload for processing
|
||||
audio_url = await self.upload_audio(audio_path, transcript)
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
transcript.target_language,
|
||||
)
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
) -> Path:
|
||||
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||
self.logger.info(f"Processing audio file: {file_path}")
|
||||
|
||||
# Check if it's already audio-only
|
||||
container = av.open(str(file_path))
|
||||
has_video = len(container.streams.video) > 0
|
||||
container.close()
|
||||
|
||||
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||
mp3_writer = AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
)
|
||||
|
||||
# Process audio frames and write to transcript location
|
||||
input_container = av.open(str(file_path))
|
||||
for frame in input_container.decode(audio=0):
|
||||
await mp3_writer.push(frame)
|
||||
|
||||
await mp3_writer.flush()
|
||||
input_container.close()
|
||||
|
||||
if has_video:
|
||||
self.logger.info(
|
||||
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
|
||||
return transcript.audio_mp3_filename
|
||||
|
||||
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||
"""Upload audio to storage for processing"""
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
if not storage:
|
||||
raise Exception(
|
||||
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||
)
|
||||
|
||||
self.logger.info("Uploading audio to storage")
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||
await storage.put_file(storage_path, audio_data)
|
||||
|
||||
audio_url = await storage.get_file_url(storage_path)
|
||||
|
||||
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||
return audio_url
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
):
|
||||
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||
self.logger.info(
|
||||
"Starting parallel processing", transcript_id=self.transcript_id
|
||||
)
|
||||
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
)
|
||||
|
||||
transcript_result = results[0]
|
||||
diarization_result = results[1]
|
||||
|
||||
# Handle errors - raise any exception that occurred
|
||||
self._handle_gather_exceptions(results, "parallel processing")
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
# Phase 2: Assemble transcript with diarization
|
||||
self.logger.info(
|
||||
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||
)
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
input_data = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_result, diarization=diarization_result or []
|
||||
)
|
||||
|
||||
# Store result for retrieval
|
||||
diarized_transcript: Transcript | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript captured")
|
||||
|
||||
# Phase 3: Generate topics from diarized transcript
|
||||
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||
|
||||
# Phase 4: Generate title and summaries in parallel
|
||||
self.logger.info(
|
||||
"Generating title and summaries", transcript_id=self.transcript_id
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
self._handle_gather_exceptions(results, "title and summary generation")
|
||||
|
||||
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||
"""Transcribe complete file"""
|
||||
processor = FileTranscriptAutoProcessor()
|
||||
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||
|
||||
# Store result for retrieval
|
||||
result: TranscriptType | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal result
|
||||
result = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not result:
|
||||
raise ValueError("No transcript captured")
|
||||
|
||||
return result
|
||||
|
||||
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||
"""Get diarization for file"""
|
||||
if not settings.DIARIZATION_BACKEND:
|
||||
self.logger.info("Diarization disabled")
|
||||
return None
|
||||
|
||||
processor = FileDiarizationAutoProcessor()
|
||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||
|
||||
# Store result for retrieval
|
||||
result = None
|
||||
|
||||
async def capture_result(diarization_output):
|
||||
nonlocal result
|
||||
result = diarization_output.diarization
|
||||
|
||||
try:
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def detect_topics(
|
||||
self, transcript: TranscriptType, target_language: str
|
||||
) -> list[TitleSummary]:
|
||||
"""Detect topics from complete transcript"""
|
||||
chunk_size = 300
|
||||
topics: list[TitleSummary] = []
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
topics.append(topic)
|
||||
return await self.on_topic(topic)
|
||||
|
||||
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||
topic_detector.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for i in range(0, len(transcript.words), chunk_size):
|
||||
chunk_words = transcript.words[i : i + chunk_size]
|
||||
if not chunk_words:
|
||||
continue
|
||||
|
||||
chunk_transcript = TranscriptType(
|
||||
words=chunk_words, translation=transcript.translation
|
||||
)
|
||||
|
||||
await topic_detector.push(chunk_transcript)
|
||||
|
||||
await topic_detector.flush()
|
||||
return topics
|
||||
|
||||
async def generate_title(self, topics: list[TitleSummary]):
|
||||
"""Generate title from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for title generation")
|
||||
return
|
||||
|
||||
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
on_short_summary=self.on_short_summary,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Find the file to process
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
|
||||
if not audio_file:
|
||||
raise Exception("No audio file found to process")
|
||||
|
||||
# Run file pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process(audio_file)
|
||||
@@ -147,15 +147,18 @@ class StrValue(BaseModel):
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__()
|
||||
self._lock = asyncio.Lock()
|
||||
self.transcript_id = transcript_id
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
self._ws_manager = None
|
||||
|
||||
@property
|
||||
def ws_manager(self) -> WebsocketManager:
|
||||
if self._ws_manager is None:
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
@@ -355,7 +358,6 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
@@ -376,6 +378,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
pipeline.describe()
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -394,7 +397,6 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
@@ -435,8 +437,6 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
|
||||
|
||||
@@ -18,22 +18,14 @@ During its lifecycle, it will emit the following status:
|
||||
import asyncio
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
PipelineMessage = TypeVar("PipelineMessage")
|
||||
|
||||
|
||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
class PipelineRunner(Generic[PipelineMessage]):
|
||||
def __init__(self):
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
@@ -42,6 +34,8 @@ class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
self.status = "idle"
|
||||
self.pipeline: Pipeline | None = None
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,13 @@ from .base import ( # noqa: F401
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||
from .transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||
)
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
|
||||
@@ -1,28 +1,340 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioChunkerProcessor(Processor):
|
||||
"""
|
||||
Assemble audio frames into chunks
|
||||
Assemble audio frames into chunks with VAD-based speech detection
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = list[av.AudioFrame]
|
||||
|
||||
def __init__(self, max_frames=256):
|
||||
def __init__(
|
||||
self,
|
||||
block_frames=256,
|
||||
max_frames=1024,
|
||||
vad_threshold=0.5,
|
||||
use_onnx=False,
|
||||
min_frames=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.frames: list[av.AudioFrame] = []
|
||||
self.block_frames = block_frames
|
||||
self.max_frames = max_frames
|
||||
self.vad_threshold = vad_threshold
|
||||
self.min_frames = min_frames
|
||||
|
||||
# Initialize Silero VAD
|
||||
self._init_vad(use_onnx)
|
||||
|
||||
def _init_vad(self, use_onnx=False):
|
||||
"""Initialize Silero VAD model"""
|
||||
try:
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||
self.logger.info("Silero VAD initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||
self.vad_model = None
|
||||
self.vad_iterator = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
self.frames.append(data)
|
||||
if len(self.frames) >= self.max_frames:
|
||||
await self.flush()
|
||||
# print("timestamp", data.pts * data.time_base * 1000)
|
||||
|
||||
# Check for speech segments every 32 frames (~1 second)
|
||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||
await self._process_block()
|
||||
|
||||
# Safety fallback - emit if we hit max frames
|
||||
elif len(self.frames) >= self.max_frames:
|
||||
self.logger.warning(
|
||||
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
|
||||
f"emitting first {self.max_frames // 2} frames"
|
||||
)
|
||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||
self.frames = self.frames[self.max_frames // 2 :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
async def _process_block(self):
|
||||
# Need at least 32 frames for VAD detection (~1 second)
|
||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||
return
|
||||
|
||||
# Processing block with current buffer size
|
||||
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||
|
||||
try:
|
||||
# Convert frames to numpy array for VAD
|
||||
audio_array = self._frames_to_numpy(self.frames)
|
||||
|
||||
if audio_array is None:
|
||||
# Fallback: emit all frames if conversion failed
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
return
|
||||
|
||||
# Find complete speech segments in the buffer
|
||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||
|
||||
if speech_end_frame is None or speech_end_frame <= 0:
|
||||
# No speech found but buffer is getting large
|
||||
if len(self.frames) > 512:
|
||||
# Check if it's all silence and can be discarded
|
||||
# No speech segment found, buffer at {len(self.frames)} frames
|
||||
|
||||
# Could emit silence or discard old frames here
|
||||
# For now, keep first 256 frames and discard older silence
|
||||
if len(self.frames) > 768:
|
||||
self.logger.debug(
|
||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||
)
|
||||
self.frames = self.frames[-256:]
|
||||
return
|
||||
|
||||
# Calculate segment timing information
|
||||
frames_to_emit = self.frames[:speech_end_frame]
|
||||
|
||||
# Get timing from av.AudioFrame
|
||||
if frames_to_emit:
|
||||
first_frame = frames_to_emit[0]
|
||||
last_frame = frames_to_emit[-1]
|
||||
sample_rate = first_frame.sample_rate
|
||||
|
||||
# Calculate duration
|
||||
total_samples = sum(f.samples for f in frames_to_emit)
|
||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||
|
||||
# Get timestamps if available
|
||||
start_time = (
|
||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||
)
|
||||
end_time = (
|
||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||
)
|
||||
|
||||
# Convert to HH:MM:SS format for logging
|
||||
def format_time(seconds):
|
||||
if not seconds:
|
||||
return "00:00:00"
|
||||
total_seconds = int(float(seconds))
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds % 3600) // 60
|
||||
secs = total_seconds % 60
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
remaining_after = len(self.frames) - speech_end_frame
|
||||
|
||||
# Single structured log line
|
||||
self.logger.info(
|
||||
"Speech segment found",
|
||||
start=start_formatted,
|
||||
end=end_formatted,
|
||||
frames=speech_end_frame,
|
||||
duration=round(duration_seconds, 2),
|
||||
buffer_before=len(self.frames),
|
||||
remaining=remaining_after,
|
||||
)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
self.frames = self.frames[speech_end_frame:]
|
||||
|
||||
# Filter out segments with too few frames
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in VAD processing: {e}")
|
||||
# Fallback to simple chunking
|
||||
if len(self.frames) >= self.block_frames:
|
||||
frames_to_emit = self.frames[: self.block_frames]
|
||||
self.frames = self.frames[self.block_frames :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||
if not frames:
|
||||
return None
|
||||
|
||||
try:
|
||||
first_frame = frames[0]
|
||||
original_sample_rate = first_frame.sample_rate
|
||||
|
||||
audio_data = []
|
||||
for frame in frames:
|
||||
frame_array = frame.to_ndarray()
|
||||
|
||||
# Handle stereo -> mono conversion
|
||||
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
|
||||
frame_array = np.mean(frame_array, axis=0)
|
||||
elif len(frame_array.shape) == 2:
|
||||
frame_array = frame_array.flatten()
|
||||
|
||||
audio_data.append(frame_array)
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
combined_audio = np.concatenate(audio_data)
|
||||
|
||||
# Resample from 48kHz to 16kHz if needed
|
||||
if original_sample_rate != 16000:
|
||||
combined_audio = self._resample_audio(
|
||||
combined_audio, original_sample_rate, 16000
|
||||
)
|
||||
|
||||
# Ensure float32 format
|
||||
if combined_audio.dtype == np.int16:
|
||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||
elif combined_audio.dtype != np.float32:
|
||||
combined_audio = combined_audio.astype(np.float32)
|
||||
|
||||
return combined_audio
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _resample_audio(
|
||||
self, audio: np.ndarray, from_sr: int, to_sr: int
|
||||
) -> np.ndarray:
|
||||
"""Simple linear resampling from from_sr to to_sr"""
|
||||
if from_sr == to_sr:
|
||||
return audio
|
||||
|
||||
try:
|
||||
# Simple linear interpolation resampling
|
||||
ratio = to_sr / from_sr
|
||||
new_length = int(len(audio) * ratio)
|
||||
|
||||
# Create indices for interpolation
|
||||
old_indices = np.linspace(0, len(audio) - 1, new_length)
|
||||
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
|
||||
|
||||
return resampled.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Resampling error", exc_info=e)
|
||||
# Fallback: simple decimation/repetition
|
||||
if from_sr > to_sr:
|
||||
# Downsample by taking every nth sample
|
||||
step = from_sr // to_sr
|
||||
return audio[::step]
|
||||
else:
|
||||
# Upsample by repeating samples
|
||||
repeat = to_sr // from_sr
|
||||
return np.repeat(audio, repeat)
|
||||
|
||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||
"""Find complete speech segments and return frame index at segment end"""
|
||||
if self.vad_iterator is None or len(audio_array) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process audio in 512-sample windows for VAD
|
||||
window_size = 512
|
||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||
|
||||
# Track speech state
|
||||
in_speech = False
|
||||
speech_start = None
|
||||
speech_end = None
|
||||
silence_count = 0
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||
|
||||
# Detect if this window has speech
|
||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||
|
||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||
if speech_dict:
|
||||
if not in_speech:
|
||||
# Speech started
|
||||
speech_start = i
|
||||
in_speech = True
|
||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||
silence_count = 0 # Reset silence counter
|
||||
continue
|
||||
|
||||
if not in_speech:
|
||||
continue
|
||||
|
||||
# We're in speech but found silence
|
||||
silence_count += 1
|
||||
if silence_count < min_silence_windows:
|
||||
continue
|
||||
|
||||
# Found end of speech segment
|
||||
speech_end = i - (min_silence_windows - 1) * window_size
|
||||
# Debug: print(f"Speech END at sample {speech_end}")
|
||||
|
||||
# Convert sample position to frame index
|
||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||
# Account for resampling: we process at 16kHz but frames might be 48kHz
|
||||
resample_ratio = 48000 / 16000 # 3x
|
||||
actual_sample_pos = int(speech_end * resample_ratio)
|
||||
frame_index = actual_sample_pos // samples_per_frame
|
||||
|
||||
# Ensure we don't exceed buffer
|
||||
frame_index = min(frame_index, len(self.frames))
|
||||
return frame_index
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding speech segment: {e}")
|
||||
return None
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
await self.emit(frames)
|
||||
if len(frames) >= self.min_frames:
|
||||
await self.emit(frames)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring flush segment with {len(frames)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
@@ -37,18 +38,21 @@ class AudioDiarizationProcessor(Processor):
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
|
||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
self._diarization_remove_overlap(diarization)
|
||||
self._diarization_remove_segment_without_words(words, diarization)
|
||||
self._diarization_merge_same_speaker(words, diarization)
|
||||
self._diarization_assign_speaker(words, diarization)
|
||||
@classmethod
|
||||
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||
cls._diarization_remove_overlap(diarization)
|
||||
cls._diarization_remove_segment_without_words(words, diarization)
|
||||
cls._diarization_merge_same_speaker(diarization)
|
||||
cls._diarization_assign_speaker(words, diarization)
|
||||
|
||||
def iter_words_from_topics(self, topics: TitleSummary):
|
||||
@staticmethod
|
||||
def iter_words_from_topics(topics: list[TitleSummary]):
|
||||
for topic in topics:
|
||||
for word in topic.transcript.words:
|
||||
yield word
|
||||
|
||||
def is_word_continuation(self, word_prev, word):
|
||||
@staticmethod
|
||||
def is_word_continuation(word_prev, word):
|
||||
"""
|
||||
Return True if the word is a continuation of the previous word
|
||||
by checking if the previous word is ending with a punctuation
|
||||
@@ -61,7 +65,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||
@staticmethod
|
||||
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||
"""
|
||||
Remove overlap in diarization results
|
||||
|
||||
@@ -86,8 +91,9 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_segment_without_words(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
"""
|
||||
Remove diarization segments without words
|
||||
@@ -116,9 +122,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_merge_same_speaker(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
@staticmethod
|
||||
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||
"""
|
||||
Merge diarization contigous segments with the same speaker
|
||||
|
||||
@@ -135,7 +140,10 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
@classmethod
|
||||
def _diarization_assign_speaker(
|
||||
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
"""
|
||||
Assign speaker to words based on diarization
|
||||
|
||||
@@ -143,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
"""
|
||||
|
||||
word_idx = 0
|
||||
last_speaker = None
|
||||
last_speaker = 0
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
@@ -158,7 +166,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
# If it's a continuation, assign with the last speaker
|
||||
is_continuation = False
|
||||
if word_idx > 0 and word_idx < len(words) - 1:
|
||||
is_continuation = self.is_word_continuation(
|
||||
is_continuation = cls.is_word_continuation(
|
||||
*words[word_idx - 1 : word_idx + 1]
|
||||
)
|
||||
if is_continuation:
|
||||
|
||||
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
"""Local diarization processor using pyannote.audio library"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||
pyannote_auth_token: str | None = None,
|
||||
device: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.model_name = model_name
|
||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||
self.device = device
|
||||
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
self.model_name, use_auth_token=self.auth_token
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||
try:
|
||||
# Load audio file (audio_url is assumed to be a local file path)
|
||||
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||
self.logger.info("Running speaker diarization")
|
||||
diarization = self.diarization_pipeline(audio_input)
|
||||
|
||||
# Convert pyannote diarization output to our format
|
||||
segments = []
|
||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||
speaker_id = 0
|
||||
if speaker.startswith("SPEAKER_"):
|
||||
try:
|
||||
speaker_id = int(speaker.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback to hash-based ID if parsing fails
|
||||
speaker_id = hash(speaker) % 1000
|
||||
|
||||
segments.append(
|
||||
{
|
||||
"start": round(segment.start, 3),
|
||||
"end": round(segment.end, 3),
|
||||
"speaker": speaker_id,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||
return segments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Diarization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
@@ -3,11 +3,24 @@ from time import monotonic_ns
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
from av.audio.resampler import AudioResampler
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
|
||||
|
||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||
frame_copy = frame.from_ndarray(
|
||||
frame.to_ndarray(),
|
||||
format=frame.format.name,
|
||||
layout=frame.layout.name,
|
||||
)
|
||||
frame_copy.sample_rate = frame.sample_rate
|
||||
frame_copy.pts = frame.pts
|
||||
frame_copy.time_base = frame.time_base
|
||||
return frame_copy
|
||||
|
||||
|
||||
class AudioMergeProcessor(Processor):
|
||||
"""
|
||||
Merge audio frame into a single file
|
||||
@@ -16,37 +29,92 @@ class AudioMergeProcessor(Processor):
|
||||
INPUT_TYPE = list[av.AudioFrame]
|
||||
OUTPUT_TYPE = AudioFile
|
||||
|
||||
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.downsample_to_16k_mono = downsample_to_16k_mono
|
||||
|
||||
async def _push(self, data: list[av.AudioFrame]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# get audio information from first frame
|
||||
frame = data[0]
|
||||
channels = len(frame.layout.channels)
|
||||
sample_rate = frame.sample_rate
|
||||
sample_width = frame.format.bytes
|
||||
original_channels = len(frame.layout.channels)
|
||||
original_sample_rate = frame.sample_rate
|
||||
original_sample_width = frame.format.bytes
|
||||
|
||||
# determine if we need processing
|
||||
needs_processing = self.downsample_to_16k_mono and (
|
||||
original_sample_rate != 16000 or original_channels != 1
|
||||
)
|
||||
|
||||
# determine output parameters
|
||||
if self.downsample_to_16k_mono:
|
||||
output_sample_rate = 16000
|
||||
output_channels = 1
|
||||
output_sample_width = 2 # 16-bit = 2 bytes
|
||||
else:
|
||||
output_sample_rate = original_sample_rate
|
||||
output_channels = original_channels
|
||||
output_sample_width = original_sample_width
|
||||
|
||||
# create audio file
|
||||
uu = uuid4().hex
|
||||
fd = io.BytesIO()
|
||||
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
if needs_processing:
|
||||
# Process with PyAV resampler
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
|
||||
out_stream.layout = "mono"
|
||||
|
||||
# Create resampler if needed
|
||||
resampler = None
|
||||
if original_sample_rate != 16000 or original_channels != 1:
|
||||
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
||||
|
||||
for frame in data:
|
||||
if resampler:
|
||||
# Resample and convert to mono
|
||||
# XXX for an unknown reason, if we don't use a copy of the frame, we get
|
||||
# Invalid Argumment from resample. Debugging indicate that when a previous processor
|
||||
# already used the frame (like AudioFileWriter), it make it invalid argument here.
|
||||
resampled_frames = resampler.resample(copy_frame(frame))
|
||||
for resampled_frame in resampled_frames:
|
||||
for packet in out_stream.encode(resampled_frame):
|
||||
out_container.mux(packet)
|
||||
else:
|
||||
# Direct encoding without resampling
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
out_container.close()
|
||||
else:
|
||||
# Use PyAV for original frames (no processing needed)
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||
out_stream.layout = "mono" if output_channels == 1 else frame.layout
|
||||
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
|
||||
fd.seek(0)
|
||||
|
||||
# emit audio file
|
||||
audiofile = AudioFile(
|
||||
name=f"{monotonic_ns()}-{uu}.wav",
|
||||
fd=fd,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
sample_width=sample_width,
|
||||
sample_rate=output_sample_rate,
|
||||
channels=output_channels,
|
||||
sample_width=output_sample_width,
|
||||
timestamp=data[0].pts * data[0].time_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
@@ -21,7 +24,9 @@ from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
def __init__(
|
||||
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
@@ -30,6 +35,126 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
self.max_batch_duration = 10.0
|
||||
self.max_batch_files = 15
|
||||
self.batch_enabled = batch_enabled
|
||||
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
|
||||
|
||||
@classmethod
|
||||
def _calculate_duration(cls, audio_file: AudioFile) -> float:
|
||||
"""Calculate audio duration in seconds from AudioFile metadata"""
|
||||
# Duration = total_samples / sample_rate
|
||||
# We need to estimate total samples from the file data
|
||||
import wave
|
||||
|
||||
try:
|
||||
# Try to read as WAV file to get duration
|
||||
audio_file.fd.seek(0)
|
||||
with wave.open(audio_file.fd, "rb") as wav_file:
|
||||
frames = wav_file.getnframes()
|
||||
sample_rate = wav_file.getframerate()
|
||||
duration = frames / sample_rate
|
||||
return duration
|
||||
except Exception:
|
||||
# Fallback: estimate from file size and audio parameters
|
||||
audio_file.fd.seek(0, 2) # Seek to end
|
||||
file_size = audio_file.fd.tell()
|
||||
audio_file.fd.seek(0) # Reset to beginning
|
||||
|
||||
# Estimate: file_size / (sample_rate * channels * sample_width)
|
||||
bytes_per_second = (
|
||||
audio_file.sample_rate
|
||||
* audio_file.channels
|
||||
* (audio_file.sample_width // 8)
|
||||
)
|
||||
estimated_duration = (
|
||||
file_size / bytes_per_second if bytes_per_second > 0 else 0
|
||||
)
|
||||
return max(0, estimated_duration)
|
||||
|
||||
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
|
||||
"""Group audio files into batches with maximum 30s total duration"""
|
||||
batches = []
|
||||
current_batch = []
|
||||
current_duration = 0.0
|
||||
|
||||
for audio_file in audio_files:
|
||||
duration = self._calculate_duration(audio_file)
|
||||
|
||||
# If adding this file exceeds max duration, start a new batch
|
||||
if current_duration + duration > self.max_batch_duration and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = [audio_file]
|
||||
current_duration = duration
|
||||
else:
|
||||
current_batch.append(audio_file)
|
||||
current_duration += duration
|
||||
|
||||
# Add the last batch if not empty
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
|
||||
"""Transcribe a batch of audio files using the parakeet backend"""
|
||||
if not audio_files:
|
||||
return []
|
||||
|
||||
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
|
||||
|
||||
# Prepare form data for batch request
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("language", self.get_pref("audio:source_language", "en"))
|
||||
data.add_field("batch", "true")
|
||||
|
||||
for i, audio_file in enumerate(audio_files):
|
||||
audio_file.fd.seek(0)
|
||||
data.add_field(
|
||||
"files",
|
||||
audio_file.fd,
|
||||
filename=f"{audio_file.name}",
|
||||
content_type="audio/wav",
|
||||
)
|
||||
|
||||
# Make batch request
|
||||
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.transcript_url}/audio/transcriptions",
|
||||
data=data,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"Batch transcription failed: {response.status} {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# Process batch results
|
||||
transcripts = []
|
||||
results = result.get("results", [])
|
||||
|
||||
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
|
||||
transcript = Transcript(
|
||||
words=[
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in file_result.get("words", [])
|
||||
]
|
||||
)
|
||||
transcript.add_offset(audio_file.timestamp)
|
||||
transcripts.append(transcript)
|
||||
|
||||
return transcripts
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with AsyncOpenAI(
|
||||
@@ -62,5 +187,96 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
|
||||
return transcript
|
||||
|
||||
async def transcript_multiple(
|
||||
self, audio_files: List[AudioFile]
|
||||
) -> List[Transcript]:
|
||||
"""Transcribe multiple audio files using batching"""
|
||||
if len(audio_files) == 1:
|
||||
# Single file, use existing method
|
||||
return [await self._transcript(audio_files[0])]
|
||||
|
||||
# Create batches with max 30s duration each
|
||||
batches = self._create_batches(audio_files)
|
||||
|
||||
self.logger.debug(
|
||||
f"Processing {len(audio_files)} files in {len(batches)} batches"
|
||||
)
|
||||
|
||||
# Process all batches concurrently
|
||||
all_transcripts = []
|
||||
|
||||
for batch in batches:
|
||||
batch_transcripts = await self._transcript_batch(batch)
|
||||
all_transcripts.extend(batch_transcripts)
|
||||
|
||||
return all_transcripts
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
"""Override _push to support batching"""
|
||||
if not self.batch_enabled:
|
||||
# Use parent implementation for single file processing
|
||||
return await super()._push(data)
|
||||
|
||||
# Add file to pending batch
|
||||
self.pending_files.append(data)
|
||||
self.logger.debug(
|
||||
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
|
||||
)
|
||||
|
||||
# Calculate total duration of pending files
|
||||
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
|
||||
|
||||
# Process batch if it reaches max duration or has multiple files ready for optimization
|
||||
should_process_batch = (
|
||||
total_duration >= self.max_batch_duration
|
||||
or len(self.pending_files) >= self.max_batch_files
|
||||
)
|
||||
|
||||
if should_process_batch:
|
||||
await self._process_pending_batch()
|
||||
|
||||
async def _process_pending_batch(self):
|
||||
"""Process all pending files as batches"""
|
||||
if not self.pending_files:
|
||||
return
|
||||
|
||||
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
|
||||
|
||||
try:
|
||||
# Create batches respecting duration limit
|
||||
batches = self._create_batches(self.pending_files)
|
||||
|
||||
# Process each batch
|
||||
for batch in batches:
|
||||
self.m_transcript_call.inc()
|
||||
try:
|
||||
with self.m_transcript.time():
|
||||
# Use batch transcription
|
||||
transcripts = await self._transcript_batch(batch)
|
||||
|
||||
self.m_transcript_success.inc()
|
||||
|
||||
# Emit each transcript
|
||||
for transcript in transcripts:
|
||||
if transcript:
|
||||
await self.emit(transcript)
|
||||
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
finally:
|
||||
# Release audio files
|
||||
for audio_file in batch:
|
||||
audio_file.release()
|
||||
|
||||
finally:
|
||||
# Clear pending files
|
||||
self.pending_files.clear()
|
||||
|
||||
async def _flush(self):
|
||||
"""Process any remaining files when flushing"""
|
||||
await self._process_pending_batch()
|
||||
await super()._flush()
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
||||
|
||||
@@ -173,6 +173,7 @@ class Processor(Emitter):
|
||||
except Exception:
|
||||
self.m_processor_failure.inc()
|
||||
self.logger.exception("Error in push")
|
||||
raise
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
@@ -240,33 +241,45 @@ class ThreadedProcessor(Processor):
|
||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.queue = asyncio.Queue()
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
self.queue = asyncio.Queue(maxsize=50)
|
||||
self.task: asyncio.Task | None = None
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
async def loop(self):
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||
|
||||
async def _ensure_task(self):
|
||||
if self.task is None:
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||
# to be running without having a chance to work on the task here.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def _push(self, data):
|
||||
await self._ensure_task()
|
||||
await self.queue.put(data)
|
||||
|
||||
async def _flush(self):
|
||||
|
||||
33
server/reflector/processors/file_diarization.py
Normal file
33
server/reflector/processors/file_diarization.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
|
||||
class FileDiarizationInput(BaseModel):
|
||||
"""Input for file diarization containing audio URL"""
|
||||
|
||||
audio_url: str
|
||||
|
||||
|
||||
class FileDiarizationOutput(BaseModel):
|
||||
"""Output for file diarization containing speaker segments"""
|
||||
|
||||
diarization: list[DiarizationSegment]
|
||||
|
||||
|
||||
class FileDiarizationProcessor(Processor):
|
||||
"""
|
||||
Diarize complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileDiarizationInput
|
||||
OUTPUT_TYPE = FileDiarizationOutput
|
||||
|
||||
async def _push(self, data: FileDiarizationInput):
|
||||
result = await self._diarize(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
raise NotImplementedError
|
||||
33
server/reflector/processors/file_diarization_auto.py
Normal file
33
server/reflector/processors/file_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
57
server/reflector/processors/file_diarization_modal.py
Normal file
57
server/reflector/processors/file_diarization_modal.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
File diarization implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to DIARIZATION_URL:
|
||||
|
||||
```
|
||||
POST /diarize?audio_file_url=...×tamp=0
|
||||
Authorization: Bearer <modal_api_key>
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.DIARIZATION_URL:
|
||||
raise Exception(
|
||||
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||
)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Get speaker diarization for file"""
|
||||
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
diarization_data = response.json()["diarization"]
|
||||
|
||||
return FileDiarizationOutput(diarization=diarization_data)
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||
65
server/reflector/processors/file_transcript.py
Normal file
65
server/reflector/processors/file_transcript.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript
|
||||
|
||||
|
||||
class FileTranscriptInput:
|
||||
"""Input for file transcription containing audio URL and language settings"""
|
||||
|
||||
def __init__(self, audio_url: str, language: str = "en"):
|
||||
self.audio_url = audio_url
|
||||
self.language = language
|
||||
|
||||
|
||||
class FileTranscriptProcessor(Processor):
|
||||
"""
|
||||
Transcript complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileTranscriptInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
m_transcript = Histogram(
|
||||
"file_transcript",
|
||||
"Time spent in FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_call = Counter(
|
||||
"file_transcript_call",
|
||||
"Number of calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_success = Counter(
|
||||
"file_transcript_success",
|
||||
"Number of successful calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_failure = Counter(
|
||||
"file_transcript_failure",
|
||||
"Number of failed calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_transcript = self.m_transcript.labels(name)
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: FileTranscriptInput):
|
||||
try:
|
||||
self.m_transcript_call.inc()
|
||||
with self.m_transcript.time():
|
||||
result = await self._transcript(data)
|
||||
self.m_transcript_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
raise NotImplementedError
|
||||
32
server/reflector/processors/file_transcript_auto.py
Normal file
32
server/reflector/processors/file_transcript_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
74
server/reflector/processors/file_transcript_modal.py
Normal file
74
server/reflector/processors/file_transcript_modal.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
File transcription implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://...",
|
||||
"language": "en",
|
||||
"model": "parakeet-tdt-0.6b-v2",
|
||||
"batch": true
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||
)
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Send full file to Modal for transcription"""
|
||||
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||
|
||||
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={
|
||||
"audio_file_url": data.audio_url,
|
||||
"language": data.language,
|
||||
"batch": True,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
words = [
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# Register with the auto processor
|
||||
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Processor to assemble transcript with diarization results
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment, Transcript
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerInput:
|
||||
"""Input containing transcript and diarization data"""
|
||||
|
||||
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||
self.transcript = transcript
|
||||
self.diarization = diarization
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||
"""
|
||||
Assemble transcript with diarization results by applying speaker assignments
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||
result = await self._assemble(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||
"""Apply diarization to transcript words"""
|
||||
if not data.diarization:
|
||||
self.logger.info(
|
||||
"No diarization data provided, returning original transcript"
|
||||
)
|
||||
return data.transcript
|
||||
|
||||
# Reuse logic from AudioDiarizationProcessor
|
||||
processor = AudioDiarizationProcessor()
|
||||
words = data.transcript.words
|
||||
processor.assign_speaker(words, data.diarization)
|
||||
|
||||
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||
return data.transcript
|
||||
@@ -2,13 +2,22 @@ import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
|
||||
class DiarizationSegment(TypedDict):
|
||||
"""Type definition for diarization segment containing speaker information"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
|
||||
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Audio Transcription: modal backend
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
@@ -66,10 +67,14 @@ class Settings(BaseSettings):
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
DIARIZATION_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Diarization: modal backend
|
||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization: local pyannote.audio
|
||||
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
"""
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
@@ -15,7 +28,43 @@ from reflector.processors import (
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
@@ -24,18 +73,40 @@ async def process_audio_file(
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="pyannote",
|
||||
):
|
||||
# build pipeline for audio processing
|
||||
processors = [
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
@@ -44,14 +115,14 @@ async def process_audio_file(
|
||||
),
|
||||
]
|
||||
|
||||
# transcription output
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# start processing audio
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
@@ -62,43 +133,242 @@ async def process_audio_file(
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
logger.info("All done !")
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
filename: str,
|
||||
event_callback,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
try:
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
await database.connect()
|
||||
try:
|
||||
# Create a temporary transcript for processing
|
||||
transcript = await transcripts_controller.add(
|
||||
"",
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
)
|
||||
|
||||
# Process the file
|
||||
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||
await pipeline.process(Path(filename))
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
except ImportError as e:
|
||||
logger.error(f"File pipeline not available: {e}")
|
||||
logger.info("Falling back to stream pipeline")
|
||||
# Fall back to stream pipeline
|
||||
await process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
enable_diarization=enable_diarization,
|
||||
diarization_backend=diarization_backend,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
||||
parser.add_argument("--source-language", default="en")
|
||||
parser.add_argument("--target-language", default="en")
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use streaming pipeline (original frame-based processing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="pyannote",
|
||||
choices=["pyannote", "modal"],
|
||||
help="Diarization backend to use (default: pyannote)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
# ignore some processor
|
||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioChunkerProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
logger.info(f"Event: {event}")
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
if args.stream:
|
||||
# Use original streaming pipeline
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use optimized file pipeline (default)
|
||||
asyncio.run(
|
||||
process_file_pipeline(
|
||||
args.source,
|
||||
event_callback,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
|
||||
@@ -160,6 +160,7 @@ async def transcripts_search(
|
||||
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
||||
offset: SearchOffsetParam = 0,
|
||||
room_id: Optional[str] = None,
|
||||
source_kind: Optional[SourceKind] = None,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
] = None,
|
||||
@@ -173,7 +174,12 @@ async def transcripts_search(
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
search_params = SearchParameters(
|
||||
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
|
||||
query_text=q,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
results, total = await search_controller.search_transcripts(search_params)
|
||||
|
||||
@@ -14,7 +14,8 @@ from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.recordings import Recording, recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import get_room_sessions
|
||||
|
||||
@@ -140,7 +141,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
task_pipeline_process.delay(transcript_id=transcript.id)
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
|
||||
|
||||
@shared_task
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '0'
|
||||
host:
|
||||
- monadical-sas--reflector-diarizer-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||
response:
|
||||
body:
|
||||
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '220'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:25:34 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAVNEP6N7Y1Y7W3T98BCXK
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,46 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||
"language": "en", "batch": true}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '136'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||
response:
|
||||
body:
|
||||
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||
With reflector, I can easily convert any audio into accurate transcription.
|
||||
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '1933'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:26:59 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAWC7GAMKX4DSJ21WV31NG
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,84 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||
"language": "en", "batch": true}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '136'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||
response:
|
||||
body:
|
||||
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||
With reflector, I can easily convert any audio into accurate transcription.
|
||||
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '1933'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:27:02 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAYZ1AR2HE422VJVKBWX9Z
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '0'
|
||||
host:
|
||||
- monadical-sas--reflector-diarizer-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||
response:
|
||||
body:
|
||||
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '220'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:27:18 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAZ1M34NQRJK03CCFK95D6
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -5,7 +5,29 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
# Pytest-docker configuration
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def settings_configuration():
|
||||
# theses settings are linked to monadical for pytest-recording
|
||||
# if a fork is done, they have to provide their own url when cassettes needs to be updated
|
||||
# modal api keys has to be defined by the user
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.TRANSCRIPT_BACKEND = "modal"
|
||||
settings.TRANSCRIPT_URL = (
|
||||
"https://monadical-sas--reflector-transcriber-parakeet-web.modal.run"
|
||||
)
|
||||
settings.DIARIZATION_BACKEND = "modal"
|
||||
settings.DIARIZATION_URL = "https://monadical-sas--reflector-diarizer-web.modal.run"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
"""VCR configuration to filter sensitive headers"""
|
||||
return {
|
||||
"filter_headers": [("authorization", "DUMMY_API_KEY")],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def docker_compose_file(pytestconfig):
|
||||
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
version: '3.8'
|
||||
version: "3.8"
|
||||
services:
|
||||
postgres_test:
|
||||
image: postgres:15
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_DB: reflector_test
|
||||
POSTGRES_USER: test_user
|
||||
@@ -10,4 +10,4 @@ services:
|
||||
- "15432:5432"
|
||||
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
|
||||
tmpfs:
|
||||
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
|
||||
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
|
||||
|
||||
330
server/tests/test_gpu_modal_transcript.py
Normal file
330
server/tests/test_gpu_modal_transcript.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Tests for GPU Modal transcription endpoints.
|
||||
|
||||
These tests are marked with the "gpu-modal" group and will not run by default.
|
||||
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
||||
|
||||
Required environment variables:
|
||||
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
||||
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
||||
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||
|
||||
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
||||
|
||||
# Or with completely clean options:
|
||||
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
||||
|
||||
Running Modal locally for testing:
|
||||
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Test audio file URL for testing
|
||||
TEST_AUDIO_URL = (
|
||||
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||
)
|
||||
|
||||
|
||||
def get_modal_transcript_url():
|
||||
"""Get and validate the Modal transcript URL from environment."""
|
||||
url = os.environ.get("TRANSCRIPT_URL")
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
"""Get authentication headers if API key is available."""
|
||||
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
||||
if api_key:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_name():
|
||||
"""Get the model name from environment or use default."""
|
||||
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||
|
||||
|
||||
@pytest.mark.gpu_modal
|
||||
class TestGPUModalTranscript:
|
||||
"""Test suite for GPU Modal transcription endpoints."""
|
||||
|
||||
def test_transcriptions_from_url(self):
|
||||
"""Test the /v1/audio/transcriptions-from-url endpoint."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert isinstance(result["text"], str)
|
||||
assert isinstance(result["words"], list)
|
||||
|
||||
# Verify content is meaningful
|
||||
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||
|
||||
# Verify word structure
|
||||
for word in result["words"]:
|
||||
assert "word" in word
|
||||
assert "start" in word
|
||||
assert "end" in word
|
||||
assert isinstance(word["start"], (int, float))
|
||||
assert isinstance(word["end"], (int, float))
|
||||
assert word["start"] <= word["end"]
|
||||
|
||||
def test_transcriptions_single_file(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with a single file."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Download test audio file to upload
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
|
||||
tmp_file.write(audio_response.content)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Upload the file for transcription
|
||||
with open(tmp_file_path, "rb") as f:
|
||||
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "false",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for single file
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert "filename" in result
|
||||
assert isinstance(result["text"], str)
|
||||
assert isinstance(result["words"], list)
|
||||
|
||||
# Verify content
|
||||
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||
|
||||
finally:
|
||||
Path(tmp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_multiple_files(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Create multiple test files (we'll use the same audio content for simplicity)
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
audio_content = audio_response.content
|
||||
|
||||
temp_files = []
|
||||
try:
|
||||
# Create 3 temporary files
|
||||
for i in range(3):
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||
tmp_file.write(audio_content)
|
||||
tmp_file.close()
|
||||
temp_files.append(tmp_file.name)
|
||||
|
||||
# Upload multiple files for transcription (non-batch)
|
||||
files = [
|
||||
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||
for i, f in enumerate(temp_files)
|
||||
]
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "false",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Close file handles
|
||||
for _, file_tuple in files:
|
||||
file_tuple[1].close()
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for multiple files (non-batch)
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
for idx, file_result in enumerate(result["results"]):
|
||||
assert "text" in file_result
|
||||
assert "words" in file_result
|
||||
assert "filename" in file_result
|
||||
assert isinstance(file_result["text"], str)
|
||||
assert isinstance(file_result["words"], list)
|
||||
assert len(file_result["text"]) > 0
|
||||
|
||||
finally:
|
||||
for f in temp_files:
|
||||
Path(f).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_multiple_files_batch(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Create multiple test files
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
audio_content = audio_response.content
|
||||
|
||||
temp_files = []
|
||||
try:
|
||||
# Create 3 temporary files
|
||||
for i in range(3):
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||
tmp_file.write(audio_content)
|
||||
tmp_file.close()
|
||||
temp_files.append(tmp_file.name)
|
||||
|
||||
# Upload multiple files for batch transcription
|
||||
files = [
|
||||
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||
for i, f in enumerate(temp_files)
|
||||
]
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "true",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Close file handles
|
||||
for _, file_tuple in files:
|
||||
file_tuple[1].close()
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for batch mode
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
for idx, batch_result in enumerate(result["results"]):
|
||||
assert "text" in batch_result
|
||||
assert "words" in batch_result
|
||||
assert "filename" in batch_result
|
||||
assert isinstance(batch_result["text"], str)
|
||||
assert isinstance(batch_result["words"], list)
|
||||
assert len(batch_result["text"]) > 0
|
||||
|
||||
finally:
|
||||
for f in temp_files:
|
||||
Path(f).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_error_handling(self):
|
||||
"""Test error handling for invalid requests."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Test with unsupported language
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "fr", # Parakeet only supports English
|
||||
"timestamp_offset": 0.0,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "only supports English" in response.text
|
||||
|
||||
def test_transcriptions_with_timestamp_offset(self):
|
||||
"""Test transcription with timestamp offset parameter."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Test with timestamp offset
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"timestamp_offset": 10.0, # Add 10 second offset
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||
|
||||
# Verify that timestamps have been offset
|
||||
for word in result["words"]:
|
||||
# All timestamps should be >= 10.0 due to offset
|
||||
assert (
|
||||
word["start"] >= 10.0
|
||||
), f"Word start time {word['start']} should be >= 10.0"
|
||||
assert (
|
||||
word["end"] >= 10.0
|
||||
), f"Word end time {word['end']} should be >= 10.0"
|
||||
633
server/tests/test_pipeline_main_file.py
Normal file
633
server/tests/test_pipeline_main_file.py
Normal file
@@ -0,0 +1,633 @@
|
||||
"""
|
||||
Tests for PipelineMainFile - file-based processing pipeline
|
||||
|
||||
This test verifies the complete file processing pipeline without mocking much,
|
||||
ensuring all processors are correctly invoked and the happy path works correctly.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
from reflector.processors.file_diarization import FileDiarizationOutput
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_transcript():
|
||||
"""Mock FileTranscriptAutoProcessor for file processing"""
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
|
||||
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
||||
async def _transcript(self, data):
|
||||
return TranscriptType(
|
||||
text="Hello world. How are you today?",
|
||||
words=[
|
||||
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
||||
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
||||
Word(start=0.6, end=1.0, text="world", speaker=0),
|
||||
Word(start=1.0, end=1.1, text=".", speaker=0),
|
||||
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
||||
Word(start=1.2, end=1.5, text="How", speaker=0),
|
||||
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
||||
Word(start=1.6, end=1.8, text="are", speaker=0),
|
||||
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
||||
Word(start=1.9, end=2.1, text="you", speaker=0),
|
||||
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
||||
Word(start=2.2, end=2.5, text="today", speaker=0),
|
||||
Word(start=2.5, end=2.6, text="?", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_diarization():
|
||||
"""Mock FileDiarizationAutoProcessor for file processing"""
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
|
||||
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data):
|
||||
return FileDiarizationOutput(
|
||||
diarization=[
|
||||
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
||||
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_transcript_in_db(tmpdir):
|
||||
"""Create a mock transcript in the database"""
|
||||
from reflector.db.transcripts import Transcript
|
||||
from reflector.settings import settings
|
||||
|
||||
# Set the DATA_DIR to our tmpdir
|
||||
original_data_dir = settings.DATA_DIR
|
||||
settings.DATA_DIR = str(tmpdir)
|
||||
|
||||
transcript_id = str(uuid4())
|
||||
data_path = Path(tmpdir) / transcript_id
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create mock transcript object
|
||||
transcript = Transcript(
|
||||
id=transcript_id,
|
||||
name="Test Transcript",
|
||||
status="processing",
|
||||
source_kind="file",
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
)
|
||||
|
||||
# Mock the controller to return our transcript
|
||||
try:
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get:
|
||||
mock_get.return_value = transcript
|
||||
with patch(
|
||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get2:
|
||||
mock_get2.return_value = transcript
|
||||
with patch(
|
||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
||||
) as mock_update:
|
||||
mock_update.return_value = None
|
||||
yield transcript
|
||||
finally:
|
||||
# Restore original DATA_DIR
|
||||
settings.DATA_DIR = original_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_storage():
|
||||
"""Mock storage for file uploads"""
|
||||
from reflector.storage.base import Storage
|
||||
|
||||
class TestStorage(Storage):
|
||||
async def _put_file(self, path, data):
|
||||
return None
|
||||
|
||||
async def _get_file_url(self, path):
|
||||
return f"http://test-storage/{path}"
|
||||
|
||||
async def _get_file(self, path):
|
||||
return b"test_audio_data"
|
||||
|
||||
async def _delete_file(self, path):
|
||||
return None
|
||||
|
||||
storage = TestStorage()
|
||||
# Add mock tracking for verification
|
||||
storage._put_file = AsyncMock(side_effect=storage._put_file)
|
||||
storage._get_file_url = AsyncMock(side_effect=storage._get_file_url)
|
||||
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
||||
) as mock_get:
|
||||
mock_get.return_value = storage
|
||||
yield storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_audio_file_writer():
|
||||
"""Mock AudioFileWriterProcessor to avoid actual file writing"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.AudioFileWriterProcessor"
|
||||
) as mock_writer_class:
|
||||
mock_writer = AsyncMock()
|
||||
mock_writer.push = AsyncMock()
|
||||
mock_writer.flush = AsyncMock()
|
||||
mock_writer_class.return_value = mock_writer
|
||||
yield mock_writer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_waveform_processor():
|
||||
"""Mock AudioWaveformProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.AudioWaveformProcessor"
|
||||
) as mock_waveform_class:
|
||||
mock_waveform = AsyncMock()
|
||||
mock_waveform.set_pipeline = MagicMock()
|
||||
mock_waveform.flush = AsyncMock()
|
||||
mock_waveform_class.return_value = mock_waveform
|
||||
yield mock_waveform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_topic_detector():
|
||||
"""Mock TranscriptTopicDetectorProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor"
|
||||
) as mock_topic_class:
|
||||
mock_topic = AsyncMock()
|
||||
mock_topic.set_pipeline = MagicMock()
|
||||
mock_topic.push = AsyncMock()
|
||||
mock_topic.flush_called = False
|
||||
|
||||
# When flush is called, simulate topic detection by calling the callback
|
||||
async def flush_with_callback():
|
||||
mock_topic.flush_called = True
|
||||
if hasattr(mock_topic, "_callback"):
|
||||
# Create a minimal transcript for the TitleSummary
|
||||
test_transcript = TranscriptType(words=[], text="test transcript")
|
||||
await mock_topic._callback(
|
||||
TitleSummary(
|
||||
title="Test Topic",
|
||||
summary="Test topic summary",
|
||||
timestamp=0.0,
|
||||
duration=10.0,
|
||||
transcript=test_transcript,
|
||||
)
|
||||
)
|
||||
|
||||
mock_topic.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(callback=None):
|
||||
mock_topic._callback = callback
|
||||
return mock_topic
|
||||
|
||||
mock_topic_class.side_effect = init_with_callback
|
||||
yield mock_topic
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_title_processor():
|
||||
"""Mock TranscriptFinalTitleProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor"
|
||||
) as mock_title_class:
|
||||
mock_title = AsyncMock()
|
||||
mock_title.set_pipeline = MagicMock()
|
||||
mock_title.push = AsyncMock()
|
||||
mock_title.flush_called = False
|
||||
|
||||
# When flush is called, simulate title generation by calling the callback
|
||||
async def flush_with_callback():
|
||||
mock_title.flush_called = True
|
||||
if hasattr(mock_title, "_callback"):
|
||||
from reflector.processors.types import FinalTitle
|
||||
|
||||
await mock_title._callback(FinalTitle(title="Test Title"))
|
||||
|
||||
mock_title.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(callback=None):
|
||||
mock_title._callback = callback
|
||||
return mock_title
|
||||
|
||||
mock_title_class.side_effect = init_with_callback
|
||||
yield mock_title
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_summary_processor():
|
||||
"""Mock TranscriptFinalSummaryProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor"
|
||||
) as mock_summary_class:
|
||||
mock_summary = AsyncMock()
|
||||
mock_summary.set_pipeline = MagicMock()
|
||||
mock_summary.push = AsyncMock()
|
||||
mock_summary.flush_called = False
|
||||
|
||||
# When flush is called, simulate summary generation by calling the callbacks
|
||||
async def flush_with_callback():
|
||||
mock_summary.flush_called = True
|
||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary
|
||||
|
||||
if hasattr(mock_summary, "_callback"):
|
||||
await mock_summary._callback(
|
||||
FinalLongSummary(long_summary="Test long summary", duration=10.0)
|
||||
)
|
||||
if hasattr(mock_summary, "_on_short_summary"):
|
||||
await mock_summary._on_short_summary(
|
||||
FinalShortSummary(short_summary="Test short summary", duration=10.0)
|
||||
)
|
||||
|
||||
mock_summary.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(transcript=None, callback=None, on_short_summary=None):
|
||||
mock_summary._callback = callback
|
||||
mock_summary._on_short_summary = on_short_summary
|
||||
return mock_summary
|
||||
|
||||
mock_summary_class.side_effect = init_with_callback
|
||||
yield mock_summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_process(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test the complete PipelineMainFile processing pipeline.
|
||||
|
||||
This test verifies:
|
||||
1. Audio extraction and writing
|
||||
2. Audio upload to storage
|
||||
3. Parallel processing of transcription, diarization, and waveform
|
||||
4. Assembly of transcript with diarization
|
||||
5. Topic detection
|
||||
6. Title and summary generation
|
||||
"""
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path as if it was uploaded
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||
# Since we're mocking AudioFileWriterProcessor, we need to create this manually
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Track callback invocations
|
||||
callback_marks = {
|
||||
"on_status": [],
|
||||
"on_duration": [],
|
||||
"on_waveform": [],
|
||||
"on_topic": [],
|
||||
"on_title": [],
|
||||
"on_long_summary": [],
|
||||
"on_short_summary": [],
|
||||
}
|
||||
|
||||
# Create pipeline with mocked callbacks
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Override callbacks to track invocations
|
||||
async def track_callback(name, data):
|
||||
callback_marks[name].append(data)
|
||||
# Call the original callback
|
||||
original = getattr(PipelineMainFile, name)
|
||||
return await original(pipeline, data)
|
||||
|
||||
for callback_name in callback_marks.keys():
|
||||
setattr(
|
||||
pipeline,
|
||||
callback_name,
|
||||
lambda data, n=callback_name: track_callback(n, data),
|
||||
)
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams (audio only)
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter(
|
||||
[MagicMock()]
|
||||
) # One mock audio frame
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify audio extraction and writing
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
|
||||
# Verify storage upload
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_storage._get_file_url.called
|
||||
|
||||
# Verify waveform generation
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_waveform_processor.set_pipeline.called
|
||||
|
||||
# Verify topic detection
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_topic_detector.flush_called
|
||||
|
||||
# Verify title generation
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_title_processor.flush_called
|
||||
|
||||
# Verify summary generation
|
||||
assert mock_summary_processor.push.called
|
||||
assert mock_summary_processor.flush_called
|
||||
|
||||
# Verify callbacks were invoked
|
||||
assert len(callback_marks["on_topic"]) > 0, "Topic callback should be invoked"
|
||||
assert len(callback_marks["on_title"]) > 0, "Title callback should be invoked"
|
||||
assert (
|
||||
len(callback_marks["on_long_summary"]) > 0
|
||||
), "Long summary callback should be invoked"
|
||||
assert (
|
||||
len(callback_marks["on_short_summary"]) > 0
|
||||
), "Short summary callback should be invoked"
|
||||
|
||||
print(f"Callback marks: {callback_marks}")
|
||||
|
||||
# Verify the pipeline completed successfully
|
||||
assert pipeline.logger is not None
|
||||
print("PipelineMainFile test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_with_video(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test PipelineMainFile with video input (verifies audio extraction).
|
||||
"""
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path as if it was a video upload
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.mp4"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Mock av.open for video processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [MagicMock()] # Has video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter(
|
||||
[MagicMock()]
|
||||
) # One mock audio frame
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify audio extraction from video
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
|
||||
# Verify the rest of the pipeline completed
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("PipelineMainFile video test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_no_diarization(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test PipelineMainFile with diarization disabled.
|
||||
"""
|
||||
from reflector.settings import settings
|
||||
|
||||
# Disable diarization
|
||||
with patch.object(settings, "DIARIZATION_BACKEND", None):
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify the pipeline completed without diarization
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("PipelineMainFile no-diarization test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_pipeline_file_process(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test the Celery task entry point for file pipeline processing.
|
||||
"""
|
||||
# Direct import of the underlying async function, bypassing the asynctask decorator
|
||||
|
||||
# Create a test audio file in the transcript's data path
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Get the original async function without the asynctask decorator
|
||||
# The function is wrapped, so we need to call it differently
|
||||
# For now, we test the pipeline directly since the task is just a thin wrapper
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify the pipeline was executed through the task
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("task_pipeline_file_process test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_file_process_no_transcript():
|
||||
"""
|
||||
Test the pipeline with a non-existent transcript.
|
||||
"""
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
# Mock the controller to return None (transcript not found)
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=str(uuid4()))
|
||||
|
||||
# Should raise an exception for missing transcript when get_transcript is called
|
||||
with pytest.raises(Exception, match="Transcript not found"):
|
||||
await pipeline.get_transcript()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_file_process_no_audio_file(
|
||||
mock_transcript_in_db,
|
||||
):
|
||||
"""
|
||||
Test the pipeline when no audio file is found.
|
||||
"""
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
# Don't create any audio files in the data path
|
||||
# The pipeline's process should handle missing files gracefully
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Try to process a non-existent file
|
||||
non_existent_path = mock_transcript_in_db.data_path / "nonexistent.wav"
|
||||
|
||||
# This should fail when trying to open the file with av
|
||||
with pytest.raises(Exception):
|
||||
await pipeline.process(non_existent_path)
|
||||
265
server/tests/test_processors_modal.py
Normal file
265
server/tests/test_processors_modal.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for Modal-based processors using pytest-recording for HTTP recording/playbook
|
||||
|
||||
Note: theses tests require full modal configuration to be able to record
|
||||
vcr cassettes
|
||||
|
||||
Configuration required for the first recording:
|
||||
- TRANSCRIPT_BACKEND=modal
|
||||
- TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
- TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||
- DIARIZATION_BACKEND=modal
|
||||
- DIARIZATION_URL=https://xxxxx--reflector-diarizer-web.modal.run
|
||||
- DIARIZATION_MODAL_API_KEY=xxxxx
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_modal import FileDiarizationModalProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_modal import FileTranscriptModalProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import DiarizationSegment, Transcript, Word
|
||||
|
||||
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||
TEST_AUDIO_URL = (
|
||||
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_modal_processor_missing_url():
|
||||
with patch("reflector.processors.file_transcript_modal.settings") as mock_settings:
|
||||
mock_settings.TRANSCRIPT_URL = None
|
||||
with pytest.raises(Exception, match="TRANSCRIPT_URL required"):
|
||||
FileTranscriptModalProcessor(modal_api_key="test-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_modal_processor_missing_url():
|
||||
with patch("reflector.processors.file_diarization_modal.settings") as mock_settings:
|
||||
mock_settings.DIARIZATION_URL = None
|
||||
with pytest.raises(Exception, match="DIARIZATION_URL required"):
|
||||
FileDiarizationModalProcessor(modal_api_key="test-api-key")
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_modal_processor(vcr):
|
||||
"""Test FileDiarizationModalProcessor using public audio URL and Modal API"""
|
||||
from reflector.settings import settings
|
||||
|
||||
processor = FileDiarizationModalProcessor(
|
||||
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||
)
|
||||
|
||||
test_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||
result = await processor._diarize(test_input)
|
||||
|
||||
# Verify the result structure
|
||||
assert result is not None
|
||||
assert hasattr(result, "diarization")
|
||||
assert isinstance(result.diarization, list)
|
||||
|
||||
# Check structure of each diarization segment
|
||||
for segment in result.diarization:
|
||||
assert "start" in segment
|
||||
assert "end" in segment
|
||||
assert "speaker" in segment
|
||||
assert isinstance(segment["start"], (int, float))
|
||||
assert isinstance(segment["end"], (int, float))
|
||||
assert isinstance(segment["speaker"], int)
|
||||
# Basic sanity check - start should be before end
|
||||
assert segment["start"] < segment["end"]
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_modal_processor():
|
||||
"""Test FileTranscriptModalProcessor using public audio URL and Modal API"""
|
||||
from reflector.settings import settings
|
||||
|
||||
processor = FileTranscriptModalProcessor(
|
||||
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||
)
|
||||
|
||||
test_input = FileTranscriptInput(
|
||||
audio_url=TEST_AUDIO_URL,
|
||||
language="en",
|
||||
)
|
||||
|
||||
# This will record the HTTP interaction on first run, replay on subsequent runs
|
||||
result = await processor._transcript(test_input)
|
||||
|
||||
# Verify the result structure
|
||||
assert result is not None
|
||||
assert hasattr(result, "words")
|
||||
assert isinstance(result.words, list)
|
||||
|
||||
# Check structure of each word if present
|
||||
for word in result.words:
|
||||
assert hasattr(word, "text")
|
||||
assert hasattr(word, "start")
|
||||
assert hasattr(word, "end")
|
||||
assert isinstance(word.start, (int, float))
|
||||
assert isinstance(word.end, (int, float))
|
||||
assert isinstance(word.text, str)
|
||||
# Basic sanity check - start should be before or equal to end
|
||||
assert word.start <= word.end
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_diarization_assembler_processor():
|
||||
"""Test TranscriptDiarizationAssemblerProcessor without VCR (no HTTP requests)"""
|
||||
# Create test transcript with words
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" ", start=1.0, end=1.1, speaker=0),
|
||||
Word(text="world", start=1.1, end=2.0, speaker=0),
|
||||
Word(text=".", start=2.0, end=2.1, speaker=0),
|
||||
Word(text=" ", start=2.1, end=2.2, speaker=0),
|
||||
Word(text="How", start=2.2, end=2.8, speaker=0),
|
||||
Word(text=" ", start=2.8, end=2.9, speaker=0),
|
||||
Word(text="are", start=2.9, end=3.2, speaker=0),
|
||||
Word(text=" ", start=3.2, end=3.3, speaker=0),
|
||||
Word(text="you", start=3.3, end=3.8, speaker=0),
|
||||
Word(text="?", start=3.8, end=3.9, speaker=0),
|
||||
]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
# Create test diarization segments
|
||||
diarization = [
|
||||
DiarizationSegment(start=0.0, end=2.1, speaker=0),
|
||||
DiarizationSegment(start=2.1, end=3.9, speaker=1),
|
||||
]
|
||||
|
||||
# Create processor and test input
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
test_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=diarization
|
||||
)
|
||||
|
||||
# Track emitted results
|
||||
emitted_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
emitted_results.append(result)
|
||||
|
||||
processor.on(capture_result)
|
||||
|
||||
# Process the input
|
||||
await processor.push(test_input)
|
||||
|
||||
# Verify result was emitted
|
||||
assert len(emitted_results) == 1
|
||||
result = emitted_results[0]
|
||||
|
||||
# Verify result structure
|
||||
assert isinstance(result, Transcript)
|
||||
assert len(result.words) == len(words)
|
||||
|
||||
# Verify speaker assignments were applied
|
||||
# Words 0-3 (indices) should be speaker 0 (time 0.0-2.0)
|
||||
# Words 4-10 (indices) should be speaker 1 (time 2.1-3.9)
|
||||
for i in range(4): # First 4 words (Hello, space, world, .)
|
||||
assert (
|
||||
result.words[i].speaker == 0
|
||||
), f"Word {i} '{result.words[i].text}' should be speaker 0, got {result.words[i].speaker}"
|
||||
|
||||
for i in range(4, 11): # Remaining words (space, How, space, are, space, you, ?)
|
||||
assert (
|
||||
result.words[i].speaker == 1
|
||||
), f"Word {i} '{result.words[i].text}' should be speaker 1, got {result.words[i].speaker}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_diarization_assembler_no_diarization():
|
||||
"""Test TranscriptDiarizationAssemblerProcessor with no diarization data"""
|
||||
# Create test transcript
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
# Create processor and test input with empty diarization
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
test_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=[]
|
||||
)
|
||||
|
||||
# Track emitted results
|
||||
emitted_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
emitted_results.append(result)
|
||||
|
||||
processor.on(capture_result)
|
||||
|
||||
# Process the input
|
||||
await processor.push(test_input)
|
||||
|
||||
# Verify original transcript was returned unchanged
|
||||
assert len(emitted_results) == 1
|
||||
result = emitted_results[0]
|
||||
assert result is transcript # Should be the same object
|
||||
assert result.words[0].speaker == 0 # Original speaker unchanged
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_modal_pipeline_integration(vcr):
|
||||
"""Integration test: Transcription -> Diarization -> Assembly
|
||||
|
||||
This test demonstrates the full pipeline:
|
||||
1. Run transcription via Modal
|
||||
2. Run diarization via Modal
|
||||
3. Assemble transcript with diarization
|
||||
"""
|
||||
from reflector.settings import settings
|
||||
|
||||
# Step 1: Transcription
|
||||
transcript_processor = FileTranscriptModalProcessor(
|
||||
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||
)
|
||||
transcript_input = FileTranscriptInput(audio_url=TEST_AUDIO_URL, language="en")
|
||||
transcript = await transcript_processor._transcript(transcript_input)
|
||||
|
||||
# Step 2: Diarization
|
||||
diarization_processor = FileDiarizationModalProcessor(
|
||||
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||
)
|
||||
diarization_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||
diarization_result = await diarization_processor._diarize(diarization_input)
|
||||
|
||||
# Step 3: Assembly
|
||||
assembler = TranscriptDiarizationAssemblerProcessor()
|
||||
assembly_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=diarization_result.diarization
|
||||
)
|
||||
|
||||
# Track assembled result
|
||||
assembled_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
assembled_results.append(result)
|
||||
|
||||
assembler.on(capture_result)
|
||||
|
||||
await assembler.push(assembly_input)
|
||||
|
||||
# Verify the full pipeline worked
|
||||
assert len(assembled_results) == 1
|
||||
final_transcript = assembled_results[0]
|
||||
|
||||
# Verify the final transcript has the original words with updated speaker info
|
||||
assert isinstance(final_transcript, Transcript)
|
||||
assert len(final_transcript.words) == len(transcript.words)
|
||||
assert len(final_transcript.words) > 0
|
||||
|
||||
# Verify some words have been assigned speakers from diarization
|
||||
speakers_found = set(word.speaker for word in final_transcript.words)
|
||||
assert len(speakers_found) > 0 # At least some speaker assignments
|
||||
@@ -2,10 +2,13 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("enable_diarization", [False, True])
|
||||
async def test_basic_process(
|
||||
dummy_transcript,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
enable_diarization,
|
||||
dummy_diarization,
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
@@ -28,12 +31,31 @@ async def test_basic_process(
|
||||
|
||||
# invoke the process and capture events
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
await process_audio_file(path.as_posix(), event_callback)
|
||||
print(marks)
|
||||
|
||||
if enable_diarization:
|
||||
# Test with diarization - may fail if pyannote.audio is not installed
|
||||
try:
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=True
|
||||
)
|
||||
except SystemExit:
|
||||
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
||||
else:
|
||||
# Test without diarization - should always work
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=False
|
||||
)
|
||||
|
||||
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
||||
|
||||
# validate the events
|
||||
assert marks["TranscriptLinerProcessor"] == 1
|
||||
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
|
||||
# Each processor should be called for each audio segment processed
|
||||
# The final processors (Topic, Title, Summary) should be called once at the end
|
||||
assert marks["TranscriptLinerProcessor"] > 0
|
||||
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||
|
||||
if enable_diarization:
|
||||
assert marks["TestAudioDiarizationProcessor"] == 1
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
from reflector.db.search import (
|
||||
SearchController,
|
||||
SearchParameters,
|
||||
SearchResult,
|
||||
search_controller,
|
||||
)
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,39 +23,135 @@ async def test_search_postgresql_only():
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
params_empty = SearchParameters(query_text="")
|
||||
results_empty, total_empty = await search_controller.search_transcripts(
|
||||
params_empty
|
||||
)
|
||||
assert isinstance(results_empty, list)
|
||||
assert isinstance(total_empty, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_input_validation():
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
async def test_search_with_empty_query():
|
||||
"""Test that empty query returns all transcripts."""
|
||||
params = SearchParameters(query_text="")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
if len(results) > 1:
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i].created_at >= results[i + 1].created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_title_only_match():
|
||||
"""Test that transcripts with title-only matches return empty snippets."""
|
||||
test_id = "test-empty-9b3f2a8d"
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Empty Transcript",
|
||||
"title": "Empty Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 0.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": None,
|
||||
"long_summary": None,
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": None,
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="empty")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = next((r for r in results if r.id == test_id), None)
|
||||
assert found is not None, "Should find transcript by title match"
|
||||
assert found.search_snippets == []
|
||||
assert found.total_match_count == 0
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_long_summary():
|
||||
"""Test that long_summary content is searchable."""
|
||||
test_id = "test-long-summary-8a9f3c2d"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Summary",
|
||||
"title": "Regular Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Brief overview",
|
||||
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Basic meeting content without special keywords.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="quantum computing")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
@@ -94,28 +195,24 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
@@ -123,14 +220,12 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
assert test_result.duration == 1800.0
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
@@ -142,3 +237,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_params():
|
||||
"""Create sample search parameters for testing."""
|
||||
return SearchParameters(
|
||||
query_text="test query",
|
||||
limit=20,
|
||||
offset=0,
|
||||
user_id="test-user",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_result():
|
||||
"""Create a mock database result."""
|
||||
return {
|
||||
"id": "test-transcript-id",
|
||||
"title": "Test Transcript",
|
||||
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
"duration": 3600.0,
|
||||
"status": "completed",
|
||||
"user_id": "test-user",
|
||||
"room_id": "room1",
|
||||
"source_kind": SourceKind.LIVE,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
|
||||
"rank": 0.95,
|
||||
}
|
||||
|
||||
|
||||
class TestSearchParameters:
|
||||
"""Test SearchParameters model validation and functionality."""
|
||||
|
||||
def test_search_parameters_with_available_filters(self):
|
||||
"""Test creating SearchParameters with currently available filter options."""
|
||||
params = SearchParameters(
|
||||
query_text="search term",
|
||||
limit=50,
|
||||
offset=10,
|
||||
user_id="user123",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
assert params.query_text == "search term"
|
||||
assert params.limit == 50
|
||||
assert params.offset == 10
|
||||
assert params.user_id == "user123"
|
||||
assert params.room_id == "room1"
|
||||
|
||||
def test_search_parameters_defaults(self):
|
||||
"""Test SearchParameters with default values."""
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
assert params.query_text == "test"
|
||||
assert params.limit == 20
|
||||
assert params.offset == 0
|
||||
assert params.user_id is None
|
||||
assert params.room_id is None
|
||||
|
||||
|
||||
class TestSearchControllerFilters:
|
||||
"""Test SearchController functionality with various filters."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_source_kind_filter(self):
|
||||
"""Test search filtering by source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_single_room_id(self):
|
||||
"""Test search filtering by single room ID (currently supported)."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="test",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||
"""Test that search results include available fields like source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
|
||||
class MockRow:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._mapping = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.items())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
mock_row = MockRow(mock_db_result)
|
||||
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
result = results[0]
|
||||
assert isinstance(result, SearchResult)
|
||||
assert result.id == "test-transcript-id"
|
||||
assert result.title == "Test Transcript"
|
||||
assert result.rank == 0.95
|
||||
|
||||
|
||||
class TestSearchEndpointParsing:
|
||||
"""Test parameter parsing in the search endpoint."""
|
||||
|
||||
def test_parse_comma_separated_room_ids(self):
|
||||
"""Test parsing comma-separated room IDs."""
|
||||
room_ids_str = "room1,room2,room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1, room2 , room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1,,room3,"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room3"]
|
||||
|
||||
def test_parse_source_kind(self):
|
||||
"""Test parsing source_kind values."""
|
||||
for kind_str in ["live", "file", "room"]:
|
||||
parsed = SourceKind(kind_str)
|
||||
assert parsed == SourceKind(kind_str)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SourceKind("invalid_kind")
|
||||
|
||||
|
||||
class TestSearchResultModel:
|
||||
"""Test SearchResult model and serialization."""
|
||||
|
||||
def test_search_result_with_available_fields(self):
|
||||
"""Test SearchResult model with currently available fields populated."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
title="Test Title",
|
||||
user_id="user-123",
|
||||
room_id="room-456",
|
||||
source_kind=SourceKind.ROOM,
|
||||
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.85,
|
||||
duration=1800.5,
|
||||
search_snippets=["snippet 1", "snippet 2"],
|
||||
)
|
||||
|
||||
assert result.id == "test-id"
|
||||
assert result.title == "Test Title"
|
||||
assert result.user_id == "user-123"
|
||||
assert result.room_id == "room-456"
|
||||
assert result.status == "completed"
|
||||
assert result.rank == 0.85
|
||||
assert result.duration == 1800.5
|
||||
assert len(result.search_snippets) == 2
|
||||
|
||||
def test_search_result_with_optional_fields_none(self):
|
||||
"""Test SearchResult model with optional fields as None."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.FILE,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
status="processing",
|
||||
rank=0.5,
|
||||
search_snippets=[],
|
||||
title=None,
|
||||
user_id=None,
|
||||
room_id=None,
|
||||
duration=None,
|
||||
)
|
||||
|
||||
assert result.title is None
|
||||
assert result.user_id is None
|
||||
assert result.room_id is None
|
||||
assert result.duration is None
|
||||
|
||||
def test_search_result_datetime_field(self):
|
||||
"""Test that SearchResult accepts datetime field."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.LIVE,
|
||||
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.9,
|
||||
duration=None,
|
||||
search_snippets=[],
|
||||
)
|
||||
|
||||
assert result.created_at == datetime(
|
||||
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
164
server/tests/test_search_long_summary.py
Normal file
164
server/tests/test_search_long_summary.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for long_summary in search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_snippet_prioritization():
|
||||
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||
test_id = "test-snippet-priority-3f9a2b8c"
|
||||
|
||||
try:
|
||||
# Clean up any existing test data
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Snippet Priority",
|
||||
"title": "Meeting About Projects",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Project discussion",
|
||||
"long_summary": (
|
||||
"The team discussed advanced robotics applications including "
|
||||
"autonomous navigation systems and sensor fusion techniques. "
|
||||
"Robotics development will focus on real-time processing."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
We talked about many different topics today.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
The robotics project is making good progress.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
We need to consider various implementation approaches.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for "robotics" which appears in both long_summary and webvtt
|
||||
params = SearchParameters(query_text="robotics")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result, "Should find the test transcript"
|
||||
|
||||
snippets = test_result.search_snippets
|
||||
assert len(snippets) > 0, "Should have at least one snippet"
|
||||
|
||||
# The first snippets should be from long_summary (more detailed content)
|
||||
first_snippet = snippets[0].lower()
|
||||
assert (
|
||||
"advanced robotics" in first_snippet or "autonomous" in first_snippet
|
||||
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
|
||||
|
||||
# With max 3 snippets, we should get both from long_summary and webvtt
|
||||
assert len(snippets) <= 3, "Should respect max snippets limit"
|
||||
|
||||
# All snippets should contain the search term
|
||||
for snippet in snippets:
|
||||
assert (
|
||||
"robotics" in snippet.lower()
|
||||
), f"Snippet should contain search term: {snippet}"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_only_search():
|
||||
"""Test searching for content that only exists in long_summary."""
|
||||
test_id = "test-long-only-8b3c9f2a"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Only",
|
||||
"title": "Standard Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Team sync",
|
||||
"long_summary": (
|
||||
"Detailed analysis of cryptocurrency market trends and "
|
||||
"decentralized finance protocols. Discussion included "
|
||||
"yield farming strategies and liquidity pool mechanics."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Team meeting about general project updates.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Discussion of timeline and deliverables.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for terms only in long_summary
|
||||
params = SearchParameters(query_text="cryptocurrency")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary-only content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
|
||||
# Verify the snippet is about cryptocurrency
|
||||
snippet = test_result.search_snippets[0].lower()
|
||||
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
|
||||
|
||||
# Search for "yield farming" - a more specific term
|
||||
params2 = SearchParameters(query_text="yield farming")
|
||||
results2, total2 = await search_controller.search_transcripts(params2)
|
||||
|
||||
found2 = any(r.id == test_id for r in results2)
|
||||
assert found2, "Should find transcript by specific long_summary phrase"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import SearchController
|
||||
from reflector.db.search import (
|
||||
SnippetCandidate,
|
||||
SnippetGenerator,
|
||||
WebVTTProcessor,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
separator = " This is filler text. " * 20
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
snippets = SnippetGenerator.generate(text, "Python")
|
||||
assert len(snippets) >= 2
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
snippets = SnippetGenerator.generate(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "It requires lots of data and computational resources. " * 5
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Deep learning is a subset of it. " * 5
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "..." in snippets[0]
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
text = "test test test word" * 10
|
||||
snippets = SnippetGenerator.generate(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
assert len(snippets) == len(set(snippets))
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
assert SnippetGenerator.generate("", "search") == []
|
||||
assert SnippetGenerator.generate("text", "") == []
|
||||
assert SnippetGenerator.generate("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
separator = " filler " * 50
|
||||
text = ("Python is amazing" + separator) * 10
|
||||
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
text = "word " * 200
|
||||
snippets = SnippetGenerator.generate(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
assert len(snippet) <= 200
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
@@ -157,7 +149,6 @@ class TestFullPipeline:
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
@@ -182,17 +173,362 @@ class TestFullPipeline:
|
||||
"""
|
||||
)
|
||||
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt)
|
||||
snippets = SnippetGenerator.generate(plain_text, "machine learning")
|
||||
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
assert len(snippets) >= 1
|
||||
assert len(snippets) <= 3
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
|
||||
|
||||
class TestMultiWordQueryBehavior:
|
||||
"""Tests for multi-word query behavior and exact phrase matching."""
|
||||
|
||||
def test_multi_word_query_snippet_behavior(self):
|
||||
"""Test that multi-word queries generate snippets based on exact phrase matching."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
user_snippets = SnippetGenerator.generate(sample_text, "user")
|
||||
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
|
||||
|
||||
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
|
||||
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
|
||||
|
||||
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
assert len(multi_word_snippets) == 1, (
|
||||
"Should return exactly 1 snippet for 'user jordan' "
|
||||
"(only the exact phrase match, not individual word occurrences)"
|
||||
)
|
||||
|
||||
snippet = multi_word_snippets[0]
|
||||
assert (
|
||||
"user jordan" in snippet.lower()
|
||||
), "The snippet should contain the exact phrase 'user jordan'"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), "The snippet should not include the first standalone 'user' with Alice"
|
||||
|
||||
def test_multi_word_query_without_exact_match(self):
|
||||
"""Test snippet generation when exact phrase is not found."""
|
||||
sample_text = """User Alice is here. Bob and jordan are talking.
|
||||
Later jordan mentions something. The user is happy."""
|
||||
|
||||
snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
|
||||
assert (
|
||||
len(snippets) >= 1
|
||||
), "Should find at least 1 snippet when falling back to first word"
|
||||
|
||||
all_snippets_text = " ".join(snippets).lower()
|
||||
assert (
|
||||
"user" in all_snippets_text
|
||||
), "Snippets should contain 'user' (the first word)"
|
||||
|
||||
def test_exact_phrase_at_text_boundaries(self):
|
||||
"""Test snippet generation when exact phrase appears at text boundaries."""
|
||||
|
||||
text_start = "user jordan started the meeting. Other content here."
|
||||
snippets = SnippetGenerator.generate(text_start, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
text_end = "Other content here. The meeting ended with user jordan"
|
||||
snippets = SnippetGenerator.generate(text_end, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
|
||||
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
search_query = "user jordan"
|
||||
snippets = SnippetGenerator.generate(sample_text, search_query)
|
||||
|
||||
assert len(snippets) == 1, (
|
||||
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
|
||||
f"got {len(snippets)}. Should ignore individual word occurrences."
|
||||
)
|
||||
|
||||
snippet = snippets[0]
|
||||
|
||||
assert (
|
||||
search_query in snippet.lower()
|
||||
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"jordan mentions" in snippet.lower()
|
||||
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
|
||||
|
||||
text_2 = """The alpha version was released.
|
||||
Beta testing started yesterday.
|
||||
The alpha beta integration is complete."""
|
||||
|
||||
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
|
||||
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
|
||||
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
|
||||
assert (
|
||||
"version" not in snippets_2[0].lower()
|
||||
), "Should not include first separate occurrence"
|
||||
|
||||
|
||||
class TestSnippetGenerationEnhanced:
|
||||
"""Additional snippet generation tests from test_search_enhancements.py."""
|
||||
|
||||
def test_snippet_generation_from_webvtt(self):
|
||||
"""Test snippet generation from WebVTT content."""
|
||||
webvtt_content = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:05.000
|
||||
This is the beginning of the transcript
|
||||
|
||||
00:00:05.000 --> 00:00:10.000
|
||||
The search term appears here in the middle
|
||||
|
||||
00:00:10.000 --> 00:00:15.000
|
||||
And this is the end of the content"""
|
||||
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt_content)
|
||||
snippets = SnippetGenerator.generate(plain_text, "search term")
|
||||
|
||||
assert len(snippets) > 0
|
||||
assert any("search term" in snippet.lower() for snippet in snippets)
|
||||
|
||||
def test_extract_webvtt_text_with_malformed_variations(self):
|
||||
"""Test WebVTT extraction with various malformed content."""
|
||||
malformed_vtt = "This is not valid WebVTT content"
|
||||
result = WebVTTProcessor.extract_text(malformed_vtt)
|
||||
assert result == ""
|
||||
|
||||
partial_vtt = "WEBVTT\nNo timestamps here"
|
||||
result = WebVTTProcessor.extract_text(partial_vtt)
|
||||
assert result == "" or "No timestamps" not in result
|
||||
|
||||
|
||||
class TestPureFunctions:
|
||||
"""Test the pure functions extracted for functional programming."""
|
||||
|
||||
def test_find_all_matches(self):
|
||||
"""Test finding all match positions in text."""
|
||||
text = "Python is great. Python is powerful. I love Python."
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
|
||||
assert matches == []
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches("", "test"))
|
||||
assert matches == []
|
||||
matches = list(SnippetGenerator.find_all_matches("test", ""))
|
||||
assert matches == []
|
||||
|
||||
def test_create_snippet(self):
|
||||
"""Test creating a snippet from a match position."""
|
||||
text = "This is a long text with the word Python in the middle and more text after."
|
||||
|
||||
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
|
||||
assert "Python" in snippet.text()
|
||||
assert snippet.start >= 0
|
||||
assert snippet.end <= len(text)
|
||||
assert isinstance(snippet, SnippetCandidate)
|
||||
|
||||
assert len(snippet.text()) > 0
|
||||
assert snippet.start <= snippet.end
|
||||
|
||||
long_text = "A" * 200
|
||||
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
|
||||
assert snippet.text().startswith("...")
|
||||
assert snippet.text().endswith("...")
|
||||
|
||||
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
|
||||
assert snippet.start == 0
|
||||
assert "short text" in snippet.text()
|
||||
|
||||
def test_filter_non_overlapping(self):
|
||||
"""Test filtering overlapping snippets."""
|
||||
candidates = [
|
||||
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
|
||||
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
|
||||
SnippetCandidate(
|
||||
_text="Third snippet", start=40, _original_text_length=100
|
||||
),
|
||||
SnippetCandidate(
|
||||
_text="Fourth snippet", start=65, _original_text_length=100
|
||||
),
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
|
||||
assert filtered == [
|
||||
"First snippet...",
|
||||
"...Third snippet...",
|
||||
"...Fourth snippet...",
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
|
||||
assert filtered == []
|
||||
|
||||
def test_generate_integration(self):
|
||||
"""Test the main SnippetGenerator.generate function."""
|
||||
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
assert len(snippets) <= 3
|
||||
assert all("machine learning" in s.lower() for s in snippets)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
|
||||
assert len(snippets) <= 2
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine vision")
|
||||
assert len(snippets) > 0
|
||||
assert any("machine" in s.lower() for s in snippets)
|
||||
|
||||
def test_extract_webvtt_text_basic(self):
|
||||
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Hello world
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
This is a test"""
|
||||
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world" in result
|
||||
assert "This is a test" in result
|
||||
|
||||
# Test empty input
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
assert WebVTTProcessor.extract_text(None) == ""
|
||||
|
||||
def test_generate_webvtt_snippets(self):
|
||||
"""Test generating snippets from WebVTT content."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Python programming is great
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Learn Python today"""
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets("", "Python")
|
||||
assert snippets == []
|
||||
|
||||
def test_from_summary(self):
|
||||
"""Test generating snippets from summary text."""
|
||||
summary = "This meeting discussed Python development and machine learning applications."
|
||||
|
||||
snippets = SnippetGenerator.from_summary(summary, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
long_summary = "Python " * 20
|
||||
snippets = SnippetGenerator.from_summary(long_summary, "Python")
|
||||
assert len(snippets) <= 2
|
||||
|
||||
def test_combine_sources(self):
|
||||
"""Test combining snippets from multiple sources."""
|
||||
summary = "Python is a great programming language."
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Learn Python programming
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Python is powerful"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) > 0
|
||||
assert total_count > 0
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert all("Python" in s for s in snippets)
|
||||
assert total_count == 1
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert total_count == 2
|
||||
|
||||
long_summary = "Python " * 10
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, "Python", max_total=2
|
||||
)
|
||||
assert len(snippets) == 2
|
||||
assert total_count >= 10
|
||||
|
||||
def test_match_counting_sum_logic(self):
|
||||
"""Test that match counting correctly sums matches from both sources."""
|
||||
summary = "data science uses data analysis and data mining techniques"
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Big data processing
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
data visualization and data storage"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "data", max_total=3
|
||||
)
|
||||
assert total_count == 6
|
||||
assert len(snippets) <= 3
|
||||
|
||||
summary_snippets, summary_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "data", max_total=3
|
||||
)
|
||||
assert summary_count == 3
|
||||
|
||||
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "data", max_total=3
|
||||
)
|
||||
assert webvtt_count == 3
|
||||
|
||||
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
||||
None, None, "data", max_total=3
|
||||
)
|
||||
assert snippets_empty == []
|
||||
assert count_empty == 0
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for the pure functions."""
|
||||
text = "Test with special: @#$%^&*() characters"
|
||||
snippets = SnippetGenerator.generate(text, "@#$%")
|
||||
assert len(snippets) > 0
|
||||
|
||||
long_query = "a" * 100
|
||||
snippets = SnippetGenerator.generate("Some text", long_query)
|
||||
assert snippets == []
|
||||
|
||||
text = "Unicode test: café, naïve, 日本語"
|
||||
snippets = SnippetGenerator.generate(text, "café")
|
||||
assert len(snippets) > 0
|
||||
assert "café" in snippets[0]
|
||||
|
||||
873
server/uv.lock
generated
873
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,26 +1,67 @@
|
||||
import React from "react";
|
||||
import React, { useEffect } from "react";
|
||||
import { Pagination, IconButton, ButtonGroup } from "@chakra-ui/react";
|
||||
import { LuChevronLeft, LuChevronRight } from "react-icons/lu";
|
||||
|
||||
// explicitly 1-based to prevent +/-1-confusion errors
|
||||
export const FIRST_PAGE = 1 as PaginationPage;
|
||||
export const parsePaginationPage = (
|
||||
page: number,
|
||||
):
|
||||
| {
|
||||
value: PaginationPage;
|
||||
}
|
||||
| {
|
||||
error: string;
|
||||
} => {
|
||||
if (page < FIRST_PAGE)
|
||||
return {
|
||||
error: "Page must be greater than 0",
|
||||
};
|
||||
if (!Number.isInteger(page))
|
||||
return {
|
||||
error: "Page must be an integer",
|
||||
};
|
||||
return {
|
||||
value: page as PaginationPage,
|
||||
};
|
||||
};
|
||||
export type PaginationPage = number & { __brand: "PaginationPage" };
|
||||
export const PaginationPage = (page: number): PaginationPage => {
|
||||
const v = parsePaginationPage(page);
|
||||
if ("error" in v) throw new Error(v.error);
|
||||
return v.value;
|
||||
};
|
||||
|
||||
export const paginationPageTo0Based = (page: PaginationPage): number =>
|
||||
page - FIRST_PAGE;
|
||||
|
||||
type PaginationProps = {
|
||||
page: number;
|
||||
setPage: (page: number) => void;
|
||||
page: PaginationPage;
|
||||
setPage: (page: PaginationPage) => void;
|
||||
total: number;
|
||||
size: number;
|
||||
};
|
||||
|
||||
export const totalPages = (total: number, size: number) => {
|
||||
return Math.ceil(total / size);
|
||||
};
|
||||
|
||||
export default function PaginationComponent(props: PaginationProps) {
|
||||
const { page, setPage, total, size } = props;
|
||||
const totalPages = Math.ceil(total / size);
|
||||
|
||||
if (totalPages <= 1) return null;
|
||||
useEffect(() => {
|
||||
if (page > totalPages(total, size)) {
|
||||
console.error(
|
||||
`Page number (${page}) is greater than total pages (${totalPages}) in pagination`,
|
||||
);
|
||||
}
|
||||
}, [page, totalPages(total, size)]);
|
||||
|
||||
return (
|
||||
<Pagination.Root
|
||||
count={total}
|
||||
pageSize={size}
|
||||
page={page}
|
||||
onPageChange={(details) => setPage(details.page)}
|
||||
onPageChange={(details) => setPage(PaginationPage(details.page))}
|
||||
style={{ display: "flex", justifyContent: "center" }}
|
||||
>
|
||||
<ButtonGroup variant="ghost" size="xs">
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import React, { useState } from "react";
|
||||
import { Flex, Input, Button } from "@chakra-ui/react";
|
||||
|
||||
interface SearchBarProps {
|
||||
onSearch: (searchTerm: string) => void;
|
||||
}
|
||||
|
||||
export default function SearchBar({ onSearch }: SearchBarProps) {
|
||||
const [searchInputValue, setSearchInputValue] = useState("");
|
||||
|
||||
const handleSearch = () => {
|
||||
onSearch(searchInputValue);
|
||||
};
|
||||
|
||||
const handleKeyDown = (event: React.KeyboardEvent) => {
|
||||
if (event.key === "Enter") {
|
||||
handleSearch();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex alignItems="center">
|
||||
<Input
|
||||
placeholder="Search transcriptions..."
|
||||
value={searchInputValue}
|
||||
onChange={(e) => setSearchInputValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
<Button ml={2} onClick={handleSearch}>
|
||||
Search
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import { LuMenu, LuTrash, LuRotateCw } from "react-icons/lu";
|
||||
|
||||
interface TranscriptActionsMenuProps {
|
||||
transcriptId: string;
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
}
|
||||
|
||||
export default function TranscriptActionsMenu({
|
||||
@@ -24,11 +24,17 @@ export default function TranscriptActionsMenu({
|
||||
<Menu.Content>
|
||||
<Menu.Item
|
||||
value="reprocess"
|
||||
onClick={(e) => onReprocess(transcriptId)(e)}
|
||||
onClick={() => onReprocess(transcriptId)}
|
||||
>
|
||||
<LuRotateCw /> Reprocess
|
||||
</Menu.Item>
|
||||
<Menu.Item value="delete" onClick={(e) => onDelete(transcriptId)(e)}>
|
||||
<Menu.Item
|
||||
value="delete"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDelete(transcriptId);
|
||||
}}
|
||||
>
|
||||
<LuTrash /> Delete
|
||||
</Menu.Item>
|
||||
</Menu.Content>
|
||||
|
||||
@@ -1,27 +1,290 @@
|
||||
import React from "react";
|
||||
import { Box, Stack, Text, Flex, Link, Spinner } from "@chakra-ui/react";
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
Box,
|
||||
Stack,
|
||||
Text,
|
||||
Flex,
|
||||
Link,
|
||||
Spinner,
|
||||
Badge,
|
||||
HStack,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { GetTranscriptMinimal } from "../../../api";
|
||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
||||
import {
|
||||
highlightMatches,
|
||||
generateTextFragment,
|
||||
} from "../../../lib/textHighlight";
|
||||
import { SearchResult } from "../../../api";
|
||||
|
||||
interface TranscriptCardsProps {
|
||||
transcripts: GetTranscriptMinimal[];
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
loading?: boolean;
|
||||
results: SearchResult[];
|
||||
query: string;
|
||||
isLoading?: boolean;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
}
|
||||
|
||||
function highlightText(text: string, query: string): React.ReactNode {
|
||||
if (!query) return text;
|
||||
|
||||
const matches = highlightMatches(text, query);
|
||||
|
||||
if (matches.length === 0) return text;
|
||||
|
||||
// Sort matches by index to process them in order
|
||||
const sortedMatches = [...matches].sort((a, b) => a.index - b.index);
|
||||
|
||||
const parts: React.ReactNode[] = [];
|
||||
let lastIndex = 0;
|
||||
|
||||
sortedMatches.forEach((match, i) => {
|
||||
// Add text before the match
|
||||
if (match.index > lastIndex) {
|
||||
parts.push(
|
||||
<Text as="span" key={`text-${i}`} display="inline">
|
||||
{text.slice(lastIndex, match.index)}
|
||||
</Text>,
|
||||
);
|
||||
}
|
||||
|
||||
// Add the highlighted match
|
||||
parts.push(
|
||||
<Text
|
||||
as="mark"
|
||||
key={`match-${i}`}
|
||||
bg="yellow.200"
|
||||
px={0.5}
|
||||
display="inline"
|
||||
>
|
||||
{match.match}
|
||||
</Text>,
|
||||
);
|
||||
|
||||
lastIndex = match.index + match.match.length;
|
||||
});
|
||||
|
||||
// Add remaining text after last match
|
||||
if (lastIndex < text.length) {
|
||||
parts.push(
|
||||
<Text as="span" key={`text-end`} display="inline">
|
||||
{text.slice(lastIndex)}
|
||||
</Text>,
|
||||
);
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
const transcriptHref = (
|
||||
transcriptId: string,
|
||||
mainSnippet: string,
|
||||
query: string,
|
||||
): `/transcripts/${string}` => {
|
||||
const urlTextFragment = mainSnippet
|
||||
? generateTextFragment(mainSnippet, query)
|
||||
: null;
|
||||
const urlTextFragmentWithHash = urlTextFragment
|
||||
? `#${urlTextFragment.k}=${encodeURIComponent(urlTextFragment.v)}`
|
||||
: "";
|
||||
return `/transcripts/${transcriptId}${urlTextFragmentWithHash}`;
|
||||
};
|
||||
|
||||
// note that it's strongly tied to search logic - in case you want to use it independently, refactor
|
||||
function TranscriptCard({
|
||||
result,
|
||||
query,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
}: {
|
||||
result: SearchResult;
|
||||
query: string;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
}) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
|
||||
const mainSnippet = result.search_snippets[0];
|
||||
const additionalSnippets = result.search_snippets.slice(1);
|
||||
const totalMatches = result.total_match_count || 0;
|
||||
const snippetsShown = result.search_snippets.length;
|
||||
const remainingMatches = totalMatches - snippetsShown;
|
||||
const hasAdditionalSnippets = additionalSnippets.length > 0;
|
||||
const resultTitle = result.title || "Unnamed Transcript";
|
||||
|
||||
const formattedDuration = result.duration
|
||||
? formatTimeMs(result.duration)
|
||||
: "N/A";
|
||||
const formattedDate = formatLocalDate(result.created_at);
|
||||
const source =
|
||||
result.source_kind === "room"
|
||||
? result.room_name || result.room_id
|
||||
: result.source_kind;
|
||||
|
||||
const handleExpandClick = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsExpanded(!isExpanded);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
|
||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||
<Box>
|
||||
<TranscriptStatusIcon status={result.status} />
|
||||
</Box>
|
||||
<Box flex="1">
|
||||
{/* Title with highlighting and text fragment for deep linking */}
|
||||
<Link
|
||||
as={NextLink}
|
||||
href={transcriptHref(result.id, mainSnippet, query)}
|
||||
fontWeight="600"
|
||||
display="block"
|
||||
mb={2}
|
||||
>
|
||||
{highlightText(resultTitle, query)}
|
||||
</Link>
|
||||
|
||||
{/* Metadata - Horizontal on desktop, vertical on mobile */}
|
||||
<Flex
|
||||
direction={{ base: "column", md: "row" }}
|
||||
gap={{ base: 1, md: 2 }}
|
||||
fontSize="xs"
|
||||
color="gray.600"
|
||||
flexWrap="wrap"
|
||||
align={{ base: "flex-start", md: "center" }}
|
||||
>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Source:
|
||||
</Text>
|
||||
<Text>{source}</Text>
|
||||
</Flex>
|
||||
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||
•
|
||||
</Text>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Date:
|
||||
</Text>
|
||||
<Text>{formattedDate}</Text>
|
||||
</Flex>
|
||||
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||
•
|
||||
</Text>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Duration:
|
||||
</Text>
|
||||
<Text>{formattedDuration}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
{/* Search Results Section - only show when searching */}
|
||||
{mainSnippet && (
|
||||
<>
|
||||
{/* Main Snippet */}
|
||||
<Box
|
||||
mt={3}
|
||||
p={2}
|
||||
bg="gray.50"
|
||||
borderLeft="2px solid"
|
||||
borderLeftColor="blue.400"
|
||||
borderRadius="sm"
|
||||
fontSize="xs"
|
||||
>
|
||||
<Text color="gray.700">
|
||||
{highlightText(mainSnippet, query)}
|
||||
</Text>
|
||||
</Box>
|
||||
|
||||
{hasAdditionalSnippets && (
|
||||
<>
|
||||
<Flex
|
||||
mt={2}
|
||||
p={2}
|
||||
bg="blue.50"
|
||||
borderRadius="sm"
|
||||
cursor="pointer"
|
||||
onClick={handleExpandClick}
|
||||
_hover={{ bg: "blue.100" }}
|
||||
align="center"
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack gap={2}>
|
||||
<Badge
|
||||
bg="blue.500"
|
||||
color="white"
|
||||
fontSize="xs"
|
||||
px={2}
|
||||
borderRadius="full"
|
||||
>
|
||||
{remainingMatches > 0
|
||||
? `${additionalSnippets.length + remainingMatches}+`
|
||||
: additionalSnippets.length}
|
||||
</Badge>
|
||||
<Text fontSize="xs" color="blue.600" fontWeight="medium">
|
||||
more{" "}
|
||||
{additionalSnippets.length + remainingMatches === 1
|
||||
? "match"
|
||||
: "matches"}
|
||||
{remainingMatches > 0 &&
|
||||
` (${additionalSnippets.length} shown)`}
|
||||
</Text>
|
||||
</HStack>
|
||||
<Text fontSize="xs" color="blue.600">
|
||||
{isExpanded ? "▲" : "▼"}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
{/* Additional Snippets */}
|
||||
{isExpanded && (
|
||||
<VStack align="stretch" gap={2} mt={2}>
|
||||
{additionalSnippets.map((snippet, index) => (
|
||||
<Box
|
||||
key={index}
|
||||
p={2}
|
||||
bg="gray.50"
|
||||
borderLeft="2px solid"
|
||||
borderLeftColor="gray.300"
|
||||
borderRadius="sm"
|
||||
fontSize="xs"
|
||||
>
|
||||
<Text color="gray.700">
|
||||
{highlightText(snippet, query)}
|
||||
</Text>
|
||||
</Box>
|
||||
))}
|
||||
</VStack>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Box>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={result.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
export default function TranscriptCards({
|
||||
transcripts,
|
||||
results,
|
||||
query,
|
||||
isLoading,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
loading,
|
||||
}: TranscriptCardsProps) {
|
||||
return (
|
||||
<Box display={{ base: "block", lg: "none" }} position="relative">
|
||||
{loading && (
|
||||
<Box position="relative">
|
||||
{isLoading && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
@@ -37,48 +300,19 @@ export default function TranscriptCards({
|
||||
</Flex>
|
||||
)}
|
||||
<Box
|
||||
opacity={loading ? 0.9 : 1}
|
||||
pointerEvents={loading ? "none" : "auto"}
|
||||
opacity={isLoading ? 0.9 : 1}
|
||||
pointerEvents={isLoading ? "none" : "auto"}
|
||||
transition="opacity 0.2s ease-in-out"
|
||||
>
|
||||
<Stack gap={2}>
|
||||
{transcripts.map((item) => (
|
||||
<Box
|
||||
key={item.id}
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
borderRadius="md"
|
||||
fontSize="sm"
|
||||
>
|
||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||
<Box>
|
||||
<TranscriptStatusIcon status={item.status} />
|
||||
</Box>
|
||||
<Box flex="1">
|
||||
<Link
|
||||
as={NextLink}
|
||||
href={`/transcripts/${item.id}`}
|
||||
fontWeight="600"
|
||||
display="block"
|
||||
>
|
||||
{item.title || "Unnamed Transcript"}
|
||||
</Link>
|
||||
<Text>
|
||||
Source:{" "}
|
||||
{item.source_kind === "room"
|
||||
? item.room_name
|
||||
: item.source_kind}
|
||||
</Text>
|
||||
<Text>Date: {formatLocalDate(item.created_at)}</Text>
|
||||
<Text>Duration: {formatTimeMs(item.duration)}</Text>
|
||||
</Box>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={item.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Flex>
|
||||
</Box>
|
||||
<Stack gap={3}>
|
||||
{results.map((result) => (
|
||||
<TranscriptCard
|
||||
key={result.id}
|
||||
result={result}
|
||||
query={query}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import React from "react";
|
||||
import { Box, Table, Link, Flex, Spinner } from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { GetTranscriptMinimal } from "../../../api";
|
||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
||||
|
||||
interface TranscriptTableProps {
|
||||
transcripts: GetTranscriptMinimal[];
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
loading?: boolean;
|
||||
}
|
||||
|
||||
export default function TranscriptTable({
|
||||
transcripts,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
loading,
|
||||
}: TranscriptTableProps) {
|
||||
return (
|
||||
<Box display={{ base: "none", lg: "block" }} position="relative">
|
||||
{loading && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
align="center"
|
||||
justify="center"
|
||||
>
|
||||
<Spinner size="xl" color="gray.700" />
|
||||
</Flex>
|
||||
)}
|
||||
<Box
|
||||
opacity={loading ? 0.9 : 1}
|
||||
pointerEvents={loading ? "none" : "auto"}
|
||||
transition="opacity 0.2s ease-in-out"
|
||||
>
|
||||
<Table.Root>
|
||||
<Table.Header>
|
||||
<Table.Row>
|
||||
<Table.ColumnHeader
|
||||
width="16px"
|
||||
fontWeight="600"
|
||||
></Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="400px" fontWeight="600">
|
||||
Transcription Title
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="150px" fontWeight="600">
|
||||
Source
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="200px" fontWeight="600">
|
||||
Date
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="100px" fontWeight="600">
|
||||
Duration
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader
|
||||
width="50px"
|
||||
fontWeight="600"
|
||||
></Table.ColumnHeader>
|
||||
</Table.Row>
|
||||
</Table.Header>
|
||||
<Table.Body>
|
||||
{transcripts.map((item) => (
|
||||
<Table.Row key={item.id}>
|
||||
<Table.Cell>
|
||||
<TranscriptStatusIcon status={item.status} />
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Link as={NextLink} href={`/transcripts/${item.id}`}>
|
||||
{item.title || "Unnamed Transcript"}
|
||||
</Link>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
{item.source_kind === "room"
|
||||
? item.room_name
|
||||
: item.source_kind}
|
||||
</Table.Cell>
|
||||
<Table.Cell>{formatLocalDate(item.created_at)}</Table.Cell>
|
||||
<Table.Cell>{formatTimeMs(item.duration)}</Table.Cell>
|
||||
<Table.Cell>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={item.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Table.Cell>
|
||||
</Table.Row>
|
||||
))}
|
||||
</Table.Body>
|
||||
</Table.Root>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -1,33 +1,264 @@
|
||||
"use client";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Flex, Spinner, Heading, Text, Link } from "@chakra-ui/react";
|
||||
import useTranscriptList from "../transcripts/useTranscriptList";
|
||||
import {
|
||||
Flex,
|
||||
Spinner,
|
||||
Heading,
|
||||
Text,
|
||||
Link,
|
||||
Box,
|
||||
Stack,
|
||||
Input,
|
||||
Button,
|
||||
IconButton,
|
||||
} from "@chakra-ui/react";
|
||||
import {
|
||||
useQueryState,
|
||||
parseAsString,
|
||||
parseAsInteger,
|
||||
parseAsStringLiteral,
|
||||
} from "nuqs";
|
||||
import { LuX } from "react-icons/lu";
|
||||
import { useSearchTranscripts } from "../transcripts/useSearchTranscripts";
|
||||
import useSessionUser from "../../lib/useSessionUser";
|
||||
import { Room } from "../../api";
|
||||
import Pagination from "./_components/Pagination";
|
||||
import { Room, SourceKind, SearchResult, $SourceKind } from "../../api";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import { SourceKind } from "../../api";
|
||||
import FilterSidebar from "./_components/FilterSidebar";
|
||||
import SearchBar from "./_components/SearchBar";
|
||||
import TranscriptTable from "./_components/TranscriptTable";
|
||||
import Pagination, {
|
||||
FIRST_PAGE,
|
||||
PaginationPage,
|
||||
parsePaginationPage,
|
||||
totalPages as getTotalPages,
|
||||
} from "./_components/Pagination";
|
||||
import TranscriptCards from "./_components/TranscriptCards";
|
||||
import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
|
||||
import { formatLocalDate } from "../../lib/time";
|
||||
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||
|
||||
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
|
||||
|
||||
const usePrefetchRooms = (setRooms: (rooms: Room[]) => void): void => {
|
||||
const { setError } = useError();
|
||||
const api = useApi();
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
api
|
||||
.v1RoomsList({ page: 1 })
|
||||
.then((rooms) => setRooms(rooms.items))
|
||||
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
||||
}, [api, setError]);
|
||||
};
|
||||
|
||||
const SearchForm: React.FC<{
|
||||
setPage: (page: PaginationPage) => void;
|
||||
sourceKind: SourceKind | null;
|
||||
roomId: string | null;
|
||||
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||
setRoomId: (roomId: string | null) => void;
|
||||
rooms: Room[];
|
||||
searchQuery: string | null;
|
||||
setSearchQuery: (query: string | null) => void;
|
||||
}> = ({
|
||||
setPage,
|
||||
sourceKind,
|
||||
roomId,
|
||||
setRoomId,
|
||||
setSourceKind,
|
||||
rooms,
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
}) => {
|
||||
// to keep the search input controllable + more fine grained control (urlSearchQuery is updated on submits)
|
||||
const [searchInputValue, setSearchInputValue] = useState(searchQuery || "");
|
||||
const handleSearchQuerySubmit = async (d: FormData) => {
|
||||
await setSearchQuery((d.get(SEARCH_FORM_QUERY_INPUT_NAME) as string) || "");
|
||||
};
|
||||
|
||||
const handleClearSearch = () => {
|
||||
setSearchInputValue("");
|
||||
setSearchQuery(null);
|
||||
setPage(FIRST_PAGE);
|
||||
};
|
||||
return (
|
||||
<Stack gap={2}>
|
||||
<form action={handleSearchQuerySubmit}>
|
||||
<Flex alignItems="center">
|
||||
<Box position="relative" flex="1">
|
||||
<Input
|
||||
placeholder="Search transcriptions..."
|
||||
value={searchInputValue}
|
||||
onChange={(e) => setSearchInputValue(e.target.value)}
|
||||
name={SEARCH_FORM_QUERY_INPUT_NAME}
|
||||
pr={searchQuery ? "2.5rem" : undefined}
|
||||
/>
|
||||
{searchQuery && (
|
||||
<IconButton
|
||||
aria-label="Clear search"
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={handleClearSearch}
|
||||
position="absolute"
|
||||
right="0.25rem"
|
||||
top="50%"
|
||||
transform="translateY(-50%)"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
>
|
||||
<LuX />
|
||||
</IconButton>
|
||||
)}
|
||||
</Box>
|
||||
<Button ml={2} type="submit">
|
||||
Search
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
<UnderSearchFormFilterIndicators
|
||||
sourceKind={sourceKind}
|
||||
roomId={roomId}
|
||||
setSourceKind={setSourceKind}
|
||||
setRoomId={setRoomId}
|
||||
rooms={rooms}
|
||||
/>
|
||||
</Stack>
|
||||
);
|
||||
};
|
||||
|
||||
const UnderSearchFormFilterIndicators: React.FC<{
|
||||
sourceKind: SourceKind | null;
|
||||
roomId: string | null;
|
||||
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||
setRoomId: (roomId: string | null) => void;
|
||||
rooms: Room[];
|
||||
}> = ({ sourceKind, roomId, setRoomId, setSourceKind, rooms }) => {
|
||||
return (
|
||||
<>
|
||||
{(sourceKind || roomId) && (
|
||||
<Flex gap={2} flexWrap="wrap" align="center">
|
||||
<Text fontSize="sm" color="gray.600">
|
||||
Active filters:
|
||||
</Text>
|
||||
{sourceKind && (
|
||||
<Flex
|
||||
align="center"
|
||||
px={2}
|
||||
py={1}
|
||||
bg="blue.100"
|
||||
borderRadius="md"
|
||||
fontSize="xs"
|
||||
gap={1}
|
||||
>
|
||||
<Text>
|
||||
{roomId
|
||||
? `Room: ${
|
||||
rooms.find((r) => r.id === roomId)?.name || roomId
|
||||
}`
|
||||
: `Source: ${sourceKind}`}
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
minW="auto"
|
||||
h="auto"
|
||||
p="1px"
|
||||
onClick={() => {
|
||||
setSourceKind(null);
|
||||
// TODO questionable
|
||||
setRoomId(null);
|
||||
}}
|
||||
_hover={{ bg: "blue.200" }}
|
||||
aria-label="Clear filter"
|
||||
>
|
||||
<LuX size={14} />
|
||||
</Button>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const EmptyResult: React.FC<{
|
||||
searchQuery: string;
|
||||
}> = ({ searchQuery }) => {
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" justifyContent="center" py={8}>
|
||||
<Text textAlign="center">
|
||||
{searchQuery
|
||||
? `No results found for "${searchQuery}". Try adjusting your search terms.`
|
||||
: "No transcripts found, but you can "}
|
||||
{!searchQuery && (
|
||||
<>
|
||||
<Link href={RECORD_A_MEETING_URL} color="blue.500">
|
||||
record a meeting
|
||||
</Link>
|
||||
{" to get started."}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default function TranscriptBrowser() {
|
||||
const [selectedSourceKind, setSelectedSourceKind] =
|
||||
useState<SourceKind | null>(null);
|
||||
const [selectedRoomId, setSelectedRoomId] = useState("");
|
||||
const [rooms, setRooms] = useState<Room[]>([]);
|
||||
const [page, setPage] = useState(1);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const { loading, response, refetch } = useTranscriptList(
|
||||
page,
|
||||
selectedSourceKind,
|
||||
selectedRoomId,
|
||||
searchTerm,
|
||||
const [urlSearchQuery, setUrlSearchQuery] = useQueryState(
|
||||
"q",
|
||||
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [urlSourceKind, setUrlSourceKind] = useQueryState(
|
||||
"source",
|
||||
parseAsStringLiteral($SourceKind.enum).withOptions({
|
||||
shallow: false,
|
||||
}),
|
||||
);
|
||||
const [urlRoomId, setUrlRoomId] = useQueryState(
|
||||
"room",
|
||||
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [urlPage, setPage] = useQueryState(
|
||||
"page",
|
||||
parseAsInteger.withDefault(1).withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [page, _setSafePage] = useState(FIRST_PAGE);
|
||||
|
||||
// safety net
|
||||
useEffect(() => {
|
||||
const maybePage = parsePaginationPage(urlPage);
|
||||
if ("error" in maybePage) {
|
||||
setPage(FIRST_PAGE).then(() => {
|
||||
/*may be called n times we dont care*/
|
||||
});
|
||||
return;
|
||||
}
|
||||
_setSafePage(maybePage.value);
|
||||
}, [urlPage]);
|
||||
|
||||
const [rooms, setRooms] = useState<Room[]>([]);
|
||||
|
||||
const pageSize = 20;
|
||||
const {
|
||||
results,
|
||||
totalCount: totalResults,
|
||||
isLoading,
|
||||
reload,
|
||||
} = useSearchTranscripts(
|
||||
urlSearchQuery,
|
||||
{
|
||||
roomIds: urlRoomId ? [urlRoomId] : null,
|
||||
sourceKind: urlSourceKind,
|
||||
},
|
||||
{
|
||||
pageSize,
|
||||
page,
|
||||
},
|
||||
);
|
||||
|
||||
const totalPages = getTotalPages(totalResults, pageSize);
|
||||
|
||||
const userName = useSessionUser().name;
|
||||
const [deletionLoading, setDeletionLoading] = useState(false);
|
||||
const api = useApi();
|
||||
@@ -35,37 +266,73 @@ export default function TranscriptBrowser() {
|
||||
const cancelRef = React.useRef(null);
|
||||
const [transcriptToDeleteId, setTranscriptToDeleteId] =
|
||||
React.useState<string>();
|
||||
const [deletedItemIds, setDeletedItemIds] = React.useState<string[]>();
|
||||
|
||||
useEffect(() => {
|
||||
setDeletedItemIds([]);
|
||||
}, [page, response]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
api
|
||||
.v1RoomsList({ page: 1 })
|
||||
.then((rooms) => setRooms(rooms.items))
|
||||
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
||||
}, [api]);
|
||||
usePrefetchRooms(setRooms);
|
||||
|
||||
const handleFilterTranscripts = (
|
||||
sourceKind: SourceKind | null,
|
||||
roomId: string,
|
||||
) => {
|
||||
setSelectedSourceKind(sourceKind);
|
||||
setSelectedRoomId(roomId);
|
||||
setUrlSourceKind(sourceKind);
|
||||
setUrlRoomId(roomId);
|
||||
setPage(1);
|
||||
};
|
||||
|
||||
const handleSearch = (searchTerm: string) => {
|
||||
setPage(1);
|
||||
setSearchTerm(searchTerm);
|
||||
setSelectedSourceKind(null);
|
||||
setSelectedRoomId("");
|
||||
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
||||
|
||||
const confirmDeleteTranscript = (transcriptId: string) => {
|
||||
if (!api || deletionLoading) return;
|
||||
setDeletionLoading(true);
|
||||
api
|
||||
.v1TranscriptDelete({ transcriptId })
|
||||
.then(() => {
|
||||
setDeletionLoading(false);
|
||||
onCloseDeletion();
|
||||
reload();
|
||||
})
|
||||
.catch((err) => {
|
||||
setDeletionLoading(false);
|
||||
setError(err, "There was an error deleting the transcript");
|
||||
});
|
||||
};
|
||||
|
||||
if (loading && !response)
|
||||
const handleProcessTranscript = (transcriptId: string) => {
|
||||
if (!api) {
|
||||
console.error("API not available on handleProcessTranscript");
|
||||
return;
|
||||
}
|
||||
api
|
||||
.v1TranscriptProcess({ transcriptId })
|
||||
.then((result) => {
|
||||
const status =
|
||||
result && typeof result === "object" && "status" in result
|
||||
? (result as { status: string }).status
|
||||
: undefined;
|
||||
if (status === "already running") {
|
||||
setError(
|
||||
new Error("Processing is already running, please wait"),
|
||||
"Processing is already running, please wait",
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err, "There was an error processing the transcript");
|
||||
});
|
||||
};
|
||||
|
||||
const transcriptToDelete = results?.find(
|
||||
(i) => i.id === transcriptToDeleteId,
|
||||
);
|
||||
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
||||
const dialogDate = transcriptToDelete?.created_at
|
||||
? formatLocalDate(transcriptToDelete.created_at)
|
||||
: undefined;
|
||||
const dialogSource =
|
||||
transcriptToDelete?.source_kind === "room" && transcriptToDelete?.room_id
|
||||
? transcriptToDelete.room_name || transcriptToDelete.room_id
|
||||
: transcriptToDelete?.source_kind;
|
||||
|
||||
if (isLoading && results.length === 0) {
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
@@ -76,82 +343,7 @@ export default function TranscriptBrowser() {
|
||||
<Spinner size="xl" />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
if (!loading && !response)
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="100%"
|
||||
>
|
||||
<Text>
|
||||
No transcripts found, but you can
|
||||
<Link href="/transcripts/new" className="underline">
|
||||
record a meeting
|
||||
</Link>
|
||||
to get started.
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
||||
|
||||
const confirmDeleteTranscript = (transcriptId: string) => {
|
||||
if (!api || deletionLoading) return;
|
||||
setDeletionLoading(true);
|
||||
api
|
||||
.v1TranscriptDelete({ transcriptId })
|
||||
.then(() => {
|
||||
refetch();
|
||||
setDeletionLoading(false);
|
||||
onCloseDeletion();
|
||||
setDeletedItemIds((prev) =>
|
||||
prev ? [...prev, transcriptId] : [transcriptId],
|
||||
);
|
||||
})
|
||||
.catch((err) => {
|
||||
setDeletionLoading(false);
|
||||
setError(err, "There was an error deleting the transcript");
|
||||
});
|
||||
};
|
||||
|
||||
const handleDeleteTranscript = (transcriptId: string) => (e: any) => {
|
||||
e?.stopPropagation?.();
|
||||
setTranscriptToDeleteId(transcriptId);
|
||||
};
|
||||
|
||||
const handleProcessTranscript = (transcriptId) => (e) => {
|
||||
if (api) {
|
||||
api
|
||||
.v1TranscriptProcess({ transcriptId })
|
||||
.then((result) => {
|
||||
const status = (result as any).status;
|
||||
if (status === "already running") {
|
||||
setError(
|
||||
new Error("Processing is already running, please wait"),
|
||||
"Processing is already running, please wait",
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err, "There was an error processing the transcript");
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const transcriptToDelete = response?.items?.find(
|
||||
(i) => i.id === transcriptToDeleteId,
|
||||
);
|
||||
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
||||
const dialogDate = transcriptToDelete?.created_at
|
||||
? formatLocalDate(transcriptToDelete.created_at)
|
||||
: undefined;
|
||||
const dialogSource = transcriptToDelete
|
||||
? transcriptToDelete.source_kind === "room"
|
||||
? transcriptToDelete.room_name || undefined
|
||||
: transcriptToDelete.source_kind
|
||||
: undefined;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -168,15 +360,15 @@ export default function TranscriptBrowser() {
|
||||
>
|
||||
<Heading size="lg">
|
||||
{userName ? `${userName}'s Transcriptions` : "Your Transcriptions"}{" "}
|
||||
{loading || (deletionLoading && <Spinner size="sm" />)}
|
||||
{(isLoading || deletionLoading) && <Spinner size="sm" />}
|
||||
</Heading>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir={{ base: "column", md: "row" }}>
|
||||
<FilterSidebar
|
||||
rooms={rooms}
|
||||
selectedSourceKind={selectedSourceKind}
|
||||
selectedRoomId={selectedRoomId}
|
||||
selectedSourceKind={urlSourceKind}
|
||||
selectedRoomId={urlRoomId}
|
||||
onFilterChange={handleFilterTranscripts}
|
||||
/>
|
||||
|
||||
@@ -188,25 +380,37 @@ export default function TranscriptBrowser() {
|
||||
gap={4}
|
||||
px={{ base: 0, md: 4 }}
|
||||
>
|
||||
<SearchBar onSearch={handleSearch} />
|
||||
<Pagination
|
||||
page={page}
|
||||
<SearchForm
|
||||
setPage={setPage}
|
||||
total={response?.total || 0}
|
||||
size={response?.size || 0}
|
||||
/>
|
||||
<TranscriptTable
|
||||
transcripts={response?.items || []}
|
||||
onDelete={handleDeleteTranscript}
|
||||
onReprocess={handleProcessTranscript}
|
||||
loading={loading}
|
||||
sourceKind={urlSourceKind}
|
||||
roomId={urlRoomId}
|
||||
searchQuery={urlSearchQuery}
|
||||
setSearchQuery={setUrlSearchQuery}
|
||||
setSourceKind={setUrlSourceKind}
|
||||
setRoomId={setUrlRoomId}
|
||||
rooms={rooms}
|
||||
/>
|
||||
|
||||
{totalPages > 1 ? (
|
||||
<Pagination
|
||||
page={page}
|
||||
setPage={setPage}
|
||||
total={totalResults}
|
||||
size={pageSize}
|
||||
/>
|
||||
) : null}
|
||||
|
||||
<TranscriptCards
|
||||
transcripts={response?.items || []}
|
||||
onDelete={handleDeleteTranscript}
|
||||
results={results}
|
||||
query={urlSearchQuery}
|
||||
isLoading={isLoading}
|
||||
onDelete={setTranscriptToDeleteId}
|
||||
onReprocess={handleProcessTranscript}
|
||||
loading={loading}
|
||||
/>
|
||||
|
||||
{!isLoading && results.length === 0 && (
|
||||
<EmptyResult searchQuery={urlSearchQuery} />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import Image from "next/image";
|
||||
import About from "../(aboutAndPrivacy)/about";
|
||||
import Privacy from "../(aboutAndPrivacy)/privacy";
|
||||
import UserInfo from "../(auth)/userInfo";
|
||||
import { RECORD_A_MEETING_URL } from "../api/urls";
|
||||
|
||||
export default async function AppLayout({
|
||||
children,
|
||||
@@ -53,7 +54,7 @@ export default async function AppLayout({
|
||||
{/* Text link on the right */}
|
||||
<Link
|
||||
as={NextLink}
|
||||
href="/transcripts/new"
|
||||
href={RECORD_A_MEETING_URL}
|
||||
className="font-light px-2"
|
||||
>
|
||||
Create
|
||||
|
||||
@@ -19,6 +19,7 @@ import useApi from "../../lib/useApi";
|
||||
import useRoomList from "./useRoomList";
|
||||
import { ApiError, Room } from "../../api";
|
||||
import { RoomList } from "./_components/RoomList";
|
||||
import { PaginationPage } from "../browse/_components/Pagination";
|
||||
|
||||
interface SelectOption {
|
||||
label: string;
|
||||
@@ -75,8 +76,9 @@ export default function RoomsList() {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [editRoomId, setEditRoomId] = useState("");
|
||||
const api = useApi();
|
||||
// TODO seems to be no setPage calls
|
||||
const [page, setPage] = useState<number>(1);
|
||||
const { loading, response, refetch } = useRoomList(page);
|
||||
const { loading, response, refetch } = useRoomList(PaginationPage(page));
|
||||
const [streams, setStreams] = useState<Stream[]>([]);
|
||||
const [topics, setTopics] = useState<Topic[]>([]);
|
||||
const [nameError, setNameError] = useState("");
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useEffect, useState } from "react";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { Page_Room_ } from "../../api";
|
||||
import { PaginationPage } from "../browse/_components/Pagination";
|
||||
|
||||
type RoomList = {
|
||||
response: Page_Room_ | null;
|
||||
@@ -11,7 +12,7 @@ type RoomList = {
|
||||
};
|
||||
|
||||
//always protected
|
||||
const useRoomList = (page: number): RoomList => {
|
||||
const useRoomList = (page: PaginationPage): RoomList => {
|
||||
const [response, setResponse] = useState<Page_Room_ | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setErrorState] = useState<Error | null>(null);
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
[
|
||||
{
|
||||
"id": "27c07e49-d7a3-4b86-905c-f1a047366f91",
|
||||
"title": "Issue one",
|
||||
"summary": "The team discusses the first issue in the list",
|
||||
"timestamp": 0.0,
|
||||
"transcript": "",
|
||||
"duration": 33,
|
||||
"segments": [
|
||||
{
|
||||
"text": "Let's start with issue one, Alice you've been working on that, can you give an update ?",
|
||||
"start": 0.0,
|
||||
"speaker": 0
|
||||
},
|
||||
{
|
||||
"text": "Yes, I've run into an issue with the task system but Bob helped me out and I have a POC ready, should I present it now ?",
|
||||
"start": 0.38,
|
||||
"speaker": 1
|
||||
},
|
||||
{
|
||||
"text": "Yeah, I had to modify the task system because it didn't account for incoming blobs",
|
||||
"start": 4.5,
|
||||
"speaker": 2
|
||||
},
|
||||
{
|
||||
"text": "Cool, yeah lets see it",
|
||||
"start": 5.96,
|
||||
"speaker": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -11,6 +11,7 @@ import useWebRTC from "./useWebRTC";
|
||||
import useAudioDevice from "./useAudioDevice";
|
||||
import { Box, Flex, IconButton, Menu, RadioGroup } from "@chakra-ui/react";
|
||||
import { LuScreenShare, LuMic, LuPlay, LuCircleStop } from "react-icons/lu";
|
||||
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||
|
||||
type RecorderProps = {
|
||||
transcriptId: string;
|
||||
@@ -46,7 +47,7 @@ export default function Recorder(props: RecorderProps) {
|
||||
location.href = "";
|
||||
break;
|
||||
case ",":
|
||||
location.href = "/transcripts/new";
|
||||
location.href = RECORD_A_MEETING_URL;
|
||||
break;
|
||||
case "!":
|
||||
if (record.isRecording()) return;
|
||||
|
||||
123
www/app/(app)/transcripts/useSearchTranscripts.ts
Normal file
123
www/app/(app)/transcripts/useSearchTranscripts.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
// this hook is not great, we want to substitute it with a proper state management solution that is also not re-invention
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { SearchResult, SourceKind } from "../../api";
|
||||
import useApi from "../../lib/useApi";
|
||||
import {
|
||||
PaginationPage,
|
||||
paginationPageTo0Based,
|
||||
} from "../browse/_components/Pagination";
|
||||
|
||||
interface SearchFilters {
|
||||
roomIds: readonly string[] | null;
|
||||
sourceKind: SourceKind | null;
|
||||
}
|
||||
|
||||
const EMPTY_SEARCH_FILTERS: SearchFilters = {
|
||||
roomIds: null,
|
||||
sourceKind: null,
|
||||
};
|
||||
|
||||
type UseSearchTranscriptsOptions = {
|
||||
pageSize: number;
|
||||
page: PaginationPage;
|
||||
};
|
||||
|
||||
interface UseSearchTranscriptsReturn {
|
||||
results: SearchResult[];
|
||||
totalCount: number;
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
reload: () => void;
|
||||
}
|
||||
|
||||
function hashEffectFilters(filters: SearchFilters): string {
|
||||
return JSON.stringify(filters);
|
||||
}
|
||||
|
||||
export function useSearchTranscripts(
|
||||
query: string = "",
|
||||
filters: SearchFilters = EMPTY_SEARCH_FILTERS,
|
||||
options: UseSearchTranscriptsOptions = {
|
||||
pageSize: 20,
|
||||
page: PaginationPage(1),
|
||||
},
|
||||
): UseSearchTranscriptsReturn {
|
||||
const { pageSize, page } = options;
|
||||
|
||||
const [reloadCount, setReloadCount] = useState(0);
|
||||
|
||||
const api = useApi();
|
||||
const abortControllerRef = useRef<AbortController>();
|
||||
|
||||
const [data, setData] = useState<{ results: SearchResult[]; total: number }>({
|
||||
results: [],
|
||||
total: 0,
|
||||
});
|
||||
const [error, setError] = useState<any>();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const filterHash = hashEffectFilters(filters);
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) {
|
||||
setData({ results: [], total: 0 });
|
||||
setError(undefined);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
const performSearch = async () => {
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const response = await api.v1TranscriptsSearch({
|
||||
q: query || "",
|
||||
limit: pageSize,
|
||||
offset: paginationPageTo0Based(page) * pageSize,
|
||||
roomId: filters.roomIds?.[0],
|
||||
sourceKind: filters.sourceKind || undefined,
|
||||
});
|
||||
|
||||
if (abortController.signal.aborted) return;
|
||||
setData(response);
|
||||
setError(undefined);
|
||||
} catch (err: unknown) {
|
||||
if ((err as Error).name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (abortController.signal.aborted) {
|
||||
console.error("Aborted search but error", err);
|
||||
return;
|
||||
}
|
||||
|
||||
setError(err);
|
||||
} finally {
|
||||
if (!abortController.signal.aborted) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
performSearch().then(() => {});
|
||||
|
||||
return () => {
|
||||
abortController.abort();
|
||||
};
|
||||
}, [api, query, page, filterHash, pageSize, reloadCount]);
|
||||
|
||||
return {
|
||||
results: data.results,
|
||||
totalCount: data.total,
|
||||
isLoading,
|
||||
error,
|
||||
reload: () => setReloadCount(reloadCount + 1),
|
||||
};
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { Page_GetTranscriptMinimal_, SourceKind } from "../../api";
|
||||
|
||||
type TranscriptList = {
|
||||
response: Page_GetTranscriptMinimal_ | null;
|
||||
loading: boolean;
|
||||
error: Error | null;
|
||||
refetch: () => void;
|
||||
};
|
||||
|
||||
const useTranscriptList = (
|
||||
page: number,
|
||||
sourceKind: SourceKind | null,
|
||||
roomId: string | null,
|
||||
searchTerm: string | null,
|
||||
): TranscriptList => {
|
||||
const [response, setResponse] = useState<Page_GetTranscriptMinimal_ | null>(
|
||||
null,
|
||||
);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setErrorState] = useState<Error | null>(null);
|
||||
const { setError } = useError();
|
||||
const api = useApi();
|
||||
const [refetchCount, setRefetchCount] = useState(0);
|
||||
|
||||
const refetch = () => {
|
||||
setLoading(true);
|
||||
setRefetchCount(refetchCount + 1);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
setLoading(true);
|
||||
api
|
||||
.v1TranscriptsList({
|
||||
page,
|
||||
sourceKind,
|
||||
roomId,
|
||||
searchTerm,
|
||||
size: 10,
|
||||
})
|
||||
.then((response) => {
|
||||
setResponse(response);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((err) => {
|
||||
setResponse(null);
|
||||
setLoading(false);
|
||||
setError(err);
|
||||
setErrorState(err);
|
||||
});
|
||||
}, [api, page, refetchCount, roomId, searchTerm, sourceKind]);
|
||||
|
||||
return { response, loading, error, refetch };
|
||||
};
|
||||
|
||||
export default useTranscriptList;
|
||||
@@ -1002,7 +1002,7 @@ export const $SearchResponse = {
|
||||
},
|
||||
query: {
|
||||
type: "string",
|
||||
minLength: 1,
|
||||
minLength: 0,
|
||||
title: "Query",
|
||||
description: "Search query text",
|
||||
},
|
||||
@@ -1065,6 +1065,20 @@ export const $SearchResult = {
|
||||
],
|
||||
title: "Room Id",
|
||||
},
|
||||
room_name: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Room Name",
|
||||
},
|
||||
source_kind: {
|
||||
$ref: "#/components/schemas/SourceKind",
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
title: "Created At",
|
||||
@@ -1101,10 +1115,18 @@ export const $SearchResult = {
|
||||
title: "Search Snippets",
|
||||
description: "Text snippets around search matches",
|
||||
},
|
||||
total_match_count: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
title: "Total Match Count",
|
||||
description: "Total number of matches found in the transcript",
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
"id",
|
||||
"source_kind",
|
||||
"created_at",
|
||||
"status",
|
||||
"rank",
|
||||
|
||||
@@ -286,6 +286,7 @@ export class DefaultService {
|
||||
* @param data.limit Results per page
|
||||
* @param data.offset Number of results to skip
|
||||
* @param data.roomId
|
||||
* @param data.sourceKind
|
||||
* @returns SearchResponse Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
@@ -300,6 +301,7 @@ export class DefaultService {
|
||||
limit: data.limit,
|
||||
offset: data.offset,
|
||||
room_id: data.roomId,
|
||||
source_kind: data.sourceKind,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
|
||||
@@ -209,6 +209,8 @@ export type SearchResult = {
|
||||
title?: string | null;
|
||||
user_id?: string | null;
|
||||
room_id?: string | null;
|
||||
room_name?: string | null;
|
||||
source_kind: SourceKind;
|
||||
created_at: string;
|
||||
status: string;
|
||||
rank: number;
|
||||
@@ -220,6 +222,10 @@ export type SearchResult = {
|
||||
* Text snippets around search matches
|
||||
*/
|
||||
search_snippets: Array<string>;
|
||||
/**
|
||||
* Total number of matches found in the transcript
|
||||
*/
|
||||
total_match_count?: number;
|
||||
};
|
||||
|
||||
export type SourceKind = "room" | "live" | "file";
|
||||
@@ -407,6 +413,7 @@ export type V1TranscriptsSearchData = {
|
||||
*/
|
||||
q: string;
|
||||
roomId?: string | null;
|
||||
sourceKind?: SourceKind | null;
|
||||
};
|
||||
|
||||
export type V1TranscriptsSearchResponse = SearchResponse;
|
||||
|
||||
2
www/app/api/urls.ts
Normal file
2
www/app/api/urls.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// TODO better connection with generated schema; it's duplication
|
||||
export const RECORD_A_MEETING_URL = "/transcripts/new" as const;
|
||||
62
www/app/lib/textHighlight.tsx
Normal file
62
www/app/lib/textHighlight.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Text highlighting and text fragment generation utilities
|
||||
* Used for search result highlighting and deep linking with Chrome Text Fragments
|
||||
*/
|
||||
|
||||
import React from "react";
|
||||
|
||||
export interface HighlightResult {
|
||||
text: string;
|
||||
matches: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes special regex characters in a string
|
||||
*/
|
||||
function escapeRegex(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
}
|
||||
|
||||
export const highlightMatches = (
|
||||
text: string,
|
||||
query: string,
|
||||
): { match: string; index: number }[] => {
|
||||
if (!query || !text) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const queryWords = query.trim().split(/\s+/);
|
||||
|
||||
const regex = new RegExp(
|
||||
`(${queryWords.map((word) => escapeRegex(word)).join("|")})`,
|
||||
"gi",
|
||||
);
|
||||
|
||||
return Array.from(text.matchAll(regex)).map((result) => ({
|
||||
match: result[0],
|
||||
index: result.index!,
|
||||
}));
|
||||
};
|
||||
|
||||
export function findFirstHighlight(text: string, query: string): string | null {
|
||||
const matches = highlightMatches(text, query);
|
||||
if (matches.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return matches[0].match;
|
||||
}
|
||||
|
||||
export function generateTextFragment(
|
||||
text: string,
|
||||
query: string,
|
||||
): {
|
||||
k: ":~:text";
|
||||
v: string;
|
||||
} | null {
|
||||
const firstMatch = findFirstHighlight(text, query);
|
||||
if (!firstMatch) return null;
|
||||
return {
|
||||
k: ":~:text",
|
||||
v: firstMatch,
|
||||
};
|
||||
}
|
||||
@@ -136,3 +136,10 @@ export function extractDomain(url) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function assertExists<T>(value: T | null | undefined, err?: string): T {
|
||||
if (value === null || value === undefined) {
|
||||
throw new Error(`Assertion failed: ${err ?? "value is null or undefined"}`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
import { redirect } from "next/navigation";
|
||||
import { RECORD_A_MEETING_URL } from "./api/urls";
|
||||
|
||||
export default function Index() {
|
||||
redirect("/transcripts/new");
|
||||
redirect(RECORD_A_MEETING_URL);
|
||||
}
|
||||
|
||||
@@ -5,14 +5,17 @@ import system from "./styles/theme";
|
||||
|
||||
import { WherebyProvider } from "@whereby.com/browser-sdk/react";
|
||||
import { Toaster } from "./components/ui/toaster";
|
||||
import { NuqsAdapter } from "nuqs/adapters/next/app";
|
||||
|
||||
export function Providers({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<ChakraProvider value={system}>
|
||||
<WherebyProvider>
|
||||
{children}
|
||||
<Toaster />
|
||||
</WherebyProvider>
|
||||
</ChakraProvider>
|
||||
<NuqsAdapter>
|
||||
<ChakraProvider value={system}>
|
||||
<WherebyProvider>
|
||||
{children}
|
||||
<Toaster />
|
||||
</WherebyProvider>
|
||||
</ChakraProvider>
|
||||
</NuqsAdapter>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
"next": "^14.2.30",
|
||||
"next-auth": "^4.24.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"nuqs": "^2.4.3",
|
||||
"postcss": "8.4.31",
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
|
||||
39
www/pnpm-lock.yaml
generated
39
www/pnpm-lock.yaml
generated
@@ -67,6 +67,9 @@ importers:
|
||||
next-themes:
|
||||
specifier: ^0.4.6
|
||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
nuqs:
|
||||
specifier: ^2.4.3
|
||||
version: 2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1)
|
||||
postcss:
|
||||
specifier: 8.4.31
|
||||
version: 8.4.31
|
||||
@@ -5436,6 +5439,12 @@ packages:
|
||||
}
|
||||
engines: { node: ">= 8" }
|
||||
|
||||
mitt@3.0.1:
|
||||
resolution:
|
||||
{
|
||||
integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==,
|
||||
}
|
||||
|
||||
mkdirp@0.5.6:
|
||||
resolution:
|
||||
{
|
||||
@@ -5660,6 +5669,27 @@ packages:
|
||||
}
|
||||
deprecated: This package is no longer supported.
|
||||
|
||||
nuqs@2.4.3:
|
||||
resolution:
|
||||
{
|
||||
integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==,
|
||||
}
|
||||
peerDependencies:
|
||||
"@remix-run/react": ">=2"
|
||||
next: ">=14.2.0"
|
||||
react: ">=18.2.0 || ^19.0.0-0"
|
||||
react-router: ^6 || ^7
|
||||
react-router-dom: ^6 || ^7
|
||||
peerDependenciesMeta:
|
||||
"@remix-run/react":
|
||||
optional: true
|
||||
next:
|
||||
optional: true
|
||||
react-router:
|
||||
optional: true
|
||||
react-router-dom:
|
||||
optional: true
|
||||
|
||||
nypm@0.5.4:
|
||||
resolution:
|
||||
{
|
||||
@@ -11553,6 +11583,8 @@ snapshots:
|
||||
minipass: 3.3.6
|
||||
yallist: 4.0.0
|
||||
|
||||
mitt@3.0.1: {}
|
||||
|
||||
mkdirp@0.5.6:
|
||||
dependencies:
|
||||
minimist: 1.2.8
|
||||
@@ -11674,6 +11706,13 @@ snapshots:
|
||||
gauge: 3.0.2
|
||||
set-blocking: 2.0.0
|
||||
|
||||
nuqs@2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1):
|
||||
dependencies:
|
||||
mitt: 3.0.1
|
||||
react: 18.3.1
|
||||
optionalDependencies:
|
||||
next: 14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0)
|
||||
|
||||
nypm@0.5.4:
|
||||
dependencies:
|
||||
citty: 0.1.6
|
||||
|
||||
Reference in New Issue
Block a user