mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Compare commits
9 Commits
mathieu/ca
...
v0.7.1
| Author | SHA1 | Date | |
|---|---|---|---|
| bc5b351d2b | |||
|
|
07981e8090 | ||
| 7e366f6338 | |||
| 7592679a35 | |||
| af16178f86 | |||
| 3ea7f6b7b6 | |||
|
|
009590c080 | ||
|
|
fe5d344cff | ||
|
|
86455ce573 |
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,18 +8,30 @@ env:
|
|||||||
ECR_REPOSITORY: reflector
|
ECR_REPOSITORY: reflector
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
build:
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- platform: linux/amd64
|
||||||
|
runner: linux-amd64
|
||||||
|
arch: amd64
|
||||||
|
- platform: linux/arm64
|
||||||
|
runner: linux-arm64
|
||||||
|
arch: arm64
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
deployments: write
|
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
uses: aws-actions/configure-aws-credentials@v4
|
||||||
with:
|
with:
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
@@ -27,21 +39,52 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Amazon ECR
|
- name: Login to Amazon ECR
|
||||||
id: login-ecr
|
id: login-ecr
|
||||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
uses: aws-actions/amazon-ecr-login@v2
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Build and push
|
- name: Build and push ${{ matrix.arch }}
|
||||||
id: docker_build
|
uses: docker/build-push-action@v5
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: ${{ matrix.platform }}
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||||
|
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||||
|
provenance: false
|
||||||
|
|
||||||
|
create-manifest:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [build]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
deployments: write
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v4
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
|
||||||
|
- name: Login to Amazon ECR
|
||||||
|
uses: aws-actions/amazon-ecr-login@v2
|
||||||
|
|
||||||
|
- name: Create and push multi-arch manifest
|
||||||
|
run: |
|
||||||
|
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||||
|
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||||
|
|
||||||
|
docker manifest create \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||||
|
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||||
|
|
||||||
|
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||||
|
|
||||||
|
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||||
|
|||||||
38
.github/workflows/test_server.yml
vendored
38
.github/workflows/test_server.yml
vendored
@@ -19,29 +19,41 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v3
|
uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
working-directory: server
|
working-directory: server
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: |
|
run: |
|
||||||
cd server
|
cd server
|
||||||
uv run -m pytest -v tests
|
uv run -m pytest -v tests
|
||||||
|
|
||||||
docker:
|
docker-amd64:
|
||||||
runs-on: ubuntu-latest
|
runs-on: linux-amd64
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build and push
|
- name: Build AMD64
|
||||||
id: docker_build
|
uses: docker/build-push-action@v6
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=amd64
|
||||||
|
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||||
|
|
||||||
|
docker-arm64:
|
||||||
|
runs-on: linux-arm64
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Build ARM64
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: server
|
||||||
|
platforms: linux/arm64
|
||||||
|
cache-from: type=gha,scope=arm64
|
||||||
|
cache-to: type=gha,mode=max,scope=arm64
|
||||||
|
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||||
|
|||||||
24
CHANGELOG.md
24
CHANGELOG.md
@@ -1,5 +1,29 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||||
|
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
||||||
|
|
||||||
|
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
||||||
|
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
||||||
|
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
||||||
|
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
||||||
|
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
||||||
|
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
||||||
|
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
||||||
|
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
||||||
|
|
||||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,7 +176,8 @@ artefacts/
|
|||||||
audio_*.wav
|
audio_*.wav
|
||||||
|
|
||||||
# ignore local database
|
# ignore local database
|
||||||
reflector.sqlite3
|
*.sqlite3
|
||||||
|
*.db
|
||||||
data/
|
data/
|
||||||
|
|
||||||
dump.rdb
|
dump.rdb
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
UV_LINK_MODE=copy
|
UV_LINK_MODE=copy \
|
||||||
|
UV_NO_CACHE=1
|
||||||
|
|
||||||
# builder install base dependencies
|
# builder install base dependencies
|
||||||
WORKDIR /tmp
|
WORKDIR /tmp
|
||||||
@@ -13,8 +14,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
|||||||
# install application dependencies
|
# install application dependencies
|
||||||
RUN mkdir -p /app
|
RUN mkdir -p /app
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY pyproject.toml uv.lock /app/
|
COPY pyproject.toml uv.lock README.md /app/
|
||||||
RUN touch README.md && env uv sync --compile-bytecode --locked
|
RUN uv sync --compile-bytecode --locked
|
||||||
|
|
||||||
# pre-download nltk packages
|
# pre-download nltk packages
|
||||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||||
|
|||||||
@@ -40,3 +40,5 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
|||||||
```bash
|
```bash
|
||||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
.
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ This repository hold an API for the GPU implementation of the Reflector API serv
|
|||||||
and use [Modal.com](https://modal.com)
|
and use [Modal.com](https://modal.com)
|
||||||
|
|
||||||
- `reflector_diarizer.py` - Diarization API
|
- `reflector_diarizer.py` - Diarization API
|
||||||
- `reflector_transcriber.py` - Transcription API
|
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||||
|
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||||
- `reflector_translator.py` - Translation API
|
- `reflector_translator.py` - Translation API
|
||||||
|
|
||||||
## Modal.com deployment
|
## Modal.com deployment
|
||||||
@@ -19,6 +20,10 @@ $ modal deploy reflector_transcriber.py
|
|||||||
...
|
...
|
||||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||||
|
|
||||||
|
$ modal deploy reflector_transcriber_parakeet.py
|
||||||
|
...
|
||||||
|
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
|
||||||
$ modal deploy reflector_llm.py
|
$ modal deploy reflector_llm.py
|
||||||
...
|
...
|
||||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||||
@@ -68,6 +73,86 @@ Authorization: bearer <REFLECTOR_APIKEY>
|
|||||||
|
|
||||||
### Transcription
|
### Transcription
|
||||||
|
|
||||||
|
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||||
|
|
||||||
|
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||||
|
|
||||||
|
**GPU Configuration:**
|
||||||
|
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||||
|
- Higher concurrency (max_inputs=10)
|
||||||
|
- Optimized for multiple small audio files
|
||||||
|
- Supports batch processing for efficiency
|
||||||
|
|
||||||
|
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||||
|
- Lower concurrency but more powerful processing
|
||||||
|
- Optimized for single large audio files
|
||||||
|
- VAD-based chunking for long-form audio
|
||||||
|
|
||||||
|
##### `/v1/audio/transcriptions` - Small file transcription
|
||||||
|
|
||||||
|
**request** (multipart/form-data)
|
||||||
|
- `file` or `files[]` - audio file(s) to transcribe
|
||||||
|
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||||
|
- `language` - language code (default: `en`)
|
||||||
|
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0}
|
||||||
|
],
|
||||||
|
"filename": "audio.mp3"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For multiple files with batch=true:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"filename": "audio1.mp3",
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [...]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "audio2.mp3",
|
||||||
|
"text": "transcribed text",
|
||||||
|
"words": [...]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||||
|
|
||||||
|
**request** (application/json)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_file_url": "https://example.com/audio.mp3",
|
||||||
|
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 0.0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"text": "transcribed text from large file",
|
||||||
|
"words": [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||||
|
|
||||||
|
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||||
|
|
||||||
`POST /transcribe`
|
`POST /transcribe`
|
||||||
|
|
||||||
**request** (multipart/form-data)
|
**request** (multipart/form-data)
|
||||||
|
|||||||
@@ -4,14 +4,80 @@ Reflector GPU backend - diarizer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Mapping, NewType
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import modal.gpu
|
import modal
|
||||||
from modal import App, Image, Secret, asgi_app, enter, method
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||||
MODEL_DIR = "/root/diarization_models"
|
MODEL_DIR = "/root/diarization_models"
|
||||||
app = App(name="reflector-diarizer")
|
UPLOADS_PATH = "/uploads"
|
||||||
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
|
||||||
|
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||||
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||||
|
|
||||||
|
app = modal.App(name="reflector-diarizer")
|
||||||
|
|
||||||
|
# Volume for temporary file uploads
|
||||||
|
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
url_path = parsed_url.path
|
||||||
|
|
||||||
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
if url_path.lower().endswith(f".{ext}"):
|
||||||
|
return AudioFileExtension(ext)
|
||||||
|
|
||||||
|
content_type = headers.get("content-type", "").lower()
|
||||||
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||||
|
return AudioFileExtension("mp3")
|
||||||
|
if "audio/wav" in content_type:
|
||||||
|
return AudioFileExtension("wav")
|
||||||
|
if "audio/mp4" in content_type:
|
||||||
|
return AudioFileExtension("mp4")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported audio format for URL: {url}. "
|
||||||
|
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_to_volume(
|
||||||
|
audio_file_url: str,
|
||||||
|
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||||
|
import requests
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
print(f"Checking audio file at: {audio_file_url}")
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
|
||||||
|
print(f"Downloading audio file from: {audio_file_url}")
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail=f"Failed to download audio file: {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||||
|
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
print(f"File saved as: {unique_filename}")
|
||||||
|
return unique_filename, audio_suffix
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
def migrate_cache_llm():
|
||||||
@@ -39,7 +105,7 @@ def download_pyannote_audio():
|
|||||||
|
|
||||||
|
|
||||||
diarizer_image = (
|
diarizer_image = (
|
||||||
Image.debian_slim(python_version="3.10.8")
|
modal.Image.debian_slim(python_version="3.10.8")
|
||||||
.pip_install(
|
.pip_install(
|
||||||
"pyannote.audio==3.1.0",
|
"pyannote.audio==3.1.0",
|
||||||
"requests",
|
"requests",
|
||||||
@@ -55,7 +121,8 @@ diarizer_image = (
|
|||||||
"hf-transfer",
|
"hf-transfer",
|
||||||
)
|
)
|
||||||
.run_function(
|
.run_function(
|
||||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
download_pyannote_audio,
|
||||||
|
secrets=[modal.Secret.from_name("hf_token")],
|
||||||
)
|
)
|
||||||
.run_function(migrate_cache_llm)
|
.run_function(migrate_cache_llm)
|
||||||
.env(
|
.env(
|
||||||
@@ -70,44 +137,51 @@ diarizer_image = (
|
|||||||
|
|
||||||
|
|
||||||
@app.cls(
|
@app.cls(
|
||||||
gpu=modal.gpu.A100(size="40GB"),
|
gpu="A100",
|
||||||
timeout=60 * 30,
|
timeout=60 * 30,
|
||||||
scaledown_window=60,
|
|
||||||
allow_concurrent_inputs=1,
|
|
||||||
image=diarizer_image,
|
image=diarizer_image,
|
||||||
|
volumes={UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
secrets=[
|
||||||
|
modal.Secret.from_name("hf_token"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@modal.concurrent(max_inputs=1)
|
||||||
class Diarizer:
|
class Diarizer:
|
||||||
@enter()
|
@modal.enter(snap=True)
|
||||||
def enter(self):
|
def enter(self):
|
||||||
import torch
|
import torch
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
self.use_gpu = torch.cuda.is_available()
|
self.use_gpu = torch.cuda.is_available()
|
||||||
self.device = "cuda" if self.use_gpu else "cpu"
|
self.device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
PYANNOTE_MODEL_NAME,
|
||||||
|
cache_dir=MODEL_DIR,
|
||||||
|
use_auth_token=os.environ["HF_TOKEN"],
|
||||||
)
|
)
|
||||||
self.diarization_pipeline.to(torch.device(self.device))
|
self.diarization_pipeline.to(torch.device(self.device))
|
||||||
|
|
||||||
@method()
|
@modal.method()
|
||||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
upload_volume.reload()
|
||||||
fp.write(audio_data)
|
|
||||||
|
|
||||||
print("Diarizing audio")
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
waveform, sample_rate = torchaudio.load(fp.name)
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
print(f"Diarizing audio from: {file_path}")
|
||||||
|
waveform, sample_rate = torchaudio.load(file_path)
|
||||||
diarization = self.diarization_pipeline(
|
diarization = self.diarization_pipeline(
|
||||||
{"waveform": waveform, "sample_rate": sample_rate}
|
{"waveform": waveform, "sample_rate": sample_rate}
|
||||||
)
|
)
|
||||||
|
|
||||||
words = []
|
words = []
|
||||||
for diarization_segment, _, speaker in diarization.itertracks(
|
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
yield_label=True
|
|
||||||
):
|
|
||||||
words.append(
|
words.append(
|
||||||
{
|
{
|
||||||
"start": round(timestamp + diarization_segment.start, 3),
|
"start": round(timestamp + diarization_segment.start, 3),
|
||||||
@@ -127,17 +201,18 @@ class Diarizer:
|
|||||||
@app.function(
|
@app.function(
|
||||||
timeout=60 * 10,
|
timeout=60 * 10,
|
||||||
scaledown_window=60 * 3,
|
scaledown_window=60 * 3,
|
||||||
allow_concurrent_inputs=40,
|
|
||||||
secrets=[
|
secrets=[
|
||||||
Secret.from_name("reflector-gpu"),
|
modal.Secret.from_name("reflector-gpu"),
|
||||||
],
|
],
|
||||||
|
volumes={UPLOADS_PATH: upload_volume},
|
||||||
image=diarizer_image,
|
image=diarizer_image,
|
||||||
)
|
)
|
||||||
@asgi_app()
|
@modal.concurrent(max_inputs=40)
|
||||||
|
@modal.asgi_app()
|
||||||
def web():
|
def web():
|
||||||
import requests
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
diarizerstub = Diarizer()
|
diarizerstub = Diarizer()
|
||||||
|
|
||||||
@@ -153,35 +228,26 @@ def web():
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_audio_file(audio_file_url: str):
|
|
||||||
# Check if the audio file exists
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail="The audio file does not exist.",
|
|
||||||
)
|
|
||||||
|
|
||||||
class DiarizationResponse(BaseModel):
|
class DiarizationResponse(BaseModel):
|
||||||
result: dict
|
result: dict
|
||||||
|
|
||||||
@app.post(
|
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||||
)
|
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||||
def diarize(
|
|
||||||
audio_file_url: str, timestamp: float = 0.0
|
|
||||||
) -> HTTPException | DiarizationResponse:
|
|
||||||
# Currently the uploaded files are in mp3 format
|
|
||||||
audio_suffix = "mp3"
|
|
||||||
|
|
||||||
print("Downloading audio file")
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
print("Audio file downloaded successfully")
|
|
||||||
|
|
||||||
|
try:
|
||||||
func = diarizerstub.diarize.spawn(
|
func = diarizerstub.diarize.spawn(
|
||||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
filename=unique_filename, timestamp=timestamp
|
||||||
)
|
)
|
||||||
result = func.get()
|
result = func.get()
|
||||||
return result
|
return result
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
upload_volume.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cleaning up {unique_filename}: {e}")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
622
server/gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from typing import Mapping, NewType
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import modal
|
||||||
|
|
||||||
|
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||||
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
SAMPLERATE = 16000
|
||||||
|
UPLOADS_PATH = "/uploads"
|
||||||
|
CACHE_PATH = "/cache"
|
||||||
|
VAD_CONFIG = {
|
||||||
|
"max_segment_duration": 30.0,
|
||||||
|
"batch_max_files": 10,
|
||||||
|
"batch_max_duration": 5.0,
|
||||||
|
"min_segment_duration": 0.02,
|
||||||
|
"silence_padding": 0.5,
|
||||||
|
"window_size": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||||
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||||
|
|
||||||
|
app = modal.App("reflector-transcriber-parakeet")
|
||||||
|
|
||||||
|
# Volume for caching model weights
|
||||||
|
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||||
|
# Volume for temporary file uploads
|
||||||
|
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||||
|
|
||||||
|
image = (
|
||||||
|
modal.Image.from_registry(
|
||||||
|
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||||
|
)
|
||||||
|
.env(
|
||||||
|
{
|
||||||
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||||
|
"HF_HOME": "/cache",
|
||||||
|
"DEBIAN_FRONTEND": "noninteractive",
|
||||||
|
"CXX": "g++",
|
||||||
|
"CC": "g++",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.apt_install("ffmpeg")
|
||||||
|
.pip_install(
|
||||||
|
"hf_transfer==0.1.9",
|
||||||
|
"huggingface_hub[hf-xet]==0.31.2",
|
||||||
|
"nemo_toolkit[asr]==2.3.0",
|
||||||
|
"cuda-python==12.8.0",
|
||||||
|
"fastapi==0.115.12",
|
||||||
|
"numpy<2",
|
||||||
|
"librosa==0.10.1",
|
||||||
|
"requests",
|
||||||
|
"silero-vad==5.1.0",
|
||||||
|
"torch",
|
||||||
|
)
|
||||||
|
.entrypoint([]) # silence chatty logs by container on start
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
url_path = parsed_url.path
|
||||||
|
|
||||||
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
if url_path.lower().endswith(f".{ext}"):
|
||||||
|
return AudioFileExtension(ext)
|
||||||
|
|
||||||
|
content_type = headers.get("content-type", "").lower()
|
||||||
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||||
|
return AudioFileExtension("mp3")
|
||||||
|
if "audio/wav" in content_type:
|
||||||
|
return AudioFileExtension("wav")
|
||||||
|
if "audio/mp4" in content_type:
|
||||||
|
return AudioFileExtension("mp4")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported audio format for URL: {url}. "
|
||||||
|
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_to_volume(
|
||||||
|
audio_file_url: str,
|
||||||
|
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||||
|
import requests
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||||
|
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
return unique_filename, audio_suffix
|
||||||
|
|
||||||
|
|
||||||
|
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||||
|
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||||
|
|
||||||
|
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||||
|
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||||
|
have to be of the same length
|
||||||
|
|
||||||
|
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
audio_duration = len(audio_array) / sample_rate
|
||||||
|
if audio_duration < 0.5:
|
||||||
|
silence_samples = int(sample_rate * 0.5)
|
||||||
|
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
return np.concatenate([audio_array, silence])
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
|
||||||
|
@app.cls(
|
||||||
|
gpu="A10G",
|
||||||
|
timeout=600,
|
||||||
|
scaledown_window=300,
|
||||||
|
image=image,
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
)
|
||||||
|
@modal.concurrent(max_inputs=10)
|
||||||
|
class TranscriberParakeetLive:
|
||||||
|
@modal.enter(snap=True)
|
||||||
|
def enter(self):
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
|
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
print(f"Model is on device: {device}")
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_segment(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
padded_audio = pad_audio(audio_array, sample_rate)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
with NoStdStreams():
|
||||||
|
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||||
|
|
||||||
|
text = output.text.strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(word_info["start"], 2),
|
||||||
|
"end": round(word_info["end"], 2),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"text": text, "words": words}
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_batch(
|
||||||
|
self,
|
||||||
|
filenames: list[str],
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
audio_arrays = []
|
||||||
|
|
||||||
|
# Load all audio files with padding
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
padded_audio = pad_audio(audio_array, sample_rate)
|
||||||
|
audio_arrays.append(padded_audio)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
with NoStdStreams():
|
||||||
|
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||||
|
|
||||||
|
# Process results for each file
|
||||||
|
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||||
|
text = output.text.strip()
|
||||||
|
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(word_info["start"], 2),
|
||||||
|
"end": round(word_info["end"], 2),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"filename": filename,
|
||||||
|
"text": text,
|
||||||
|
"words": words,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# L40S class for file transcription (bigger files)
|
||||||
|
@app.cls(
|
||||||
|
gpu="L40S",
|
||||||
|
timeout=900,
|
||||||
|
image=image,
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
enable_memory_snapshot=True,
|
||||||
|
experimental_options={"enable_gpu_snapshot": True},
|
||||||
|
)
|
||||||
|
class TranscriberParakeetFile:
|
||||||
|
@modal.enter(snap=True)
|
||||||
|
def enter(self):
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
import torch
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
|
||||||
|
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
print(f"Model is on device: {device}")
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
self.vad_model = load_silero_vad(onnx=False)
|
||||||
|
print("Silero VAD initialized")
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_segment(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
timestamp_offset: float = 0.0,
|
||||||
|
):
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from silero_vad import VADIterator
|
||||||
|
|
||||||
|
def load_and_convert_audio(file_path):
|
||||||
|
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
def vad_segment_generator(audio_array):
|
||||||
|
"""Generate speech segments using VAD with start/end sample indices"""
|
||||||
|
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||||
|
window_size = VAD_CONFIG["window_size"]
|
||||||
|
start = None
|
||||||
|
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(
|
||||||
|
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
speech_dict = vad_iterator(chunk)
|
||||||
|
if not speech_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "start" in speech_dict:
|
||||||
|
start = speech_dict["start"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "end" in speech_dict and start is not None:
|
||||||
|
end = speech_dict["end"]
|
||||||
|
start_time = start / float(SAMPLERATE)
|
||||||
|
end_time = end / float(SAMPLERATE)
|
||||||
|
|
||||||
|
# Extract the actual audio segment
|
||||||
|
audio_segment = audio_array[start:end]
|
||||||
|
|
||||||
|
yield (start_time, end_time, audio_segment)
|
||||||
|
start = None
|
||||||
|
|
||||||
|
vad_iterator.reset_states()
|
||||||
|
|
||||||
|
def vad_segment_filter(segments):
|
||||||
|
"""Filter VAD segments by duration and chunk large segments"""
|
||||||
|
min_dur = VAD_CONFIG["min_segment_duration"]
|
||||||
|
max_dur = VAD_CONFIG["max_segment_duration"]
|
||||||
|
|
||||||
|
for start_time, end_time, audio_segment in segments:
|
||||||
|
segment_duration = end_time - start_time
|
||||||
|
|
||||||
|
# Skip very small segments
|
||||||
|
if segment_duration < min_dur:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If segment is within max duration, yield as-is
|
||||||
|
if segment_duration <= max_dur:
|
||||||
|
yield (start_time, end_time, audio_segment)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Chunk large segments into smaller pieces
|
||||||
|
chunk_samples = int(max_dur * SAMPLERATE)
|
||||||
|
current_start = start_time
|
||||||
|
|
||||||
|
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
||||||
|
chunk_audio = audio_segment[
|
||||||
|
chunk_offset : chunk_offset + chunk_samples
|
||||||
|
]
|
||||||
|
if len(chunk_audio) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
||||||
|
chunk_end = current_start + chunk_duration
|
||||||
|
|
||||||
|
# Only yield chunks that meet minimum duration
|
||||||
|
if chunk_duration >= min_dur:
|
||||||
|
yield (current_start, chunk_end, chunk_audio)
|
||||||
|
|
||||||
|
current_start = chunk_end
|
||||||
|
|
||||||
|
def batch_segments(segments, max_files=10, max_duration=5.0):
|
||||||
|
batch = []
|
||||||
|
batch_duration = 0.0
|
||||||
|
|
||||||
|
for start_time, end_time, audio_segment in segments:
|
||||||
|
segment_duration = end_time - start_time
|
||||||
|
|
||||||
|
if segment_duration < VAD_CONFIG["silence_padding"]:
|
||||||
|
silence_samples = int(
|
||||||
|
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
||||||
|
)
|
||||||
|
padding = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
audio_segment = np.concatenate([audio_segment, padding])
|
||||||
|
segment_duration = VAD_CONFIG["silence_padding"]
|
||||||
|
|
||||||
|
batch.append((start_time, end_time, audio_segment))
|
||||||
|
batch_duration += segment_duration
|
||||||
|
|
||||||
|
if len(batch) >= max_files or batch_duration >= max_duration:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
batch_duration = 0.0
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def transcribe_batch(model, audio_segments):
|
||||||
|
with NoStdStreams():
|
||||||
|
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def emit_results(
|
||||||
|
results,
|
||||||
|
segments_info,
|
||||||
|
batch_index,
|
||||||
|
total_batches,
|
||||||
|
):
|
||||||
|
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||||
|
for i, (output, (start_time, end_time, _)) in enumerate(
|
||||||
|
zip(results, segments_info)
|
||||||
|
):
|
||||||
|
text = output.text.strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word_info["word"],
|
||||||
|
"start": round(
|
||||||
|
word_info["start"] + start_time + timestamp_offset, 2
|
||||||
|
),
|
||||||
|
"end": round(
|
||||||
|
word_info["end"] + start_time + timestamp_offset, 2
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for word_info in output.timestamp["word"]
|
||||||
|
]
|
||||||
|
|
||||||
|
yield text, words
|
||||||
|
|
||||||
|
upload_volume.reload()
|
||||||
|
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
audio_array = load_and_convert_audio(file_path)
|
||||||
|
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||||
|
processed_duration = 0.0
|
||||||
|
|
||||||
|
all_text_parts = []
|
||||||
|
all_words = []
|
||||||
|
|
||||||
|
raw_segments = vad_segment_generator(audio_array)
|
||||||
|
filtered_segments = vad_segment_filter(raw_segments)
|
||||||
|
batches = batch_segments(
|
||||||
|
filtered_segments,
|
||||||
|
VAD_CONFIG["batch_max_files"],
|
||||||
|
VAD_CONFIG["batch_max_duration"],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_index = 0
|
||||||
|
total_batches = max(
|
||||||
|
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_index += 1
|
||||||
|
audio_segments = [seg[2] for seg in batch]
|
||||||
|
results = transcribe_batch(self.model, audio_segments)
|
||||||
|
|
||||||
|
for text, words in emit_results(
|
||||||
|
results,
|
||||||
|
batch,
|
||||||
|
batch_index,
|
||||||
|
total_batches,
|
||||||
|
):
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
all_text_parts.append(text)
|
||||||
|
all_words.extend(words)
|
||||||
|
|
||||||
|
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
||||||
|
|
||||||
|
combined_text = " ".join(all_text_parts)
|
||||||
|
return {"text": combined_text, "words": all_words}
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
scaledown_window=60,
|
||||||
|
timeout=600,
|
||||||
|
secrets=[
|
||||||
|
modal.Secret.from_name("reflector-gpu"),
|
||||||
|
],
|
||||||
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
@modal.concurrent(max_inputs=40)
|
||||||
|
@modal.asgi_app()
|
||||||
|
def web():
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import (
|
||||||
|
Body,
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
Form,
|
||||||
|
HTTPException,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
transcriber_live = TranscriberParakeetLive()
|
||||||
|
transcriber_file = TranscriberParakeetFile()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||||
|
return
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class TranscriptResponse(BaseModel):
|
||||||
|
result: dict
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||||
|
def transcribe(
|
||||||
|
file: UploadFile = None,
|
||||||
|
files: list[UploadFile] | None = None,
|
||||||
|
model: str = Form(MODEL_NAME),
|
||||||
|
language: str = Form("en"),
|
||||||
|
batch: bool = Form(False),
|
||||||
|
):
|
||||||
|
# Parakeet only supports English
|
||||||
|
if language != "en":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||||
|
)
|
||||||
|
# Handle both single file and multiple files
|
||||||
|
if not file and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||||
|
)
|
||||||
|
if batch and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Batch transcription requires 'files'"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_files = [file] if file else files
|
||||||
|
|
||||||
|
# Upload files to volume
|
||||||
|
uploaded_filenames = []
|
||||||
|
for upload_file in upload_files:
|
||||||
|
audio_suffix = upload_file.filename.split(".")[-1]
|
||||||
|
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
|
||||||
|
print(f"Writing file to: {file_path}")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
content = upload_file.file.read()
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
uploaded_filenames.append(unique_filename)
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use A10G live transcriber for per-file transcription
|
||||||
|
if batch and len(upload_files) > 1:
|
||||||
|
# Use batch transcription
|
||||||
|
func = transcriber_live.transcribe_batch.spawn(
|
||||||
|
filenames=uploaded_filenames,
|
||||||
|
)
|
||||||
|
results = func.get()
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
# Per-file transcription
|
||||||
|
results = []
|
||||||
|
for filename in uploaded_filenames:
|
||||||
|
func = transcriber_live.transcribe_segment.spawn(
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
result["filename"] = filename
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return {"results": results} if len(results) > 1 else results[0]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for filename in uploaded_filenames:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting {filename}: {e}")
|
||||||
|
|
||||||
|
upload_volume.commit()
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||||
|
def transcribe_from_url(
|
||||||
|
audio_file_url: str = Body(
|
||||||
|
..., description="URL of the audio file to transcribe"
|
||||||
|
),
|
||||||
|
model: str = Body(MODEL_NAME),
|
||||||
|
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||||
|
timestamp_offset: float = Body(0.0),
|
||||||
|
):
|
||||||
|
# Parakeet only supports English
|
||||||
|
if language != "en":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||||
|
)
|
||||||
|
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
func = transcriber_file.transcribe_segment.spawn(
|
||||||
|
filename=unique_filename,
|
||||||
|
timestamp_offset=timestamp_offset,
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||||
|
print(f"Deleting file: {file_path}")
|
||||||
|
os.remove(file_path)
|
||||||
|
upload_volume.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cleaning up {unique_filename}: {e}")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class NoStdStreams:
|
||||||
|
def __init__(self):
|
||||||
|
self.devnull = open(os.devnull, "w")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||||
|
self._stdout.flush()
|
||||||
|
self._stderr.flush()
|
||||||
|
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||||
|
self.devnull.close()
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""add_long_summary_to_search_vector
|
||||||
|
|
||||||
|
Revision ID: 0ab2d7ffaa16
|
||||||
|
Revises: b1c33bd09963
|
||||||
|
Create Date: 2025-08-15 13:27:52.680211
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0ab2d7ffaa16"
|
||||||
|
down_revision: Union[str, None] = "b1c33bd09963"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Drop the existing search vector column and index
|
||||||
|
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||||
|
op.drop_column("transcript", "search_vector_en")
|
||||||
|
|
||||||
|
# Recreate the search vector column with long_summary included
|
||||||
|
op.execute("""
|
||||||
|
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||||
|
GENERATED ALWAYS AS (
|
||||||
|
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||||
|
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||||
|
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||||
|
) STORED
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Recreate the GIN index for the search vector
|
||||||
|
op.create_index(
|
||||||
|
"idx_transcript_search_vector_en",
|
||||||
|
"transcript",
|
||||||
|
["search_vector_en"],
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop the updated search vector column and index
|
||||||
|
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||||
|
op.drop_column("transcript", "search_vector_en")
|
||||||
|
|
||||||
|
# Recreate the original search vector column without long_summary
|
||||||
|
op.execute("""
|
||||||
|
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||||
|
GENERATED ALWAYS AS (
|
||||||
|
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||||
|
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||||
|
) STORED
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Recreate the GIN index for the search vector
|
||||||
|
op.create_index(
|
||||||
|
"idx_transcript_search_vector_en",
|
||||||
|
"transcript",
|
||||||
|
["search_vector_en"],
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""add_search_optimization_indexes
|
||||||
|
|
||||||
|
Revision ID: b1c33bd09963
|
||||||
|
Revises: 9f5c78d352d6
|
||||||
|
Create Date: 2025-08-14 17:26:02.117408
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b1c33bd09963"
|
||||||
|
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add indexes for actual search filtering patterns used in frontend
|
||||||
|
# Based on /browse page filters: room_id and source_kind
|
||||||
|
|
||||||
|
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||||
|
op.create_index(
|
||||||
|
"idx_transcript_room_id_created_at",
|
||||||
|
"transcript",
|
||||||
|
["room_id", "created_at"],
|
||||||
|
if_not_exists=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index for source_kind alone (actively used filter in frontend)
|
||||||
|
op.create_index(
|
||||||
|
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Remove the indexes in reverse order
|
||||||
|
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||||
|
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||||
@@ -32,7 +32,6 @@ dependencies = [
|
|||||||
"redis>=5.0.1",
|
"redis>=5.0.1",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
"faster-whisper>=0.10.0",
|
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"jsonschema>=4.23.0",
|
"jsonschema>=4.23.0",
|
||||||
"openai>=1.59.7",
|
"openai>=1.59.7",
|
||||||
@@ -57,6 +56,7 @@ tests = [
|
|||||||
"httpx-ws>=0.4.1",
|
"httpx-ws>=0.4.1",
|
||||||
"pytest-httpx>=0.23.1",
|
"pytest-httpx>=0.23.1",
|
||||||
"pytest-celery>=0.0.0",
|
"pytest-celery>=0.0.0",
|
||||||
|
"pytest-recording>=0.13.4",
|
||||||
"pytest-docker>=3.2.3",
|
"pytest-docker>=3.2.3",
|
||||||
"asgi-lifespan>=2.1.0",
|
"asgi-lifespan>=2.1.0",
|
||||||
]
|
]
|
||||||
@@ -67,6 +67,15 @@ evaluation = [
|
|||||||
"tqdm>=4.66.0",
|
"tqdm>=4.66.0",
|
||||||
"pydantic>=2.1.1",
|
"pydantic>=2.1.1",
|
||||||
]
|
]
|
||||||
|
local = [
|
||||||
|
"pyannote-audio>=3.3.2",
|
||||||
|
"faster-whisper>=0.10.0",
|
||||||
|
]
|
||||||
|
silero-vad = [
|
||||||
|
"silero-vad>=5.1.2",
|
||||||
|
"torch>=2.8.0",
|
||||||
|
"torchaudio>=2.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
default-groups = [
|
default-groups = [
|
||||||
@@ -74,6 +83,21 @@ default-groups = [
|
|||||||
"tests",
|
"tests",
|
||||||
"aws",
|
"aws",
|
||||||
"evaluation",
|
"evaluation",
|
||||||
|
"local",
|
||||||
|
"silero-vad"
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu" },
|
||||||
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{ index = "pytorch-cpu" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@@ -94,6 +118,9 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
|
|||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
markers = [
|
||||||
|
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
|
|||||||
@@ -1,24 +1,37 @@
|
|||||||
"""Search functionality for transcripts and other entities."""
|
"""Search functionality for transcripts and other entities."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Annotated, Any, Dict
|
from typing import Annotated, Any, Dict, Iterator
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import webvtt
|
import webvtt
|
||||||
from pydantic import BaseModel, Field, constr, field_serializer
|
from fastapi import HTTPException
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
NonNegativeFloat,
|
||||||
|
NonNegativeInt,
|
||||||
|
ValidationError,
|
||||||
|
constr,
|
||||||
|
field_serializer,
|
||||||
|
)
|
||||||
|
|
||||||
from reflector.db import get_database
|
from reflector.db import get_database
|
||||||
|
from reflector.db.rooms import rooms
|
||||||
from reflector.db.transcripts import SourceKind, transcripts
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
from reflector.db.utils import is_postgresql
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 20
|
DEFAULT_SEARCH_LIMIT = 20
|
||||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||||
DEFAULT_MAX_SNIPPETS = 3
|
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||||
|
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||||
|
|
||||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
||||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||||
@@ -32,6 +45,82 @@ SearchTotal = Annotated[
|
|||||||
SearchTotalBase, Field(description="Total number of search results")
|
SearchTotalBase, Field(description="Total number of search results")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
WEBVTT_SPEC_HEADER = "WEBVTT"
|
||||||
|
|
||||||
|
WebVTTContent = Annotated[
|
||||||
|
str,
|
||||||
|
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WebVTTProcessor:
|
||||||
|
"""Stateless processor for WebVTT content operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(raw_content: str) -> WebVTTContent:
|
||||||
|
"""Parse WebVTT content and return it as a string."""
|
||||||
|
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||||
|
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||||
|
return raw_content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||||
|
"""Extract plain text from WebVTT content using webvtt library."""
|
||||||
|
try:
|
||||||
|
buffer = StringIO(webvtt_content)
|
||||||
|
vtt = webvtt.read_buffer(buffer)
|
||||||
|
return " ".join(caption.text for caption in vtt if caption.text)
|
||||||
|
except webvtt.errors.MalformedFileError as e:
|
||||||
|
logger.warning(f"Malformed WebVTT content: {e}")
|
||||||
|
return ""
|
||||||
|
except (UnicodeDecodeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||||
|
return ""
|
||||||
|
except AttributeError as e:
|
||||||
|
logger.error(
|
||||||
|
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_snippets(
|
||||||
|
webvtt_content: WebVTTContent,
|
||||||
|
query: str,
|
||||||
|
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate snippets from WebVTT content."""
|
||||||
|
return SnippetGenerator.generate(
|
||||||
|
WebVTTProcessor.extract_text(webvtt_content),
|
||||||
|
query,
|
||||||
|
max_snippets=max_snippets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SnippetCandidate:
|
||||||
|
"""Represents a candidate snippet with its position."""
|
||||||
|
|
||||||
|
_text: str
|
||||||
|
start: NonNegativeInt
|
||||||
|
_original_text_length: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end(self) -> NonNegativeInt:
|
||||||
|
"""Calculate end position from start and raw text length."""
|
||||||
|
return self.start + len(self._text)
|
||||||
|
|
||||||
|
def text(self) -> str:
|
||||||
|
"""Get display text with ellipses added if needed."""
|
||||||
|
result = self._text.strip()
|
||||||
|
if self.start > 0:
|
||||||
|
result = "..." + result
|
||||||
|
if self.end < self._original_text_length:
|
||||||
|
result = result + "..."
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SearchParameters(BaseModel):
|
class SearchParameters(BaseModel):
|
||||||
"""Validated search parameters for full-text search."""
|
"""Validated search parameters for full-text search."""
|
||||||
@@ -41,6 +130,7 @@ class SearchParameters(BaseModel):
|
|||||||
offset: SearchOffset = 0
|
offset: SearchOffset = 0
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
|
source_kind: SourceKind | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchResultDB(BaseModel):
|
class SearchResultDB(BaseModel):
|
||||||
@@ -64,13 +154,18 @@ class SearchResult(BaseModel):
|
|||||||
title: str | None = None
|
title: str | None = None
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
|
room_name: str | None = None
|
||||||
|
source_kind: SourceKind
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
status: str = Field(..., min_length=1)
|
status: str = Field(..., min_length=1)
|
||||||
rank: float = Field(..., ge=0, le=1)
|
rank: float = Field(..., ge=0, le=1)
|
||||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||||
search_snippets: list[str] = Field(
|
search_snippets: list[str] = Field(
|
||||||
description="Text snippets around search matches"
|
description="Text snippets around search matches"
|
||||||
)
|
)
|
||||||
|
total_match_count: NonNegativeInt = Field(
|
||||||
|
default=0, description="Total number of matches found in the transcript"
|
||||||
|
)
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def serialize_datetime(self, dt: datetime) -> str:
|
def serialize_datetime(self, dt: datetime) -> str:
|
||||||
@@ -79,84 +174,153 @@ class SearchResult(BaseModel):
|
|||||||
return dt.isoformat()
|
return dt.isoformat()
|
||||||
|
|
||||||
|
|
||||||
class SearchController:
|
class SnippetGenerator:
|
||||||
"""Controller for search operations across different entities."""
|
"""Stateless generator for text snippets and match operations."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||||
"""Extract plain text from WebVTT content using webvtt library."""
|
"""Generate all match positions for a query in text."""
|
||||||
if not webvtt_content:
|
if not text:
|
||||||
return ""
|
logger.warning("Empty text for search query in find_all_matches")
|
||||||
|
return
|
||||||
|
if not query:
|
||||||
|
logger.warning("Empty query for search text in find_all_matches")
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
text_lower = text.lower()
|
||||||
buffer = StringIO(webvtt_content)
|
query_lower = query.lower()
|
||||||
vtt = webvtt.read_buffer(buffer)
|
start = 0
|
||||||
return " ".join(caption.text for caption in vtt if caption.text)
|
prev_start = start
|
||||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
yield pos
|
||||||
return ""
|
start = pos + len(query_lower)
|
||||||
except AttributeError as e:
|
if start <= prev_start:
|
||||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
raise ValueError("panic! find_all_matches is not incremental")
|
||||||
return ""
|
prev_start = start
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_snippets(
|
def count_matches(text: str, query: str) -> NonNegativeInt:
|
||||||
text: str,
|
"""Count total number of matches for a query in text."""
|
||||||
q: SearchQuery,
|
ZERO = NonNegativeInt(0)
|
||||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
if not text:
|
||||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
logger.warning("Empty text for search query in count_matches")
|
||||||
) -> list[str]:
|
return ZERO
|
||||||
"""Generate multiple snippets around all occurrences of search term."""
|
if not query:
|
||||||
if not text or not q:
|
logger.warning("Empty query for search text in count_matches")
|
||||||
return []
|
return ZERO
|
||||||
|
return NonNegativeInt(
|
||||||
snippets = []
|
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||||
lower_text = text.lower()
|
|
||||||
search_lower = q.lower()
|
|
||||||
|
|
||||||
last_snippet_end = 0
|
|
||||||
start_pos = 0
|
|
||||||
|
|
||||||
while len(snippets) < max_snippets:
|
|
||||||
match_pos = lower_text.find(search_lower, start_pos)
|
|
||||||
|
|
||||||
if match_pos == -1:
|
|
||||||
if not snippets and search_lower.split():
|
|
||||||
first_word = search_lower.split()[0]
|
|
||||||
match_pos = lower_text.find(first_word, start_pos)
|
|
||||||
if match_pos == -1:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
|
||||||
snippet_end = min(
|
|
||||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if snippet_start < last_snippet_end:
|
@staticmethod
|
||||||
start_pos = match_pos + len(search_lower)
|
def create_snippet(
|
||||||
continue
|
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||||
|
) -> SnippetCandidate:
|
||||||
|
"""Create a snippet from a match position."""
|
||||||
|
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||||
|
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||||
|
|
||||||
snippet = text[snippet_start:snippet_end]
|
snippet_text = text[snippet_start:snippet_end]
|
||||||
|
|
||||||
if snippet_start > 0:
|
return SnippetCandidate(
|
||||||
snippet = "..." + snippet
|
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||||
if snippet_end < len(text):
|
)
|
||||||
snippet = snippet + "..."
|
|
||||||
|
|
||||||
snippet = snippet.strip()
|
@staticmethod
|
||||||
|
def filter_non_overlapping(
|
||||||
|
candidates: Iterator[SnippetCandidate],
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Filter out overlapping snippets and return only display text."""
|
||||||
|
last_end = 0
|
||||||
|
for candidate in candidates:
|
||||||
|
display_text = candidate.text()
|
||||||
|
# it means that next overlapping snippets simply don't get included
|
||||||
|
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||||
|
if candidate.start >= last_end and display_text:
|
||||||
|
yield display_text
|
||||||
|
last_end = candidate.end
|
||||||
|
|
||||||
if snippet:
|
@staticmethod
|
||||||
snippets.append(snippet)
|
def generate(
|
||||||
last_snippet_end = snippet_end
|
text: str,
|
||||||
|
query: str,
|
||||||
|
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||||
|
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate snippets from text."""
|
||||||
|
if not text or not query:
|
||||||
|
logger.warning("Empty text or query for generate_snippets")
|
||||||
|
return []
|
||||||
|
|
||||||
start_pos = match_pos + len(search_lower)
|
candidates = (
|
||||||
if start_pos >= len(text):
|
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||||
break
|
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||||
|
)
|
||||||
|
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||||
|
snippets = list(itertools.islice(filtered, max_snippets))
|
||||||
|
|
||||||
|
# Fallback to first word search if no full matches
|
||||||
|
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||||
|
if not snippets and " " in query:
|
||||||
|
first_word = query.split()[0]
|
||||||
|
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||||
|
|
||||||
return snippets
|
return snippets
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_summary(
|
||||||
|
summary: str,
|
||||||
|
query: str,
|
||||||
|
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate snippets from summary text."""
|
||||||
|
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_sources(
|
||||||
|
summary: str | None,
|
||||||
|
webvtt: WebVTTContent | None,
|
||||||
|
query: str,
|
||||||
|
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
|
) -> tuple[list[str], NonNegativeInt]:
|
||||||
|
"""Combine snippets from multiple sources and return total match count.
|
||||||
|
|
||||||
|
Returns (snippets, total_match_count) tuple.
|
||||||
|
|
||||||
|
snippets can be empty for real in case of e.g. title match
|
||||||
|
"""
|
||||||
|
webvtt_matches = 0
|
||||||
|
summary_matches = 0
|
||||||
|
|
||||||
|
if webvtt:
|
||||||
|
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||||
|
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||||
|
|
||||||
|
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||||
|
|
||||||
|
summary_snippets = (
|
||||||
|
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(summary_snippets) >= max_total:
|
||||||
|
return summary_snippets[:max_total], total_matches
|
||||||
|
|
||||||
|
remaining = max_total - len(summary_snippets)
|
||||||
|
webvtt_snippets = (
|
||||||
|
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||||
|
if webvtt
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
return summary_snippets + webvtt_snippets, total_matches
|
||||||
|
|
||||||
|
|
||||||
|
class SearchController:
|
||||||
|
"""Controller for search operations across different entities."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def search_transcripts(
|
async def search_transcripts(
|
||||||
cls, params: SearchParameters
|
cls, params: SearchParameters
|
||||||
@@ -172,12 +336,7 @@ class SearchController:
|
|||||||
)
|
)
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
base_columns = [
|
||||||
"english", params.query_text
|
|
||||||
)
|
|
||||||
|
|
||||||
base_query = sqlalchemy.select(
|
|
||||||
[
|
|
||||||
transcripts.c.id,
|
transcripts.c.id,
|
||||||
transcripts.c.title,
|
transcripts.c.title,
|
||||||
transcripts.c.created_at,
|
transcripts.c.created_at,
|
||||||
@@ -187,24 +346,54 @@ class SearchController:
|
|||||||
transcripts.c.room_id,
|
transcripts.c.room_id,
|
||||||
transcripts.c.source_kind,
|
transcripts.c.source_kind,
|
||||||
transcripts.c.webvtt,
|
transcripts.c.webvtt,
|
||||||
sqlalchemy.func.ts_rank(
|
transcripts.c.long_summary,
|
||||||
|
sqlalchemy.case(
|
||||||
|
(
|
||||||
|
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||||
|
"Deleted Room",
|
||||||
|
),
|
||||||
|
else_=rooms.c.name,
|
||||||
|
).label("room_name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
if params.query_text:
|
||||||
|
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||||
|
"english", params.query_text
|
||||||
|
)
|
||||||
|
rank_column = sqlalchemy.func.ts_rank(
|
||||||
transcripts.c.search_vector_en,
|
transcripts.c.search_vector_en,
|
||||||
search_query,
|
search_query,
|
||||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||||
).label("rank"),
|
).label("rank")
|
||||||
]
|
else:
|
||||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||||
|
|
||||||
|
columns = base_columns + [rank_column]
|
||||||
|
base_query = sqlalchemy.select(columns).select_from(
|
||||||
|
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.query_text:
|
||||||
|
base_query = base_query.where(
|
||||||
|
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||||
|
)
|
||||||
|
|
||||||
if params.user_id:
|
if params.user_id:
|
||||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||||
if params.room_id:
|
if params.room_id:
|
||||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||||
|
if params.source_kind:
|
||||||
query = (
|
base_query = base_query.where(
|
||||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
transcripts.c.source_kind == params.source_kind
|
||||||
.limit(params.limit)
|
|
||||||
.offset(params.offset)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.query_text:
|
||||||
|
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||||
|
else:
|
||||||
|
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||||
|
|
||||||
|
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||||
|
|
||||||
rs = await get_database().fetch_all(query)
|
rs = await get_database().fetch_all(query)
|
||||||
|
|
||||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||||
@@ -214,18 +403,40 @@ class SearchController:
|
|||||||
|
|
||||||
def _process_result(r) -> SearchResult:
|
def _process_result(r) -> SearchResult:
|
||||||
r_dict: Dict[str, Any] = dict(r)
|
r_dict: Dict[str, Any] = dict(r)
|
||||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||||
|
if webvtt_raw:
|
||||||
|
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||||
|
else:
|
||||||
|
webvtt = None
|
||||||
|
long_summary: str | None = r_dict.pop("long_summary", None)
|
||||||
|
room_name: str | None = r_dict.pop("room_name", None)
|
||||||
db_result = SearchResultDB.model_validate(r_dict)
|
db_result = SearchResultDB.model_validate(r_dict)
|
||||||
|
|
||||||
snippets = []
|
snippets, total_match_count = SnippetGenerator.combine_sources(
|
||||||
if webvtt:
|
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||||
plain_text = cls._extract_webvtt_text(webvtt)
|
)
|
||||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
|
||||||
|
|
||||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
return SearchResult(
|
||||||
|
**db_result.model_dump(),
|
||||||
|
room_name=room_name,
|
||||||
|
search_snippets=snippets,
|
||||||
|
total_match_count=total_match_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
results = [_process_result(r) for r in rs]
|
results = [_process_result(r) for r in rs]
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Internal search result data consistency error"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
return results, total
|
return results, total
|
||||||
|
|
||||||
|
|
||||||
search_controller = SearchController()
|
search_controller = SearchController()
|
||||||
|
webvtt_processor = WebVTTProcessor()
|
||||||
|
snippet_generator = SnippetGenerator()
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||||
|
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add PostgreSQL-specific full-text search column
|
# Add PostgreSQL-specific full-text search column
|
||||||
@@ -99,7 +101,8 @@ if is_postgresql():
|
|||||||
TSVECTOR,
|
TSVECTOR,
|
||||||
sqlalchemy.Computed(
|
sqlalchemy.Computed(
|
||||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||||
|
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||||
persisted=True,
|
persisted=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
375
server/reflector/pipelines/main_file_pipeline.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
File-based processing pipeline
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Optimized pipeline for processing complete audio/video files.
|
||||||
|
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
import structlog
|
||||||
|
from celery import shared_task
|
||||||
|
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
Transcript,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
|
||||||
|
from reflector.processors import (
|
||||||
|
AudioFileWriterProcessor,
|
||||||
|
TranscriptFinalSummaryProcessor,
|
||||||
|
TranscriptFinalTitleProcessor,
|
||||||
|
TranscriptTopicDetectorProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationInput
|
||||||
|
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptInput
|
||||||
|
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||||
|
from reflector.processors.transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
DiarizationSegment,
|
||||||
|
TitleSummary,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyPipeline:
|
||||||
|
"""Empty pipeline for processors that need a pipeline reference"""
|
||||||
|
|
||||||
|
def __init__(self, logger: structlog.BoundLogger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def get_pref(self, k, d=None):
|
||||||
|
return d
|
||||||
|
|
||||||
|
async def emit(self, event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainFile(PipelineMainBase):
|
||||||
|
"""
|
||||||
|
Optimized file processing pipeline.
|
||||||
|
Processes complete audio/video files with parallel execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger: structlog.BoundLogger = None
|
||||||
|
empty_pipeline = None
|
||||||
|
|
||||||
|
def __init__(self, transcript_id: str):
|
||||||
|
super().__init__(transcript_id=transcript_id)
|
||||||
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
|
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||||
|
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if not isinstance(result, Exception):
|
||||||
|
continue
|
||||||
|
self.logger.error(
|
||||||
|
f"Error in {operation} (task {i}): {result}",
|
||||||
|
transcript_id=self.transcript_id,
|
||||||
|
exc_info=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process(self, file_path: Path):
|
||||||
|
"""Main entry point for file processing"""
|
||||||
|
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
# Extract audio and write to transcript location
|
||||||
|
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||||
|
|
||||||
|
# Upload for processing
|
||||||
|
audio_url = await self.upload_audio(audio_path, transcript)
|
||||||
|
|
||||||
|
# Run parallel processing
|
||||||
|
await self.run_parallel_processing(
|
||||||
|
audio_path,
|
||||||
|
audio_url,
|
||||||
|
transcript.source_language,
|
||||||
|
transcript.target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("File pipeline complete")
|
||||||
|
|
||||||
|
async def extract_and_write_audio(
|
||||||
|
self, file_path: Path, transcript: Transcript
|
||||||
|
) -> Path:
|
||||||
|
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||||
|
self.logger.info(f"Processing audio file: {file_path}")
|
||||||
|
|
||||||
|
# Check if it's already audio-only
|
||||||
|
container = av.open(str(file_path))
|
||||||
|
has_video = len(container.streams.video) > 0
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||||
|
mp3_writer = AudioFileWriterProcessor(
|
||||||
|
path=transcript.audio_mp3_filename,
|
||||||
|
on_duration=self.on_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process audio frames and write to transcript location
|
||||||
|
input_container = av.open(str(file_path))
|
||||||
|
for frame in input_container.decode(audio=0):
|
||||||
|
await mp3_writer.push(frame)
|
||||||
|
|
||||||
|
await mp3_writer.flush()
|
||||||
|
input_container.close()
|
||||||
|
|
||||||
|
if has_video:
|
||||||
|
self.logger.info(
|
||||||
|
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger.info(
|
||||||
|
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return transcript.audio_mp3_filename
|
||||||
|
|
||||||
|
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||||
|
"""Upload audio to storage for processing"""
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
if not storage:
|
||||||
|
raise Exception(
|
||||||
|
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("Uploading audio to storage")
|
||||||
|
|
||||||
|
with open(audio_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||||
|
await storage.put_file(storage_path, audio_data)
|
||||||
|
|
||||||
|
audio_url = await storage.get_file_url(storage_path)
|
||||||
|
|
||||||
|
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||||
|
return audio_url
|
||||||
|
|
||||||
|
async def run_parallel_processing(
|
||||||
|
self,
|
||||||
|
audio_path: Path,
|
||||||
|
audio_url: str,
|
||||||
|
source_language: str,
|
||||||
|
target_language: str,
|
||||||
|
):
|
||||||
|
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||||
|
self.logger.info(
|
||||||
|
"Starting parallel processing", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: Parallel processing of independent tasks
|
||||||
|
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||||
|
diarization_task = self.diarize_file(audio_url)
|
||||||
|
waveform_task = self.generate_waveform(audio_path)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_result = results[0]
|
||||||
|
diarization_result = results[1]
|
||||||
|
|
||||||
|
# Handle errors - raise any exception that occurred
|
||||||
|
self._handle_gather_exceptions(results, "parallel processing")
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
|
||||||
|
# Phase 2: Assemble transcript with diarization
|
||||||
|
self.logger.info(
|
||||||
|
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
input_data = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript_result, diarization=diarization_result or []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
diarized_transcript: Transcript | None = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal diarized_transcript
|
||||||
|
diarized_transcript = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not diarized_transcript:
|
||||||
|
raise ValueError("No diarized transcript captured")
|
||||||
|
|
||||||
|
# Phase 3: Generate topics from diarized transcript
|
||||||
|
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||||
|
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||||
|
|
||||||
|
# Phase 4: Generate title and summaries in parallel
|
||||||
|
self.logger.info(
|
||||||
|
"Generating title and summaries", transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
|
results = await asyncio.gather(
|
||||||
|
self.generate_title(topics),
|
||||||
|
self.generate_summaries(topics),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle_gather_exceptions(results, "title and summary generation")
|
||||||
|
|
||||||
|
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||||
|
"""Transcribe complete file"""
|
||||||
|
processor = FileTranscriptAutoProcessor()
|
||||||
|
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
result: TranscriptType | None = None
|
||||||
|
|
||||||
|
async def capture_result(transcript):
|
||||||
|
nonlocal result
|
||||||
|
result = transcript
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise ValueError("No transcript captured")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||||
|
"""Get diarization for file"""
|
||||||
|
if not settings.DIARIZATION_BACKEND:
|
||||||
|
self.logger.info("Diarization disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = FileDiarizationAutoProcessor()
|
||||||
|
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||||
|
|
||||||
|
# Store result for retrieval
|
||||||
|
result = None
|
||||||
|
|
||||||
|
async def capture_result(diarization_output):
|
||||||
|
nonlocal result
|
||||||
|
result = diarization_output.diarization
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor.on(capture_result)
|
||||||
|
await processor.push(input_data)
|
||||||
|
await processor.flush()
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Diarization failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def generate_waveform(self, audio_path: Path):
|
||||||
|
"""Generate and save waveform"""
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
processor = AudioWaveformProcessor(
|
||||||
|
audio_path=audio_path,
|
||||||
|
waveform_path=transcript.audio_waveform_filename,
|
||||||
|
on_waveform=self.on_waveform,
|
||||||
|
)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
async def detect_topics(
|
||||||
|
self, transcript: TranscriptType, target_language: str
|
||||||
|
) -> list[TitleSummary]:
|
||||||
|
"""Detect topics from complete transcript"""
|
||||||
|
chunk_size = 300
|
||||||
|
topics: list[TitleSummary] = []
|
||||||
|
|
||||||
|
async def on_topic(topic: TitleSummary):
|
||||||
|
topics.append(topic)
|
||||||
|
return await self.on_topic(topic)
|
||||||
|
|
||||||
|
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||||
|
topic_detector.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for i in range(0, len(transcript.words), chunk_size):
|
||||||
|
chunk_words = transcript.words[i : i + chunk_size]
|
||||||
|
if not chunk_words:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_transcript = TranscriptType(
|
||||||
|
words=chunk_words, translation=transcript.translation
|
||||||
|
)
|
||||||
|
|
||||||
|
await topic_detector.push(chunk_transcript)
|
||||||
|
|
||||||
|
await topic_detector.flush()
|
||||||
|
return topics
|
||||||
|
|
||||||
|
async def generate_title(self, topics: list[TitleSummary]):
|
||||||
|
"""Generate title from topics"""
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for title generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||||
|
"""Generate long and short summaries from topics"""
|
||||||
|
if not topics:
|
||||||
|
self.logger.warning("No topics for summary generation")
|
||||||
|
return
|
||||||
|
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
|
transcript=transcript,
|
||||||
|
callback=self.on_long_summary,
|
||||||
|
on_short_summary=self.on_short_summary,
|
||||||
|
)
|
||||||
|
processor.set_pipeline(self.empty_pipeline)
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
await processor.push(topic)
|
||||||
|
|
||||||
|
await processor.flush()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_file_process(*, transcript_id: str):
|
||||||
|
"""Celery task for file pipeline processing"""
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
|
# Find the file to process
|
||||||
|
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||||
|
if not audio_file:
|
||||||
|
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||||
|
|
||||||
|
if not audio_file:
|
||||||
|
raise Exception("No audio file found to process")
|
||||||
|
|
||||||
|
# Run file pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||||
|
await pipeline.process(audio_file)
|
||||||
@@ -147,15 +147,18 @@ class StrValue(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||||
transcript_id: str
|
def __init__(self, transcript_id: str):
|
||||||
ws_room_id: str | None = None
|
super().__init__()
|
||||||
ws_manager: WebsocketManager | None = None
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
# prepare websocket
|
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self.transcript_id = transcript_id
|
||||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||||
self.ws_manager = get_ws_manager()
|
self._ws_manager = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ws_manager(self) -> WebsocketManager:
|
||||||
|
if self._ws_manager is None:
|
||||||
|
self._ws_manager = get_ws_manager()
|
||||||
|
return self._ws_manager
|
||||||
|
|
||||||
async def get_transcript(self) -> Transcript:
|
async def get_transcript(self) -> Transcript:
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
@@ -355,7 +358,6 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
@@ -376,6 +378,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
pipeline.logger.bind(transcript_id=transcript.id)
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
pipeline.logger.info("Pipeline main live created")
|
pipeline.logger.info("Pipeline main live created")
|
||||||
|
pipeline.describe()
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@@ -394,7 +397,6 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||||
)
|
)
|
||||||
@@ -435,8 +437,6 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
self.prepare()
|
|
||||||
|
|
||||||
# get transcript
|
# get transcript
|
||||||
self._transcript = transcript = await self.get_transcript()
|
self._transcript = transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
|||||||
@@ -18,22 +18,14 @@ During its lifecycle, it will emit the following status:
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import Pipeline
|
from reflector.processors import Pipeline
|
||||||
|
|
||||||
PipelineMessage = TypeVar("PipelineMessage")
|
PipelineMessage = TypeVar("PipelineMessage")
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
class PipelineRunner(Generic[PipelineMessage]):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
def __init__(self):
|
||||||
|
|
||||||
status: str = "idle"
|
|
||||||
pipeline: Pipeline | None = None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._task = None
|
self._task = None
|
||||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||||
self._ev_done = asyncio.Event()
|
self._ev_done = asyncio.Event()
|
||||||
@@ -42,6 +34,8 @@ class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
|||||||
runner=id(self),
|
runner=id(self),
|
||||||
runner_cls=self.__class__.__name__,
|
runner_cls=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
self.status = "idle"
|
||||||
|
self.pipeline: Pipeline | None = None
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ from .base import ( # noqa: F401
|
|||||||
Processor,
|
Processor,
|
||||||
ThreadedProcessor,
|
ThreadedProcessor,
|
||||||
)
|
)
|
||||||
|
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||||
|
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||||
|
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||||
|
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||||
|
from .transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||||
|
)
|
||||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||||
|
|||||||
@@ -1,28 +1,340 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from silero_vad import VADIterator, load_silero_vad
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerProcessor(Processor):
|
class AudioChunkerProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Assemble audio frames into chunks
|
Assemble audio frames into chunks with VAD-based speech detection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = av.AudioFrame
|
INPUT_TYPE = av.AudioFrame
|
||||||
OUTPUT_TYPE = list[av.AudioFrame]
|
OUTPUT_TYPE = list[av.AudioFrame]
|
||||||
|
|
||||||
def __init__(self, max_frames=256):
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_frames=256,
|
||||||
|
max_frames=1024,
|
||||||
|
vad_threshold=0.5,
|
||||||
|
use_onnx=False,
|
||||||
|
min_frames=2,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.frames: list[av.AudioFrame] = []
|
self.frames: list[av.AudioFrame] = []
|
||||||
|
self.block_frames = block_frames
|
||||||
self.max_frames = max_frames
|
self.max_frames = max_frames
|
||||||
|
self.vad_threshold = vad_threshold
|
||||||
|
self.min_frames = min_frames
|
||||||
|
|
||||||
|
# Initialize Silero VAD
|
||||||
|
self._init_vad(use_onnx)
|
||||||
|
|
||||||
|
def _init_vad(self, use_onnx=False):
|
||||||
|
"""Initialize Silero VAD model"""
|
||||||
|
try:
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||||
|
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||||
|
self.logger.info("Silero VAD initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||||
|
self.vad_model = None
|
||||||
|
self.vad_iterator = None
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
self.frames.append(data)
|
self.frames.append(data)
|
||||||
if len(self.frames) >= self.max_frames:
|
# print("timestamp", data.pts * data.time_base * 1000)
|
||||||
await self.flush()
|
|
||||||
|
# Check for speech segments every 32 frames (~1 second)
|
||||||
|
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||||
|
await self._process_block()
|
||||||
|
|
||||||
|
# Safety fallback - emit if we hit max frames
|
||||||
|
elif len(self.frames) >= self.max_frames:
|
||||||
|
self.logger.warning(
|
||||||
|
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
|
||||||
|
f"emitting first {self.max_frames // 2} frames"
|
||||||
|
)
|
||||||
|
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||||
|
self.frames = self.frames[self.max_frames // 2 :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_block(self):
|
||||||
|
# Need at least 32 frames for VAD detection (~1 second)
|
||||||
|
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Processing block with current buffer size
|
||||||
|
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert frames to numpy array for VAD
|
||||||
|
audio_array = self._frames_to_numpy(self.frames)
|
||||||
|
|
||||||
|
if audio_array is None:
|
||||||
|
# Fallback: emit all frames if conversion failed
|
||||||
|
frames_to_emit = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find complete speech segments in the buffer
|
||||||
|
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||||
|
|
||||||
|
if speech_end_frame is None or speech_end_frame <= 0:
|
||||||
|
# No speech found but buffer is getting large
|
||||||
|
if len(self.frames) > 512:
|
||||||
|
# Check if it's all silence and can be discarded
|
||||||
|
# No speech segment found, buffer at {len(self.frames)} frames
|
||||||
|
|
||||||
|
# Could emit silence or discard old frames here
|
||||||
|
# For now, keep first 256 frames and discard older silence
|
||||||
|
if len(self.frames) > 768:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||||
|
)
|
||||||
|
self.frames = self.frames[-256:]
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate segment timing information
|
||||||
|
frames_to_emit = self.frames[:speech_end_frame]
|
||||||
|
|
||||||
|
# Get timing from av.AudioFrame
|
||||||
|
if frames_to_emit:
|
||||||
|
first_frame = frames_to_emit[0]
|
||||||
|
last_frame = frames_to_emit[-1]
|
||||||
|
sample_rate = first_frame.sample_rate
|
||||||
|
|
||||||
|
# Calculate duration
|
||||||
|
total_samples = sum(f.samples for f in frames_to_emit)
|
||||||
|
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||||
|
|
||||||
|
# Get timestamps if available
|
||||||
|
start_time = (
|
||||||
|
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to HH:MM:SS format for logging
|
||||||
|
def format_time(seconds):
|
||||||
|
if not seconds:
|
||||||
|
return "00:00:00"
|
||||||
|
total_seconds = int(float(seconds))
|
||||||
|
hours = total_seconds // 3600
|
||||||
|
minutes = (total_seconds % 3600) // 60
|
||||||
|
secs = total_seconds % 60
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||||
|
|
||||||
|
start_formatted = format_time(start_time)
|
||||||
|
end_formatted = format_time(end_time)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
remaining_after = len(self.frames) - speech_end_frame
|
||||||
|
|
||||||
|
# Single structured log line
|
||||||
|
self.logger.info(
|
||||||
|
"Speech segment found",
|
||||||
|
start=start_formatted,
|
||||||
|
end=end_formatted,
|
||||||
|
frames=speech_end_frame,
|
||||||
|
duration=round(duration_seconds, 2),
|
||||||
|
buffer_before=len(self.frames),
|
||||||
|
remaining=remaining_after,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
self.frames = self.frames[speech_end_frame:]
|
||||||
|
|
||||||
|
# Filter out segments with too few frames
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in VAD processing: {e}")
|
||||||
|
# Fallback to simple chunking
|
||||||
|
if len(self.frames) >= self.block_frames:
|
||||||
|
frames_to_emit = self.frames[: self.block_frames]
|
||||||
|
self.frames = self.frames[self.block_frames :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
await self.emit(frames_to_emit)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||||
|
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||||
|
if not frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_frame = frames[0]
|
||||||
|
original_sample_rate = first_frame.sample_rate
|
||||||
|
|
||||||
|
audio_data = []
|
||||||
|
for frame in frames:
|
||||||
|
frame_array = frame.to_ndarray()
|
||||||
|
|
||||||
|
# Handle stereo -> mono conversion
|
||||||
|
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
|
||||||
|
frame_array = np.mean(frame_array, axis=0)
|
||||||
|
elif len(frame_array.shape) == 2:
|
||||||
|
frame_array = frame_array.flatten()
|
||||||
|
|
||||||
|
audio_data.append(frame_array)
|
||||||
|
|
||||||
|
if not audio_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
combined_audio = np.concatenate(audio_data)
|
||||||
|
|
||||||
|
# Resample from 48kHz to 16kHz if needed
|
||||||
|
if original_sample_rate != 16000:
|
||||||
|
combined_audio = self._resample_audio(
|
||||||
|
combined_audio, original_sample_rate, 16000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure float32 format
|
||||||
|
if combined_audio.dtype == np.int16:
|
||||||
|
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||||
|
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||||
|
elif combined_audio.dtype != np.float32:
|
||||||
|
combined_audio = combined_audio.astype(np.float32)
|
||||||
|
|
||||||
|
return combined_audio
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resample_audio(
|
||||||
|
self, audio: np.ndarray, from_sr: int, to_sr: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Simple linear resampling from from_sr to to_sr"""
|
||||||
|
if from_sr == to_sr:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Simple linear interpolation resampling
|
||||||
|
ratio = to_sr / from_sr
|
||||||
|
new_length = int(len(audio) * ratio)
|
||||||
|
|
||||||
|
# Create indices for interpolation
|
||||||
|
old_indices = np.linspace(0, len(audio) - 1, new_length)
|
||||||
|
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
|
||||||
|
|
||||||
|
return resampled.astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Resampling error", exc_info=e)
|
||||||
|
# Fallback: simple decimation/repetition
|
||||||
|
if from_sr > to_sr:
|
||||||
|
# Downsample by taking every nth sample
|
||||||
|
step = from_sr // to_sr
|
||||||
|
return audio[::step]
|
||||||
|
else:
|
||||||
|
# Upsample by repeating samples
|
||||||
|
repeat = to_sr // from_sr
|
||||||
|
return np.repeat(audio, repeat)
|
||||||
|
|
||||||
|
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||||
|
"""Find complete speech segments and return frame index at segment end"""
|
||||||
|
if self.vad_iterator is None or len(audio_array) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process audio in 512-sample windows for VAD
|
||||||
|
window_size = 512
|
||||||
|
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||||
|
|
||||||
|
# Track speech state
|
||||||
|
in_speech = False
|
||||||
|
speech_start = None
|
||||||
|
speech_end = None
|
||||||
|
silence_count = 0
|
||||||
|
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||||
|
|
||||||
|
# Detect if this window has speech
|
||||||
|
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||||
|
|
||||||
|
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||||
|
if speech_dict:
|
||||||
|
if not in_speech:
|
||||||
|
# Speech started
|
||||||
|
speech_start = i
|
||||||
|
in_speech = True
|
||||||
|
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||||
|
silence_count = 0 # Reset silence counter
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not in_speech:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We're in speech but found silence
|
||||||
|
silence_count += 1
|
||||||
|
if silence_count < min_silence_windows:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Found end of speech segment
|
||||||
|
speech_end = i - (min_silence_windows - 1) * window_size
|
||||||
|
# Debug: print(f"Speech END at sample {speech_end}")
|
||||||
|
|
||||||
|
# Convert sample position to frame index
|
||||||
|
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||||
|
# Account for resampling: we process at 16kHz but frames might be 48kHz
|
||||||
|
resample_ratio = 48000 / 16000 # 3x
|
||||||
|
actual_sample_pos = int(speech_end * resample_ratio)
|
||||||
|
frame_index = actual_sample_pos // samples_per_frame
|
||||||
|
|
||||||
|
# Ensure we don't exceed buffer
|
||||||
|
frame_index = min(frame_index, len(self.frames))
|
||||||
|
return frame_index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding speech segment: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
frames = self.frames[:]
|
frames = self.frames[:]
|
||||||
self.frames = []
|
self.frames = []
|
||||||
if frames:
|
if frames:
|
||||||
|
if len(frames) >= self.min_frames:
|
||||||
await self.emit(frames)
|
await self.emit(frames)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring flush segment with {len(frames)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import (
|
from reflector.processors.types import (
|
||||||
AudioDiarizationInput,
|
AudioDiarizationInput,
|
||||||
|
DiarizationSegment,
|
||||||
TitleSummary,
|
TitleSummary,
|
||||||
Word,
|
Word,
|
||||||
)
|
)
|
||||||
@@ -37,18 +38,21 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
async def _diarize(self, data: AudioDiarizationInput):
|
async def _diarize(self, data: AudioDiarizationInput):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
@classmethod
|
||||||
self._diarization_remove_overlap(diarization)
|
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||||
self._diarization_remove_segment_without_words(words, diarization)
|
cls._diarization_remove_overlap(diarization)
|
||||||
self._diarization_merge_same_speaker(words, diarization)
|
cls._diarization_remove_segment_without_words(words, diarization)
|
||||||
self._diarization_assign_speaker(words, diarization)
|
cls._diarization_merge_same_speaker(diarization)
|
||||||
|
cls._diarization_assign_speaker(words, diarization)
|
||||||
|
|
||||||
def iter_words_from_topics(self, topics: TitleSummary):
|
@staticmethod
|
||||||
|
def iter_words_from_topics(topics: list[TitleSummary]):
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
for word in topic.transcript.words:
|
for word in topic.transcript.words:
|
||||||
yield word
|
yield word
|
||||||
|
|
||||||
def is_word_continuation(self, word_prev, word):
|
@staticmethod
|
||||||
|
def is_word_continuation(word_prev, word):
|
||||||
"""
|
"""
|
||||||
Return True if the word is a continuation of the previous word
|
Return True if the word is a continuation of the previous word
|
||||||
by checking if the previous word is ending with a punctuation
|
by checking if the previous word is ending with a punctuation
|
||||||
@@ -61,7 +65,8 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
@staticmethod
|
||||||
|
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||||
"""
|
"""
|
||||||
Remove overlap in diarization results
|
Remove overlap in diarization results
|
||||||
|
|
||||||
@@ -86,8 +91,9 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _diarization_remove_segment_without_words(
|
def _diarization_remove_segment_without_words(
|
||||||
self, words: list[Word], diarization: list[dict]
|
words: list[Word], diarization: list[DiarizationSegment]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Remove diarization segments without words
|
Remove diarization segments without words
|
||||||
@@ -116,9 +122,8 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
def _diarization_merge_same_speaker(
|
@staticmethod
|
||||||
self, words: list[Word], diarization: list[dict]
|
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Merge diarization contigous segments with the same speaker
|
Merge diarization contigous segments with the same speaker
|
||||||
|
|
||||||
@@ -135,7 +140,10 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
@classmethod
|
||||||
|
def _diarization_assign_speaker(
|
||||||
|
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Assign speaker to words based on diarization
|
Assign speaker to words based on diarization
|
||||||
|
|
||||||
@@ -143,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
word_idx = 0
|
word_idx = 0
|
||||||
last_speaker = None
|
last_speaker = 0
|
||||||
for d in diarization:
|
for d in diarization:
|
||||||
start = d["start"]
|
start = d["start"]
|
||||||
end = d["end"]
|
end = d["end"]
|
||||||
@@ -158,7 +166,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
# If it's a continuation, assign with the last speaker
|
# If it's a continuation, assign with the last speaker
|
||||||
is_continuation = False
|
is_continuation = False
|
||||||
if word_idx > 0 and word_idx < len(words) - 1:
|
if word_idx > 0 and word_idx < len(words) - 1:
|
||||||
is_continuation = self.is_word_continuation(
|
is_continuation = cls.is_word_continuation(
|
||||||
*words[word_idx - 1 : word_idx + 1]
|
*words[word_idx - 1 : word_idx + 1]
|
||||||
)
|
)
|
||||||
if is_continuation:
|
if is_continuation:
|
||||||
|
|||||||
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
|
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||||
|
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||||
|
"""Local diarization processor using pyannote.audio library"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||||
|
pyannote_auth_token: str | None = None,
|
||||||
|
device: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||||
|
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||||
|
self.model_name, use_auth_token=self.auth_token
|
||||||
|
)
|
||||||
|
self.diarization_pipeline.to(torch.device(self.device))
|
||||||
|
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||||
|
|
||||||
|
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||||
|
try:
|
||||||
|
# Load audio file (audio_url is assumed to be a local file path)
|
||||||
|
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||||
|
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||||
|
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||||
|
self.logger.info("Running speaker diarization")
|
||||||
|
diarization = self.diarization_pipeline(audio_input)
|
||||||
|
|
||||||
|
# Convert pyannote diarization output to our format
|
||||||
|
segments = []
|
||||||
|
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||||
|
speaker_id = 0
|
||||||
|
if speaker.startswith("SPEAKER_"):
|
||||||
|
try:
|
||||||
|
speaker_id = int(speaker.split("_")[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback to hash-based ID if parsing fails
|
||||||
|
speaker_id = hash(speaker) % 1000
|
||||||
|
|
||||||
|
segments.append(
|
||||||
|
{
|
||||||
|
"start": round(segment.start, 3),
|
||||||
|
"end": round(segment.end, 3),
|
||||||
|
"speaker": speaker_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||||
|
return segments
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Diarization failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||||
@@ -3,11 +3,24 @@ from time import monotonic_ns
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
from av.audio.resampler import AudioResampler
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import AudioFile
|
from reflector.processors.types import AudioFile
|
||||||
|
|
||||||
|
|
||||||
|
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||||
|
frame_copy = frame.from_ndarray(
|
||||||
|
frame.to_ndarray(),
|
||||||
|
format=frame.format.name,
|
||||||
|
layout=frame.layout.name,
|
||||||
|
)
|
||||||
|
frame_copy.sample_rate = frame.sample_rate
|
||||||
|
frame_copy.pts = frame.pts
|
||||||
|
frame_copy.time_base = frame.time_base
|
||||||
|
return frame_copy
|
||||||
|
|
||||||
|
|
||||||
class AudioMergeProcessor(Processor):
|
class AudioMergeProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Merge audio frame into a single file
|
Merge audio frame into a single file
|
||||||
@@ -16,37 +29,92 @@ class AudioMergeProcessor(Processor):
|
|||||||
INPUT_TYPE = list[av.AudioFrame]
|
INPUT_TYPE = list[av.AudioFrame]
|
||||||
OUTPUT_TYPE = AudioFile
|
OUTPUT_TYPE = AudioFile
|
||||||
|
|
||||||
|
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.downsample_to_16k_mono = downsample_to_16k_mono
|
||||||
|
|
||||||
async def _push(self, data: list[av.AudioFrame]):
|
async def _push(self, data: list[av.AudioFrame]):
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get audio information from first frame
|
# get audio information from first frame
|
||||||
frame = data[0]
|
frame = data[0]
|
||||||
channels = len(frame.layout.channels)
|
original_channels = len(frame.layout.channels)
|
||||||
sample_rate = frame.sample_rate
|
original_sample_rate = frame.sample_rate
|
||||||
sample_width = frame.format.bytes
|
original_sample_width = frame.format.bytes
|
||||||
|
|
||||||
|
# determine if we need processing
|
||||||
|
needs_processing = self.downsample_to_16k_mono and (
|
||||||
|
original_sample_rate != 16000 or original_channels != 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# determine output parameters
|
||||||
|
if self.downsample_to_16k_mono:
|
||||||
|
output_sample_rate = 16000
|
||||||
|
output_channels = 1
|
||||||
|
output_sample_width = 2 # 16-bit = 2 bytes
|
||||||
|
else:
|
||||||
|
output_sample_rate = original_sample_rate
|
||||||
|
output_channels = original_channels
|
||||||
|
output_sample_width = original_sample_width
|
||||||
|
|
||||||
# create audio file
|
# create audio file
|
||||||
uu = uuid4().hex
|
uu = uuid4().hex
|
||||||
fd = io.BytesIO()
|
fd = io.BytesIO()
|
||||||
|
|
||||||
|
if needs_processing:
|
||||||
|
# Process with PyAV resampler
|
||||||
out_container = av.open(fd, "w", format="wav")
|
out_container = av.open(fd, "w", format="wav")
|
||||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
|
||||||
|
out_stream.layout = "mono"
|
||||||
|
|
||||||
|
# Create resampler if needed
|
||||||
|
resampler = None
|
||||||
|
if original_sample_rate != 16000 or original_channels != 1:
|
||||||
|
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
||||||
|
|
||||||
for frame in data:
|
for frame in data:
|
||||||
|
if resampler:
|
||||||
|
# Resample and convert to mono
|
||||||
|
# XXX for an unknown reason, if we don't use a copy of the frame, we get
|
||||||
|
# Invalid Argumment from resample. Debugging indicate that when a previous processor
|
||||||
|
# already used the frame (like AudioFileWriter), it make it invalid argument here.
|
||||||
|
resampled_frames = resampler.resample(copy_frame(frame))
|
||||||
|
for resampled_frame in resampled_frames:
|
||||||
|
for packet in out_stream.encode(resampled_frame):
|
||||||
|
out_container.mux(packet)
|
||||||
|
else:
|
||||||
|
# Direct encoding without resampling
|
||||||
for packet in out_stream.encode(frame):
|
for packet in out_stream.encode(frame):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
|
|
||||||
|
# Flush the encoder
|
||||||
for packet in out_stream.encode(None):
|
for packet in out_stream.encode(None):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
out_container.close()
|
out_container.close()
|
||||||
|
else:
|
||||||
|
# Use PyAV for original frames (no processing needed)
|
||||||
|
out_container = av.open(fd, "w", format="wav")
|
||||||
|
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||||
|
out_stream.layout = "mono" if output_channels == 1 else frame.layout
|
||||||
|
|
||||||
|
for frame in data:
|
||||||
|
for packet in out_stream.encode(frame):
|
||||||
|
out_container.mux(packet)
|
||||||
|
|
||||||
|
for packet in out_stream.encode(None):
|
||||||
|
out_container.mux(packet)
|
||||||
|
out_container.close()
|
||||||
|
|
||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
|
|
||||||
# emit audio file
|
# emit audio file
|
||||||
audiofile = AudioFile(
|
audiofile = AudioFile(
|
||||||
name=f"{monotonic_ns()}-{uu}.wav",
|
name=f"{monotonic_ns()}-{uu}.wav",
|
||||||
fd=fd,
|
fd=fd,
|
||||||
sample_rate=sample_rate,
|
sample_rate=output_sample_rate,
|
||||||
channels=channels,
|
channels=output_channels,
|
||||||
sample_width=sample_width,
|
sample_width=output_sample_width,
|
||||||
timestamp=data[0].pts * data[0].time_base,
|
timestamp=data[0].pts * data[0].time_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ API will be a POST request to TRANSCRIPT_URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
@@ -21,7 +24,9 @@ from reflector.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
def __init__(
|
||||||
|
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not settings.TRANSCRIPT_URL:
|
if not settings.TRANSCRIPT_URL:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -30,6 +35,126 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||||
self.modal_api_key = modal_api_key
|
self.modal_api_key = modal_api_key
|
||||||
|
self.max_batch_duration = 10.0
|
||||||
|
self.max_batch_files = 15
|
||||||
|
self.batch_enabled = batch_enabled
|
||||||
|
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_duration(cls, audio_file: AudioFile) -> float:
|
||||||
|
"""Calculate audio duration in seconds from AudioFile metadata"""
|
||||||
|
# Duration = total_samples / sample_rate
|
||||||
|
# We need to estimate total samples from the file data
|
||||||
|
import wave
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to read as WAV file to get duration
|
||||||
|
audio_file.fd.seek(0)
|
||||||
|
with wave.open(audio_file.fd, "rb") as wav_file:
|
||||||
|
frames = wav_file.getnframes()
|
||||||
|
sample_rate = wav_file.getframerate()
|
||||||
|
duration = frames / sample_rate
|
||||||
|
return duration
|
||||||
|
except Exception:
|
||||||
|
# Fallback: estimate from file size and audio parameters
|
||||||
|
audio_file.fd.seek(0, 2) # Seek to end
|
||||||
|
file_size = audio_file.fd.tell()
|
||||||
|
audio_file.fd.seek(0) # Reset to beginning
|
||||||
|
|
||||||
|
# Estimate: file_size / (sample_rate * channels * sample_width)
|
||||||
|
bytes_per_second = (
|
||||||
|
audio_file.sample_rate
|
||||||
|
* audio_file.channels
|
||||||
|
* (audio_file.sample_width // 8)
|
||||||
|
)
|
||||||
|
estimated_duration = (
|
||||||
|
file_size / bytes_per_second if bytes_per_second > 0 else 0
|
||||||
|
)
|
||||||
|
return max(0, estimated_duration)
|
||||||
|
|
||||||
|
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
|
||||||
|
"""Group audio files into batches with maximum 30s total duration"""
|
||||||
|
batches = []
|
||||||
|
current_batch = []
|
||||||
|
current_duration = 0.0
|
||||||
|
|
||||||
|
for audio_file in audio_files:
|
||||||
|
duration = self._calculate_duration(audio_file)
|
||||||
|
|
||||||
|
# If adding this file exceeds max duration, start a new batch
|
||||||
|
if current_duration + duration > self.max_batch_duration and current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = [audio_file]
|
||||||
|
current_duration = duration
|
||||||
|
else:
|
||||||
|
current_batch.append(audio_file)
|
||||||
|
current_duration += duration
|
||||||
|
|
||||||
|
# Add the last batch if not empty
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
|
||||||
|
"""Transcribe a batch of audio files using the parakeet backend"""
|
||||||
|
if not audio_files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
|
||||||
|
|
||||||
|
# Prepare form data for batch request
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field("language", self.get_pref("audio:source_language", "en"))
|
||||||
|
data.add_field("batch", "true")
|
||||||
|
|
||||||
|
for i, audio_file in enumerate(audio_files):
|
||||||
|
audio_file.fd.seek(0)
|
||||||
|
data.add_field(
|
||||||
|
"files",
|
||||||
|
audio_file.fd,
|
||||||
|
filename=f"{audio_file.name}",
|
||||||
|
content_type="audio/wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make batch request
|
||||||
|
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.transcript_url}/audio/transcriptions",
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise Exception(
|
||||||
|
f"Batch transcription failed: {response.status} {error_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
|
||||||
|
# Process batch results
|
||||||
|
transcripts = []
|
||||||
|
results = result.get("results", [])
|
||||||
|
|
||||||
|
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
|
||||||
|
transcript = Transcript(
|
||||||
|
words=[
|
||||||
|
Word(
|
||||||
|
text=word_info["word"],
|
||||||
|
start=word_info["start"],
|
||||||
|
end=word_info["end"],
|
||||||
|
)
|
||||||
|
for word_info in file_result.get("words", [])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transcript.add_offset(audio_file.timestamp)
|
||||||
|
transcripts.append(transcript)
|
||||||
|
|
||||||
|
return transcripts
|
||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
async with AsyncOpenAI(
|
async with AsyncOpenAI(
|
||||||
@@ -62,5 +187,96 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
|
async def transcript_multiple(
|
||||||
|
self, audio_files: List[AudioFile]
|
||||||
|
) -> List[Transcript]:
|
||||||
|
"""Transcribe multiple audio files using batching"""
|
||||||
|
if len(audio_files) == 1:
|
||||||
|
# Single file, use existing method
|
||||||
|
return [await self._transcript(audio_files[0])]
|
||||||
|
|
||||||
|
# Create batches with max 30s duration each
|
||||||
|
batches = self._create_batches(audio_files)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Processing {len(audio_files)} files in {len(batches)} batches"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process all batches concurrently
|
||||||
|
all_transcripts = []
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_transcripts = await self._transcript_batch(batch)
|
||||||
|
all_transcripts.extend(batch_transcripts)
|
||||||
|
|
||||||
|
return all_transcripts
|
||||||
|
|
||||||
|
async def _push(self, data: AudioFile):
|
||||||
|
"""Override _push to support batching"""
|
||||||
|
if not self.batch_enabled:
|
||||||
|
# Use parent implementation for single file processing
|
||||||
|
return await super()._push(data)
|
||||||
|
|
||||||
|
# Add file to pending batch
|
||||||
|
self.pending_files.append(data)
|
||||||
|
self.logger.debug(
|
||||||
|
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate total duration of pending files
|
||||||
|
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
|
||||||
|
|
||||||
|
# Process batch if it reaches max duration or has multiple files ready for optimization
|
||||||
|
should_process_batch = (
|
||||||
|
total_duration >= self.max_batch_duration
|
||||||
|
or len(self.pending_files) >= self.max_batch_files
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_process_batch:
|
||||||
|
await self._process_pending_batch()
|
||||||
|
|
||||||
|
async def _process_pending_batch(self):
|
||||||
|
"""Process all pending files as batches"""
|
||||||
|
if not self.pending_files:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create batches respecting duration limit
|
||||||
|
batches = self._create_batches(self.pending_files)
|
||||||
|
|
||||||
|
# Process each batch
|
||||||
|
for batch in batches:
|
||||||
|
self.m_transcript_call.inc()
|
||||||
|
try:
|
||||||
|
with self.m_transcript.time():
|
||||||
|
# Use batch transcription
|
||||||
|
transcripts = await self._transcript_batch(batch)
|
||||||
|
|
||||||
|
self.m_transcript_success.inc()
|
||||||
|
|
||||||
|
# Emit each transcript
|
||||||
|
for transcript in transcripts:
|
||||||
|
if transcript:
|
||||||
|
await self.emit(transcript)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
self.m_transcript_failure.inc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Release audio files
|
||||||
|
for audio_file in batch:
|
||||||
|
audio_file.release()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clear pending files
|
||||||
|
self.pending_files.clear()
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
"""Process any remaining files when flushing"""
|
||||||
|
await self._process_pending_batch()
|
||||||
|
await super()._flush()
|
||||||
|
|
||||||
|
|
||||||
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class Processor(Emitter):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.m_processor_failure.inc()
|
self.m_processor_failure.inc()
|
||||||
self.logger.exception("Error in push")
|
self.logger.exception("Error in push")
|
||||||
|
raise
|
||||||
|
|
||||||
async def flush(self):
|
async def flush(self):
|
||||||
"""
|
"""
|
||||||
@@ -240,14 +241,15 @@ class ThreadedProcessor(Processor):
|
|||||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue(maxsize=50)
|
||||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
self.task: asyncio.Task | None = None
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: "Pipeline"):
|
def set_pipeline(self, pipeline: "Pipeline"):
|
||||||
super().set_pipeline(pipeline)
|
super().set_pipeline(pipeline)
|
||||||
self.processor.set_pipeline(pipeline)
|
self.processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
async def loop(self):
|
async def loop(self):
|
||||||
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await self.queue.get()
|
data = await self.queue.get()
|
||||||
self.m_processor_queue.set(self.queue.qsize())
|
self.m_processor_queue.set(self.queue.qsize())
|
||||||
@@ -265,8 +267,19 @@ class ThreadedProcessor(Processor):
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||||
|
|
||||||
|
async def _ensure_task(self):
|
||||||
|
if self.task is None:
|
||||||
|
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||||
|
|
||||||
|
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||||
|
# to be running without having a chance to work on the task here.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def _push(self, data):
|
async def _push(self, data):
|
||||||
|
await self._ensure_task()
|
||||||
await self.queue.put(data)
|
await self.queue.put(data)
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
|
|||||||
33
server/reflector/processors/file_diarization.py
Normal file
33
server/reflector/processors/file_diarization.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import DiarizationSegment
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationInput(BaseModel):
|
||||||
|
"""Input for file diarization containing audio URL"""
|
||||||
|
|
||||||
|
audio_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationOutput(BaseModel):
|
||||||
|
"""Output for file diarization containing speaker segments"""
|
||||||
|
|
||||||
|
diarization: list[DiarizationSegment]
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Diarize complete audio files from URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = FileDiarizationInput
|
||||||
|
OUTPUT_TYPE = FileDiarizationOutput
|
||||||
|
|
||||||
|
async def _push(self, data: FileDiarizationInput):
|
||||||
|
result = await self._diarize(data)
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
|
||||||
|
async def _diarize(self, data: FileDiarizationInput):
|
||||||
|
raise NotImplementedError
|
||||||
33
server/reflector/processors/file_diarization_auto.py
Normal file
33
server/reflector/processors/file_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.DIARIZATION_BACKEND
|
||||||
|
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.file_diarization_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "DIARIZATION_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
57
server/reflector/processors/file_diarization_modal.py
Normal file
57
server/reflector/processors/file_diarization_modal.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
File diarization implementation using the GPU service from modal.com
|
||||||
|
|
||||||
|
API will be a POST request to DIARIZATION_URL:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /diarize?audio_file_url=...×tamp=0
|
||||||
|
Authorization: Bearer <modal_api_key>
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import (
|
||||||
|
FileDiarizationInput,
|
||||||
|
FileDiarizationOutput,
|
||||||
|
FileDiarizationProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||||
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not settings.DIARIZATION_URL:
|
||||||
|
raise Exception(
|
||||||
|
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||||
|
)
|
||||||
|
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||||
|
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key
|
||||||
|
|
||||||
|
async def _diarize(self, data: FileDiarizationInput):
|
||||||
|
"""Get speaker diarization for file"""
|
||||||
|
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.modal_api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.diarization_url,
|
||||||
|
headers=headers,
|
||||||
|
params={
|
||||||
|
"audio_file_url": data.audio_url,
|
||||||
|
"timestamp": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
diarization_data = response.json()["diarization"]
|
||||||
|
|
||||||
|
return FileDiarizationOutput(diarization=diarization_data)
|
||||||
|
|
||||||
|
|
||||||
|
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||||
65
server/reflector/processors/file_transcript.py
Normal file
65
server/reflector/processors/file_transcript.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import Transcript
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptInput:
|
||||||
|
"""Input for file transcription containing audio URL and language settings"""
|
||||||
|
|
||||||
|
def __init__(self, audio_url: str, language: str = "en"):
|
||||||
|
self.audio_url = audio_url
|
||||||
|
self.language = language
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Transcript complete audio files from URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = FileTranscriptInput
|
||||||
|
OUTPUT_TYPE = Transcript
|
||||||
|
|
||||||
|
m_transcript = Histogram(
|
||||||
|
"file_transcript",
|
||||||
|
"Time spent in FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_call = Counter(
|
||||||
|
"file_transcript_call",
|
||||||
|
"Number of calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_success = Counter(
|
||||||
|
"file_transcript_success",
|
||||||
|
"Number of successful calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_transcript_failure = Counter(
|
||||||
|
"file_transcript_failure",
|
||||||
|
"Number of failed calls to FileTranscript.transcript",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
name = self.__class__.__name__
|
||||||
|
self.m_transcript = self.m_transcript.labels(name)
|
||||||
|
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||||
|
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||||
|
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _push(self, data: FileTranscriptInput):
|
||||||
|
try:
|
||||||
|
self.m_transcript_call.inc()
|
||||||
|
with self.m_transcript.time():
|
||||||
|
result = await self._transcript(data)
|
||||||
|
self.m_transcript_success.inc()
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
except Exception:
|
||||||
|
self.m_transcript_failure.inc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _transcript(self, data: FileTranscriptInput):
|
||||||
|
raise NotImplementedError
|
||||||
32
server/reflector/processors/file_transcript_auto.py
Normal file
32
server/reflector/processors/file_transcript_auto.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.TRANSCRIPT_BACKEND
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.file_transcript_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "TRANSCRIPT_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
74
server/reflector/processors/file_transcript_modal.py
Normal file
74
server/reflector/processors/file_transcript_modal.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
File transcription implementation using the GPU service from modal.com
|
||||||
|
|
||||||
|
API will be a POST request to TRANSCRIPT_URL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_file_url": "https://...",
|
||||||
|
"language": "en",
|
||||||
|
"model": "parakeet-tdt-0.6b-v2",
|
||||||
|
"batch": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.processors.file_transcript import (
|
||||||
|
FileTranscriptInput,
|
||||||
|
FileTranscriptProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||||
|
from reflector.processors.types import Transcript, Word
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||||
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not settings.TRANSCRIPT_URL:
|
||||||
|
raise Exception(
|
||||||
|
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||||
|
)
|
||||||
|
self.transcript_url = settings.TRANSCRIPT_URL
|
||||||
|
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key
|
||||||
|
|
||||||
|
async def _transcript(self, data: FileTranscriptInput):
|
||||||
|
"""Send full file to Modal for transcription"""
|
||||||
|
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||||
|
|
||||||
|
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.modal_api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"audio_file_url": data.audio_url,
|
||||||
|
"language": data.language,
|
||||||
|
"batch": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
words = [
|
||||||
|
Word(
|
||||||
|
text=word_info["word"],
|
||||||
|
start=word_info["start"],
|
||||||
|
end=word_info["end"],
|
||||||
|
)
|
||||||
|
for word_info in result.get("words", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return Transcript(words=words)
|
||||||
|
|
||||||
|
|
||||||
|
# Register with the auto processor
|
||||||
|
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Processor to assemble transcript with diarization results
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import DiarizationSegment, Transcript
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptDiarizationAssemblerInput:
|
||||||
|
"""Input containing transcript and diarization data"""
|
||||||
|
|
||||||
|
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||||
|
self.transcript = transcript
|
||||||
|
self.diarization = diarization
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Assemble transcript with diarization results by applying speaker assignments
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||||
|
OUTPUT_TYPE = Transcript
|
||||||
|
|
||||||
|
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||||
|
result = await self._assemble(data)
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
|
||||||
|
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||||
|
"""Apply diarization to transcript words"""
|
||||||
|
if not data.diarization:
|
||||||
|
self.logger.info(
|
||||||
|
"No diarization data provided, returning original transcript"
|
||||||
|
)
|
||||||
|
return data.transcript
|
||||||
|
|
||||||
|
# Reuse logic from AudioDiarizationProcessor
|
||||||
|
processor = AudioDiarizationProcessor()
|
||||||
|
words = data.transcript.words
|
||||||
|
processor.assign_speaker(words, data.diarization)
|
||||||
|
|
||||||
|
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||||
|
return data.transcript
|
||||||
@@ -2,13 +2,22 @@ import io
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated
|
from typing import Annotated, TypedDict
|
||||||
|
|
||||||
from profanityfilter import ProfanityFilter
|
from profanityfilter import ProfanityFilter
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from reflector.redis_cache import redis_cache
|
from reflector.redis_cache import redis_cache
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationSegment(TypedDict):
|
||||||
|
"""Type definition for diarization segment containing speaker information"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
speaker: int
|
||||||
|
|
||||||
|
|
||||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||||
|
|
||||||
profanity_filter = ProfanityFilter()
|
profanity_filter = ProfanityFilter()
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_BACKEND: str = "whisper"
|
TRANSCRIPT_BACKEND: str = "whisper"
|
||||||
TRANSCRIPT_URL: str | None = None
|
TRANSCRIPT_URL: str | None = None
|
||||||
TRANSCRIPT_TIMEOUT: int = 90
|
TRANSCRIPT_TIMEOUT: int = 90
|
||||||
|
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||||
|
|
||||||
# Audio Transcription: modal backend
|
# Audio Transcription: modal backend
|
||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
@@ -66,10 +67,14 @@ class Settings(BaseSettings):
|
|||||||
DIARIZATION_ENABLED: bool = True
|
DIARIZATION_ENABLED: bool = True
|
||||||
DIARIZATION_BACKEND: str = "modal"
|
DIARIZATION_BACKEND: str = "modal"
|
||||||
DIARIZATION_URL: str | None = None
|
DIARIZATION_URL: str | None = None
|
||||||
|
DIARIZATION_FILE_TIMEOUT: int = 600
|
||||||
|
|
||||||
# Diarization: modal backend
|
# Diarization: modal backend
|
||||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Diarization: local pyannote.audio
|
||||||
|
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
SENTRY_DSN: str | None = None
|
SENTRY_DSN: str | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,23 @@
|
|||||||
|
"""
|
||||||
|
Process audio file with diarization support
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Extended version of process.py that includes speaker diarization.
|
||||||
|
This tool processes audio files locally without requiring the full server infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerProcessor,
|
||||||
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
@@ -15,7 +28,43 @@ from reflector.processors import (
|
|||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorAutoProcessor,
|
TranscriptTranslatorAutoProcessor,
|
||||||
)
|
)
|
||||||
from reflector.processors.base import BroadcastProcessor
|
from reflector.processors.base import BroadcastProcessor, Processor
|
||||||
|
from reflector.processors.types import (
|
||||||
|
AudioDiarizationInput,
|
||||||
|
TitleSummary,
|
||||||
|
TitleSummaryWithId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopicCollectorProcessor(Processor):
|
||||||
|
"""Collect topics for diarization"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TitleSummary
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.topics: List[TitleSummaryWithId] = []
|
||||||
|
self._topic_id = 0
|
||||||
|
|
||||||
|
async def _push(self, data: TitleSummary):
|
||||||
|
# Convert to TitleSummaryWithId and collect
|
||||||
|
self._topic_id += 1
|
||||||
|
topic_with_id = TitleSummaryWithId(
|
||||||
|
id=str(self._topic_id),
|
||||||
|
title=data.title,
|
||||||
|
summary=data.summary,
|
||||||
|
timestamp=data.timestamp,
|
||||||
|
duration=data.duration,
|
||||||
|
transcript=data.transcript,
|
||||||
|
)
|
||||||
|
self.topics.append(topic_with_id)
|
||||||
|
|
||||||
|
# Pass through the original topic
|
||||||
|
await self.emit(data)
|
||||||
|
|
||||||
|
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||||
|
return self.topics
|
||||||
|
|
||||||
|
|
||||||
async def process_audio_file(
|
async def process_audio_file(
|
||||||
@@ -24,18 +73,40 @@ async def process_audio_file(
|
|||||||
only_transcript=False,
|
only_transcript=False,
|
||||||
source_language="en",
|
source_language="en",
|
||||||
target_language="en",
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="pyannote",
|
||||||
):
|
):
|
||||||
# build pipeline for audio processing
|
# Create temp file for audio if diarization is enabled
|
||||||
processors = [
|
audio_temp_path = None
|
||||||
|
if enable_diarization:
|
||||||
|
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
audio_temp_path = audio_temp_file.name
|
||||||
|
audio_temp_file.close()
|
||||||
|
|
||||||
|
# Create processor for collecting topics
|
||||||
|
topic_collector = TopicCollectorProcessor()
|
||||||
|
|
||||||
|
# Build pipeline for audio processing
|
||||||
|
processors = []
|
||||||
|
|
||||||
|
# Add audio file writer at the beginning if diarization is enabled
|
||||||
|
if enable_diarization:
|
||||||
|
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||||
|
|
||||||
|
# Add the rest of the processors
|
||||||
|
processors += [
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not only_transcript:
|
if not only_transcript:
|
||||||
processors += [
|
processors += [
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||||
|
# Collect topics for diarization
|
||||||
|
topic_collector,
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalTitleProcessor.as_threaded(),
|
TranscriptFinalTitleProcessor.as_threaded(),
|
||||||
@@ -44,14 +115,14 @@ async def process_audio_file(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# transcription output
|
# Create main pipeline
|
||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
pipeline.set_pref("audio:source_language", source_language)
|
pipeline.set_pref("audio:source_language", source_language)
|
||||||
pipeline.set_pref("audio:target_language", target_language)
|
pipeline.set_pref("audio:target_language", target_language)
|
||||||
pipeline.describe()
|
pipeline.describe()
|
||||||
pipeline.on(event_callback)
|
pipeline.on(event_callback)
|
||||||
|
|
||||||
# start processing audio
|
# Start processing audio
|
||||||
logger.info(f"Opening {filename}")
|
logger.info(f"Opening {filename}")
|
||||||
container = av.open(filename)
|
container = av.open(filename)
|
||||||
try:
|
try:
|
||||||
@@ -62,34 +133,219 @@ async def process_audio_file(
|
|||||||
logger.info("Flushing the pipeline")
|
logger.info("Flushing the pipeline")
|
||||||
await pipeline.flush()
|
await pipeline.flush()
|
||||||
|
|
||||||
logger.info("All done !")
|
# Run diarization if enabled and we have topics
|
||||||
|
if enable_diarization and not only_transcript and audio_temp_path:
|
||||||
|
topics = topic_collector.get_topics()
|
||||||
|
|
||||||
|
if topics:
|
||||||
|
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reflector.processors import AudioDiarizationAutoProcessor
|
||||||
|
|
||||||
|
diarization_processor = AudioDiarizationAutoProcessor(
|
||||||
|
name=diarization_backend
|
||||||
|
)
|
||||||
|
|
||||||
|
diarization_processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
|
# For Modal backend, we need to upload the file to S3 first
|
||||||
|
if diarization_backend == "modal":
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||||
|
|
||||||
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
|
# Generate a unique filename in evaluation folder
|
||||||
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
|
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||||
|
|
||||||
|
# Use context manager for automatic cleanup
|
||||||
|
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||||
|
# Read and upload the audio file
|
||||||
|
with open(audio_temp_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
audio_url = await s3_file.upload(audio_data)
|
||||||
|
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||||
|
|
||||||
|
# Create diarization input with S3 URL
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
# File will be automatically cleaned up when exiting the context
|
||||||
|
else:
|
||||||
|
# For local backend, use local file path
|
||||||
|
audio_url = audio_temp_path
|
||||||
|
|
||||||
|
# Create diarization input
|
||||||
|
diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url, topics=topics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
await diarization_processor.push(diarization_input)
|
||||||
|
await diarization_processor.flush()
|
||||||
|
|
||||||
|
logger.info("Diarization complete")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||||
|
logger.error(
|
||||||
|
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"And set HF_TOKEN environment variable for pyannote models"
|
||||||
|
)
|
||||||
|
raise SystemExit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Diarization failed: {e}")
|
||||||
|
raise SystemExit(1)
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping diarization: no topics available")
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if audio_temp_path:
|
||||||
|
try:
|
||||||
|
Path(audio_temp_path).unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||||
|
|
||||||
|
logger.info("All done!")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_file_pipeline(
|
||||||
|
filename: str,
|
||||||
|
event_callback,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
enable_diarization=True,
|
||||||
|
diarization_backend="modal",
|
||||||
|
):
|
||||||
|
"""Process audio/video file using the optimized file pipeline"""
|
||||||
|
try:
|
||||||
|
from reflector.db import database
|
||||||
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
# Create a temporary transcript for processing
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
"",
|
||||||
|
source_kind=SourceKind.FILE,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the file
|
||||||
|
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||||
|
await pipeline.process(Path(filename))
|
||||||
|
|
||||||
|
logger.info("File pipeline processing complete")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"File pipeline not available: {e}")
|
||||||
|
logger.info("Falling back to stream pipeline")
|
||||||
|
# Fall back to stream pipeline
|
||||||
|
await process_audio_file(
|
||||||
|
filename,
|
||||||
|
event_callback,
|
||||||
|
only_transcript=False,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
enable_diarization=enable_diarization,
|
||||||
|
diarization_backend=diarization_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Process audio files with optional speaker diarization"
|
||||||
|
)
|
||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
parser.add_argument(
|
||||||
parser.add_argument("--source-language", default="en")
|
"--stream",
|
||||||
parser.add_argument("--target-language", default="en")
|
action="store_true",
|
||||||
|
help="Use streaming pipeline (original frame-based processing)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only-transcript",
|
||||||
|
"-t",
|
||||||
|
action="store_true",
|
||||||
|
help="Only generate transcript without topics/summaries",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-language", default="en", help="Source language code (default: en)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-language", default="en", help="Target language code (default: en)"
|
||||||
|
)
|
||||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-diarization",
|
||||||
|
"-d",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable speaker diarization",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
default="pyannote",
|
||||||
|
choices=["pyannote", "modal"],
|
||||||
|
help="Diarization backend to use (default: pyannote)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if "REDIS_HOST" not in os.environ:
|
||||||
|
os.environ["REDIS_HOST"] = "localhost"
|
||||||
|
|
||||||
output_fd = None
|
output_fd = None
|
||||||
if args.output:
|
if args.output:
|
||||||
output_fd = open(args.output, "w")
|
output_fd = open(args.output, "w")
|
||||||
|
|
||||||
async def event_callback(event: PipelineEvent):
|
async def event_callback(event: PipelineEvent):
|
||||||
processor = event.processor
|
processor = event.processor
|
||||||
# ignore some processor
|
data = event.data
|
||||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
|
||||||
|
# Ignore internal processors
|
||||||
|
if processor in (
|
||||||
|
"AudioChunkerProcessor",
|
||||||
|
"AudioMergeProcessor",
|
||||||
|
"AudioFileWriterProcessor",
|
||||||
|
"TopicCollectorProcessor",
|
||||||
|
"BroadcastProcessor",
|
||||||
|
):
|
||||||
return
|
return
|
||||||
logger.info(f"Event: {event}")
|
|
||||||
|
# If diarization is enabled, skip the original topic events from the pipeline
|
||||||
|
# The diarization processor will emit the same topics but with speaker info
|
||||||
|
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log all events
|
||||||
|
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||||
|
|
||||||
|
# Write to output
|
||||||
if output_fd:
|
if output_fd:
|
||||||
output_fd.write(event.model_dump_json())
|
output_fd.write(event.model_dump_json())
|
||||||
output_fd.write("\n")
|
output_fd.write("\n")
|
||||||
|
output_fd.flush()
|
||||||
|
|
||||||
|
if args.stream:
|
||||||
|
# Use original streaming pipeline
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
process_audio_file(
|
process_audio_file(
|
||||||
args.source,
|
args.source,
|
||||||
@@ -97,6 +353,20 @@ if __name__ == "__main__":
|
|||||||
only_transcript=args.only_transcript,
|
only_transcript=args.only_transcript,
|
||||||
source_language=args.source_language,
|
source_language=args.source_language,
|
||||||
target_language=args.target_language,
|
target_language=args.target_language,
|
||||||
|
enable_diarization=args.enable_diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use optimized file pipeline (default)
|
||||||
|
asyncio.run(
|
||||||
|
process_file_pipeline(
|
||||||
|
args.source,
|
||||||
|
event_callback,
|
||||||
|
source_language=args.source_language,
|
||||||
|
target_language=args.target_language,
|
||||||
|
enable_diarization=args.enable_diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ async def transcripts_search(
|
|||||||
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
||||||
offset: SearchOffsetParam = 0,
|
offset: SearchOffsetParam = 0,
|
||||||
room_id: Optional[str] = None,
|
room_id: Optional[str] = None,
|
||||||
|
source_kind: Optional[SourceKind] = None,
|
||||||
user: Annotated[
|
user: Annotated[
|
||||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||||
] = None,
|
] = None,
|
||||||
@@ -173,7 +174,12 @@ async def transcripts_search(
|
|||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
search_params = SearchParameters(
|
search_params = SearchParameters(
|
||||||
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
|
query_text=q,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
user_id=user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
source_kind=source_kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
results, total = await search_controller.search_transcripts(search_params)
|
results, total = await search_controller.search_transcripts(search_params)
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from reflector.db.meetings import meetings_controller
|
|||||||
from reflector.db.recordings import Recording, recordings_controller
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.whereby import get_room_sessions
|
from reflector.whereby import get_room_sessions
|
||||||
|
|
||||||
@@ -140,7 +141,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
task_pipeline_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: ''
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '0'
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-diarizer-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '220'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:25:34 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAVNEP6N7Y1Y7W3T98BCXK
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||||
|
"language": "en", "batch": true}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '136'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||||
|
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||||
|
With reflector, I can easily convert any audio into accurate transcription.
|
||||||
|
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '1933'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:26:59 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAWC7GAMKX4DSJ21WV31NG
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||||
|
"language": "en", "batch": true}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '136'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||||
|
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||||
|
With reflector, I can easily convert any audio into accurate transcription.
|
||||||
|
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '1933'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:27:02 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAYZ1AR2HE422VJVKBWX9Z
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
- request:
|
||||||
|
body: ''
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- '*/*'
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
authorization:
|
||||||
|
- DUMMY_API_KEY
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '0'
|
||||||
|
host:
|
||||||
|
- monadical-sas--reflector-diarizer-web.modal.run
|
||||||
|
user-agent:
|
||||||
|
- python-httpx/0.27.2
|
||||||
|
method: POST
|
||||||
|
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||||
|
headers:
|
||||||
|
Alt-Svc:
|
||||||
|
- h3=":443"; ma=2592000
|
||||||
|
Content-Length:
|
||||||
|
- '220'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Wed, 13 Aug 2025 18:27:18 GMT
|
||||||
|
Modal-Function-Call-Id:
|
||||||
|
- fc-01K2JAZ1M34NQRJK03CCFK95D6
|
||||||
|
Vary:
|
||||||
|
- accept-encoding
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
@@ -5,7 +5,29 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# Pytest-docker configuration
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def settings_configuration():
|
||||||
|
# theses settings are linked to monadical for pytest-recording
|
||||||
|
# if a fork is done, they have to provide their own url when cassettes needs to be updated
|
||||||
|
# modal api keys has to be defined by the user
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
settings.TRANSCRIPT_BACKEND = "modal"
|
||||||
|
settings.TRANSCRIPT_URL = (
|
||||||
|
"https://monadical-sas--reflector-transcriber-parakeet-web.modal.run"
|
||||||
|
)
|
||||||
|
settings.DIARIZATION_BACKEND = "modal"
|
||||||
|
settings.DIARIZATION_URL = "https://monadical-sas--reflector-diarizer-web.modal.run"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def vcr_config():
|
||||||
|
"""VCR configuration to filter sensitive headers"""
|
||||||
|
return {
|
||||||
|
"filter_headers": [("authorization", "DUMMY_API_KEY")],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def docker_compose_file(pytestconfig):
|
def docker_compose_file(pytestconfig):
|
||||||
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
version: '3.8'
|
version: "3.8"
|
||||||
services:
|
services:
|
||||||
postgres_test:
|
postgres_test:
|
||||||
image: postgres:15
|
image: postgres:17
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_DB: reflector_test
|
POSTGRES_DB: reflector_test
|
||||||
POSTGRES_USER: test_user
|
POSTGRES_USER: test_user
|
||||||
|
|||||||
330
server/tests/test_gpu_modal_transcript.py
Normal file
330
server/tests/test_gpu_modal_transcript.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
Tests for GPU Modal transcription endpoints.
|
||||||
|
|
||||||
|
These tests are marked with the "gpu-modal" group and will not run by default.
|
||||||
|
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
||||||
|
|
||||||
|
Required environment variables:
|
||||||
|
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
||||||
|
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
||||||
|
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||||
|
|
||||||
|
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
||||||
|
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||||
|
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||||
|
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
||||||
|
|
||||||
|
# Or with completely clean options:
|
||||||
|
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
||||||
|
|
||||||
|
Running Modal locally for testing:
|
||||||
|
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||||
|
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Test audio file URL for testing
|
||||||
|
TEST_AUDIO_URL = (
|
||||||
|
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_modal_transcript_url():
|
||||||
|
"""Get and validate the Modal transcript URL from environment."""
|
||||||
|
url = os.environ.get("TRANSCRIPT_URL")
|
||||||
|
if not url:
|
||||||
|
pytest.skip(
|
||||||
|
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_headers():
|
||||||
|
"""Get authentication headers if API key is available."""
|
||||||
|
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name():
|
||||||
|
"""Get the model name from environment or use default."""
|
||||||
|
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.gpu_modal
|
||||||
|
class TestGPUModalTranscript:
|
||||||
|
"""Test suite for GPU Modal transcription endpoints."""
|
||||||
|
|
||||||
|
def test_transcriptions_from_url(self):
|
||||||
|
"""Test the /v1/audio/transcriptions-from-url endpoint."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 0.0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert isinstance(result["text"], str)
|
||||||
|
assert isinstance(result["words"], list)
|
||||||
|
|
||||||
|
# Verify content is meaningful
|
||||||
|
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||||
|
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||||
|
|
||||||
|
# Verify word structure
|
||||||
|
for word in result["words"]:
|
||||||
|
assert "word" in word
|
||||||
|
assert "start" in word
|
||||||
|
assert "end" in word
|
||||||
|
assert isinstance(word["start"], (int, float))
|
||||||
|
assert isinstance(word["end"], (int, float))
|
||||||
|
assert word["start"] <= word["end"]
|
||||||
|
|
||||||
|
def test_transcriptions_single_file(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with a single file."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Download test audio file to upload
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(audio_response.content)
|
||||||
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Upload the file for transcription
|
||||||
|
with open(tmp_file_path, "rb") as f:
|
||||||
|
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for single file
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert "filename" in result
|
||||||
|
assert isinstance(result["text"], str)
|
||||||
|
assert isinstance(result["words"], list)
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
Path(tmp_file_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_multiple_files(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Create multiple test files (we'll use the same audio content for simplicity)
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
audio_content = audio_response.content
|
||||||
|
|
||||||
|
temp_files = []
|
||||||
|
try:
|
||||||
|
# Create 3 temporary files
|
||||||
|
for i in range(3):
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||||
|
tmp_file.write(audio_content)
|
||||||
|
tmp_file.close()
|
||||||
|
temp_files.append(tmp_file.name)
|
||||||
|
|
||||||
|
# Upload multiple files for transcription (non-batch)
|
||||||
|
files = [
|
||||||
|
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||||
|
for i, f in enumerate(temp_files)
|
||||||
|
]
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close file handles
|
||||||
|
for _, file_tuple in files:
|
||||||
|
file_tuple[1].close()
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for multiple files (non-batch)
|
||||||
|
assert "results" in result
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
for idx, file_result in enumerate(result["results"]):
|
||||||
|
assert "text" in file_result
|
||||||
|
assert "words" in file_result
|
||||||
|
assert "filename" in file_result
|
||||||
|
assert isinstance(file_result["text"], str)
|
||||||
|
assert isinstance(file_result["words"], list)
|
||||||
|
assert len(file_result["text"]) > 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for f in temp_files:
|
||||||
|
Path(f).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_multiple_files_batch(self):
|
||||||
|
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
# Create multiple test files
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
audio_response = client.get(TEST_AUDIO_URL)
|
||||||
|
audio_response.raise_for_status()
|
||||||
|
audio_content = audio_response.content
|
||||||
|
|
||||||
|
temp_files = []
|
||||||
|
try:
|
||||||
|
# Create 3 temporary files
|
||||||
|
for i in range(3):
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||||
|
tmp_file.write(audio_content)
|
||||||
|
tmp_file.close()
|
||||||
|
temp_files.append(tmp_file.name)
|
||||||
|
|
||||||
|
# Upload multiple files for batch transcription
|
||||||
|
files = [
|
||||||
|
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||||
|
for i, f in enumerate(temp_files)
|
||||||
|
]
|
||||||
|
data = {
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"batch": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close file handles
|
||||||
|
for _, file_tuple in files:
|
||||||
|
file_tuple[1].close()
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure for batch mode
|
||||||
|
assert "results" in result
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
for idx, batch_result in enumerate(result["results"]):
|
||||||
|
assert "text" in batch_result
|
||||||
|
assert "words" in batch_result
|
||||||
|
assert "filename" in batch_result
|
||||||
|
assert isinstance(batch_result["text"], str)
|
||||||
|
assert isinstance(batch_result["words"], list)
|
||||||
|
assert len(batch_result["text"]) > 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for f in temp_files:
|
||||||
|
Path(f).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_transcriptions_error_handling(self):
|
||||||
|
"""Test error handling for invalid requests."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# Test with unsupported language
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "fr", # Parakeet only supports English
|
||||||
|
"timestamp_offset": 0.0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "only supports English" in response.text
|
||||||
|
|
||||||
|
def test_transcriptions_with_timestamp_offset(self):
|
||||||
|
"""Test transcription with timestamp offset parameter."""
|
||||||
|
url = get_modal_transcript_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# Test with timestamp offset
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/v1/audio/transcriptions-from-url",
|
||||||
|
json={
|
||||||
|
"audio_file_url": TEST_AUDIO_URL,
|
||||||
|
"model": get_model_name(),
|
||||||
|
"language": "en",
|
||||||
|
"timestamp_offset": 10.0, # Add 10 second offset
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert "text" in result
|
||||||
|
assert "words" in result
|
||||||
|
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||||
|
|
||||||
|
# Verify that timestamps have been offset
|
||||||
|
for word in result["words"]:
|
||||||
|
# All timestamps should be >= 10.0 due to offset
|
||||||
|
assert (
|
||||||
|
word["start"] >= 10.0
|
||||||
|
), f"Word start time {word['start']} should be >= 10.0"
|
||||||
|
assert (
|
||||||
|
word["end"] >= 10.0
|
||||||
|
), f"Word end time {word['end']} should be >= 10.0"
|
||||||
633
server/tests/test_pipeline_main_file.py
Normal file
633
server/tests/test_pipeline_main_file.py
Normal file
@@ -0,0 +1,633 @@
|
|||||||
|
"""
|
||||||
|
Tests for PipelineMainFile - file-based processing pipeline
|
||||||
|
|
||||||
|
This test verifies the complete file processing pipeline without mocking much,
|
||||||
|
ensuring all processors are correctly invoked and the happy path works correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationOutput
|
||||||
|
from reflector.processors.types import (
|
||||||
|
DiarizationSegment,
|
||||||
|
TitleSummary,
|
||||||
|
Word,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import (
|
||||||
|
Transcript as TranscriptType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_file_transcript():
|
||||||
|
"""Mock FileTranscriptAutoProcessor for file processing"""
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||||
|
|
||||||
|
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
||||||
|
async def _transcript(self, data):
|
||||||
|
return TranscriptType(
|
||||||
|
text="Hello world. How are you today?",
|
||||||
|
words=[
|
||||||
|
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
||||||
|
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
||||||
|
Word(start=0.6, end=1.0, text="world", speaker=0),
|
||||||
|
Word(start=1.0, end=1.1, text=".", speaker=0),
|
||||||
|
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
||||||
|
Word(start=1.2, end=1.5, text="How", speaker=0),
|
||||||
|
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
||||||
|
Word(start=1.6, end=1.8, text="are", speaker=0),
|
||||||
|
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
||||||
|
Word(start=1.9, end=2.1, text="you", speaker=0),
|
||||||
|
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
||||||
|
Word(start=2.2, end=2.5, text="today", speaker=0),
|
||||||
|
Word(start=2.5, end=2.6, text="?", speaker=0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
||||||
|
) as mock_auto:
|
||||||
|
mock_auto.return_value = TestFileTranscriptProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_file_diarization():
|
||||||
|
"""Mock FileDiarizationAutoProcessor for file processing"""
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||||
|
|
||||||
|
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
||||||
|
async def _diarize(self, data):
|
||||||
|
return FileDiarizationOutput(
|
||||||
|
diarization=[
|
||||||
|
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
||||||
|
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
||||||
|
) as mock_auto:
|
||||||
|
mock_auto.return_value = TestFileDiarizationProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_transcript_in_db(tmpdir):
|
||||||
|
"""Create a mock transcript in the database"""
|
||||||
|
from reflector.db.transcripts import Transcript
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Set the DATA_DIR to our tmpdir
|
||||||
|
original_data_dir = settings.DATA_DIR
|
||||||
|
settings.DATA_DIR = str(tmpdir)
|
||||||
|
|
||||||
|
transcript_id = str(uuid4())
|
||||||
|
data_path = Path(tmpdir) / transcript_id
|
||||||
|
data_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create mock transcript object
|
||||||
|
transcript = Transcript(
|
||||||
|
id=transcript_id,
|
||||||
|
name="Test Transcript",
|
||||||
|
status="processing",
|
||||||
|
source_kind="file",
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the controller to return our transcript
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = transcript
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get2:
|
||||||
|
mock_get2.return_value = transcript
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
||||||
|
) as mock_update:
|
||||||
|
mock_update.return_value = None
|
||||||
|
yield transcript
|
||||||
|
finally:
|
||||||
|
# Restore original DATA_DIR
|
||||||
|
settings.DATA_DIR = original_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_storage():
|
||||||
|
"""Mock storage for file uploads"""
|
||||||
|
from reflector.storage.base import Storage
|
||||||
|
|
||||||
|
class TestStorage(Storage):
|
||||||
|
async def _put_file(self, path, data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_file_url(self, path):
|
||||||
|
return f"http://test-storage/{path}"
|
||||||
|
|
||||||
|
async def _get_file(self, path):
|
||||||
|
return b"test_audio_data"
|
||||||
|
|
||||||
|
async def _delete_file(self, path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
storage = TestStorage()
|
||||||
|
# Add mock tracking for verification
|
||||||
|
storage._put_file = AsyncMock(side_effect=storage._put_file)
|
||||||
|
storage._get_file_url = AsyncMock(side_effect=storage._get_file_url)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = storage
|
||||||
|
yield storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_audio_file_writer():
|
||||||
|
"""Mock AudioFileWriterProcessor to avoid actual file writing"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.AudioFileWriterProcessor"
|
||||||
|
) as mock_writer_class:
|
||||||
|
mock_writer = AsyncMock()
|
||||||
|
mock_writer.push = AsyncMock()
|
||||||
|
mock_writer.flush = AsyncMock()
|
||||||
|
mock_writer_class.return_value = mock_writer
|
||||||
|
yield mock_writer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_waveform_processor():
|
||||||
|
"""Mock AudioWaveformProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.AudioWaveformProcessor"
|
||||||
|
) as mock_waveform_class:
|
||||||
|
mock_waveform = AsyncMock()
|
||||||
|
mock_waveform.set_pipeline = MagicMock()
|
||||||
|
mock_waveform.flush = AsyncMock()
|
||||||
|
mock_waveform_class.return_value = mock_waveform
|
||||||
|
yield mock_waveform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_topic_detector():
|
||||||
|
"""Mock TranscriptTopicDetectorProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor"
|
||||||
|
) as mock_topic_class:
|
||||||
|
mock_topic = AsyncMock()
|
||||||
|
mock_topic.set_pipeline = MagicMock()
|
||||||
|
mock_topic.push = AsyncMock()
|
||||||
|
mock_topic.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate topic detection by calling the callback
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_topic.flush_called = True
|
||||||
|
if hasattr(mock_topic, "_callback"):
|
||||||
|
# Create a minimal transcript for the TitleSummary
|
||||||
|
test_transcript = TranscriptType(words=[], text="test transcript")
|
||||||
|
await mock_topic._callback(
|
||||||
|
TitleSummary(
|
||||||
|
title="Test Topic",
|
||||||
|
summary="Test topic summary",
|
||||||
|
timestamp=0.0,
|
||||||
|
duration=10.0,
|
||||||
|
transcript=test_transcript,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_topic.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(callback=None):
|
||||||
|
mock_topic._callback = callback
|
||||||
|
return mock_topic
|
||||||
|
|
||||||
|
mock_topic_class.side_effect = init_with_callback
|
||||||
|
yield mock_topic
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_title_processor():
|
||||||
|
"""Mock TranscriptFinalTitleProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor"
|
||||||
|
) as mock_title_class:
|
||||||
|
mock_title = AsyncMock()
|
||||||
|
mock_title.set_pipeline = MagicMock()
|
||||||
|
mock_title.push = AsyncMock()
|
||||||
|
mock_title.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate title generation by calling the callback
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_title.flush_called = True
|
||||||
|
if hasattr(mock_title, "_callback"):
|
||||||
|
from reflector.processors.types import FinalTitle
|
||||||
|
|
||||||
|
await mock_title._callback(FinalTitle(title="Test Title"))
|
||||||
|
|
||||||
|
mock_title.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(callback=None):
|
||||||
|
mock_title._callback = callback
|
||||||
|
return mock_title
|
||||||
|
|
||||||
|
mock_title_class.side_effect = init_with_callback
|
||||||
|
yield mock_title
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_summary_processor():
|
||||||
|
"""Mock TranscriptFinalSummaryProcessor"""
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor"
|
||||||
|
) as mock_summary_class:
|
||||||
|
mock_summary = AsyncMock()
|
||||||
|
mock_summary.set_pipeline = MagicMock()
|
||||||
|
mock_summary.push = AsyncMock()
|
||||||
|
mock_summary.flush_called = False
|
||||||
|
|
||||||
|
# When flush is called, simulate summary generation by calling the callbacks
|
||||||
|
async def flush_with_callback():
|
||||||
|
mock_summary.flush_called = True
|
||||||
|
from reflector.processors.types import FinalLongSummary, FinalShortSummary
|
||||||
|
|
||||||
|
if hasattr(mock_summary, "_callback"):
|
||||||
|
await mock_summary._callback(
|
||||||
|
FinalLongSummary(long_summary="Test long summary", duration=10.0)
|
||||||
|
)
|
||||||
|
if hasattr(mock_summary, "_on_short_summary"):
|
||||||
|
await mock_summary._on_short_summary(
|
||||||
|
FinalShortSummary(short_summary="Test short summary", duration=10.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_summary.flush = flush_with_callback
|
||||||
|
|
||||||
|
def init_with_callback(transcript=None, callback=None, on_short_summary=None):
|
||||||
|
mock_summary._callback = callback
|
||||||
|
mock_summary._on_short_summary = on_short_summary
|
||||||
|
return mock_summary
|
||||||
|
|
||||||
|
mock_summary_class.side_effect = init_with_callback
|
||||||
|
yield mock_summary
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_process(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the complete PipelineMainFile processing pipeline.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
1. Audio extraction and writing
|
||||||
|
2. Audio upload to storage
|
||||||
|
3. Parallel processing of transcription, diarization, and waveform
|
||||||
|
4. Assembly of transcript with diarization
|
||||||
|
5. Topic detection
|
||||||
|
6. Title and summary generation
|
||||||
|
"""
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path as if it was uploaded
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||||
|
# Since we're mocking AudioFileWriterProcessor, we need to create this manually
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Track callback invocations
|
||||||
|
callback_marks = {
|
||||||
|
"on_status": [],
|
||||||
|
"on_duration": [],
|
||||||
|
"on_waveform": [],
|
||||||
|
"on_topic": [],
|
||||||
|
"on_title": [],
|
||||||
|
"on_long_summary": [],
|
||||||
|
"on_short_summary": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create pipeline with mocked callbacks
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Override callbacks to track invocations
|
||||||
|
async def track_callback(name, data):
|
||||||
|
callback_marks[name].append(data)
|
||||||
|
# Call the original callback
|
||||||
|
original = getattr(PipelineMainFile, name)
|
||||||
|
return await original(pipeline, data)
|
||||||
|
|
||||||
|
for callback_name in callback_marks.keys():
|
||||||
|
setattr(
|
||||||
|
pipeline,
|
||||||
|
callback_name,
|
||||||
|
lambda data, n=callback_name: track_callback(n, data),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams (audio only)
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter(
|
||||||
|
[MagicMock()]
|
||||||
|
) # One mock audio frame
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify audio extraction and writing
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
|
||||||
|
# Verify storage upload
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_storage._get_file_url.called
|
||||||
|
|
||||||
|
# Verify waveform generation
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_waveform_processor.set_pipeline.called
|
||||||
|
|
||||||
|
# Verify topic detection
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_topic_detector.flush_called
|
||||||
|
|
||||||
|
# Verify title generation
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_title_processor.flush_called
|
||||||
|
|
||||||
|
# Verify summary generation
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
assert mock_summary_processor.flush_called
|
||||||
|
|
||||||
|
# Verify callbacks were invoked
|
||||||
|
assert len(callback_marks["on_topic"]) > 0, "Topic callback should be invoked"
|
||||||
|
assert len(callback_marks["on_title"]) > 0, "Title callback should be invoked"
|
||||||
|
assert (
|
||||||
|
len(callback_marks["on_long_summary"]) > 0
|
||||||
|
), "Long summary callback should be invoked"
|
||||||
|
assert (
|
||||||
|
len(callback_marks["on_short_summary"]) > 0
|
||||||
|
), "Short summary callback should be invoked"
|
||||||
|
|
||||||
|
print(f"Callback marks: {callback_marks}")
|
||||||
|
|
||||||
|
# Verify the pipeline completed successfully
|
||||||
|
assert pipeline.logger is not None
|
||||||
|
print("PipelineMainFile test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_with_video(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test PipelineMainFile with video input (verifies audio extraction).
|
||||||
|
"""
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path as if it was a video upload
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.mp4"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Mock av.open for video processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [MagicMock()] # Has video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter(
|
||||||
|
[MagicMock()]
|
||||||
|
) # One mock audio frame
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify audio extraction from video
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
|
||||||
|
# Verify the rest of the pipeline completed
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("PipelineMainFile video test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_main_file_no_diarization(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test PipelineMainFile with diarization disabled.
|
||||||
|
"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Disable diarization
|
||||||
|
with patch.object(settings, "DIARIZATION_BACKEND", None):
|
||||||
|
# Create a test audio file
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
|
||||||
|
# Copy test audio to the transcript's data path
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify the pipeline completed without diarization
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("PipelineMainFile no-diarization test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_task_pipeline_file_process(
|
||||||
|
tmpdir,
|
||||||
|
mock_transcript_in_db,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
mock_storage,
|
||||||
|
mock_audio_file_writer,
|
||||||
|
mock_waveform_processor,
|
||||||
|
mock_topic_detector,
|
||||||
|
mock_title_processor,
|
||||||
|
mock_summary_processor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the Celery task entry point for file pipeline processing.
|
||||||
|
"""
|
||||||
|
# Direct import of the underlying async function, bypassing the asynctask decorator
|
||||||
|
|
||||||
|
# Create a test audio file in the transcript's data path
|
||||||
|
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||||
|
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||||
|
|
||||||
|
# Also create the audio.mp3 file
|
||||||
|
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||||
|
mp3_path.write_bytes(b"mock_mp3_data")
|
||||||
|
|
||||||
|
# Mock av.open for audio processing
|
||||||
|
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||||
|
# Mock container for checking video streams
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.streams.video = [] # No video streams
|
||||||
|
mock_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Mock container for decoding audio frames
|
||||||
|
mock_decode_container = MagicMock()
|
||||||
|
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||||
|
mock_decode_container.close = MagicMock()
|
||||||
|
|
||||||
|
# Return different containers for different calls
|
||||||
|
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||||
|
|
||||||
|
# Get the original async function without the asynctask decorator
|
||||||
|
# The function is wrapped, so we need to call it differently
|
||||||
|
# For now, we test the pipeline directly since the task is just a thin wrapper
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
await pipeline.process(upload_path)
|
||||||
|
|
||||||
|
# Verify the pipeline was executed through the task
|
||||||
|
assert mock_audio_file_writer.push.called
|
||||||
|
assert mock_audio_file_writer.flush.called
|
||||||
|
assert mock_storage._put_file.called
|
||||||
|
assert mock_waveform_processor.flush.called
|
||||||
|
assert mock_topic_detector.push.called
|
||||||
|
assert mock_title_processor.push.called
|
||||||
|
assert mock_summary_processor.push.called
|
||||||
|
|
||||||
|
print("task_pipeline_file_process test completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_file_process_no_transcript():
|
||||||
|
"""
|
||||||
|
Test the pipeline with a non-existent transcript.
|
||||||
|
"""
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
# Mock the controller to return None (transcript not found)
|
||||||
|
with patch(
|
||||||
|
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = None
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=str(uuid4()))
|
||||||
|
|
||||||
|
# Should raise an exception for missing transcript when get_transcript is called
|
||||||
|
with pytest.raises(Exception, match="Transcript not found"):
|
||||||
|
await pipeline.get_transcript()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_file_process_no_audio_file(
|
||||||
|
mock_transcript_in_db,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the pipeline when no audio file is found.
|
||||||
|
"""
|
||||||
|
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||||
|
|
||||||
|
# Don't create any audio files in the data path
|
||||||
|
# The pipeline's process should handle missing files gracefully
|
||||||
|
|
||||||
|
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||||
|
|
||||||
|
# Try to process a non-existent file
|
||||||
|
non_existent_path = mock_transcript_in_db.data_path / "nonexistent.wav"
|
||||||
|
|
||||||
|
# This should fail when trying to open the file with av
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await pipeline.process(non_existent_path)
|
||||||
265
server/tests/test_processors_modal.py
Normal file
265
server/tests/test_processors_modal.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
Tests for Modal-based processors using pytest-recording for HTTP recording/playbook
|
||||||
|
|
||||||
|
Note: theses tests require full modal configuration to be able to record
|
||||||
|
vcr cassettes
|
||||||
|
|
||||||
|
Configuration required for the first recording:
|
||||||
|
- TRANSCRIPT_BACKEND=modal
|
||||||
|
- TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-parakeet-web.modal.run
|
||||||
|
- TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||||
|
- DIARIZATION_BACKEND=modal
|
||||||
|
- DIARIZATION_URL=https://xxxxx--reflector-diarizer-web.modal.run
|
||||||
|
- DIARIZATION_MODAL_API_KEY=xxxxx
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.processors.file_diarization import FileDiarizationInput
|
||||||
|
from reflector.processors.file_diarization_modal import FileDiarizationModalProcessor
|
||||||
|
from reflector.processors.file_transcript import FileTranscriptInput
|
||||||
|
from reflector.processors.file_transcript_modal import FileTranscriptModalProcessor
|
||||||
|
from reflector.processors.transcript_diarization_assembler import (
|
||||||
|
TranscriptDiarizationAssemblerInput,
|
||||||
|
TranscriptDiarizationAssemblerProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import DiarizationSegment, Transcript, Word
|
||||||
|
|
||||||
|
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||||
|
TEST_AUDIO_URL = (
|
||||||
|
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_transcript_modal_processor_missing_url():
|
||||||
|
with patch("reflector.processors.file_transcript_modal.settings") as mock_settings:
|
||||||
|
mock_settings.TRANSCRIPT_URL = None
|
||||||
|
with pytest.raises(Exception, match="TRANSCRIPT_URL required"):
|
||||||
|
FileTranscriptModalProcessor(modal_api_key="test-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_diarization_modal_processor_missing_url():
|
||||||
|
with patch("reflector.processors.file_diarization_modal.settings") as mock_settings:
|
||||||
|
mock_settings.DIARIZATION_URL = None
|
||||||
|
with pytest.raises(Exception, match="DIARIZATION_URL required"):
|
||||||
|
FileDiarizationModalProcessor(modal_api_key="test-api-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_diarization_modal_processor(vcr):
|
||||||
|
"""Test FileDiarizationModalProcessor using public audio URL and Modal API"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
processor = FileDiarizationModalProcessor(
|
||||||
|
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
test_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||||
|
result = await processor._diarize(test_input)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, "diarization")
|
||||||
|
assert isinstance(result.diarization, list)
|
||||||
|
|
||||||
|
# Check structure of each diarization segment
|
||||||
|
for segment in result.diarization:
|
||||||
|
assert "start" in segment
|
||||||
|
assert "end" in segment
|
||||||
|
assert "speaker" in segment
|
||||||
|
assert isinstance(segment["start"], (int, float))
|
||||||
|
assert isinstance(segment["end"], (int, float))
|
||||||
|
assert isinstance(segment["speaker"], int)
|
||||||
|
# Basic sanity check - start should be before end
|
||||||
|
assert segment["start"] < segment["end"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_transcript_modal_processor():
|
||||||
|
"""Test FileTranscriptModalProcessor using public audio URL and Modal API"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
processor = FileTranscriptModalProcessor(
|
||||||
|
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
test_input = FileTranscriptInput(
|
||||||
|
audio_url=TEST_AUDIO_URL,
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will record the HTTP interaction on first run, replay on subsequent runs
|
||||||
|
result = await processor._transcript(test_input)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, "words")
|
||||||
|
assert isinstance(result.words, list)
|
||||||
|
|
||||||
|
# Check structure of each word if present
|
||||||
|
for word in result.words:
|
||||||
|
assert hasattr(word, "text")
|
||||||
|
assert hasattr(word, "start")
|
||||||
|
assert hasattr(word, "end")
|
||||||
|
assert isinstance(word.start, (int, float))
|
||||||
|
assert isinstance(word.end, (int, float))
|
||||||
|
assert isinstance(word.text, str)
|
||||||
|
# Basic sanity check - start should be before or equal to end
|
||||||
|
assert word.start <= word.end
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_diarization_assembler_processor():
|
||||||
|
"""Test TranscriptDiarizationAssemblerProcessor without VCR (no HTTP requests)"""
|
||||||
|
# Create test transcript with words
|
||||||
|
words = [
|
||||||
|
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||||
|
Word(text=" ", start=1.0, end=1.1, speaker=0),
|
||||||
|
Word(text="world", start=1.1, end=2.0, speaker=0),
|
||||||
|
Word(text=".", start=2.0, end=2.1, speaker=0),
|
||||||
|
Word(text=" ", start=2.1, end=2.2, speaker=0),
|
||||||
|
Word(text="How", start=2.2, end=2.8, speaker=0),
|
||||||
|
Word(text=" ", start=2.8, end=2.9, speaker=0),
|
||||||
|
Word(text="are", start=2.9, end=3.2, speaker=0),
|
||||||
|
Word(text=" ", start=3.2, end=3.3, speaker=0),
|
||||||
|
Word(text="you", start=3.3, end=3.8, speaker=0),
|
||||||
|
Word(text="?", start=3.8, end=3.9, speaker=0),
|
||||||
|
]
|
||||||
|
transcript = Transcript(words=words)
|
||||||
|
|
||||||
|
# Create test diarization segments
|
||||||
|
diarization = [
|
||||||
|
DiarizationSegment(start=0.0, end=2.1, speaker=0),
|
||||||
|
DiarizationSegment(start=2.1, end=3.9, speaker=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create processor and test input
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
test_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=diarization
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track emitted results
|
||||||
|
emitted_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
emitted_results.append(result)
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
|
||||||
|
# Process the input
|
||||||
|
await processor.push(test_input)
|
||||||
|
|
||||||
|
# Verify result was emitted
|
||||||
|
assert len(emitted_results) == 1
|
||||||
|
result = emitted_results[0]
|
||||||
|
|
||||||
|
# Verify result structure
|
||||||
|
assert isinstance(result, Transcript)
|
||||||
|
assert len(result.words) == len(words)
|
||||||
|
|
||||||
|
# Verify speaker assignments were applied
|
||||||
|
# Words 0-3 (indices) should be speaker 0 (time 0.0-2.0)
|
||||||
|
# Words 4-10 (indices) should be speaker 1 (time 2.1-3.9)
|
||||||
|
for i in range(4): # First 4 words (Hello, space, world, .)
|
||||||
|
assert (
|
||||||
|
result.words[i].speaker == 0
|
||||||
|
), f"Word {i} '{result.words[i].text}' should be speaker 0, got {result.words[i].speaker}"
|
||||||
|
|
||||||
|
for i in range(4, 11): # Remaining words (space, How, space, are, space, you, ?)
|
||||||
|
assert (
|
||||||
|
result.words[i].speaker == 1
|
||||||
|
), f"Word {i} '{result.words[i].text}' should be speaker 1, got {result.words[i].speaker}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_diarization_assembler_no_diarization():
|
||||||
|
"""Test TranscriptDiarizationAssemblerProcessor with no diarization data"""
|
||||||
|
# Create test transcript
|
||||||
|
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||||
|
transcript = Transcript(words=words)
|
||||||
|
|
||||||
|
# Create processor and test input with empty diarization
|
||||||
|
processor = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
test_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track emitted results
|
||||||
|
emitted_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
emitted_results.append(result)
|
||||||
|
|
||||||
|
processor.on(capture_result)
|
||||||
|
|
||||||
|
# Process the input
|
||||||
|
await processor.push(test_input)
|
||||||
|
|
||||||
|
# Verify original transcript was returned unchanged
|
||||||
|
assert len(emitted_results) == 1
|
||||||
|
result = emitted_results[0]
|
||||||
|
assert result is transcript # Should be the same object
|
||||||
|
assert result.words[0].speaker == 0 # Original speaker unchanged
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_modal_pipeline_integration(vcr):
|
||||||
|
"""Integration test: Transcription -> Diarization -> Assembly
|
||||||
|
|
||||||
|
This test demonstrates the full pipeline:
|
||||||
|
1. Run transcription via Modal
|
||||||
|
2. Run diarization via Modal
|
||||||
|
3. Assemble transcript with diarization
|
||||||
|
"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
# Step 1: Transcription
|
||||||
|
transcript_processor = FileTranscriptModalProcessor(
|
||||||
|
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
transcript_input = FileTranscriptInput(audio_url=TEST_AUDIO_URL, language="en")
|
||||||
|
transcript = await transcript_processor._transcript(transcript_input)
|
||||||
|
|
||||||
|
# Step 2: Diarization
|
||||||
|
diarization_processor = FileDiarizationModalProcessor(
|
||||||
|
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||||
|
)
|
||||||
|
diarization_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||||
|
diarization_result = await diarization_processor._diarize(diarization_input)
|
||||||
|
|
||||||
|
# Step 3: Assembly
|
||||||
|
assembler = TranscriptDiarizationAssemblerProcessor()
|
||||||
|
assembly_input = TranscriptDiarizationAssemblerInput(
|
||||||
|
transcript=transcript, diarization=diarization_result.diarization
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track assembled result
|
||||||
|
assembled_results = []
|
||||||
|
|
||||||
|
async def capture_result(result):
|
||||||
|
assembled_results.append(result)
|
||||||
|
|
||||||
|
assembler.on(capture_result)
|
||||||
|
|
||||||
|
await assembler.push(assembly_input)
|
||||||
|
|
||||||
|
# Verify the full pipeline worked
|
||||||
|
assert len(assembled_results) == 1
|
||||||
|
final_transcript = assembled_results[0]
|
||||||
|
|
||||||
|
# Verify the final transcript has the original words with updated speaker info
|
||||||
|
assert isinstance(final_transcript, Transcript)
|
||||||
|
assert len(final_transcript.words) == len(transcript.words)
|
||||||
|
assert len(final_transcript.words) > 0
|
||||||
|
|
||||||
|
# Verify some words have been assigned speakers from diarization
|
||||||
|
speakers_found = set(word.speaker for word in final_transcript.words)
|
||||||
|
assert len(speakers_found) > 0 # At least some speaker assignments
|
||||||
@@ -2,10 +2,13 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("enable_diarization", [False, True])
|
||||||
async def test_basic_process(
|
async def test_basic_process(
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
|
enable_diarization,
|
||||||
|
dummy_diarization,
|
||||||
):
|
):
|
||||||
# goal is to start the server, and send rtc audio to it
|
# goal is to start the server, and send rtc audio to it
|
||||||
# validate the events received
|
# validate the events received
|
||||||
@@ -28,12 +31,31 @@ async def test_basic_process(
|
|||||||
|
|
||||||
# invoke the process and capture events
|
# invoke the process and capture events
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
await process_audio_file(path.as_posix(), event_callback)
|
|
||||||
print(marks)
|
if enable_diarization:
|
||||||
|
# Test with diarization - may fail if pyannote.audio is not installed
|
||||||
|
try:
|
||||||
|
await process_audio_file(
|
||||||
|
path.as_posix(), event_callback, enable_diarization=True
|
||||||
|
)
|
||||||
|
except SystemExit:
|
||||||
|
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
||||||
|
else:
|
||||||
|
# Test without diarization - should always work
|
||||||
|
await process_audio_file(
|
||||||
|
path.as_posix(), event_callback, enable_diarization=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
||||||
|
|
||||||
# validate the events
|
# validate the events
|
||||||
assert marks["TranscriptLinerProcessor"] == 1
|
# Each processor should be called for each audio segment processed
|
||||||
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
|
# The final processors (Topic, Title, Summary) should be called once at the end
|
||||||
|
assert marks["TranscriptLinerProcessor"] > 0
|
||||||
|
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
||||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||||
|
|
||||||
|
if enable_diarization:
|
||||||
|
assert marks["TestAudioDiarizationProcessor"] == 1
|
||||||
|
|||||||
@@ -2,13 +2,18 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from reflector.db import get_database
|
from reflector.db import get_database
|
||||||
from reflector.db.search import SearchParameters, search_controller
|
from reflector.db.search import (
|
||||||
from reflector.db.transcripts import transcripts
|
SearchController,
|
||||||
|
SearchParameters,
|
||||||
|
SearchResult,
|
||||||
|
search_controller,
|
||||||
|
)
|
||||||
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -18,39 +23,135 @@ async def test_search_postgresql_only():
|
|||||||
assert results == []
|
assert results == []
|
||||||
assert total == 0
|
assert total == 0
|
||||||
|
|
||||||
try:
|
params_empty = SearchParameters(query_text="")
|
||||||
SearchParameters(query_text="")
|
results_empty, total_empty = await search_controller.search_transcripts(
|
||||||
assert False, "Should have raised validation error"
|
params_empty
|
||||||
except ValidationError:
|
)
|
||||||
pass # Expected
|
assert isinstance(results_empty, list)
|
||||||
|
assert isinstance(total_empty, int)
|
||||||
# Test that whitespace query raises validation error
|
|
||||||
try:
|
|
||||||
SearchParameters(query_text=" ")
|
|
||||||
assert False, "Should have raised validation error"
|
|
||||||
except ValidationError:
|
|
||||||
pass # Expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_input_validation():
|
async def test_search_with_empty_query():
|
||||||
try:
|
"""Test that empty query returns all transcripts."""
|
||||||
SearchParameters(query_text="")
|
params = SearchParameters(query_text="")
|
||||||
assert False, "Should have raised ValidationError"
|
results, total = await search_controller.search_transcripts(params)
|
||||||
except ValidationError:
|
|
||||||
pass # Expected
|
assert isinstance(results, list)
|
||||||
|
assert isinstance(total, int)
|
||||||
|
if len(results) > 1:
|
||||||
|
for i in range(len(results) - 1):
|
||||||
|
assert results[i].created_at >= results[i + 1].created_at
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_transcript_title_only_match():
|
||||||
|
"""Test that transcripts with title-only matches return empty snippets."""
|
||||||
|
test_id = "test-empty-9b3f2a8d"
|
||||||
|
|
||||||
# Test that whitespace query raises validation error
|
|
||||||
try:
|
try:
|
||||||
SearchParameters(query_text=" \t\n ")
|
await get_database().execute(
|
||||||
assert False, "Should have raised ValidationError"
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
except ValidationError:
|
)
|
||||||
pass # Expected
|
|
||||||
|
test_data = {
|
||||||
|
"id": test_id,
|
||||||
|
"name": "Empty Transcript",
|
||||||
|
"title": "Empty Meeting",
|
||||||
|
"status": "completed",
|
||||||
|
"locked": False,
|
||||||
|
"duration": 0.0,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"short_summary": None,
|
||||||
|
"long_summary": None,
|
||||||
|
"topics": json.dumps([]),
|
||||||
|
"events": json.dumps([]),
|
||||||
|
"participants": json.dumps([]),
|
||||||
|
"source_language": "en",
|
||||||
|
"target_language": "en",
|
||||||
|
"reviewed": False,
|
||||||
|
"audio_location": "local",
|
||||||
|
"share_mode": "private",
|
||||||
|
"source_kind": "room",
|
||||||
|
"webvtt": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
|
params = SearchParameters(query_text="empty")
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
found = next((r for r in results if r.id == test_id), None)
|
||||||
|
assert found is not None, "Should find transcript by title match"
|
||||||
|
assert found.search_snippets == []
|
||||||
|
assert found.total_match_count == 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_long_summary():
|
||||||
|
"""Test that long_summary content is searchable."""
|
||||||
|
test_id = "test-long-summary-8a9f3c2d"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"id": test_id,
|
||||||
|
"name": "Test Long Summary",
|
||||||
|
"title": "Regular Meeting",
|
||||||
|
"status": "completed",
|
||||||
|
"locked": False,
|
||||||
|
"duration": 1800.0,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"short_summary": "Brief overview",
|
||||||
|
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
|
||||||
|
"topics": json.dumps([]),
|
||||||
|
"events": json.dumps([]),
|
||||||
|
"participants": json.dumps([]),
|
||||||
|
"source_language": "en",
|
||||||
|
"target_language": "en",
|
||||||
|
"reviewed": False,
|
||||||
|
"audio_location": "local",
|
||||||
|
"share_mode": "private",
|
||||||
|
"source_kind": "room",
|
||||||
|
"webvtt": """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:10.000
|
||||||
|
Basic meeting content without special keywords.""",
|
||||||
|
}
|
||||||
|
|
||||||
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
|
params = SearchParameters(query_text="quantum computing")
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
found = any(r.id == test_id for r in results)
|
||||||
|
assert found, "Should find transcript by long_summary content"
|
||||||
|
|
||||||
|
test_result = next((r for r in results if r.id == test_id), None)
|
||||||
|
assert test_result
|
||||||
|
assert len(test_result.search_snippets) > 0
|
||||||
|
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgresql_search_with_data():
|
async def test_postgresql_search_with_data():
|
||||||
# collision is improbable
|
|
||||||
test_id = "test-search-e2e-7f3a9b2c"
|
test_id = "test-search-e2e-7f3a9b2c"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -94,28 +195,24 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
|
|
||||||
await get_database().execute(transcripts.insert().values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
# Test 1: Search for a word in title
|
|
||||||
params = SearchParameters(query_text="planning")
|
params = SearchParameters(query_text="planning")
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by title word"
|
assert found, "Should find test transcript by title word"
|
||||||
|
|
||||||
# Test 2: Search for a word in webvtt content
|
|
||||||
params = SearchParameters(query_text="tsvector")
|
params = SearchParameters(query_text="tsvector")
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by webvtt content"
|
assert found, "Should find test transcript by webvtt content"
|
||||||
|
|
||||||
# Test 3: Search with multiple words
|
|
||||||
params = SearchParameters(query_text="engineering planning")
|
params = SearchParameters(query_text="engineering planning")
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by multiple words"
|
assert found, "Should find test transcript by multiple words"
|
||||||
|
|
||||||
# Test 4: Verify SearchResult structure
|
|
||||||
test_result = next((r for r in results if r.id == test_id), None)
|
test_result = next((r for r in results if r.id == test_id), None)
|
||||||
if test_result:
|
if test_result:
|
||||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||||
@@ -123,14 +220,12 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
assert test_result.duration == 1800.0
|
assert test_result.duration == 1800.0
|
||||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||||
|
|
||||||
# Test 5: Search with OR operator
|
|
||||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript with OR query"
|
assert found, "Should find test transcript with OR query"
|
||||||
|
|
||||||
# Test 6: Quoted phrase search
|
|
||||||
params = SearchParameters(query_text='"full-text search"')
|
params = SearchParameters(query_text='"full-text search"')
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
@@ -142,3 +237,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
transcripts.delete().where(transcripts.c.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await get_database().disconnect()
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_params():
|
||||||
|
"""Create sample search parameters for testing."""
|
||||||
|
return SearchParameters(
|
||||||
|
query_text="test query",
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
user_id="test-user",
|
||||||
|
room_id="room1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_result():
|
||||||
|
"""Create a mock database result."""
|
||||||
|
return {
|
||||||
|
"id": "test-transcript-id",
|
||||||
|
"title": "Test Transcript",
|
||||||
|
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||||
|
"duration": 3600.0,
|
||||||
|
"status": "completed",
|
||||||
|
"user_id": "test-user",
|
||||||
|
"room_id": "room1",
|
||||||
|
"source_kind": SourceKind.LIVE,
|
||||||
|
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
|
||||||
|
"rank": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchParameters:
|
||||||
|
"""Test SearchParameters model validation and functionality."""
|
||||||
|
|
||||||
|
def test_search_parameters_with_available_filters(self):
|
||||||
|
"""Test creating SearchParameters with currently available filter options."""
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text="search term",
|
||||||
|
limit=50,
|
||||||
|
offset=10,
|
||||||
|
user_id="user123",
|
||||||
|
room_id="room1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert params.query_text == "search term"
|
||||||
|
assert params.limit == 50
|
||||||
|
assert params.offset == 10
|
||||||
|
assert params.user_id == "user123"
|
||||||
|
assert params.room_id == "room1"
|
||||||
|
|
||||||
|
def test_search_parameters_defaults(self):
|
||||||
|
"""Test SearchParameters with default values."""
|
||||||
|
params = SearchParameters(query_text="test")
|
||||||
|
|
||||||
|
assert params.query_text == "test"
|
||||||
|
assert params.limit == 20
|
||||||
|
assert params.offset == 0
|
||||||
|
assert params.user_id is None
|
||||||
|
assert params.room_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchControllerFilters:
|
||||||
|
"""Test SearchController functionality with various filters."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_source_kind_filter(self):
|
||||||
|
"""Test search filtering by source_kind."""
|
||||||
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||||
|
|
||||||
|
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||||
|
|
||||||
|
results, total = await controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
mock_db.return_value.fetch_all.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_single_room_id(self):
|
||||||
|
"""Test search filtering by single room ID (currently supported)."""
|
||||||
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||||
|
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text="test",
|
||||||
|
room_id="room1",
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
mock_db.return_value.fetch_all.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||||
|
"""Test that search results include available fields like source_kind."""
|
||||||
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
|
||||||
|
class MockRow:
|
||||||
|
def __init__(self, data):
|
||||||
|
self._data = data
|
||||||
|
self._mapping = data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.items())
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self._data.keys()
|
||||||
|
|
||||||
|
mock_row = MockRow(mock_db_result)
|
||||||
|
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
params = SearchParameters(query_text="test")
|
||||||
|
|
||||||
|
results, total = await controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert total == 1
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
result = results[0]
|
||||||
|
assert isinstance(result, SearchResult)
|
||||||
|
assert result.id == "test-transcript-id"
|
||||||
|
assert result.title == "Test Transcript"
|
||||||
|
assert result.rank == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchEndpointParsing:
|
||||||
|
"""Test parameter parsing in the search endpoint."""
|
||||||
|
|
||||||
|
def test_parse_comma_separated_room_ids(self):
|
||||||
|
"""Test parsing comma-separated room IDs."""
|
||||||
|
room_ids_str = "room1,room2,room3"
|
||||||
|
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||||
|
assert parsed == ["room1", "room2", "room3"]
|
||||||
|
|
||||||
|
room_ids_str = "room1, room2 , room3"
|
||||||
|
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||||
|
assert parsed == ["room1", "room2", "room3"]
|
||||||
|
|
||||||
|
room_ids_str = "room1,,room3,"
|
||||||
|
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||||
|
assert parsed == ["room1", "room3"]
|
||||||
|
|
||||||
|
def test_parse_source_kind(self):
|
||||||
|
"""Test parsing source_kind values."""
|
||||||
|
for kind_str in ["live", "file", "room"]:
|
||||||
|
parsed = SourceKind(kind_str)
|
||||||
|
assert parsed == SourceKind(kind_str)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SourceKind("invalid_kind")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchResultModel:
|
||||||
|
"""Test SearchResult model and serialization."""
|
||||||
|
|
||||||
|
def test_search_result_with_available_fields(self):
|
||||||
|
"""Test SearchResult model with currently available fields populated."""
|
||||||
|
result = SearchResult(
|
||||||
|
id="test-id",
|
||||||
|
title="Test Title",
|
||||||
|
user_id="user-123",
|
||||||
|
room_id="room-456",
|
||||||
|
source_kind=SourceKind.ROOM,
|
||||||
|
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||||
|
status="completed",
|
||||||
|
rank=0.85,
|
||||||
|
duration=1800.5,
|
||||||
|
search_snippets=["snippet 1", "snippet 2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.id == "test-id"
|
||||||
|
assert result.title == "Test Title"
|
||||||
|
assert result.user_id == "user-123"
|
||||||
|
assert result.room_id == "room-456"
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.rank == 0.85
|
||||||
|
assert result.duration == 1800.5
|
||||||
|
assert len(result.search_snippets) == 2
|
||||||
|
|
||||||
|
def test_search_result_with_optional_fields_none(self):
|
||||||
|
"""Test SearchResult model with optional fields as None."""
|
||||||
|
result = SearchResult(
|
||||||
|
id="test-id",
|
||||||
|
source_kind=SourceKind.FILE,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
status="processing",
|
||||||
|
rank=0.5,
|
||||||
|
search_snippets=[],
|
||||||
|
title=None,
|
||||||
|
user_id=None,
|
||||||
|
room_id=None,
|
||||||
|
duration=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.title is None
|
||||||
|
assert result.user_id is None
|
||||||
|
assert result.room_id is None
|
||||||
|
assert result.duration is None
|
||||||
|
|
||||||
|
def test_search_result_datetime_field(self):
|
||||||
|
"""Test that SearchResult accepts datetime field."""
|
||||||
|
result = SearchResult(
|
||||||
|
id="test-id",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
|
||||||
|
status="completed",
|
||||||
|
rank=0.9,
|
||||||
|
duration=None,
|
||||||
|
search_snippets=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.created_at == datetime(
|
||||||
|
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
|
||||||
|
)
|
||||||
|
|||||||
164
server/tests/test_search_long_summary.py
Normal file
164
server/tests/test_search_long_summary.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Tests for long_summary in search functionality."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.db.search import SearchParameters, search_controller
|
||||||
|
from reflector.db.transcripts import transcripts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_summary_snippet_prioritization():
|
||||||
|
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||||
|
test_id = "test-snippet-priority-3f9a2b8c"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clean up any existing test data
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"id": test_id,
|
||||||
|
"name": "Test Snippet Priority",
|
||||||
|
"title": "Meeting About Projects",
|
||||||
|
"status": "completed",
|
||||||
|
"locked": False,
|
||||||
|
"duration": 1800.0,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"short_summary": "Project discussion",
|
||||||
|
"long_summary": (
|
||||||
|
"The team discussed advanced robotics applications including "
|
||||||
|
"autonomous navigation systems and sensor fusion techniques. "
|
||||||
|
"Robotics development will focus on real-time processing."
|
||||||
|
),
|
||||||
|
"topics": json.dumps([]),
|
||||||
|
"events": json.dumps([]),
|
||||||
|
"participants": json.dumps([]),
|
||||||
|
"source_language": "en",
|
||||||
|
"target_language": "en",
|
||||||
|
"reviewed": False,
|
||||||
|
"audio_location": "local",
|
||||||
|
"share_mode": "private",
|
||||||
|
"source_kind": "room",
|
||||||
|
"webvtt": """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:10.000
|
||||||
|
We talked about many different topics today.
|
||||||
|
|
||||||
|
00:00:10.000 --> 00:00:20.000
|
||||||
|
The robotics project is making good progress.
|
||||||
|
|
||||||
|
00:00:20.000 --> 00:00:30.000
|
||||||
|
We need to consider various implementation approaches.""",
|
||||||
|
}
|
||||||
|
|
||||||
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
|
# Search for "robotics" which appears in both long_summary and webvtt
|
||||||
|
params = SearchParameters(query_text="robotics")
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
test_result = next((r for r in results if r.id == test_id), None)
|
||||||
|
assert test_result, "Should find the test transcript"
|
||||||
|
|
||||||
|
snippets = test_result.search_snippets
|
||||||
|
assert len(snippets) > 0, "Should have at least one snippet"
|
||||||
|
|
||||||
|
# The first snippets should be from long_summary (more detailed content)
|
||||||
|
first_snippet = snippets[0].lower()
|
||||||
|
assert (
|
||||||
|
"advanced robotics" in first_snippet or "autonomous" in first_snippet
|
||||||
|
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
|
||||||
|
|
||||||
|
# With max 3 snippets, we should get both from long_summary and webvtt
|
||||||
|
assert len(snippets) <= 3, "Should respect max snippets limit"
|
||||||
|
|
||||||
|
# All snippets should contain the search term
|
||||||
|
for snippet in snippets:
|
||||||
|
assert (
|
||||||
|
"robotics" in snippet.lower()
|
||||||
|
), f"Snippet should contain search term: {snippet}"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_summary_only_search():
|
||||||
|
"""Test searching for content that only exists in long_summary."""
|
||||||
|
test_id = "test-long-only-8b3c9f2a"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"id": test_id,
|
||||||
|
"name": "Test Long Only",
|
||||||
|
"title": "Standard Meeting",
|
||||||
|
"status": "completed",
|
||||||
|
"locked": False,
|
||||||
|
"duration": 1800.0,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"short_summary": "Team sync",
|
||||||
|
"long_summary": (
|
||||||
|
"Detailed analysis of cryptocurrency market trends and "
|
||||||
|
"decentralized finance protocols. Discussion included "
|
||||||
|
"yield farming strategies and liquidity pool mechanics."
|
||||||
|
),
|
||||||
|
"topics": json.dumps([]),
|
||||||
|
"events": json.dumps([]),
|
||||||
|
"participants": json.dumps([]),
|
||||||
|
"source_language": "en",
|
||||||
|
"target_language": "en",
|
||||||
|
"reviewed": False,
|
||||||
|
"audio_location": "local",
|
||||||
|
"share_mode": "private",
|
||||||
|
"source_kind": "room",
|
||||||
|
"webvtt": """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:10.000
|
||||||
|
Team meeting about general project updates.
|
||||||
|
|
||||||
|
00:00:10.000 --> 00:00:20.000
|
||||||
|
Discussion of timeline and deliverables.""",
|
||||||
|
}
|
||||||
|
|
||||||
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
|
# Search for terms only in long_summary
|
||||||
|
params = SearchParameters(query_text="cryptocurrency")
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
found = any(r.id == test_id for r in results)
|
||||||
|
assert found, "Should find transcript by long_summary-only content"
|
||||||
|
|
||||||
|
test_result = next((r for r in results if r.id == test_id), None)
|
||||||
|
assert test_result
|
||||||
|
assert len(test_result.search_snippets) > 0
|
||||||
|
|
||||||
|
# Verify the snippet is about cryptocurrency
|
||||||
|
snippet = test_result.search_snippets[0].lower()
|
||||||
|
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
|
||||||
|
|
||||||
|
# Search for "yield farming" - a more specific term
|
||||||
|
params2 = SearchParameters(query_text="yield farming")
|
||||||
|
results2, total2 = await search_controller.search_transcripts(params2)
|
||||||
|
|
||||||
|
found2 = any(r.id == test_id for r in results2)
|
||||||
|
assert found2, "Should find transcript by specific long_summary phrase"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
await get_database().disconnect()
|
||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Unit tests for search snippet generation."""
|
"""Unit tests for search snippet generation."""
|
||||||
|
|
||||||
from reflector.db.search import SearchController
|
from reflector.db.search import (
|
||||||
|
SnippetCandidate,
|
||||||
|
SnippetGenerator,
|
||||||
|
WebVTTProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestExtractWebVTT:
|
class TestExtractWebVTT:
|
||||||
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
|
|||||||
00:00:10.000 --> 00:00:20.000
|
00:00:10.000 --> 00:00:20.000
|
||||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||||
"""
|
"""
|
||||||
result = SearchController._extract_webvtt_text(webvtt)
|
result = WebVTTProcessor.extract_text(webvtt)
|
||||||
assert "Hello world, this is a test" in result
|
assert "Hello world, this is a test" in result
|
||||||
assert "Indeed it is a test" in result
|
assert "Indeed it is a test" in result
|
||||||
assert "<v Speaker" not in result
|
assert "<v Speaker" not in result
|
||||||
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
|
|||||||
|
|
||||||
def test_extract_empty_webvtt(self):
|
def test_extract_empty_webvtt(self):
|
||||||
"""Test empty WebVTT returns empty string."""
|
"""Test empty WebVTT returns empty string."""
|
||||||
assert SearchController._extract_webvtt_text("") == ""
|
assert WebVTTProcessor.extract_text("") == ""
|
||||||
assert SearchController._extract_webvtt_text(None) == ""
|
|
||||||
|
|
||||||
def test_extract_malformed_webvtt(self):
|
def test_extract_malformed_webvtt(self):
|
||||||
"""Test malformed WebVTT returns empty string."""
|
"""Test malformed WebVTT returns empty string."""
|
||||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
|
|||||||
|
|
||||||
def test_multiple_matches(self):
|
def test_multiple_matches(self):
|
||||||
"""Test finding multiple occurrences of search term in long text."""
|
"""Test finding multiple occurrences of search term in long text."""
|
||||||
# Create text with Python mentions far apart to get separate snippets
|
separator = " This is filler text. " * 20
|
||||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
|
||||||
text = (
|
text = (
|
||||||
"Python is great for machine learning."
|
"Python is great for machine learning."
|
||||||
+ separator
|
+ separator
|
||||||
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
|
|||||||
+ "The Python community is very supportive."
|
+ "The Python community is very supportive."
|
||||||
)
|
)
|
||||||
|
|
||||||
snippets = SearchController._generate_snippets(text, "Python")
|
snippets = SnippetGenerator.generate(text, "Python")
|
||||||
# With enough separation, we should get multiple snippets
|
assert len(snippets) >= 2
|
||||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
|
||||||
|
|
||||||
# Each snippet should contain "Python"
|
|
||||||
for snippet in snippets:
|
for snippet in snippets:
|
||||||
assert "python" in snippet.lower()
|
assert "python" in snippet.lower()
|
||||||
|
|
||||||
def test_single_match(self):
|
def test_single_match(self):
|
||||||
"""Test single occurrence returns one snippet."""
|
"""Test single occurrence returns one snippet."""
|
||||||
text = "This document discusses artificial intelligence and its applications."
|
text = "This document discusses artificial intelligence and its applications."
|
||||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
snippets = SnippetGenerator.generate(text, "artificial intelligence")
|
||||||
|
|
||||||
assert len(snippets) == 1
|
assert len(snippets) == 1
|
||||||
assert "artificial intelligence" in snippets[0].lower()
|
assert "artificial intelligence" in snippets[0].lower()
|
||||||
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
|
|||||||
def test_no_matches(self):
|
def test_no_matches(self):
|
||||||
"""Test no matches returns empty list."""
|
"""Test no matches returns empty list."""
|
||||||
text = "This is some random text without the search term."
|
text = "This is some random text without the search term."
|
||||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||||
|
|
||||||
assert snippets == []
|
assert snippets == []
|
||||||
|
|
||||||
def test_case_insensitive_search(self):
|
def test_case_insensitive_search(self):
|
||||||
"""Test search is case insensitive."""
|
"""Test search is case insensitive."""
|
||||||
# Add enough text between matches to get separate snippets
|
|
||||||
text = (
|
text = (
|
||||||
"MACHINE LEARNING is important for modern applications. "
|
"MACHINE LEARNING is important for modern applications. "
|
||||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
+ "It requires lots of data and computational resources. " * 5
|
||||||
+ "Machine Learning rocks and transforms industries. "
|
+ "Machine Learning rocks and transforms industries. "
|
||||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
+ "Deep learning is a subset of it. " * 5
|
||||||
+ "Finally, machine learning will shape our future."
|
+ "Finally, machine learning will shape our future."
|
||||||
)
|
)
|
||||||
|
|
||||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||||
|
|
||||||
# Should find at least 2 (might be 3 if text is long enough)
|
|
||||||
assert len(snippets) >= 2
|
assert len(snippets) >= 2
|
||||||
for snippet in snippets:
|
for snippet in snippets:
|
||||||
assert "machine learning" in snippet.lower()
|
assert "machine learning" in snippet.lower()
|
||||||
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
|
|||||||
def test_partial_match_fallback(self):
|
def test_partial_match_fallback(self):
|
||||||
"""Test fallback to first word when exact phrase not found."""
|
"""Test fallback to first word when exact phrase not found."""
|
||||||
text = "We use machine intelligence for processing."
|
text = "We use machine intelligence for processing."
|
||||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||||
|
|
||||||
# Should fall back to finding "machine"
|
|
||||||
assert len(snippets) == 1
|
assert len(snippets) == 1
|
||||||
assert "machine" in snippets[0].lower()
|
assert "machine" in snippets[0].lower()
|
||||||
|
|
||||||
def test_snippet_ellipsis(self):
|
def test_snippet_ellipsis(self):
|
||||||
"""Test ellipsis added for truncated snippets."""
|
"""Test ellipsis added for truncated snippets."""
|
||||||
# Long text where match is in the middle
|
|
||||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
|
||||||
|
|
||||||
assert len(snippets) == 1
|
assert len(snippets) == 1
|
||||||
assert "..." in snippets[0] # Should have ellipsis
|
assert "..." in snippets[0]
|
||||||
assert "TARGET_WORD" in snippets[0]
|
assert "TARGET_WORD" in snippets[0]
|
||||||
|
|
||||||
def test_overlapping_snippets_deduplicated(self):
|
def test_overlapping_snippets_deduplicated(self):
|
||||||
"""Test overlapping matches don't create duplicate snippets."""
|
"""Test overlapping matches don't create duplicate snippets."""
|
||||||
text = "test test test word" * 10 # Repeated pattern
|
text = "test test test word" * 10
|
||||||
snippets = SearchController._generate_snippets(text, "test")
|
snippets = SnippetGenerator.generate(text, "test")
|
||||||
|
|
||||||
# Should get unique snippets, not duplicates
|
|
||||||
assert len(snippets) <= 3
|
assert len(snippets) <= 3
|
||||||
assert len(snippets) == len(set(snippets)) # All unique
|
assert len(snippets) == len(set(snippets))
|
||||||
|
|
||||||
def test_empty_inputs(self):
|
def test_empty_inputs(self):
|
||||||
"""Test empty text or search term returns empty list."""
|
"""Test empty text or search term returns empty list."""
|
||||||
assert SearchController._generate_snippets("", "search") == []
|
assert SnippetGenerator.generate("", "search") == []
|
||||||
assert SearchController._generate_snippets("text", "") == []
|
assert SnippetGenerator.generate("text", "") == []
|
||||||
assert SearchController._generate_snippets("", "") == []
|
assert SnippetGenerator.generate("", "") == []
|
||||||
|
|
||||||
def test_max_snippets_limit(self):
|
def test_max_snippets_limit(self):
|
||||||
"""Test respects max_snippets parameter."""
|
"""Test respects max_snippets parameter."""
|
||||||
# Create text with well-separated occurrences
|
separator = " filler " * 50
|
||||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
text = ("Python is amazing" + separator) * 10
|
||||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
|
||||||
|
|
||||||
# Test with different limits
|
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
|
||||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
|
||||||
assert len(snippets_1) == 1
|
assert len(snippets_1) == 1
|
||||||
|
|
||||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
|
||||||
assert len(snippets_2) == 2
|
assert len(snippets_2) == 2
|
||||||
|
|
||||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
|
||||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
assert len(snippets_5) == 5
|
||||||
|
|
||||||
def test_snippet_length(self):
|
def test_snippet_length(self):
|
||||||
"""Test snippet length is reasonable."""
|
"""Test snippet length is reasonable."""
|
||||||
text = "word " * 200 # Long text
|
text = "word " * 200
|
||||||
snippets = SearchController._generate_snippets(text, "word")
|
snippets = SnippetGenerator.generate(text, "word")
|
||||||
|
|
||||||
for snippet in snippets:
|
for snippet in snippets:
|
||||||
# Default max_length is 150 + some context
|
assert len(snippet) <= 200
|
||||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
|
||||||
|
|
||||||
|
|
||||||
class TestFullPipeline:
|
class TestFullPipeline:
|
||||||
@@ -157,7 +149,6 @@ class TestFullPipeline:
|
|||||||
|
|
||||||
def test_webvtt_to_snippets_integration(self):
|
def test_webvtt_to_snippets_integration(self):
|
||||||
"""Test full pipeline from WebVTT to search snippets."""
|
"""Test full pipeline from WebVTT to search snippets."""
|
||||||
# Create WebVTT with well-separated content for multiple snippets
|
|
||||||
webvtt = (
|
webvtt = (
|
||||||
"""WEBVTT
|
"""WEBVTT
|
||||||
|
|
||||||
@@ -182,17 +173,362 @@ class TestFullPipeline:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract and generate snippets
|
plain_text = WebVTTProcessor.extract_text(webvtt)
|
||||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
snippets = SnippetGenerator.generate(plain_text, "machine learning")
|
||||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
|
||||||
|
|
||||||
# Should find at least 2 snippets (text might still be close together)
|
assert len(snippets) >= 1
|
||||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
assert len(snippets) <= 3
|
||||||
assert len(snippets) <= 3 # At most 3 by default
|
|
||||||
|
|
||||||
# No WebVTT artifacts in snippets
|
|
||||||
for snippet in snippets:
|
for snippet in snippets:
|
||||||
assert "machine learning" in snippet.lower()
|
assert "machine learning" in snippet.lower()
|
||||||
assert "<v Speaker" not in snippet
|
assert "<v Speaker" not in snippet
|
||||||
assert "00:00" not in snippet
|
assert "00:00" not in snippet
|
||||||
assert "-->" not in snippet
|
assert "-->" not in snippet
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiWordQueryBehavior:
|
||||||
|
"""Tests for multi-word query behavior and exact phrase matching."""
|
||||||
|
|
||||||
|
def test_multi_word_query_snippet_behavior(self):
|
||||||
|
"""Test that multi-word queries generate snippets based on exact phrase matching."""
|
||||||
|
sample_text = """This is a sample transcript where user Alice is talking.
|
||||||
|
Later in the conversation, jordan mentions something important.
|
||||||
|
The user jordan collaboration was successful.
|
||||||
|
Another user named Bob joins the discussion."""
|
||||||
|
|
||||||
|
user_snippets = SnippetGenerator.generate(sample_text, "user")
|
||||||
|
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
|
||||||
|
|
||||||
|
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
|
||||||
|
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
|
||||||
|
|
||||||
|
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||||
|
assert len(multi_word_snippets) == 1, (
|
||||||
|
"Should return exactly 1 snippet for 'user jordan' "
|
||||||
|
"(only the exact phrase match, not individual word occurrences)"
|
||||||
|
)
|
||||||
|
|
||||||
|
snippet = multi_word_snippets[0]
|
||||||
|
assert (
|
||||||
|
"user jordan" in snippet.lower()
|
||||||
|
), "The snippet should contain the exact phrase 'user jordan'"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"alice" not in snippet.lower()
|
||||||
|
), "The snippet should not include the first standalone 'user' with Alice"
|
||||||
|
|
||||||
|
def test_multi_word_query_without_exact_match(self):
|
||||||
|
"""Test snippet generation when exact phrase is not found."""
|
||||||
|
sample_text = """User Alice is here. Bob and jordan are talking.
|
||||||
|
Later jordan mentions something. The user is happy."""
|
||||||
|
|
||||||
|
snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(snippets) >= 1
|
||||||
|
), "Should find at least 1 snippet when falling back to first word"
|
||||||
|
|
||||||
|
all_snippets_text = " ".join(snippets).lower()
|
||||||
|
assert (
|
||||||
|
"user" in all_snippets_text
|
||||||
|
), "Snippets should contain 'user' (the first word)"
|
||||||
|
|
||||||
|
def test_exact_phrase_at_text_boundaries(self):
|
||||||
|
"""Test snippet generation when exact phrase appears at text boundaries."""
|
||||||
|
|
||||||
|
text_start = "user jordan started the meeting. Other content here."
|
||||||
|
snippets = SnippetGenerator.generate(text_start, "user jordan")
|
||||||
|
assert len(snippets) == 1
|
||||||
|
assert "user jordan" in snippets[0].lower()
|
||||||
|
|
||||||
|
text_end = "Other content here. The meeting ended with user jordan"
|
||||||
|
snippets = SnippetGenerator.generate(text_end, "user jordan")
|
||||||
|
assert len(snippets) == 1
|
||||||
|
assert "user jordan" in snippets[0].lower()
|
||||||
|
|
||||||
|
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
|
||||||
|
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
|
||||||
|
sample_text = """This is a sample transcript where user Alice is talking.
|
||||||
|
Later in the conversation, jordan mentions something important.
|
||||||
|
The user jordan collaboration was successful.
|
||||||
|
Another user named Bob joins the discussion."""
|
||||||
|
|
||||||
|
search_query = "user jordan"
|
||||||
|
snippets = SnippetGenerator.generate(sample_text, search_query)
|
||||||
|
|
||||||
|
assert len(snippets) == 1, (
|
||||||
|
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
|
||||||
|
f"got {len(snippets)}. Should ignore individual word occurrences."
|
||||||
|
)
|
||||||
|
|
||||||
|
snippet = snippets[0]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
search_query in snippet.lower()
|
||||||
|
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"jordan mentions" in snippet.lower()
|
||||||
|
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"alice" not in snippet.lower()
|
||||||
|
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
|
||||||
|
|
||||||
|
text_2 = """The alpha version was released.
|
||||||
|
Beta testing started yesterday.
|
||||||
|
The alpha beta integration is complete."""
|
||||||
|
|
||||||
|
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
|
||||||
|
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
|
||||||
|
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
|
||||||
|
assert (
|
||||||
|
"version" not in snippets_2[0].lower()
|
||||||
|
), "Should not include first separate occurrence"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSnippetGenerationEnhanced:
|
||||||
|
"""Additional snippet generation tests from test_search_enhancements.py."""
|
||||||
|
|
||||||
|
def test_snippet_generation_from_webvtt(self):
|
||||||
|
"""Test snippet generation from WebVTT content."""
|
||||||
|
webvtt_content = """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:05.000
|
||||||
|
This is the beginning of the transcript
|
||||||
|
|
||||||
|
00:00:05.000 --> 00:00:10.000
|
||||||
|
The search term appears here in the middle
|
||||||
|
|
||||||
|
00:00:10.000 --> 00:00:15.000
|
||||||
|
And this is the end of the content"""
|
||||||
|
|
||||||
|
plain_text = WebVTTProcessor.extract_text(webvtt_content)
|
||||||
|
snippets = SnippetGenerator.generate(plain_text, "search term")
|
||||||
|
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert any("search term" in snippet.lower() for snippet in snippets)
|
||||||
|
|
||||||
|
def test_extract_webvtt_text_with_malformed_variations(self):
|
||||||
|
"""Test WebVTT extraction with various malformed content."""
|
||||||
|
malformed_vtt = "This is not valid WebVTT content"
|
||||||
|
result = WebVTTProcessor.extract_text(malformed_vtt)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
partial_vtt = "WEBVTT\nNo timestamps here"
|
||||||
|
result = WebVTTProcessor.extract_text(partial_vtt)
|
||||||
|
assert result == "" or "No timestamps" not in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestPureFunctions:
|
||||||
|
"""Test the pure functions extracted for functional programming."""
|
||||||
|
|
||||||
|
def test_find_all_matches(self):
|
||||||
|
"""Test finding all match positions in text."""
|
||||||
|
text = "Python is great. Python is powerful. I love Python."
|
||||||
|
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
|
||||||
|
assert matches == [0, 17, 44]
|
||||||
|
|
||||||
|
matches = list(SnippetGenerator.find_all_matches(text, "python"))
|
||||||
|
assert matches == [0, 17, 44]
|
||||||
|
|
||||||
|
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
|
||||||
|
assert matches == []
|
||||||
|
|
||||||
|
matches = list(SnippetGenerator.find_all_matches("", "test"))
|
||||||
|
assert matches == []
|
||||||
|
matches = list(SnippetGenerator.find_all_matches("test", ""))
|
||||||
|
assert matches == []
|
||||||
|
|
||||||
|
def test_create_snippet(self):
|
||||||
|
"""Test creating a snippet from a match position."""
|
||||||
|
text = "This is a long text with the word Python in the middle and more text after."
|
||||||
|
|
||||||
|
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
|
||||||
|
assert "Python" in snippet.text()
|
||||||
|
assert snippet.start >= 0
|
||||||
|
assert snippet.end <= len(text)
|
||||||
|
assert isinstance(snippet, SnippetCandidate)
|
||||||
|
|
||||||
|
assert len(snippet.text()) > 0
|
||||||
|
assert snippet.start <= snippet.end
|
||||||
|
|
||||||
|
long_text = "A" * 200
|
||||||
|
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
|
||||||
|
assert snippet.text().startswith("...")
|
||||||
|
assert snippet.text().endswith("...")
|
||||||
|
|
||||||
|
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
|
||||||
|
assert snippet.start == 0
|
||||||
|
assert "short text" in snippet.text()
|
||||||
|
|
||||||
|
def test_filter_non_overlapping(self):
|
||||||
|
"""Test filtering overlapping snippets."""
|
||||||
|
candidates = [
|
||||||
|
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
|
||||||
|
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
|
||||||
|
SnippetCandidate(
|
||||||
|
_text="Third snippet", start=40, _original_text_length=100
|
||||||
|
),
|
||||||
|
SnippetCandidate(
|
||||||
|
_text="Fourth snippet", start=65, _original_text_length=100
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
|
||||||
|
assert filtered == [
|
||||||
|
"First snippet...",
|
||||||
|
"...Third snippet...",
|
||||||
|
"...Fourth snippet...",
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
def test_generate_integration(self):
|
||||||
|
"""Test the main SnippetGenerator.generate function."""
|
||||||
|
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
|
||||||
|
|
||||||
|
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||||
|
assert len(snippets) <= 3
|
||||||
|
assert all("machine learning" in s.lower() for s in snippets)
|
||||||
|
|
||||||
|
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
|
||||||
|
assert len(snippets) <= 2
|
||||||
|
|
||||||
|
snippets = SnippetGenerator.generate(text, "machine vision")
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert any("machine" in s.lower() for s in snippets)
|
||||||
|
|
||||||
|
def test_extract_webvtt_text_basic(self):
|
||||||
|
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
|
||||||
|
webvtt = """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:02.000
|
||||||
|
Hello world
|
||||||
|
|
||||||
|
00:00:02.000 --> 00:00:04.000
|
||||||
|
This is a test"""
|
||||||
|
|
||||||
|
result = WebVTTProcessor.extract_text(webvtt)
|
||||||
|
assert "Hello world" in result
|
||||||
|
assert "This is a test" in result
|
||||||
|
|
||||||
|
# Test empty input
|
||||||
|
assert WebVTTProcessor.extract_text("") == ""
|
||||||
|
assert WebVTTProcessor.extract_text(None) == ""
|
||||||
|
|
||||||
|
def test_generate_webvtt_snippets(self):
|
||||||
|
"""Test generating snippets from WebVTT content."""
|
||||||
|
webvtt = """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:02.000
|
||||||
|
Python programming is great
|
||||||
|
|
||||||
|
00:00:02.000 --> 00:00:04.000
|
||||||
|
Learn Python today"""
|
||||||
|
|
||||||
|
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert any("Python" in s for s in snippets)
|
||||||
|
|
||||||
|
snippets = WebVTTProcessor.generate_snippets("", "Python")
|
||||||
|
assert snippets == []
|
||||||
|
|
||||||
|
def test_from_summary(self):
|
||||||
|
"""Test generating snippets from summary text."""
|
||||||
|
summary = "This meeting discussed Python development and machine learning applications."
|
||||||
|
|
||||||
|
snippets = SnippetGenerator.from_summary(summary, "Python")
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert any("Python" in s for s in snippets)
|
||||||
|
|
||||||
|
long_summary = "Python " * 20
|
||||||
|
snippets = SnippetGenerator.from_summary(long_summary, "Python")
|
||||||
|
assert len(snippets) <= 2
|
||||||
|
|
||||||
|
def test_combine_sources(self):
|
||||||
|
"""Test combining snippets from multiple sources."""
|
||||||
|
summary = "Python is a great programming language."
|
||||||
|
webvtt = """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:02.000
|
||||||
|
Learn Python programming
|
||||||
|
|
||||||
|
00:00:02.000 --> 00:00:04.000
|
||||||
|
Python is powerful"""
|
||||||
|
|
||||||
|
snippets, total_count = SnippetGenerator.combine_sources(
|
||||||
|
summary, webvtt, "Python", max_total=3
|
||||||
|
)
|
||||||
|
assert len(snippets) <= 3
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert total_count > 0
|
||||||
|
|
||||||
|
snippets, total_count = SnippetGenerator.combine_sources(
|
||||||
|
summary, None, "Python", max_total=3
|
||||||
|
)
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert all("Python" in s for s in snippets)
|
||||||
|
assert total_count == 1
|
||||||
|
|
||||||
|
snippets, total_count = SnippetGenerator.combine_sources(
|
||||||
|
None, webvtt, "Python", max_total=3
|
||||||
|
)
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert total_count == 2
|
||||||
|
|
||||||
|
long_summary = "Python " * 10
|
||||||
|
snippets, total_count = SnippetGenerator.combine_sources(
|
||||||
|
long_summary, webvtt, "Python", max_total=2
|
||||||
|
)
|
||||||
|
assert len(snippets) == 2
|
||||||
|
assert total_count >= 10
|
||||||
|
|
||||||
|
def test_match_counting_sum_logic(self):
|
||||||
|
"""Test that match counting correctly sums matches from both sources."""
|
||||||
|
summary = "data science uses data analysis and data mining techniques"
|
||||||
|
webvtt = """WEBVTT
|
||||||
|
|
||||||
|
00:00:00.000 --> 00:00:02.000
|
||||||
|
Big data processing
|
||||||
|
|
||||||
|
00:00:02.000 --> 00:00:04.000
|
||||||
|
data visualization and data storage"""
|
||||||
|
|
||||||
|
snippets, total_count = SnippetGenerator.combine_sources(
|
||||||
|
summary, webvtt, "data", max_total=3
|
||||||
|
)
|
||||||
|
assert total_count == 6
|
||||||
|
assert len(snippets) <= 3
|
||||||
|
|
||||||
|
summary_snippets, summary_count = SnippetGenerator.combine_sources(
|
||||||
|
summary, None, "data", max_total=3
|
||||||
|
)
|
||||||
|
assert summary_count == 3
|
||||||
|
|
||||||
|
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
|
||||||
|
None, webvtt, "data", max_total=3
|
||||||
|
)
|
||||||
|
assert webvtt_count == 3
|
||||||
|
|
||||||
|
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
||||||
|
None, None, "data", max_total=3
|
||||||
|
)
|
||||||
|
assert snippets_empty == []
|
||||||
|
assert count_empty == 0
|
||||||
|
|
||||||
|
def test_edge_cases(self):
|
||||||
|
"""Test edge cases for the pure functions."""
|
||||||
|
text = "Test with special: @#$%^&*() characters"
|
||||||
|
snippets = SnippetGenerator.generate(text, "@#$%")
|
||||||
|
assert len(snippets) > 0
|
||||||
|
|
||||||
|
long_query = "a" * 100
|
||||||
|
snippets = SnippetGenerator.generate("Some text", long_query)
|
||||||
|
assert snippets == []
|
||||||
|
|
||||||
|
text = "Unicode test: café, naïve, 日本語"
|
||||||
|
snippets = SnippetGenerator.generate(text, "café")
|
||||||
|
assert len(snippets) > 0
|
||||||
|
assert "café" in snippets[0]
|
||||||
|
|||||||
873
server/uv.lock
generated
873
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,26 +1,67 @@
|
|||||||
import React from "react";
|
import React, { useEffect } from "react";
|
||||||
import { Pagination, IconButton, ButtonGroup } from "@chakra-ui/react";
|
import { Pagination, IconButton, ButtonGroup } from "@chakra-ui/react";
|
||||||
import { LuChevronLeft, LuChevronRight } from "react-icons/lu";
|
import { LuChevronLeft, LuChevronRight } from "react-icons/lu";
|
||||||
|
|
||||||
|
// explicitly 1-based to prevent +/-1-confusion errors
|
||||||
|
export const FIRST_PAGE = 1 as PaginationPage;
|
||||||
|
export const parsePaginationPage = (
|
||||||
|
page: number,
|
||||||
|
):
|
||||||
|
| {
|
||||||
|
value: PaginationPage;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
error: string;
|
||||||
|
} => {
|
||||||
|
if (page < FIRST_PAGE)
|
||||||
|
return {
|
||||||
|
error: "Page must be greater than 0",
|
||||||
|
};
|
||||||
|
if (!Number.isInteger(page))
|
||||||
|
return {
|
||||||
|
error: "Page must be an integer",
|
||||||
|
};
|
||||||
|
return {
|
||||||
|
value: page as PaginationPage,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
export type PaginationPage = number & { __brand: "PaginationPage" };
|
||||||
|
export const PaginationPage = (page: number): PaginationPage => {
|
||||||
|
const v = parsePaginationPage(page);
|
||||||
|
if ("error" in v) throw new Error(v.error);
|
||||||
|
return v.value;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const paginationPageTo0Based = (page: PaginationPage): number =>
|
||||||
|
page - FIRST_PAGE;
|
||||||
|
|
||||||
type PaginationProps = {
|
type PaginationProps = {
|
||||||
page: number;
|
page: PaginationPage;
|
||||||
setPage: (page: number) => void;
|
setPage: (page: PaginationPage) => void;
|
||||||
total: number;
|
total: number;
|
||||||
size: number;
|
size: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const totalPages = (total: number, size: number) => {
|
||||||
|
return Math.ceil(total / size);
|
||||||
|
};
|
||||||
|
|
||||||
export default function PaginationComponent(props: PaginationProps) {
|
export default function PaginationComponent(props: PaginationProps) {
|
||||||
const { page, setPage, total, size } = props;
|
const { page, setPage, total, size } = props;
|
||||||
const totalPages = Math.ceil(total / size);
|
useEffect(() => {
|
||||||
|
if (page > totalPages(total, size)) {
|
||||||
if (totalPages <= 1) return null;
|
console.error(
|
||||||
|
`Page number (${page}) is greater than total pages (${totalPages}) in pagination`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}, [page, totalPages(total, size)]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Pagination.Root
|
<Pagination.Root
|
||||||
count={total}
|
count={total}
|
||||||
pageSize={size}
|
pageSize={size}
|
||||||
page={page}
|
page={page}
|
||||||
onPageChange={(details) => setPage(details.page)}
|
onPageChange={(details) => setPage(PaginationPage(details.page))}
|
||||||
style={{ display: "flex", justifyContent: "center" }}
|
style={{ display: "flex", justifyContent: "center" }}
|
||||||
>
|
>
|
||||||
<ButtonGroup variant="ghost" size="xs">
|
<ButtonGroup variant="ghost" size="xs">
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
import React, { useState } from "react";
|
|
||||||
import { Flex, Input, Button } from "@chakra-ui/react";
|
|
||||||
|
|
||||||
interface SearchBarProps {
|
|
||||||
onSearch: (searchTerm: string) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function SearchBar({ onSearch }: SearchBarProps) {
|
|
||||||
const [searchInputValue, setSearchInputValue] = useState("");
|
|
||||||
|
|
||||||
const handleSearch = () => {
|
|
||||||
onSearch(searchInputValue);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleKeyDown = (event: React.KeyboardEvent) => {
|
|
||||||
if (event.key === "Enter") {
|
|
||||||
handleSearch();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex alignItems="center">
|
|
||||||
<Input
|
|
||||||
placeholder="Search transcriptions..."
|
|
||||||
value={searchInputValue}
|
|
||||||
onChange={(e) => setSearchInputValue(e.target.value)}
|
|
||||||
onKeyDown={handleKeyDown}
|
|
||||||
/>
|
|
||||||
<Button ml={2} onClick={handleSearch}>
|
|
||||||
Search
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -4,8 +4,8 @@ import { LuMenu, LuTrash, LuRotateCw } from "react-icons/lu";
|
|||||||
|
|
||||||
interface TranscriptActionsMenuProps {
|
interface TranscriptActionsMenuProps {
|
||||||
transcriptId: string;
|
transcriptId: string;
|
||||||
onDelete: (transcriptId: string) => (e: any) => void;
|
onDelete: (transcriptId: string) => void;
|
||||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
onReprocess: (transcriptId: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function TranscriptActionsMenu({
|
export default function TranscriptActionsMenu({
|
||||||
@@ -24,11 +24,17 @@ export default function TranscriptActionsMenu({
|
|||||||
<Menu.Content>
|
<Menu.Content>
|
||||||
<Menu.Item
|
<Menu.Item
|
||||||
value="reprocess"
|
value="reprocess"
|
||||||
onClick={(e) => onReprocess(transcriptId)(e)}
|
onClick={() => onReprocess(transcriptId)}
|
||||||
>
|
>
|
||||||
<LuRotateCw /> Reprocess
|
<LuRotateCw /> Reprocess
|
||||||
</Menu.Item>
|
</Menu.Item>
|
||||||
<Menu.Item value="delete" onClick={(e) => onDelete(transcriptId)(e)}>
|
<Menu.Item
|
||||||
|
value="delete"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onDelete(transcriptId);
|
||||||
|
}}
|
||||||
|
>
|
||||||
<LuTrash /> Delete
|
<LuTrash /> Delete
|
||||||
</Menu.Item>
|
</Menu.Item>
|
||||||
</Menu.Content>
|
</Menu.Content>
|
||||||
|
|||||||
@@ -1,27 +1,290 @@
|
|||||||
import React from "react";
|
import React, { useState } from "react";
|
||||||
import { Box, Stack, Text, Flex, Link, Spinner } from "@chakra-ui/react";
|
import {
|
||||||
|
Box,
|
||||||
|
Stack,
|
||||||
|
Text,
|
||||||
|
Flex,
|
||||||
|
Link,
|
||||||
|
Spinner,
|
||||||
|
Badge,
|
||||||
|
HStack,
|
||||||
|
VStack,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import NextLink from "next/link";
|
import NextLink from "next/link";
|
||||||
import { GetTranscriptMinimal } from "../../../api";
|
|
||||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
||||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
||||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
||||||
|
import {
|
||||||
|
highlightMatches,
|
||||||
|
generateTextFragment,
|
||||||
|
} from "../../../lib/textHighlight";
|
||||||
|
import { SearchResult } from "../../../api";
|
||||||
|
|
||||||
interface TranscriptCardsProps {
|
interface TranscriptCardsProps {
|
||||||
transcripts: GetTranscriptMinimal[];
|
results: SearchResult[];
|
||||||
onDelete: (transcriptId: string) => (e: any) => void;
|
query: string;
|
||||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
isLoading?: boolean;
|
||||||
loading?: boolean;
|
onDelete: (transcriptId: string) => void;
|
||||||
|
onReprocess: (transcriptId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
function highlightText(text: string, query: string): React.ReactNode {
|
||||||
|
if (!query) return text;
|
||||||
|
|
||||||
|
const matches = highlightMatches(text, query);
|
||||||
|
|
||||||
|
if (matches.length === 0) return text;
|
||||||
|
|
||||||
|
// Sort matches by index to process them in order
|
||||||
|
const sortedMatches = [...matches].sort((a, b) => a.index - b.index);
|
||||||
|
|
||||||
|
const parts: React.ReactNode[] = [];
|
||||||
|
let lastIndex = 0;
|
||||||
|
|
||||||
|
sortedMatches.forEach((match, i) => {
|
||||||
|
// Add text before the match
|
||||||
|
if (match.index > lastIndex) {
|
||||||
|
parts.push(
|
||||||
|
<Text as="span" key={`text-${i}`} display="inline">
|
||||||
|
{text.slice(lastIndex, match.index)}
|
||||||
|
</Text>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the highlighted match
|
||||||
|
parts.push(
|
||||||
|
<Text
|
||||||
|
as="mark"
|
||||||
|
key={`match-${i}`}
|
||||||
|
bg="yellow.200"
|
||||||
|
px={0.5}
|
||||||
|
display="inline"
|
||||||
|
>
|
||||||
|
{match.match}
|
||||||
|
</Text>,
|
||||||
|
);
|
||||||
|
|
||||||
|
lastIndex = match.index + match.match.length;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add remaining text after last match
|
||||||
|
if (lastIndex < text.length) {
|
||||||
|
parts.push(
|
||||||
|
<Text as="span" key={`text-end`} display="inline">
|
||||||
|
{text.slice(lastIndex)}
|
||||||
|
</Text>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
const transcriptHref = (
|
||||||
|
transcriptId: string,
|
||||||
|
mainSnippet: string,
|
||||||
|
query: string,
|
||||||
|
): `/transcripts/${string}` => {
|
||||||
|
const urlTextFragment = mainSnippet
|
||||||
|
? generateTextFragment(mainSnippet, query)
|
||||||
|
: null;
|
||||||
|
const urlTextFragmentWithHash = urlTextFragment
|
||||||
|
? `#${urlTextFragment.k}=${encodeURIComponent(urlTextFragment.v)}`
|
||||||
|
: "";
|
||||||
|
return `/transcripts/${transcriptId}${urlTextFragmentWithHash}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
// note that it's strongly tied to search logic - in case you want to use it independently, refactor
|
||||||
|
function TranscriptCard({
|
||||||
|
result,
|
||||||
|
query,
|
||||||
|
onDelete,
|
||||||
|
onReprocess,
|
||||||
|
}: {
|
||||||
|
result: SearchResult;
|
||||||
|
query: string;
|
||||||
|
onDelete: (transcriptId: string) => void;
|
||||||
|
onReprocess: (transcriptId: string) => void;
|
||||||
|
}) {
|
||||||
|
const [isExpanded, setIsExpanded] = useState(false);
|
||||||
|
|
||||||
|
const mainSnippet = result.search_snippets[0];
|
||||||
|
const additionalSnippets = result.search_snippets.slice(1);
|
||||||
|
const totalMatches = result.total_match_count || 0;
|
||||||
|
const snippetsShown = result.search_snippets.length;
|
||||||
|
const remainingMatches = totalMatches - snippetsShown;
|
||||||
|
const hasAdditionalSnippets = additionalSnippets.length > 0;
|
||||||
|
const resultTitle = result.title || "Unnamed Transcript";
|
||||||
|
|
||||||
|
const formattedDuration = result.duration
|
||||||
|
? formatTimeMs(result.duration)
|
||||||
|
: "N/A";
|
||||||
|
const formattedDate = formatLocalDate(result.created_at);
|
||||||
|
const source =
|
||||||
|
result.source_kind === "room"
|
||||||
|
? result.room_name || result.room_id
|
||||||
|
: result.source_kind;
|
||||||
|
|
||||||
|
const handleExpandClick = (e: React.MouseEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
setIsExpanded(!isExpanded);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
|
||||||
|
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||||
|
<Box>
|
||||||
|
<TranscriptStatusIcon status={result.status} />
|
||||||
|
</Box>
|
||||||
|
<Box flex="1">
|
||||||
|
{/* Title with highlighting and text fragment for deep linking */}
|
||||||
|
<Link
|
||||||
|
as={NextLink}
|
||||||
|
href={transcriptHref(result.id, mainSnippet, query)}
|
||||||
|
fontWeight="600"
|
||||||
|
display="block"
|
||||||
|
mb={2}
|
||||||
|
>
|
||||||
|
{highlightText(resultTitle, query)}
|
||||||
|
</Link>
|
||||||
|
|
||||||
|
{/* Metadata - Horizontal on desktop, vertical on mobile */}
|
||||||
|
<Flex
|
||||||
|
direction={{ base: "column", md: "row" }}
|
||||||
|
gap={{ base: 1, md: 2 }}
|
||||||
|
fontSize="xs"
|
||||||
|
color="gray.600"
|
||||||
|
flexWrap="wrap"
|
||||||
|
align={{ base: "flex-start", md: "center" }}
|
||||||
|
>
|
||||||
|
<Flex align="center" gap={1}>
|
||||||
|
<Text fontWeight="medium" color="gray.500">
|
||||||
|
Source:
|
||||||
|
</Text>
|
||||||
|
<Text>{source}</Text>
|
||||||
|
</Flex>
|
||||||
|
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||||
|
•
|
||||||
|
</Text>
|
||||||
|
<Flex align="center" gap={1}>
|
||||||
|
<Text fontWeight="medium" color="gray.500">
|
||||||
|
Date:
|
||||||
|
</Text>
|
||||||
|
<Text>{formattedDate}</Text>
|
||||||
|
</Flex>
|
||||||
|
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||||
|
•
|
||||||
|
</Text>
|
||||||
|
<Flex align="center" gap={1}>
|
||||||
|
<Text fontWeight="medium" color="gray.500">
|
||||||
|
Duration:
|
||||||
|
</Text>
|
||||||
|
<Text>{formattedDuration}</Text>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
{/* Search Results Section - only show when searching */}
|
||||||
|
{mainSnippet && (
|
||||||
|
<>
|
||||||
|
{/* Main Snippet */}
|
||||||
|
<Box
|
||||||
|
mt={3}
|
||||||
|
p={2}
|
||||||
|
bg="gray.50"
|
||||||
|
borderLeft="2px solid"
|
||||||
|
borderLeftColor="blue.400"
|
||||||
|
borderRadius="sm"
|
||||||
|
fontSize="xs"
|
||||||
|
>
|
||||||
|
<Text color="gray.700">
|
||||||
|
{highlightText(mainSnippet, query)}
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{hasAdditionalSnippets && (
|
||||||
|
<>
|
||||||
|
<Flex
|
||||||
|
mt={2}
|
||||||
|
p={2}
|
||||||
|
bg="blue.50"
|
||||||
|
borderRadius="sm"
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={handleExpandClick}
|
||||||
|
_hover={{ bg: "blue.100" }}
|
||||||
|
align="center"
|
||||||
|
justify="space-between"
|
||||||
|
>
|
||||||
|
<HStack gap={2}>
|
||||||
|
<Badge
|
||||||
|
bg="blue.500"
|
||||||
|
color="white"
|
||||||
|
fontSize="xs"
|
||||||
|
px={2}
|
||||||
|
borderRadius="full"
|
||||||
|
>
|
||||||
|
{remainingMatches > 0
|
||||||
|
? `${additionalSnippets.length + remainingMatches}+`
|
||||||
|
: additionalSnippets.length}
|
||||||
|
</Badge>
|
||||||
|
<Text fontSize="xs" color="blue.600" fontWeight="medium">
|
||||||
|
more{" "}
|
||||||
|
{additionalSnippets.length + remainingMatches === 1
|
||||||
|
? "match"
|
||||||
|
: "matches"}
|
||||||
|
{remainingMatches > 0 &&
|
||||||
|
` (${additionalSnippets.length} shown)`}
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
<Text fontSize="xs" color="blue.600">
|
||||||
|
{isExpanded ? "▲" : "▼"}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
{/* Additional Snippets */}
|
||||||
|
{isExpanded && (
|
||||||
|
<VStack align="stretch" gap={2} mt={2}>
|
||||||
|
{additionalSnippets.map((snippet, index) => (
|
||||||
|
<Box
|
||||||
|
key={index}
|
||||||
|
p={2}
|
||||||
|
bg="gray.50"
|
||||||
|
borderLeft="2px solid"
|
||||||
|
borderLeftColor="gray.300"
|
||||||
|
borderRadius="sm"
|
||||||
|
fontSize="xs"
|
||||||
|
>
|
||||||
|
<Text color="gray.700">
|
||||||
|
{highlightText(snippet, query)}
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
))}
|
||||||
|
</VStack>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
<TranscriptActionsMenu
|
||||||
|
transcriptId={result.id}
|
||||||
|
onDelete={onDelete}
|
||||||
|
onReprocess={onReprocess}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function TranscriptCards({
|
export default function TranscriptCards({
|
||||||
transcripts,
|
results,
|
||||||
|
query,
|
||||||
|
isLoading,
|
||||||
onDelete,
|
onDelete,
|
||||||
onReprocess,
|
onReprocess,
|
||||||
loading,
|
|
||||||
}: TranscriptCardsProps) {
|
}: TranscriptCardsProps) {
|
||||||
return (
|
return (
|
||||||
<Box display={{ base: "block", lg: "none" }} position="relative">
|
<Box position="relative">
|
||||||
{loading && (
|
{isLoading && (
|
||||||
<Flex
|
<Flex
|
||||||
position="absolute"
|
position="absolute"
|
||||||
top={0}
|
top={0}
|
||||||
@@ -37,48 +300,19 @@ export default function TranscriptCards({
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
<Box
|
<Box
|
||||||
opacity={loading ? 0.9 : 1}
|
opacity={isLoading ? 0.9 : 1}
|
||||||
pointerEvents={loading ? "none" : "auto"}
|
pointerEvents={isLoading ? "none" : "auto"}
|
||||||
transition="opacity 0.2s ease-in-out"
|
transition="opacity 0.2s ease-in-out"
|
||||||
>
|
>
|
||||||
<Stack gap={2}>
|
<Stack gap={3}>
|
||||||
{transcripts.map((item) => (
|
{results.map((result) => (
|
||||||
<Box
|
<TranscriptCard
|
||||||
key={item.id}
|
key={result.id}
|
||||||
borderWidth={1}
|
result={result}
|
||||||
p={4}
|
query={query}
|
||||||
borderRadius="md"
|
|
||||||
fontSize="sm"
|
|
||||||
>
|
|
||||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
|
||||||
<Box>
|
|
||||||
<TranscriptStatusIcon status={item.status} />
|
|
||||||
</Box>
|
|
||||||
<Box flex="1">
|
|
||||||
<Link
|
|
||||||
as={NextLink}
|
|
||||||
href={`/transcripts/${item.id}`}
|
|
||||||
fontWeight="600"
|
|
||||||
display="block"
|
|
||||||
>
|
|
||||||
{item.title || "Unnamed Transcript"}
|
|
||||||
</Link>
|
|
||||||
<Text>
|
|
||||||
Source:{" "}
|
|
||||||
{item.source_kind === "room"
|
|
||||||
? item.room_name
|
|
||||||
: item.source_kind}
|
|
||||||
</Text>
|
|
||||||
<Text>Date: {formatLocalDate(item.created_at)}</Text>
|
|
||||||
<Text>Duration: {formatTimeMs(item.duration)}</Text>
|
|
||||||
</Box>
|
|
||||||
<TranscriptActionsMenu
|
|
||||||
transcriptId={item.id}
|
|
||||||
onDelete={onDelete}
|
onDelete={onDelete}
|
||||||
onReprocess={onReprocess}
|
onReprocess={onReprocess}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
|
||||||
</Box>
|
|
||||||
))}
|
))}
|
||||||
</Stack>
|
</Stack>
|
||||||
</Box>
|
</Box>
|
||||||
|
|||||||
@@ -1,99 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
import { Box, Table, Link, Flex, Spinner } from "@chakra-ui/react";
|
|
||||||
import NextLink from "next/link";
|
|
||||||
import { GetTranscriptMinimal } from "../../../api";
|
|
||||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
|
||||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
|
||||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
|
||||||
|
|
||||||
interface TranscriptTableProps {
|
|
||||||
transcripts: GetTranscriptMinimal[];
|
|
||||||
onDelete: (transcriptId: string) => (e: any) => void;
|
|
||||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
|
||||||
loading?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function TranscriptTable({
|
|
||||||
transcripts,
|
|
||||||
onDelete,
|
|
||||||
onReprocess,
|
|
||||||
loading,
|
|
||||||
}: TranscriptTableProps) {
|
|
||||||
return (
|
|
||||||
<Box display={{ base: "none", lg: "block" }} position="relative">
|
|
||||||
{loading && (
|
|
||||||
<Flex
|
|
||||||
position="absolute"
|
|
||||||
top={0}
|
|
||||||
left={0}
|
|
||||||
right={0}
|
|
||||||
bottom={0}
|
|
||||||
align="center"
|
|
||||||
justify="center"
|
|
||||||
>
|
|
||||||
<Spinner size="xl" color="gray.700" />
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
<Box
|
|
||||||
opacity={loading ? 0.9 : 1}
|
|
||||||
pointerEvents={loading ? "none" : "auto"}
|
|
||||||
transition="opacity 0.2s ease-in-out"
|
|
||||||
>
|
|
||||||
<Table.Root>
|
|
||||||
<Table.Header>
|
|
||||||
<Table.Row>
|
|
||||||
<Table.ColumnHeader
|
|
||||||
width="16px"
|
|
||||||
fontWeight="600"
|
|
||||||
></Table.ColumnHeader>
|
|
||||||
<Table.ColumnHeader width="400px" fontWeight="600">
|
|
||||||
Transcription Title
|
|
||||||
</Table.ColumnHeader>
|
|
||||||
<Table.ColumnHeader width="150px" fontWeight="600">
|
|
||||||
Source
|
|
||||||
</Table.ColumnHeader>
|
|
||||||
<Table.ColumnHeader width="200px" fontWeight="600">
|
|
||||||
Date
|
|
||||||
</Table.ColumnHeader>
|
|
||||||
<Table.ColumnHeader width="100px" fontWeight="600">
|
|
||||||
Duration
|
|
||||||
</Table.ColumnHeader>
|
|
||||||
<Table.ColumnHeader
|
|
||||||
width="50px"
|
|
||||||
fontWeight="600"
|
|
||||||
></Table.ColumnHeader>
|
|
||||||
</Table.Row>
|
|
||||||
</Table.Header>
|
|
||||||
<Table.Body>
|
|
||||||
{transcripts.map((item) => (
|
|
||||||
<Table.Row key={item.id}>
|
|
||||||
<Table.Cell>
|
|
||||||
<TranscriptStatusIcon status={item.status} />
|
|
||||||
</Table.Cell>
|
|
||||||
<Table.Cell>
|
|
||||||
<Link as={NextLink} href={`/transcripts/${item.id}`}>
|
|
||||||
{item.title || "Unnamed Transcript"}
|
|
||||||
</Link>
|
|
||||||
</Table.Cell>
|
|
||||||
<Table.Cell>
|
|
||||||
{item.source_kind === "room"
|
|
||||||
? item.room_name
|
|
||||||
: item.source_kind}
|
|
||||||
</Table.Cell>
|
|
||||||
<Table.Cell>{formatLocalDate(item.created_at)}</Table.Cell>
|
|
||||||
<Table.Cell>{formatTimeMs(item.duration)}</Table.Cell>
|
|
||||||
<Table.Cell>
|
|
||||||
<TranscriptActionsMenu
|
|
||||||
transcriptId={item.id}
|
|
||||||
onDelete={onDelete}
|
|
||||||
onReprocess={onReprocess}
|
|
||||||
/>
|
|
||||||
</Table.Cell>
|
|
||||||
</Table.Row>
|
|
||||||
))}
|
|
||||||
</Table.Body>
|
|
||||||
</Table.Root>
|
|
||||||
</Box>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,33 +1,264 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import React, { useState, useEffect } from "react";
|
import React, { useState, useEffect } from "react";
|
||||||
import { Flex, Spinner, Heading, Text, Link } from "@chakra-ui/react";
|
import {
|
||||||
import useTranscriptList from "../transcripts/useTranscriptList";
|
Flex,
|
||||||
|
Spinner,
|
||||||
|
Heading,
|
||||||
|
Text,
|
||||||
|
Link,
|
||||||
|
Box,
|
||||||
|
Stack,
|
||||||
|
Input,
|
||||||
|
Button,
|
||||||
|
IconButton,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
useQueryState,
|
||||||
|
parseAsString,
|
||||||
|
parseAsInteger,
|
||||||
|
parseAsStringLiteral,
|
||||||
|
} from "nuqs";
|
||||||
|
import { LuX } from "react-icons/lu";
|
||||||
|
import { useSearchTranscripts } from "../transcripts/useSearchTranscripts";
|
||||||
import useSessionUser from "../../lib/useSessionUser";
|
import useSessionUser from "../../lib/useSessionUser";
|
||||||
import { Room } from "../../api";
|
import { Room, SourceKind, SearchResult, $SourceKind } from "../../api";
|
||||||
import Pagination from "./_components/Pagination";
|
|
||||||
import useApi from "../../lib/useApi";
|
import useApi from "../../lib/useApi";
|
||||||
import { useError } from "../../(errors)/errorContext";
|
import { useError } from "../../(errors)/errorContext";
|
||||||
import { SourceKind } from "../../api";
|
|
||||||
import FilterSidebar from "./_components/FilterSidebar";
|
import FilterSidebar from "./_components/FilterSidebar";
|
||||||
import SearchBar from "./_components/SearchBar";
|
import Pagination, {
|
||||||
import TranscriptTable from "./_components/TranscriptTable";
|
FIRST_PAGE,
|
||||||
|
PaginationPage,
|
||||||
|
parsePaginationPage,
|
||||||
|
totalPages as getTotalPages,
|
||||||
|
} from "./_components/Pagination";
|
||||||
import TranscriptCards from "./_components/TranscriptCards";
|
import TranscriptCards from "./_components/TranscriptCards";
|
||||||
import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
|
import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
|
||||||
import { formatLocalDate } from "../../lib/time";
|
import { formatLocalDate } from "../../lib/time";
|
||||||
|
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||||
|
|
||||||
|
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
|
||||||
|
|
||||||
|
const usePrefetchRooms = (setRooms: (rooms: Room[]) => void): void => {
|
||||||
|
const { setError } = useError();
|
||||||
|
const api = useApi();
|
||||||
|
useEffect(() => {
|
||||||
|
if (!api) return;
|
||||||
|
api
|
||||||
|
.v1RoomsList({ page: 1 })
|
||||||
|
.then((rooms) => setRooms(rooms.items))
|
||||||
|
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
||||||
|
}, [api, setError]);
|
||||||
|
};
|
||||||
|
|
||||||
|
const SearchForm: React.FC<{
|
||||||
|
setPage: (page: PaginationPage) => void;
|
||||||
|
sourceKind: SourceKind | null;
|
||||||
|
roomId: string | null;
|
||||||
|
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||||
|
setRoomId: (roomId: string | null) => void;
|
||||||
|
rooms: Room[];
|
||||||
|
searchQuery: string | null;
|
||||||
|
setSearchQuery: (query: string | null) => void;
|
||||||
|
}> = ({
|
||||||
|
setPage,
|
||||||
|
sourceKind,
|
||||||
|
roomId,
|
||||||
|
setRoomId,
|
||||||
|
setSourceKind,
|
||||||
|
rooms,
|
||||||
|
searchQuery,
|
||||||
|
setSearchQuery,
|
||||||
|
}) => {
|
||||||
|
// to keep the search input controllable + more fine grained control (urlSearchQuery is updated on submits)
|
||||||
|
const [searchInputValue, setSearchInputValue] = useState(searchQuery || "");
|
||||||
|
const handleSearchQuerySubmit = async (d: FormData) => {
|
||||||
|
await setSearchQuery((d.get(SEARCH_FORM_QUERY_INPUT_NAME) as string) || "");
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleClearSearch = () => {
|
||||||
|
setSearchInputValue("");
|
||||||
|
setSearchQuery(null);
|
||||||
|
setPage(FIRST_PAGE);
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<Stack gap={2}>
|
||||||
|
<form action={handleSearchQuerySubmit}>
|
||||||
|
<Flex alignItems="center">
|
||||||
|
<Box position="relative" flex="1">
|
||||||
|
<Input
|
||||||
|
placeholder="Search transcriptions..."
|
||||||
|
value={searchInputValue}
|
||||||
|
onChange={(e) => setSearchInputValue(e.target.value)}
|
||||||
|
name={SEARCH_FORM_QUERY_INPUT_NAME}
|
||||||
|
pr={searchQuery ? "2.5rem" : undefined}
|
||||||
|
/>
|
||||||
|
{searchQuery && (
|
||||||
|
<IconButton
|
||||||
|
aria-label="Clear search"
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
onClick={handleClearSearch}
|
||||||
|
position="absolute"
|
||||||
|
right="0.25rem"
|
||||||
|
top="50%"
|
||||||
|
transform="translateY(-50%)"
|
||||||
|
_hover={{ bg: "gray.100" }}
|
||||||
|
>
|
||||||
|
<LuX />
|
||||||
|
</IconButton>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
<Button ml={2} type="submit">
|
||||||
|
Search
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
<UnderSearchFormFilterIndicators
|
||||||
|
sourceKind={sourceKind}
|
||||||
|
roomId={roomId}
|
||||||
|
setSourceKind={setSourceKind}
|
||||||
|
setRoomId={setRoomId}
|
||||||
|
rooms={rooms}
|
||||||
|
/>
|
||||||
|
</Stack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const UnderSearchFormFilterIndicators: React.FC<{
|
||||||
|
sourceKind: SourceKind | null;
|
||||||
|
roomId: string | null;
|
||||||
|
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||||
|
setRoomId: (roomId: string | null) => void;
|
||||||
|
rooms: Room[];
|
||||||
|
}> = ({ sourceKind, roomId, setRoomId, setSourceKind, rooms }) => {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{(sourceKind || roomId) && (
|
||||||
|
<Flex gap={2} flexWrap="wrap" align="center">
|
||||||
|
<Text fontSize="sm" color="gray.600">
|
||||||
|
Active filters:
|
||||||
|
</Text>
|
||||||
|
{sourceKind && (
|
||||||
|
<Flex
|
||||||
|
align="center"
|
||||||
|
px={2}
|
||||||
|
py={1}
|
||||||
|
bg="blue.100"
|
||||||
|
borderRadius="md"
|
||||||
|
fontSize="xs"
|
||||||
|
gap={1}
|
||||||
|
>
|
||||||
|
<Text>
|
||||||
|
{roomId
|
||||||
|
? `Room: ${
|
||||||
|
rooms.find((r) => r.id === roomId)?.name || roomId
|
||||||
|
}`
|
||||||
|
: `Source: ${sourceKind}`}
|
||||||
|
</Text>
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
minW="auto"
|
||||||
|
h="auto"
|
||||||
|
p="1px"
|
||||||
|
onClick={() => {
|
||||||
|
setSourceKind(null);
|
||||||
|
// TODO questionable
|
||||||
|
setRoomId(null);
|
||||||
|
}}
|
||||||
|
_hover={{ bg: "blue.200" }}
|
||||||
|
aria-label="Clear filter"
|
||||||
|
>
|
||||||
|
<LuX size={14} />
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const EmptyResult: React.FC<{
|
||||||
|
searchQuery: string;
|
||||||
|
}> = ({ searchQuery }) => {
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" alignItems="center" justifyContent="center" py={8}>
|
||||||
|
<Text textAlign="center">
|
||||||
|
{searchQuery
|
||||||
|
? `No results found for "${searchQuery}". Try adjusting your search terms.`
|
||||||
|
: "No transcripts found, but you can "}
|
||||||
|
{!searchQuery && (
|
||||||
|
<>
|
||||||
|
<Link href={RECORD_A_MEETING_URL} color="blue.500">
|
||||||
|
record a meeting
|
||||||
|
</Link>
|
||||||
|
{" to get started."}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export default function TranscriptBrowser() {
|
export default function TranscriptBrowser() {
|
||||||
const [selectedSourceKind, setSelectedSourceKind] =
|
const [urlSearchQuery, setUrlSearchQuery] = useQueryState(
|
||||||
useState<SourceKind | null>(null);
|
"q",
|
||||||
const [selectedRoomId, setSelectedRoomId] = useState("");
|
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||||
const [rooms, setRooms] = useState<Room[]>([]);
|
|
||||||
const [page, setPage] = useState(1);
|
|
||||||
const [searchTerm, setSearchTerm] = useState("");
|
|
||||||
const { loading, response, refetch } = useTranscriptList(
|
|
||||||
page,
|
|
||||||
selectedSourceKind,
|
|
||||||
selectedRoomId,
|
|
||||||
searchTerm,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [urlSourceKind, setUrlSourceKind] = useQueryState(
|
||||||
|
"source",
|
||||||
|
parseAsStringLiteral($SourceKind.enum).withOptions({
|
||||||
|
shallow: false,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
const [urlRoomId, setUrlRoomId] = useQueryState(
|
||||||
|
"room",
|
||||||
|
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||||
|
);
|
||||||
|
|
||||||
|
const [urlPage, setPage] = useQueryState(
|
||||||
|
"page",
|
||||||
|
parseAsInteger.withDefault(1).withOptions({ shallow: false }),
|
||||||
|
);
|
||||||
|
|
||||||
|
const [page, _setSafePage] = useState(FIRST_PAGE);
|
||||||
|
|
||||||
|
// safety net
|
||||||
|
useEffect(() => {
|
||||||
|
const maybePage = parsePaginationPage(urlPage);
|
||||||
|
if ("error" in maybePage) {
|
||||||
|
setPage(FIRST_PAGE).then(() => {
|
||||||
|
/*may be called n times we dont care*/
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_setSafePage(maybePage.value);
|
||||||
|
}, [urlPage]);
|
||||||
|
|
||||||
|
const [rooms, setRooms] = useState<Room[]>([]);
|
||||||
|
|
||||||
|
const pageSize = 20;
|
||||||
|
const {
|
||||||
|
results,
|
||||||
|
totalCount: totalResults,
|
||||||
|
isLoading,
|
||||||
|
reload,
|
||||||
|
} = useSearchTranscripts(
|
||||||
|
urlSearchQuery,
|
||||||
|
{
|
||||||
|
roomIds: urlRoomId ? [urlRoomId] : null,
|
||||||
|
sourceKind: urlSourceKind,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
pageSize,
|
||||||
|
page,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const totalPages = getTotalPages(totalResults, pageSize);
|
||||||
|
|
||||||
const userName = useSessionUser().name;
|
const userName = useSessionUser().name;
|
||||||
const [deletionLoading, setDeletionLoading] = useState(false);
|
const [deletionLoading, setDeletionLoading] = useState(false);
|
||||||
const api = useApi();
|
const api = useApi();
|
||||||
@@ -35,66 +266,18 @@ export default function TranscriptBrowser() {
|
|||||||
const cancelRef = React.useRef(null);
|
const cancelRef = React.useRef(null);
|
||||||
const [transcriptToDeleteId, setTranscriptToDeleteId] =
|
const [transcriptToDeleteId, setTranscriptToDeleteId] =
|
||||||
React.useState<string>();
|
React.useState<string>();
|
||||||
const [deletedItemIds, setDeletedItemIds] = React.useState<string[]>();
|
|
||||||
|
|
||||||
useEffect(() => {
|
usePrefetchRooms(setRooms);
|
||||||
setDeletedItemIds([]);
|
|
||||||
}, [page, response]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!api) return;
|
|
||||||
api
|
|
||||||
.v1RoomsList({ page: 1 })
|
|
||||||
.then((rooms) => setRooms(rooms.items))
|
|
||||||
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
|
||||||
}, [api]);
|
|
||||||
|
|
||||||
const handleFilterTranscripts = (
|
const handleFilterTranscripts = (
|
||||||
sourceKind: SourceKind | null,
|
sourceKind: SourceKind | null,
|
||||||
roomId: string,
|
roomId: string,
|
||||||
) => {
|
) => {
|
||||||
setSelectedSourceKind(sourceKind);
|
setUrlSourceKind(sourceKind);
|
||||||
setSelectedRoomId(roomId);
|
setUrlRoomId(roomId);
|
||||||
setPage(1);
|
setPage(1);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSearch = (searchTerm: string) => {
|
|
||||||
setPage(1);
|
|
||||||
setSearchTerm(searchTerm);
|
|
||||||
setSelectedSourceKind(null);
|
|
||||||
setSelectedRoomId("");
|
|
||||||
};
|
|
||||||
|
|
||||||
if (loading && !response)
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
flexDir="column"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
h="100%"
|
|
||||||
>
|
|
||||||
<Spinner size="xl" />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!loading && !response)
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
flexDir="column"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
h="100%"
|
|
||||||
>
|
|
||||||
<Text>
|
|
||||||
No transcripts found, but you can
|
|
||||||
<Link href="/transcripts/new" className="underline">
|
|
||||||
record a meeting
|
|
||||||
</Link>
|
|
||||||
to get started.
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
|
|
||||||
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
||||||
|
|
||||||
const confirmDeleteTranscript = (transcriptId: string) => {
|
const confirmDeleteTranscript = (transcriptId: string) => {
|
||||||
@@ -103,12 +286,9 @@ export default function TranscriptBrowser() {
|
|||||||
api
|
api
|
||||||
.v1TranscriptDelete({ transcriptId })
|
.v1TranscriptDelete({ transcriptId })
|
||||||
.then(() => {
|
.then(() => {
|
||||||
refetch();
|
|
||||||
setDeletionLoading(false);
|
setDeletionLoading(false);
|
||||||
onCloseDeletion();
|
onCloseDeletion();
|
||||||
setDeletedItemIds((prev) =>
|
reload();
|
||||||
prev ? [...prev, transcriptId] : [transcriptId],
|
|
||||||
);
|
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
setDeletionLoading(false);
|
setDeletionLoading(false);
|
||||||
@@ -116,17 +296,18 @@ export default function TranscriptBrowser() {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDeleteTranscript = (transcriptId: string) => (e: any) => {
|
const handleProcessTranscript = (transcriptId: string) => {
|
||||||
e?.stopPropagation?.();
|
if (!api) {
|
||||||
setTranscriptToDeleteId(transcriptId);
|
console.error("API not available on handleProcessTranscript");
|
||||||
};
|
return;
|
||||||
|
}
|
||||||
const handleProcessTranscript = (transcriptId) => (e) => {
|
|
||||||
if (api) {
|
|
||||||
api
|
api
|
||||||
.v1TranscriptProcess({ transcriptId })
|
.v1TranscriptProcess({ transcriptId })
|
||||||
.then((result) => {
|
.then((result) => {
|
||||||
const status = (result as any).status;
|
const status =
|
||||||
|
result && typeof result === "object" && "status" in result
|
||||||
|
? (result as { status: string }).status
|
||||||
|
: undefined;
|
||||||
if (status === "already running") {
|
if (status === "already running") {
|
||||||
setError(
|
setError(
|
||||||
new Error("Processing is already running, please wait"),
|
new Error("Processing is already running, please wait"),
|
||||||
@@ -137,21 +318,32 @@ export default function TranscriptBrowser() {
|
|||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
setError(err, "There was an error processing the transcript");
|
setError(err, "There was an error processing the transcript");
|
||||||
});
|
});
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const transcriptToDelete = response?.items?.find(
|
const transcriptToDelete = results?.find(
|
||||||
(i) => i.id === transcriptToDeleteId,
|
(i) => i.id === transcriptToDeleteId,
|
||||||
);
|
);
|
||||||
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
||||||
const dialogDate = transcriptToDelete?.created_at
|
const dialogDate = transcriptToDelete?.created_at
|
||||||
? formatLocalDate(transcriptToDelete.created_at)
|
? formatLocalDate(transcriptToDelete.created_at)
|
||||||
: undefined;
|
: undefined;
|
||||||
const dialogSource = transcriptToDelete
|
const dialogSource =
|
||||||
? transcriptToDelete.source_kind === "room"
|
transcriptToDelete?.source_kind === "room" && transcriptToDelete?.room_id
|
||||||
? transcriptToDelete.room_name || undefined
|
? transcriptToDelete.room_name || transcriptToDelete.room_id
|
||||||
: transcriptToDelete.source_kind
|
: transcriptToDelete?.source_kind;
|
||||||
: undefined;
|
|
||||||
|
if (isLoading && results.length === 0) {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
flexDir="column"
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
h="100%"
|
||||||
|
>
|
||||||
|
<Spinner size="xl" />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@@ -168,15 +360,15 @@ export default function TranscriptBrowser() {
|
|||||||
>
|
>
|
||||||
<Heading size="lg">
|
<Heading size="lg">
|
||||||
{userName ? `${userName}'s Transcriptions` : "Your Transcriptions"}{" "}
|
{userName ? `${userName}'s Transcriptions` : "Your Transcriptions"}{" "}
|
||||||
{loading || (deletionLoading && <Spinner size="sm" />)}
|
{(isLoading || deletionLoading) && <Spinner size="sm" />}
|
||||||
</Heading>
|
</Heading>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex flexDir={{ base: "column", md: "row" }}>
|
<Flex flexDir={{ base: "column", md: "row" }}>
|
||||||
<FilterSidebar
|
<FilterSidebar
|
||||||
rooms={rooms}
|
rooms={rooms}
|
||||||
selectedSourceKind={selectedSourceKind}
|
selectedSourceKind={urlSourceKind}
|
||||||
selectedRoomId={selectedRoomId}
|
selectedRoomId={urlRoomId}
|
||||||
onFilterChange={handleFilterTranscripts}
|
onFilterChange={handleFilterTranscripts}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@@ -188,25 +380,37 @@ export default function TranscriptBrowser() {
|
|||||||
gap={4}
|
gap={4}
|
||||||
px={{ base: 0, md: 4 }}
|
px={{ base: 0, md: 4 }}
|
||||||
>
|
>
|
||||||
<SearchBar onSearch={handleSearch} />
|
<SearchForm
|
||||||
|
setPage={setPage}
|
||||||
|
sourceKind={urlSourceKind}
|
||||||
|
roomId={urlRoomId}
|
||||||
|
searchQuery={urlSearchQuery}
|
||||||
|
setSearchQuery={setUrlSearchQuery}
|
||||||
|
setSourceKind={setUrlSourceKind}
|
||||||
|
setRoomId={setUrlRoomId}
|
||||||
|
rooms={rooms}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{totalPages > 1 ? (
|
||||||
<Pagination
|
<Pagination
|
||||||
page={page}
|
page={page}
|
||||||
setPage={setPage}
|
setPage={setPage}
|
||||||
total={response?.total || 0}
|
total={totalResults}
|
||||||
size={response?.size || 0}
|
size={pageSize}
|
||||||
/>
|
|
||||||
<TranscriptTable
|
|
||||||
transcripts={response?.items || []}
|
|
||||||
onDelete={handleDeleteTranscript}
|
|
||||||
onReprocess={handleProcessTranscript}
|
|
||||||
loading={loading}
|
|
||||||
/>
|
/>
|
||||||
|
) : null}
|
||||||
|
|
||||||
<TranscriptCards
|
<TranscriptCards
|
||||||
transcripts={response?.items || []}
|
results={results}
|
||||||
onDelete={handleDeleteTranscript}
|
query={urlSearchQuery}
|
||||||
|
isLoading={isLoading}
|
||||||
|
onDelete={setTranscriptToDeleteId}
|
||||||
onReprocess={handleProcessTranscript}
|
onReprocess={handleProcessTranscript}
|
||||||
loading={loading}
|
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
{!isLoading && results.length === 0 && (
|
||||||
|
<EmptyResult searchQuery={urlSearchQuery} />
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import Image from "next/image";
|
|||||||
import About from "../(aboutAndPrivacy)/about";
|
import About from "../(aboutAndPrivacy)/about";
|
||||||
import Privacy from "../(aboutAndPrivacy)/privacy";
|
import Privacy from "../(aboutAndPrivacy)/privacy";
|
||||||
import UserInfo from "../(auth)/userInfo";
|
import UserInfo from "../(auth)/userInfo";
|
||||||
|
import { RECORD_A_MEETING_URL } from "../api/urls";
|
||||||
|
|
||||||
export default async function AppLayout({
|
export default async function AppLayout({
|
||||||
children,
|
children,
|
||||||
@@ -53,7 +54,7 @@ export default async function AppLayout({
|
|||||||
{/* Text link on the right */}
|
{/* Text link on the right */}
|
||||||
<Link
|
<Link
|
||||||
as={NextLink}
|
as={NextLink}
|
||||||
href="/transcripts/new"
|
href={RECORD_A_MEETING_URL}
|
||||||
className="font-light px-2"
|
className="font-light px-2"
|
||||||
>
|
>
|
||||||
Create
|
Create
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import useApi from "../../lib/useApi";
|
|||||||
import useRoomList from "./useRoomList";
|
import useRoomList from "./useRoomList";
|
||||||
import { ApiError, Room } from "../../api";
|
import { ApiError, Room } from "../../api";
|
||||||
import { RoomList } from "./_components/RoomList";
|
import { RoomList } from "./_components/RoomList";
|
||||||
|
import { PaginationPage } from "../browse/_components/Pagination";
|
||||||
|
|
||||||
interface SelectOption {
|
interface SelectOption {
|
||||||
label: string;
|
label: string;
|
||||||
@@ -75,8 +76,9 @@ export default function RoomsList() {
|
|||||||
const [isEditing, setIsEditing] = useState(false);
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
const [editRoomId, setEditRoomId] = useState("");
|
const [editRoomId, setEditRoomId] = useState("");
|
||||||
const api = useApi();
|
const api = useApi();
|
||||||
|
// TODO seems to be no setPage calls
|
||||||
const [page, setPage] = useState<number>(1);
|
const [page, setPage] = useState<number>(1);
|
||||||
const { loading, response, refetch } = useRoomList(page);
|
const { loading, response, refetch } = useRoomList(PaginationPage(page));
|
||||||
const [streams, setStreams] = useState<Stream[]>([]);
|
const [streams, setStreams] = useState<Stream[]>([]);
|
||||||
const [topics, setTopics] = useState<Topic[]>([]);
|
const [topics, setTopics] = useState<Topic[]>([]);
|
||||||
const [nameError, setNameError] = useState("");
|
const [nameError, setNameError] = useState("");
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { useEffect, useState } from "react";
|
|||||||
import { useError } from "../../(errors)/errorContext";
|
import { useError } from "../../(errors)/errorContext";
|
||||||
import useApi from "../../lib/useApi";
|
import useApi from "../../lib/useApi";
|
||||||
import { Page_Room_ } from "../../api";
|
import { Page_Room_ } from "../../api";
|
||||||
|
import { PaginationPage } from "../browse/_components/Pagination";
|
||||||
|
|
||||||
type RoomList = {
|
type RoomList = {
|
||||||
response: Page_Room_ | null;
|
response: Page_Room_ | null;
|
||||||
@@ -11,7 +12,7 @@ type RoomList = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
//always protected
|
//always protected
|
||||||
const useRoomList = (page: number): RoomList => {
|
const useRoomList = (page: PaginationPage): RoomList => {
|
||||||
const [response, setResponse] = useState<Page_Room_ | null>(null);
|
const [response, setResponse] = useState<Page_Room_ | null>(null);
|
||||||
const [loading, setLoading] = useState<boolean>(true);
|
const [loading, setLoading] = useState<boolean>(true);
|
||||||
const [error, setErrorState] = useState<Error | null>(null);
|
const [error, setErrorState] = useState<Error | null>(null);
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
[
|
|
||||||
{
|
|
||||||
"id": "27c07e49-d7a3-4b86-905c-f1a047366f91",
|
|
||||||
"title": "Issue one",
|
|
||||||
"summary": "The team discusses the first issue in the list",
|
|
||||||
"timestamp": 0.0,
|
|
||||||
"transcript": "",
|
|
||||||
"duration": 33,
|
|
||||||
"segments": [
|
|
||||||
{
|
|
||||||
"text": "Let's start with issue one, Alice you've been working on that, can you give an update ?",
|
|
||||||
"start": 0.0,
|
|
||||||
"speaker": 0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Yes, I've run into an issue with the task system but Bob helped me out and I have a POC ready, should I present it now ?",
|
|
||||||
"start": 0.38,
|
|
||||||
"speaker": 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Yeah, I had to modify the task system because it didn't account for incoming blobs",
|
|
||||||
"start": 4.5,
|
|
||||||
"speaker": 2
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Cool, yeah lets see it",
|
|
||||||
"start": 5.96,
|
|
||||||
"speaker": 0
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
@@ -11,6 +11,7 @@ import useWebRTC from "./useWebRTC";
|
|||||||
import useAudioDevice from "./useAudioDevice";
|
import useAudioDevice from "./useAudioDevice";
|
||||||
import { Box, Flex, IconButton, Menu, RadioGroup } from "@chakra-ui/react";
|
import { Box, Flex, IconButton, Menu, RadioGroup } from "@chakra-ui/react";
|
||||||
import { LuScreenShare, LuMic, LuPlay, LuCircleStop } from "react-icons/lu";
|
import { LuScreenShare, LuMic, LuPlay, LuCircleStop } from "react-icons/lu";
|
||||||
|
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||||
|
|
||||||
type RecorderProps = {
|
type RecorderProps = {
|
||||||
transcriptId: string;
|
transcriptId: string;
|
||||||
@@ -46,7 +47,7 @@ export default function Recorder(props: RecorderProps) {
|
|||||||
location.href = "";
|
location.href = "";
|
||||||
break;
|
break;
|
||||||
case ",":
|
case ",":
|
||||||
location.href = "/transcripts/new";
|
location.href = RECORD_A_MEETING_URL;
|
||||||
break;
|
break;
|
||||||
case "!":
|
case "!":
|
||||||
if (record.isRecording()) return;
|
if (record.isRecording()) return;
|
||||||
|
|||||||
123
www/app/(app)/transcripts/useSearchTranscripts.ts
Normal file
123
www/app/(app)/transcripts/useSearchTranscripts.ts
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
// this hook is not great, we want to substitute it with a proper state management solution that is also not re-invention
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
import { SearchResult, SourceKind } from "../../api";
|
||||||
|
import useApi from "../../lib/useApi";
|
||||||
|
import {
|
||||||
|
PaginationPage,
|
||||||
|
paginationPageTo0Based,
|
||||||
|
} from "../browse/_components/Pagination";
|
||||||
|
|
||||||
|
interface SearchFilters {
|
||||||
|
roomIds: readonly string[] | null;
|
||||||
|
sourceKind: SourceKind | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const EMPTY_SEARCH_FILTERS: SearchFilters = {
|
||||||
|
roomIds: null,
|
||||||
|
sourceKind: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
type UseSearchTranscriptsOptions = {
|
||||||
|
pageSize: number;
|
||||||
|
page: PaginationPage;
|
||||||
|
};
|
||||||
|
|
||||||
|
interface UseSearchTranscriptsReturn {
|
||||||
|
results: SearchResult[];
|
||||||
|
totalCount: number;
|
||||||
|
isLoading: boolean;
|
||||||
|
error: unknown;
|
||||||
|
reload: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
function hashEffectFilters(filters: SearchFilters): string {
|
||||||
|
return JSON.stringify(filters);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useSearchTranscripts(
|
||||||
|
query: string = "",
|
||||||
|
filters: SearchFilters = EMPTY_SEARCH_FILTERS,
|
||||||
|
options: UseSearchTranscriptsOptions = {
|
||||||
|
pageSize: 20,
|
||||||
|
page: PaginationPage(1),
|
||||||
|
},
|
||||||
|
): UseSearchTranscriptsReturn {
|
||||||
|
const { pageSize, page } = options;
|
||||||
|
|
||||||
|
const [reloadCount, setReloadCount] = useState(0);
|
||||||
|
|
||||||
|
const api = useApi();
|
||||||
|
const abortControllerRef = useRef<AbortController>();
|
||||||
|
|
||||||
|
const [data, setData] = useState<{ results: SearchResult[]; total: number }>({
|
||||||
|
results: [],
|
||||||
|
total: 0,
|
||||||
|
});
|
||||||
|
const [error, setError] = useState<any>();
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
|
||||||
|
const filterHash = hashEffectFilters(filters);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!api) {
|
||||||
|
setData({ results: [], total: 0 });
|
||||||
|
setError(undefined);
|
||||||
|
setIsLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (abortControllerRef.current) {
|
||||||
|
abortControllerRef.current.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
abortControllerRef.current = abortController;
|
||||||
|
|
||||||
|
const performSearch = async () => {
|
||||||
|
setIsLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await api.v1TranscriptsSearch({
|
||||||
|
q: query || "",
|
||||||
|
limit: pageSize,
|
||||||
|
offset: paginationPageTo0Based(page) * pageSize,
|
||||||
|
roomId: filters.roomIds?.[0],
|
||||||
|
sourceKind: filters.sourceKind || undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (abortController.signal.aborted) return;
|
||||||
|
setData(response);
|
||||||
|
setError(undefined);
|
||||||
|
} catch (err: unknown) {
|
||||||
|
if ((err as Error).name === "AbortError") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (abortController.signal.aborted) {
|
||||||
|
console.error("Aborted search but error", err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setError(err);
|
||||||
|
} finally {
|
||||||
|
if (!abortController.signal.aborted) {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
performSearch().then(() => {});
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
abortController.abort();
|
||||||
|
};
|
||||||
|
}, [api, query, page, filterHash, pageSize, reloadCount]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
results: data.results,
|
||||||
|
totalCount: data.total,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
reload: () => setReloadCount(reloadCount + 1),
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
import { useEffect, useState } from "react";
|
|
||||||
import { useError } from "../../(errors)/errorContext";
|
|
||||||
import useApi from "../../lib/useApi";
|
|
||||||
import { Page_GetTranscriptMinimal_, SourceKind } from "../../api";
|
|
||||||
|
|
||||||
type TranscriptList = {
|
|
||||||
response: Page_GetTranscriptMinimal_ | null;
|
|
||||||
loading: boolean;
|
|
||||||
error: Error | null;
|
|
||||||
refetch: () => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
const useTranscriptList = (
|
|
||||||
page: number,
|
|
||||||
sourceKind: SourceKind | null,
|
|
||||||
roomId: string | null,
|
|
||||||
searchTerm: string | null,
|
|
||||||
): TranscriptList => {
|
|
||||||
const [response, setResponse] = useState<Page_GetTranscriptMinimal_ | null>(
|
|
||||||
null,
|
|
||||||
);
|
|
||||||
const [loading, setLoading] = useState<boolean>(true);
|
|
||||||
const [error, setErrorState] = useState<Error | null>(null);
|
|
||||||
const { setError } = useError();
|
|
||||||
const api = useApi();
|
|
||||||
const [refetchCount, setRefetchCount] = useState(0);
|
|
||||||
|
|
||||||
const refetch = () => {
|
|
||||||
setLoading(true);
|
|
||||||
setRefetchCount(refetchCount + 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!api) return;
|
|
||||||
setLoading(true);
|
|
||||||
api
|
|
||||||
.v1TranscriptsList({
|
|
||||||
page,
|
|
||||||
sourceKind,
|
|
||||||
roomId,
|
|
||||||
searchTerm,
|
|
||||||
size: 10,
|
|
||||||
})
|
|
||||||
.then((response) => {
|
|
||||||
setResponse(response);
|
|
||||||
setLoading(false);
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
setResponse(null);
|
|
||||||
setLoading(false);
|
|
||||||
setError(err);
|
|
||||||
setErrorState(err);
|
|
||||||
});
|
|
||||||
}, [api, page, refetchCount, roomId, searchTerm, sourceKind]);
|
|
||||||
|
|
||||||
return { response, loading, error, refetch };
|
|
||||||
};
|
|
||||||
|
|
||||||
export default useTranscriptList;
|
|
||||||
@@ -1002,7 +1002,7 @@ export const $SearchResponse = {
|
|||||||
},
|
},
|
||||||
query: {
|
query: {
|
||||||
type: "string",
|
type: "string",
|
||||||
minLength: 1,
|
minLength: 0,
|
||||||
title: "Query",
|
title: "Query",
|
||||||
description: "Search query text",
|
description: "Search query text",
|
||||||
},
|
},
|
||||||
@@ -1065,6 +1065,20 @@ export const $SearchResult = {
|
|||||||
],
|
],
|
||||||
title: "Room Id",
|
title: "Room Id",
|
||||||
},
|
},
|
||||||
|
room_name: {
|
||||||
|
anyOf: [
|
||||||
|
{
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "null",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
title: "Room Name",
|
||||||
|
},
|
||||||
|
source_kind: {
|
||||||
|
$ref: "#/components/schemas/SourceKind",
|
||||||
|
},
|
||||||
created_at: {
|
created_at: {
|
||||||
type: "string",
|
type: "string",
|
||||||
title: "Created At",
|
title: "Created At",
|
||||||
@@ -1101,10 +1115,18 @@ export const $SearchResult = {
|
|||||||
title: "Search Snippets",
|
title: "Search Snippets",
|
||||||
description: "Text snippets around search matches",
|
description: "Text snippets around search matches",
|
||||||
},
|
},
|
||||||
|
total_match_count: {
|
||||||
|
type: "integer",
|
||||||
|
minimum: 0,
|
||||||
|
title: "Total Match Count",
|
||||||
|
description: "Total number of matches found in the transcript",
|
||||||
|
default: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
type: "object",
|
type: "object",
|
||||||
required: [
|
required: [
|
||||||
"id",
|
"id",
|
||||||
|
"source_kind",
|
||||||
"created_at",
|
"created_at",
|
||||||
"status",
|
"status",
|
||||||
"rank",
|
"rank",
|
||||||
|
|||||||
@@ -286,6 +286,7 @@ export class DefaultService {
|
|||||||
* @param data.limit Results per page
|
* @param data.limit Results per page
|
||||||
* @param data.offset Number of results to skip
|
* @param data.offset Number of results to skip
|
||||||
* @param data.roomId
|
* @param data.roomId
|
||||||
|
* @param data.sourceKind
|
||||||
* @returns SearchResponse Successful Response
|
* @returns SearchResponse Successful Response
|
||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
@@ -300,6 +301,7 @@ export class DefaultService {
|
|||||||
limit: data.limit,
|
limit: data.limit,
|
||||||
offset: data.offset,
|
offset: data.offset,
|
||||||
room_id: data.roomId,
|
room_id: data.roomId,
|
||||||
|
source_kind: data.sourceKind,
|
||||||
},
|
},
|
||||||
errors: {
|
errors: {
|
||||||
422: "Validation Error",
|
422: "Validation Error",
|
||||||
|
|||||||
@@ -209,6 +209,8 @@ export type SearchResult = {
|
|||||||
title?: string | null;
|
title?: string | null;
|
||||||
user_id?: string | null;
|
user_id?: string | null;
|
||||||
room_id?: string | null;
|
room_id?: string | null;
|
||||||
|
room_name?: string | null;
|
||||||
|
source_kind: SourceKind;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
status: string;
|
status: string;
|
||||||
rank: number;
|
rank: number;
|
||||||
@@ -220,6 +222,10 @@ export type SearchResult = {
|
|||||||
* Text snippets around search matches
|
* Text snippets around search matches
|
||||||
*/
|
*/
|
||||||
search_snippets: Array<string>;
|
search_snippets: Array<string>;
|
||||||
|
/**
|
||||||
|
* Total number of matches found in the transcript
|
||||||
|
*/
|
||||||
|
total_match_count?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type SourceKind = "room" | "live" | "file";
|
export type SourceKind = "room" | "live" | "file";
|
||||||
@@ -407,6 +413,7 @@ export type V1TranscriptsSearchData = {
|
|||||||
*/
|
*/
|
||||||
q: string;
|
q: string;
|
||||||
roomId?: string | null;
|
roomId?: string | null;
|
||||||
|
sourceKind?: SourceKind | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type V1TranscriptsSearchResponse = SearchResponse;
|
export type V1TranscriptsSearchResponse = SearchResponse;
|
||||||
|
|||||||
2
www/app/api/urls.ts
Normal file
2
www/app/api/urls.ts
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// TODO better connection with generated schema; it's duplication
|
||||||
|
export const RECORD_A_MEETING_URL = "/transcripts/new" as const;
|
||||||
62
www/app/lib/textHighlight.tsx
Normal file
62
www/app/lib/textHighlight.tsx
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
/**
|
||||||
|
* Text highlighting and text fragment generation utilities
|
||||||
|
* Used for search result highlighting and deep linking with Chrome Text Fragments
|
||||||
|
*/
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
|
||||||
|
export interface HighlightResult {
|
||||||
|
text: string;
|
||||||
|
matches: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Escapes special regex characters in a string
|
||||||
|
*/
|
||||||
|
function escapeRegex(str: string): string {
|
||||||
|
return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||||
|
}
|
||||||
|
|
||||||
|
export const highlightMatches = (
|
||||||
|
text: string,
|
||||||
|
query: string,
|
||||||
|
): { match: string; index: number }[] => {
|
||||||
|
if (!query || !text) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryWords = query.trim().split(/\s+/);
|
||||||
|
|
||||||
|
const regex = new RegExp(
|
||||||
|
`(${queryWords.map((word) => escapeRegex(word)).join("|")})`,
|
||||||
|
"gi",
|
||||||
|
);
|
||||||
|
|
||||||
|
return Array.from(text.matchAll(regex)).map((result) => ({
|
||||||
|
match: result[0],
|
||||||
|
index: result.index!,
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
export function findFirstHighlight(text: string, query: string): string | null {
|
||||||
|
const matches = highlightMatches(text, query);
|
||||||
|
if (matches.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return matches[0].match;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function generateTextFragment(
|
||||||
|
text: string,
|
||||||
|
query: string,
|
||||||
|
): {
|
||||||
|
k: ":~:text";
|
||||||
|
v: string;
|
||||||
|
} | null {
|
||||||
|
const firstMatch = findFirstHighlight(text, query);
|
||||||
|
if (!firstMatch) return null;
|
||||||
|
return {
|
||||||
|
k: ":~:text",
|
||||||
|
v: firstMatch,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -136,3 +136,10 @@ export function extractDomain(url) {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function assertExists<T>(value: T | null | undefined, err?: string): T {
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
throw new Error(`Assertion failed: ${err ?? "value is null or undefined"}`);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
|
import { RECORD_A_MEETING_URL } from "./api/urls";
|
||||||
|
|
||||||
export default function Index() {
|
export default function Index() {
|
||||||
redirect("/transcripts/new");
|
redirect(RECORD_A_MEETING_URL);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,14 +5,17 @@ import system from "./styles/theme";
|
|||||||
|
|
||||||
import { WherebyProvider } from "@whereby.com/browser-sdk/react";
|
import { WherebyProvider } from "@whereby.com/browser-sdk/react";
|
||||||
import { Toaster } from "./components/ui/toaster";
|
import { Toaster } from "./components/ui/toaster";
|
||||||
|
import { NuqsAdapter } from "nuqs/adapters/next/app";
|
||||||
|
|
||||||
export function Providers({ children }: { children: React.ReactNode }) {
|
export function Providers({ children }: { children: React.ReactNode }) {
|
||||||
return (
|
return (
|
||||||
|
<NuqsAdapter>
|
||||||
<ChakraProvider value={system}>
|
<ChakraProvider value={system}>
|
||||||
<WherebyProvider>
|
<WherebyProvider>
|
||||||
{children}
|
{children}
|
||||||
<Toaster />
|
<Toaster />
|
||||||
</WherebyProvider>
|
</WherebyProvider>
|
||||||
</ChakraProvider>
|
</ChakraProvider>
|
||||||
|
</NuqsAdapter>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@
|
|||||||
"next": "^14.2.30",
|
"next": "^14.2.30",
|
||||||
"next-auth": "^4.24.7",
|
"next-auth": "^4.24.7",
|
||||||
"next-themes": "^0.4.6",
|
"next-themes": "^0.4.6",
|
||||||
|
"nuqs": "^2.4.3",
|
||||||
"postcss": "8.4.31",
|
"postcss": "8.4.31",
|
||||||
"prop-types": "^15.8.1",
|
"prop-types": "^15.8.1",
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
|
|||||||
39
www/pnpm-lock.yaml
generated
39
www/pnpm-lock.yaml
generated
@@ -67,6 +67,9 @@ importers:
|
|||||||
next-themes:
|
next-themes:
|
||||||
specifier: ^0.4.6
|
specifier: ^0.4.6
|
||||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
nuqs:
|
||||||
|
specifier: ^2.4.3
|
||||||
|
version: 2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1)
|
||||||
postcss:
|
postcss:
|
||||||
specifier: 8.4.31
|
specifier: 8.4.31
|
||||||
version: 8.4.31
|
version: 8.4.31
|
||||||
@@ -5436,6 +5439,12 @@ packages:
|
|||||||
}
|
}
|
||||||
engines: { node: ">= 8" }
|
engines: { node: ">= 8" }
|
||||||
|
|
||||||
|
mitt@3.0.1:
|
||||||
|
resolution:
|
||||||
|
{
|
||||||
|
integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==,
|
||||||
|
}
|
||||||
|
|
||||||
mkdirp@0.5.6:
|
mkdirp@0.5.6:
|
||||||
resolution:
|
resolution:
|
||||||
{
|
{
|
||||||
@@ -5660,6 +5669,27 @@ packages:
|
|||||||
}
|
}
|
||||||
deprecated: This package is no longer supported.
|
deprecated: This package is no longer supported.
|
||||||
|
|
||||||
|
nuqs@2.4.3:
|
||||||
|
resolution:
|
||||||
|
{
|
||||||
|
integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==,
|
||||||
|
}
|
||||||
|
peerDependencies:
|
||||||
|
"@remix-run/react": ">=2"
|
||||||
|
next: ">=14.2.0"
|
||||||
|
react: ">=18.2.0 || ^19.0.0-0"
|
||||||
|
react-router: ^6 || ^7
|
||||||
|
react-router-dom: ^6 || ^7
|
||||||
|
peerDependenciesMeta:
|
||||||
|
"@remix-run/react":
|
||||||
|
optional: true
|
||||||
|
next:
|
||||||
|
optional: true
|
||||||
|
react-router:
|
||||||
|
optional: true
|
||||||
|
react-router-dom:
|
||||||
|
optional: true
|
||||||
|
|
||||||
nypm@0.5.4:
|
nypm@0.5.4:
|
||||||
resolution:
|
resolution:
|
||||||
{
|
{
|
||||||
@@ -11553,6 +11583,8 @@ snapshots:
|
|||||||
minipass: 3.3.6
|
minipass: 3.3.6
|
||||||
yallist: 4.0.0
|
yallist: 4.0.0
|
||||||
|
|
||||||
|
mitt@3.0.1: {}
|
||||||
|
|
||||||
mkdirp@0.5.6:
|
mkdirp@0.5.6:
|
||||||
dependencies:
|
dependencies:
|
||||||
minimist: 1.2.8
|
minimist: 1.2.8
|
||||||
@@ -11674,6 +11706,13 @@ snapshots:
|
|||||||
gauge: 3.0.2
|
gauge: 3.0.2
|
||||||
set-blocking: 2.0.0
|
set-blocking: 2.0.0
|
||||||
|
|
||||||
|
nuqs@2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
mitt: 3.0.1
|
||||||
|
react: 18.3.1
|
||||||
|
optionalDependencies:
|
||||||
|
next: 14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0)
|
||||||
|
|
||||||
nypm@0.5.4:
|
nypm@0.5.4:
|
||||||
dependencies:
|
dependencies:
|
||||||
citty: 0.1.6
|
citty: 0.1.6
|
||||||
|
|||||||
Reference in New Issue
Block a user