mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
feat: 3-mode selfhosted refactoring (--gpu, --cpu, --hosted) + audio token auth fallback (#896)
* fix: local processing instead of http server for cpu * add fallback token if service worker doesnt work * chore: rename processors to keep processor pattern up to date and allow other processors to be createed and used with env vars
This commit is contained in:
committed by
GitHub
parent
4235ab4293
commit
a682846645
@@ -1,11 +1,12 @@
|
||||
# Self-hosted production Docker Compose — single file for everything.
|
||||
#
|
||||
# Usage: ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
# or: docker compose -f docker-compose.selfhosted.yml --profile gpu [--profile ollama-gpu] [--profile garage] [--profile caddy] up -d
|
||||
# Usage: ./scripts/setup-selfhosted.sh <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--garage] [--caddy]
|
||||
# or: docker compose -f docker-compose.selfhosted.yml [--profile gpu] [--profile ollama-gpu] [--profile garage] [--profile caddy] up -d
|
||||
#
|
||||
# Specialized models (pick ONE — required):
|
||||
# --profile gpu NVIDIA GPU for transcription/diarization/translation
|
||||
# --profile cpu CPU-only for transcription/diarization/translation
|
||||
# ML processing modes (pick ONE — required):
|
||||
# --gpu NVIDIA GPU container for transcription/diarization/translation (profile: gpu)
|
||||
# --cpu In-process CPU processing on server/worker (no ML container needed)
|
||||
# --hosted Remote GPU service URL (no ML container needed)
|
||||
#
|
||||
# Local LLM (optional — for summarization/topics):
|
||||
# --profile ollama-gpu Local Ollama with NVIDIA GPU
|
||||
@@ -45,16 +46,9 @@ services:
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
# Specialized models via gpu/cpu container (aliased as "transcription")
|
||||
TRANSCRIPT_BACKEND: modal
|
||||
TRANSCRIPT_URL: http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY: selfhosted
|
||||
DIARIZATION_BACKEND: modal
|
||||
DIARIZATION_URL: http://transcription:8000
|
||||
TRANSLATION_BACKEND: modal
|
||||
TRANSLATE_URL: http://transcription:8000
|
||||
PADDING_BACKEND: modal
|
||||
PADDING_URL: http://transcription:8000
|
||||
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
||||
# HF_TOKEN needed for in-process pyannote diarization (--cpu mode)
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
# WebRTC: fixed UDP port range for ICE candidates (mapped above)
|
||||
WEBRTC_PORT_RANGE: "51000-51100"
|
||||
depends_on:
|
||||
@@ -79,15 +73,8 @@ services:
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
TRANSCRIPT_BACKEND: modal
|
||||
TRANSCRIPT_URL: http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY: selfhosted
|
||||
DIARIZATION_BACKEND: modal
|
||||
DIARIZATION_URL: http://transcription:8000
|
||||
TRANSLATION_BACKEND: modal
|
||||
TRANSLATE_URL: http://transcription:8000
|
||||
PADDING_BACKEND: modal
|
||||
PADDING_URL: http://transcription:8000
|
||||
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -165,7 +152,10 @@ services:
|
||||
|
||||
# ===========================================================
|
||||
# Specialized model containers (transcription, diarization, translation)
|
||||
# Both gpu and cpu get alias "transcription" so server config never changes.
|
||||
# Only the gpu profile is activated by the setup script (--gpu mode).
|
||||
# The cpu service definition is kept for manual/standalone use but is
|
||||
# NOT activated by --cpu mode (which uses in-process local backends).
|
||||
# Both services get alias "transcription" so server config never changes.
|
||||
# ===========================================================
|
||||
|
||||
gpu:
|
||||
|
||||
@@ -254,15 +254,15 @@ Reflector can run completely offline:
|
||||
Control where each step happens:
|
||||
|
||||
```yaml
|
||||
# All local processing
|
||||
TRANSCRIPT_BACKEND=local
|
||||
DIARIZATION_BACKEND=local
|
||||
TRANSLATION_BACKEND=local
|
||||
# All in-process processing
|
||||
TRANSCRIPT_BACKEND=whisper
|
||||
DIARIZATION_BACKEND=pyannote
|
||||
TRANSLATION_BACKEND=marian
|
||||
|
||||
# Hybrid approach
|
||||
TRANSCRIPT_BACKEND=modal # Fast GPU processing
|
||||
DIARIZATION_BACKEND=local # Sensitive speaker data
|
||||
TRANSLATION_BACKEND=modal # Non-sensitive translation
|
||||
TRANSCRIPT_BACKEND=modal # Fast GPU processing
|
||||
DIARIZATION_BACKEND=pyannote # Sensitive speaker data
|
||||
TRANSLATION_BACKEND=modal # Non-sensitive translation
|
||||
```
|
||||
|
||||
### Storage Options
|
||||
|
||||
@@ -53,9 +53,12 @@ cd reflector
|
||||
# Same but without a domain (self-signed cert, access via IP):
|
||||
./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
|
||||
# CPU-only (same, but slower):
|
||||
# CPU-only (in-process ML, no GPU container):
|
||||
./scripts/setup-selfhosted.sh --cpu --ollama-cpu --garage --caddy
|
||||
|
||||
# Remote GPU service (your own hosted GPU, no local ML container):
|
||||
./scripts/setup-selfhosted.sh --hosted --garage --caddy
|
||||
|
||||
# With password authentication (single admin user):
|
||||
./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy --password mysecretpass
|
||||
|
||||
@@ -65,14 +68,15 @@ cd reflector
|
||||
|
||||
That's it. The script generates env files, secrets, starts all containers, waits for health checks, and prints the URL.
|
||||
|
||||
## Specialized Models (Required)
|
||||
## ML Processing Modes (Required)
|
||||
|
||||
Pick `--gpu` or `--cpu`. This determines how **transcription, diarization, translation, and audio padding** run:
|
||||
Pick `--gpu`, `--cpu`, or `--hosted`. This determines how **transcription, diarization, translation, and audio padding** run:
|
||||
|
||||
| Flag | What it does | Requires |
|
||||
|------|-------------|----------|
|
||||
| `--gpu` | NVIDIA GPU acceleration for ML models | NVIDIA GPU + drivers + `nvidia-container-toolkit` |
|
||||
| `--cpu` | CPU-only (slower but works without GPU) | 8+ cores, 32GB+ RAM recommended |
|
||||
| `--gpu` | NVIDIA GPU container for ML models | NVIDIA GPU + drivers + `nvidia-container-toolkit` |
|
||||
| `--cpu` | In-process CPU processing on server/worker (no ML container) | 8+ cores, 16GB+ RAM (32GB recommended for large files) |
|
||||
| `--hosted` | Remote GPU service URL (no local ML container) | A running GPU service instance (e.g. `gpu/self_hosted/`) |
|
||||
|
||||
## Local LLM (Optional)
|
||||
|
||||
@@ -130,9 +134,11 @@ Browse all available models at https://ollama.com/library.
|
||||
|
||||
- **`--gpu --ollama-gpu`**: Best for servers with NVIDIA GPU. Fully self-contained, no external API keys needed.
|
||||
- **`--cpu --ollama-cpu`**: No GPU available but want everything self-contained. Slower but works.
|
||||
- **`--hosted --ollama-cpu`**: Remote GPU for ML, local CPU for LLM. Great when you have a separate GPU server.
|
||||
- **`--gpu --ollama-cpu`**: GPU for transcription, CPU for LLM. Saves GPU VRAM for ML models.
|
||||
- **`--gpu`**: Have NVIDIA GPU but prefer a cloud LLM (faster/better summaries with GPT-4, Claude, etc.).
|
||||
- **`--cpu`**: No GPU, prefer cloud LLM. Slowest transcription but best summary quality.
|
||||
- **`--hosted`**: Remote GPU, cloud LLM. No local ML at all.
|
||||
|
||||
## Other Optional Flags
|
||||
|
||||
@@ -160,7 +166,7 @@ Without `--caddy` or `--domain`, no ports are exposed. Point your own reverse pr
|
||||
4. **Generate `www/.env`** — Auto-detects server IP, sets URLs
|
||||
5. **Storage setup** — Either initializes Garage (bucket, keys, permissions) or prompts for external S3 credentials
|
||||
6. **Caddyfile** — Generates domain-specific (Let's Encrypt) or IP-specific (self-signed) configuration
|
||||
7. **Build & start** — Always builds GPU/CPU model image from source. With `--build`, also builds backend and frontend from source; otherwise pulls prebuilt images from the registry
|
||||
7. **Build & start** — For `--gpu`, builds the GPU model image from source. For `--cpu` and `--hosted`, no ML container is built. With `--build`, also builds backend and frontend from source; otherwise pulls prebuilt images from the registry
|
||||
8. **Auto-detects video platforms** — If `DAILY_API_KEY` is found in `server/.env`, generates `.env.hatchet` (dashboard URL/cookie config), starts Hatchet workflow engine, and generates an API token. If any video platform is configured, enables the Rooms feature
|
||||
9. **Health checks** — Waits for each service, pulls Ollama model if needed, warns about missing LLM config
|
||||
|
||||
@@ -181,7 +187,7 @@ Without `--caddy` or `--domain`, no ports are exposed. Point your own reverse pr
|
||||
| `ADMIN_PASSWORD_HASH` | PBKDF2 hash for password auth | *(unset)* |
|
||||
| `WEBRTC_HOST` | IP advertised in WebRTC ICE candidates | Auto-detected (server IP) |
|
||||
| `TRANSCRIPT_URL` | Specialized model endpoint | `http://transcription:8000` |
|
||||
| `PADDING_BACKEND` | Audio padding backend (`local` or `modal`) | `modal` (selfhosted), `local` (default) |
|
||||
| `PADDING_BACKEND` | Audio padding backend (`pyav` or `modal`) | `modal` (selfhosted), `pyav` (default) |
|
||||
| `PADDING_URL` | Audio padding endpoint (when `PADDING_BACKEND=modal`) | `http://transcription:8000` |
|
||||
| `LLM_URL` | OpenAI-compatible LLM endpoint | Auto-set for Ollama modes |
|
||||
| `LLM_API_KEY` | LLM API key | `not-needed` for Ollama |
|
||||
@@ -604,10 +610,9 @@ The setup script is idempotent — it won't overwrite existing secrets or env va
|
||||
│ │ │
|
||||
v v v
|
||||
┌───────────┐ ┌─────────┐ ┌─────────┐
|
||||
│transcription│ │postgres │ │ redis │
|
||||
│(gpu/cpu) │ │ :5432 │ │ :6379 │
|
||||
│ :8000 │ └─────────┘ └─────────┘
|
||||
└───────────┘
|
||||
│ ML models │ │postgres │ │ redis │
|
||||
│ (varies) │ │ :5432 │ │ :6379 │
|
||||
└───────────┘ └─────────┘ └─────────┘
|
||||
│
|
||||
┌─────┴─────┐ ┌─────────┐
|
||||
│ ollama │ │ garage │
|
||||
@@ -622,6 +627,11 @@ The setup script is idempotent — it won't overwrite existing secrets or env va
|
||||
│ │ :8888 │──│ -cpu / -llm │ │
|
||||
│ └─────────┘ └───────────────┘ │
|
||||
└───────────────────────────────────┘
|
||||
|
||||
ML models box varies by mode:
|
||||
--gpu: Local GPU container (transcription:8000)
|
||||
--cpu: In-process on server/worker (no container)
|
||||
--hosted: Remote GPU service (user URL)
|
||||
```
|
||||
|
||||
All services communicate over Docker's internal network. Only Caddy (if enabled) exposes ports to the internet. Hatchet services are only started when `DAILY_API_KEY` is configured.
|
||||
|
||||
@@ -11,10 +11,11 @@ dependencies = [
|
||||
"faster-whisper>=1.1.0",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
"silero-vad==5.1.2",
|
||||
"transformers>=4.35.0",
|
||||
"sentencepiece",
|
||||
"pyannote.audio==3.1.0",
|
||||
"pyannote.audio==3.4.0",
|
||||
"pytorch-lightning<2.6",
|
||||
"torchaudio>=2.3.0",
|
||||
"av>=13.1.0",
|
||||
]
|
||||
|
||||
18
gpu/self_hosted/uv.lock
generated
18
gpu/self_hosted/uv.lock
generated
@@ -1742,7 +1742,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyannote-audio"
|
||||
version = "3.1.0"
|
||||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "asteroid-filterbanks" },
|
||||
@@ -1765,9 +1765,9 @@ dependencies = [
|
||||
{ name = "torchaudio" },
|
||||
{ name = "torchmetrics" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ad/55/7253267c35e2aa9188b1d86cba121eb5bdd91ed12d3194488625a008cae7/pyannote.audio-3.1.0.tar.gz", hash = "sha256:da04705443d3b74607e034d3ca88f8b572c7e9672dd9a4199cab65a0dbc33fad", size = 14812058, upload-time = "2023-11-16T12:26:38.939Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/1e/efe9619c38f1281ddf21640654d8ea9e3f67c459b76f78657b26d8557bbe/pyannote_audio-3.4.0.tar.gz", hash = "sha256:d523d883cb8d37cb6daf99f3ba83f9138bb193646ad71e6eae7deb89d8ddd642", size = 804850, upload-time = "2025-09-09T07:04:51.17Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/37/158859ce4c45b5ba2dca40b53b0c10d36f935b7f6d4e737298397167c8b1/pyannote.audio-3.1.0-py2.py3-none-any.whl", hash = "sha256:66ab485728c6e141760e80555cb7a083e7be824cd528cc79b9e6f7d6421a91ae", size = 208592, upload-time = "2023-11-16T12:26:36.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/13/620c6f711b723653092fd063bfee82a6af5ea3a4d3c42efc53ce623a7f4d/pyannote_audio-3.4.0-py2.py3-none-any.whl", hash = "sha256:36e38f058059f46da3478dda581cda53d9d85a21173a3e70bbdbc3ba93b5e1b7", size = 897789, upload-time = "2025-09-09T07:04:49.464Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2075,6 +2075,7 @@ dependencies = [
|
||||
{ name = "librosa" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pyannote-audio" },
|
||||
{ name = "pytorch-lightning" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "silero-vad" },
|
||||
{ name = "torch" },
|
||||
@@ -2090,9 +2091,10 @@ requires-dist = [
|
||||
{ name = "faster-whisper", specifier = ">=1.1.0" },
|
||||
{ name = "librosa", specifier = "==0.10.1" },
|
||||
{ name = "numpy", specifier = "<2" },
|
||||
{ name = "pyannote-audio", specifier = "==3.1.0" },
|
||||
{ name = "pyannote-audio", specifier = "==3.4.0" },
|
||||
{ name = "pytorch-lightning", specifier = "<2.6" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "silero-vad", specifier = "==5.1.0" },
|
||||
{ name = "silero-vad", specifier = "==5.1.2" },
|
||||
{ name = "torch", specifier = ">=2.3.0" },
|
||||
{ name = "torchaudio", specifier = ">=2.3.0" },
|
||||
{ name = "transformers", specifier = ">=4.35.0" },
|
||||
@@ -2472,16 +2474,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "silero-vad"
|
||||
version = "5.1"
|
||||
version = "5.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "onnxruntime" },
|
||||
{ name = "torch" },
|
||||
{ name = "torchaudio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7c/5d/b912e45d21b8b61859a552554893222d2cdebfd0f9afa7e8ba69c7a3441a/silero_vad-5.1.tar.gz", hash = "sha256:c644275ba5df06cee596cc050ba0bd1e0f5237d1abfa44d58dd4618f6e77434d", size = 3996829, upload-time = "2024-07-09T13:19:24.181Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/b4/d0311b2e6220a11f8f4699f4a278cb088131573286cdfe804c87c7eb5123/silero_vad-5.1.2.tar.gz", hash = "sha256:c442971160026d2d7aa0ad83f0c7ee86c89797a65289fe625c8ea59fc6fb828d", size = 5098526, upload-time = "2024-10-09T09:50:47.019Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/be/0fdbc72030b93d6f55107490d5d2185ddf0dbabdc921f589649d3e92ccd5/silero_vad-5.1-py3-none-any.whl", hash = "sha256:ecb50b484f538f7a962ce5cd3c07120d9db7b9d5a0c5861ccafe459856f22c8f", size = 3939986, upload-time = "2024-07-09T13:19:21.383Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/f7/5ae11d13fbb733cd3bfd7ff1c3a3902e6f55437df4b72307c1f168146268/silero_vad-5.1.2-py3-none-any.whl", hash = "sha256:93b41953d7774b165407fda6b533c119c5803864e367d5034dc626c82cfdf661", size = 5026737, upload-time = "2024-10-09T09:50:44.355Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
# Single script to configure and launch everything on one server.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/setup-selfhosted.sh <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASSWORD] [--build]
|
||||
# ./scripts/setup-selfhosted.sh <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASSWORD] [--build]
|
||||
#
|
||||
# Specialized models (pick ONE — required):
|
||||
# --gpu NVIDIA GPU for transcription/diarization/translation
|
||||
# --cpu CPU-only for transcription/diarization/translation (slower)
|
||||
# ML processing modes (pick ONE — required):
|
||||
# --gpu NVIDIA GPU container for transcription/diarization/translation
|
||||
# --cpu In-process CPU processing (no ML container, slower)
|
||||
# --hosted Remote GPU service URL (no ML container)
|
||||
#
|
||||
# Local LLM (optional — for summarization & topic detection):
|
||||
# --ollama-gpu Local Ollama with NVIDIA GPU acceleration
|
||||
@@ -29,6 +30,7 @@
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy --domain reflector.example.com
|
||||
# ./scripts/setup-selfhosted.sh --cpu --ollama-cpu --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --hosted --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --llm-model mistral --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --garage --caddy --password mysecretpass
|
||||
# ./scripts/setup-selfhosted.sh --gpu --garage --caddy
|
||||
@@ -183,11 +185,14 @@ for i in "${!ARGS[@]}"; do
|
||||
arg="${ARGS[$i]}"
|
||||
case "$arg" in
|
||||
--gpu)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu and --cpu. Pick one."; exit 1; }
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="gpu" ;;
|
||||
--cpu)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu and --cpu. Pick one."; exit 1; }
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="cpu" ;;
|
||||
--hosted)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="hosted" ;;
|
||||
--ollama-gpu)
|
||||
[[ -n "$OLLAMA_MODE" ]] && { err "Cannot combine --ollama-gpu and --ollama-cpu. Pick one."; exit 1; }
|
||||
OLLAMA_MODE="ollama-gpu" ;;
|
||||
@@ -224,20 +229,21 @@ for i in "${!ARGS[@]}"; do
|
||||
SKIP_NEXT=true ;;
|
||||
*)
|
||||
err "Unknown argument: $arg"
|
||||
err "Usage: $0 <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err "Usage: $0 <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$MODEL_MODE" ]]; then
|
||||
err "No model mode specified. You must choose --gpu or --cpu."
|
||||
err "No model mode specified. You must choose --gpu, --cpu, or --hosted."
|
||||
err ""
|
||||
err "Usage: $0 <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err "Usage: $0 <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err ""
|
||||
err "Specialized models (required):"
|
||||
err " --gpu NVIDIA GPU for transcription/diarization/translation"
|
||||
err " --cpu CPU-only (slower but works without GPU)"
|
||||
err "ML processing modes (required):"
|
||||
err " --gpu NVIDIA GPU container for transcription/diarization/translation"
|
||||
err " --cpu In-process CPU processing (no ML container, slower)"
|
||||
err " --hosted Remote GPU service URL (no ML container)"
|
||||
err ""
|
||||
err "Local LLM (optional):"
|
||||
err " --ollama-gpu Local Ollama with GPU (for summarization/topics)"
|
||||
@@ -255,7 +261,9 @@ if [[ -z "$MODEL_MODE" ]]; then
|
||||
fi
|
||||
|
||||
# Build profiles list — one profile per feature
|
||||
COMPOSE_PROFILES=("$MODEL_MODE")
|
||||
# Only --gpu needs a compose profile; --cpu and --hosted use in-process/remote backends
|
||||
COMPOSE_PROFILES=()
|
||||
[[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu")
|
||||
[[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE")
|
||||
[[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage")
|
||||
[[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy")
|
||||
@@ -422,43 +430,102 @@ step_server_env() {
|
||||
env_set "$SERVER_ENV" "WEBRTC_HOST" "$PRIMARY_IP"
|
||||
fi
|
||||
|
||||
# Specialized models (always via gpu/cpu container aliased as "transcription")
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "selfhosted"
|
||||
# Specialized models — backend configuration per mode
|
||||
env_set "$SERVER_ENV" "DIARIZATION_ENABLED" "true"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "PADDING_URL" "http://transcription:8000"
|
||||
case "$MODEL_MODE" in
|
||||
gpu)
|
||||
# GPU container aliased as "transcription" on docker network
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "selfhosted"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "PADDING_URL" "http://transcription:8000"
|
||||
ok "ML backends: GPU container (modal)"
|
||||
;;
|
||||
cpu)
|
||||
# In-process backends — no ML service container needed
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "whisper"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "pyannote"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "marian"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "pyav"
|
||||
ok "ML backends: in-process CPU (whisper/pyannote/marian/pyav)"
|
||||
;;
|
||||
hosted)
|
||||
# Remote GPU service — user provides URL
|
||||
local gpu_url=""
|
||||
if env_has_key "$SERVER_ENV" "TRANSCRIPT_URL"; then
|
||||
gpu_url=$(env_get "$SERVER_ENV" "TRANSCRIPT_URL")
|
||||
fi
|
||||
if [[ -z "$gpu_url" ]] && [[ -t 0 ]]; then
|
||||
echo ""
|
||||
info "Enter the URL of your remote GPU service (e.g. https://gpu.example.com)"
|
||||
read -rp " GPU service URL: " gpu_url
|
||||
fi
|
||||
if [[ -z "$gpu_url" ]]; then
|
||||
err "GPU service URL required for --hosted mode."
|
||||
err "Set TRANSCRIPT_URL in server/.env or provide it interactively."
|
||||
exit 1
|
||||
fi
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "PADDING_URL" "$gpu_url"
|
||||
# API key for remote service
|
||||
local gpu_api_key=""
|
||||
if env_has_key "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY"; then
|
||||
gpu_api_key=$(env_get "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY")
|
||||
fi
|
||||
if [[ -z "$gpu_api_key" ]] && [[ -t 0 ]]; then
|
||||
read -rp " GPU service API key (or Enter to skip): " gpu_api_key
|
||||
fi
|
||||
if [[ -n "$gpu_api_key" ]]; then
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "$gpu_api_key"
|
||||
fi
|
||||
ok "ML backends: remote hosted ($gpu_url)"
|
||||
;;
|
||||
esac
|
||||
|
||||
# HuggingFace token for gated models (pyannote diarization)
|
||||
# Written to root .env so docker compose picks it up for gpu/cpu containers
|
||||
local root_env="$ROOT_DIR/.env"
|
||||
local current_hf_token="${HF_TOKEN:-}"
|
||||
if [[ -f "$root_env" ]] && env_has_key "$root_env" "HF_TOKEN"; then
|
||||
current_hf_token=$(env_get "$root_env" "HF_TOKEN")
|
||||
fi
|
||||
if [[ -z "$current_hf_token" ]]; then
|
||||
echo ""
|
||||
warn "HF_TOKEN not set. Diarization will use a public model fallback."
|
||||
warn "For best results, get a token at https://huggingface.co/settings/tokens"
|
||||
warn "and accept pyannote licenses at https://huggingface.co/pyannote/speaker-diarization-3.1"
|
||||
if [[ -t 0 ]]; then
|
||||
read -rp " HuggingFace token (or press Enter to skip): " current_hf_token
|
||||
# --gpu: written to root .env (docker compose passes to GPU container)
|
||||
# --cpu: written to both root .env and server/.env (in-process pyannote needs it)
|
||||
# --hosted: not needed (remote service handles its own auth)
|
||||
if [[ "$MODEL_MODE" != "hosted" ]]; then
|
||||
local root_env="$ROOT_DIR/.env"
|
||||
local current_hf_token="${HF_TOKEN:-}"
|
||||
if [[ -f "$root_env" ]] && env_has_key "$root_env" "HF_TOKEN"; then
|
||||
current_hf_token=$(env_get "$root_env" "HF_TOKEN")
|
||||
fi
|
||||
if [[ -z "$current_hf_token" ]]; then
|
||||
echo ""
|
||||
warn "HF_TOKEN not set. Diarization will use a public model fallback."
|
||||
warn "For best results, get a token at https://huggingface.co/settings/tokens"
|
||||
warn "and accept pyannote licenses at https://huggingface.co/pyannote/speaker-diarization-3.1"
|
||||
if [[ -t 0 ]]; then
|
||||
read -rp " HuggingFace token (or press Enter to skip): " current_hf_token
|
||||
fi
|
||||
fi
|
||||
if [[ -n "$current_hf_token" ]]; then
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" "$current_hf_token"
|
||||
export HF_TOKEN="$current_hf_token"
|
||||
# In CPU mode, server process needs HF_TOKEN directly
|
||||
if [[ "$MODEL_MODE" == "cpu" ]]; then
|
||||
env_set "$SERVER_ENV" "HF_TOKEN" "$current_hf_token"
|
||||
fi
|
||||
ok "HF_TOKEN configured"
|
||||
else
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" ""
|
||||
ok "HF_TOKEN skipped (using public model fallback)"
|
||||
fi
|
||||
fi
|
||||
if [[ -n "$current_hf_token" ]]; then
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" "$current_hf_token"
|
||||
export HF_TOKEN="$current_hf_token"
|
||||
ok "HF_TOKEN configured"
|
||||
else
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" ""
|
||||
ok "HF_TOKEN skipped (using public model fallback)"
|
||||
fi
|
||||
|
||||
# LLM configuration
|
||||
@@ -799,11 +866,12 @@ CADDYEOF
|
||||
step_services() {
|
||||
info "Step 6: Starting Docker services"
|
||||
|
||||
# Build GPU/CPU image from source (always needed — no prebuilt image)
|
||||
local build_svc="$MODEL_MODE"
|
||||
info "Building $build_svc image (first build downloads ML models, may take a while)..."
|
||||
compose_cmd build "$build_svc"
|
||||
ok "$build_svc image built"
|
||||
# Build GPU image from source (only for --gpu mode)
|
||||
if [[ "$MODEL_MODE" == "gpu" ]]; then
|
||||
info "Building gpu image (first build downloads ML models, may take a while)..."
|
||||
compose_cmd build gpu
|
||||
ok "gpu image built"
|
||||
fi
|
||||
|
||||
# Build or pull backend and frontend images
|
||||
if [[ "$BUILD_IMAGES" == "true" ]]; then
|
||||
@@ -871,25 +939,29 @@ step_services() {
|
||||
step_health() {
|
||||
info "Step 7: Health checks"
|
||||
|
||||
# Specialized model service (gpu or cpu)
|
||||
local model_svc="$MODEL_MODE"
|
||||
|
||||
info "Waiting for $model_svc service (first start downloads ~1GB of models)..."
|
||||
local model_ok=false
|
||||
for i in $(seq 1 120); do
|
||||
if curl -sf http://localhost:8000/docs > /dev/null 2>&1; then
|
||||
model_ok=true
|
||||
break
|
||||
# Specialized model service (only for --gpu mode)
|
||||
if [[ "$MODEL_MODE" == "gpu" ]]; then
|
||||
info "Waiting for gpu service (first start downloads ~1GB of models)..."
|
||||
local model_ok=false
|
||||
for i in $(seq 1 120); do
|
||||
if curl -sf http://localhost:8000/docs > /dev/null 2>&1; then
|
||||
model_ok=true
|
||||
break
|
||||
fi
|
||||
echo -ne "\r Waiting for gpu service... ($i/120)"
|
||||
sleep 5
|
||||
done
|
||||
echo ""
|
||||
if [[ "$model_ok" == "true" ]]; then
|
||||
ok "gpu service healthy (transcription + diarization)"
|
||||
else
|
||||
warn "gpu service not ready yet — it will keep loading in the background"
|
||||
warn "Check with: docker compose -f docker-compose.selfhosted.yml logs gpu"
|
||||
fi
|
||||
echo -ne "\r Waiting for $model_svc service... ($i/120)"
|
||||
sleep 5
|
||||
done
|
||||
echo ""
|
||||
if [[ "$model_ok" == "true" ]]; then
|
||||
ok "$model_svc service healthy (transcription + diarization)"
|
||||
else
|
||||
warn "$model_svc service not ready yet — it will keep loading in the background"
|
||||
warn "Check with: docker compose -f docker-compose.selfhosted.yml logs $model_svc"
|
||||
elif [[ "$MODEL_MODE" == "cpu" ]]; then
|
||||
ok "CPU mode — ML processing runs in-process on server/worker (no separate service)"
|
||||
elif [[ "$MODEL_MODE" == "hosted" ]]; then
|
||||
ok "Hosted mode — ML processing via remote GPU service (no local health check)"
|
||||
fi
|
||||
|
||||
# Ollama (if applicable)
|
||||
|
||||
@@ -89,11 +89,11 @@ LLM_CONTEXT_WINDOW=16000
|
||||
## =======================================================
|
||||
## Audio Padding
|
||||
##
|
||||
## backends: local (in-process PyAV), modal (HTTP API client)
|
||||
## Default is "local" — no external service needed.
|
||||
## backends: pyav (in-process PyAV), modal (HTTP API client)
|
||||
## Default is "pyav" — no external service needed.
|
||||
## Set to "modal" when using Modal.com or self-hosted gpu/self_hosted/ container.
|
||||
## =======================================================
|
||||
#PADDING_BACKEND=local
|
||||
#PADDING_BACKEND=pyav
|
||||
#PADDING_BACKEND=modal
|
||||
#PADDING_URL=https://xxxxx--reflector-padding-web.modal.run
|
||||
#PADDING_MODAL_API_KEY=xxxxx
|
||||
@@ -101,8 +101,8 @@ LLM_CONTEXT_WINDOW=16000
|
||||
## =======================================================
|
||||
## Diarization
|
||||
##
|
||||
## Only available on modal
|
||||
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
|
||||
## backends: modal (HTTP API), pyannote (in-process pyannote.audio)
|
||||
## To allow diarization, you need to expose expose the files to be downloaded by the pipeline
|
||||
## =======================================================
|
||||
DIARIZATION_ENABLED=false
|
||||
DIARIZATION_BACKEND=modal
|
||||
|
||||
@@ -32,26 +32,46 @@ AUTH_BACKEND=none
|
||||
|
||||
# =======================================================
|
||||
# Specialized Models (Transcription, Diarization, Translation)
|
||||
# These run in the gpu/cpu container — NOT an LLM.
|
||||
# The "modal" backend means "HTTP API client" — it talks to
|
||||
# the self-hosted container, not Modal.com cloud.
|
||||
# These do NOT use an LLM. Configured per mode by the setup script:
|
||||
#
|
||||
# --gpu mode: modal backends → GPU container (http://transcription:8000)
|
||||
# --cpu mode: whisper/pyannote/marian/pyav → in-process ML on server/worker
|
||||
# --hosted mode: modal backends → user-provided remote GPU service URL
|
||||
# =======================================================
|
||||
|
||||
# --- --gpu mode (default) ---
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY=selfhosted
|
||||
|
||||
DIARIZATION_ENABLED=true
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=http://transcription:8000
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATE_URL=http://transcription:8000
|
||||
|
||||
PADDING_BACKEND=modal
|
||||
PADDING_URL=http://transcription:8000
|
||||
|
||||
# HuggingFace token — optional, for gated models (e.g. pyannote).
|
||||
# Falls back to public S3 model bundle if not set.
|
||||
# --- --cpu mode (set by setup script) ---
|
||||
# TRANSCRIPT_BACKEND=whisper
|
||||
# DIARIZATION_BACKEND=pyannote
|
||||
# TRANSLATION_BACKEND=marian
|
||||
# PADDING_BACKEND=pyav
|
||||
|
||||
# --- --hosted mode (set by setup script) ---
|
||||
# TRANSCRIPT_BACKEND=modal
|
||||
# TRANSCRIPT_URL=https://your-gpu-service.example.com
|
||||
# DIARIZATION_BACKEND=modal
|
||||
# DIARIZATION_URL=https://your-gpu-service.example.com
|
||||
# ... (all URLs point to one remote service)
|
||||
|
||||
# Whisper model sizes for local transcription (--cpu mode)
|
||||
# Options: "tiny", "base", "small", "medium", "large-v2"
|
||||
# WHISPER_CHUNK_MODEL=tiny
|
||||
# WHISPER_FILE_MODEL=tiny
|
||||
|
||||
# HuggingFace token — for gated models (e.g. pyannote diarization).
|
||||
# Required for --gpu and --cpu modes; falls back to public S3 bundle if not set.
|
||||
# Not needed for --hosted mode (remote service handles its own auth).
|
||||
# HF_TOKEN=hf_xxxxx
|
||||
|
||||
# =======================================================
|
||||
|
||||
@@ -6,7 +6,7 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
RUN apt-get update && apt-get install -y curl && apt-get clean
|
||||
RUN apt-get update && apt-get install -y curl ffmpeg && apt-get clean
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
|
||||
@@ -71,9 +71,12 @@ local = [
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"silero-vad==5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
"pyannote.audio==3.4.0",
|
||||
"pytorch-lightning<2.6",
|
||||
"librosa==0.10.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -14,6 +14,7 @@ current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
||||
current_user_ws_optional = auth_module.current_user_ws_optional
|
||||
verify_raw_token = auth_module.verify_raw_token
|
||||
|
||||
# Optional router (e.g. for /auth/login in password backend)
|
||||
router = getattr(auth_module, "router", None)
|
||||
|
||||
@@ -144,3 +144,8 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
if not token:
|
||||
return None
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
@@ -27,3 +27,8 @@ def parse_ws_bearer_token(websocket):
|
||||
|
||||
async def current_user_ws_optional(websocket):
|
||||
return None
|
||||
|
||||
|
||||
async def verify_raw_token(token):
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return None
|
||||
|
||||
@@ -168,6 +168,11 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
# --- Login router ---
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_padding import AudioPaddingProcessor # noqa: F401
|
||||
from .audio_padding_auto import AudioPaddingAutoProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
|
||||
86
server/reflector/processors/_audio_download.py
Normal file
86
server/reflector/processors/_audio_download.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Shared audio download utility for local processors.
|
||||
|
||||
Downloads audio from a URL to a temporary file for in-process ML inference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
async def download_audio_to_temp(url: str) -> Path:
|
||||
"""Download audio from URL to a temporary file.
|
||||
|
||||
The caller is responsible for deleting the temp file after use.
|
||||
|
||||
Args:
|
||||
url: Presigned URL or public URL to download audio from.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded temporary file.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _download_blocking, url)
|
||||
|
||||
|
||||
def _download_blocking(url: str) -> Path:
|
||||
"""Blocking download implementation."""
|
||||
log = logger.bind(url=url[:80])
|
||||
log.info("Downloading audio to temp file")
|
||||
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Determine extension from content-type or URL
|
||||
ext = _detect_extension(url, response.headers.get("content-type", ""))
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext)
|
||||
try:
|
||||
total_bytes = 0
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
log.info("Audio downloaded", bytes=total_bytes, path=tmp_path)
|
||||
return Path(tmp_path)
|
||||
except Exception:
|
||||
# Clean up on failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _detect_extension(url: str, content_type: str) -> str:
|
||||
"""Detect audio file extension from URL or content-type."""
|
||||
# Try URL path first
|
||||
path = url.split("?")[0] # Strip query params
|
||||
for ext in (".wav", ".mp3", ".mp4", ".m4a", ".webm", ".ogg", ".flac"):
|
||||
if path.lower().endswith(ext):
|
||||
return ext
|
||||
|
||||
# Try content-type
|
||||
ct_map = {
|
||||
"audio/wav": ".wav",
|
||||
"audio/x-wav": ".wav",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/mp4": ".m4a",
|
||||
"audio/webm": ".webm",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/flac": ".flac",
|
||||
}
|
||||
for ct, ext in ct_map.items():
|
||||
if ct in content_type.lower():
|
||||
return ext
|
||||
|
||||
return ".audio"
|
||||
76
server/reflector/processors/_marian_translator_service.py
Normal file
76
server/reflector/processors/_marian_translator_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
MarianMT translation service.
|
||||
|
||||
Singleton service that loads HuggingFace MarianMT translation models
|
||||
and reuses them across all MarianMT translator processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/translator.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarianTranslatorService:
|
||||
"""MarianMT text translation service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._current_pair = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
"""Load the translation model for a specific language pair."""
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
logger.info(
|
||||
"Loading MarianMT model: %s (%s -> %s)",
|
||||
model_name,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
self._current_pair = (source_language.lower(), target_language.lower())
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
"""Resolve language pair to MarianMT model name."""
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
"""Translate text between languages.
|
||||
|
||||
Args:
|
||||
text: Text to translate.
|
||||
source_language: Source language code (e.g. "en").
|
||||
target_language: Target language code (e.g. "fr").
|
||||
|
||||
Returns:
|
||||
dict with "text" key containing {source_language: original, target_language: translated}.
|
||||
"""
|
||||
pair = (source_language.lower(), target_language.lower())
|
||||
if self._pipeline is None or self._current_pair != pair:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all MarianMT translator processors
|
||||
translator_service = MarianTranslatorService()
|
||||
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Pyannote diarization service using pyannote.audio.
|
||||
|
||||
Singleton service that loads the pyannote speaker diarization model once
|
||||
and reuses it across all pyannote diarization processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/diarizer.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tarfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from urllib.request import urlopen
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import yaml
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
S3_BUNDLE_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz"
|
||||
BUNDLE_CACHE_DIR = Path.home() / ".cache" / "pyannote-bundle"
|
||||
|
||||
|
||||
def _ensure_model(cache_dir: Path) -> str:
|
||||
"""Download and extract S3 model bundle if not cached."""
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
config_path = model_dir / "config.yaml"
|
||||
|
||||
if config_path.exists():
|
||||
logger.info("Using cached model bundle at %s", model_dir)
|
||||
return str(model_dir)
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
tarball_path = cache_dir / "model.tar.gz"
|
||||
|
||||
logger.info("Downloading model bundle from %s", S3_BUNDLE_URL)
|
||||
with urlopen(S3_BUNDLE_URL) as response, open(tarball_path, "wb") as f:
|
||||
while chunk := response.read(8192):
|
||||
f.write(chunk)
|
||||
|
||||
logger.info("Extracting model bundle")
|
||||
with tarfile.open(tarball_path, "r:gz") as tar:
|
||||
tar.extractall(path=cache_dir, filter="data")
|
||||
tarball_path.unlink()
|
||||
|
||||
_patch_config(model_dir, cache_dir)
|
||||
return str(model_dir)
|
||||
|
||||
|
||||
def _patch_config(model_dir: Path, cache_dir: Path) -> None:
|
||||
"""Rewrite config.yaml to reference local pytorch_model.bin paths."""
|
||||
config_path = model_dir / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config["pipeline"]["params"]["segmentation"] = str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
config["pipeline"]["params"]["embedding"] = str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
logger.info("Patched config.yaml with local model paths")
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
"""Pyannote speaker diarization service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
hf_token = settings.HF_TOKEN
|
||||
|
||||
if hf_token:
|
||||
logger.info("Loading pyannote model from HuggingFace (HF_TOKEN set)")
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
else:
|
||||
logger.info("HF_TOKEN not set — loading model from S3 bundle")
|
||||
model_path = _ensure_model(BUNDLE_CACHE_DIR)
|
||||
config_path = Path(model_path) / "config.yaml"
|
||||
self._pipeline = Pipeline.from_pretrained(str(config_path))
|
||||
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
"""Run speaker diarization on an audio file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
timestamp: Offset to add to all segment timestamps.
|
||||
|
||||
Returns:
|
||||
dict with "diarization" key containing list of
|
||||
{"start": float, "end": float, "speaker": int} segments.
|
||||
"""
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
segments = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": segments}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all pyannote diarization processors
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Pyannote audio diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
"""Run pyannote diarization on audio from URL."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return result["diarization"]
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
23
server/reflector/processors/audio_padding.py
Normal file
23
server/reflector/processors/audio_padding.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Base class for audio padding processors.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingProcessor:
|
||||
"""Base class for audio padding processors."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
track_url: str,
|
||||
output_url: str,
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
raise NotImplementedError
|
||||
@@ -1,9 +1,10 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioPaddingAutoProcessor:
|
||||
class AudioPaddingAutoProcessor(AudioPaddingProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,19 +6,14 @@ import asyncio
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingModalProcessor:
|
||||
class AudioPaddingModalProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using Modal.com CPU backend via HTTP."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Local audio padding processor using PyAV.
|
||||
PyAV audio padding processor.
|
||||
|
||||
Pads audio tracks with silence directly in-process (no HTTP).
|
||||
Reuses the shared PyAV utilities from reflector.utils.audio_padding.
|
||||
@@ -12,15 +12,15 @@ import tempfile
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
from reflector.processors.audio_padding_modal import PaddingResponse
|
||||
from reflector.utils.audio_padding import apply_audio_padding_to_file
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
class AudioPaddingLocalProcessor:
|
||||
"""Audio padding processor using local PyAV (no HTTP backend)."""
|
||||
class AudioPaddingPyavProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using PyAV (no HTTP backend)."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
@@ -29,7 +29,7 @@ class AudioPaddingLocalProcessor:
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
"""Pad audio track with silence locally via PyAV.
|
||||
"""Pad audio track with silence via PyAV.
|
||||
|
||||
Args:
|
||||
track_url: Presigned GET URL for source audio track
|
||||
@@ -130,4 +130,4 @@ class AudioPaddingLocalProcessor:
|
||||
log.warning("Failed to cleanup temp directory", error=str(e))
|
||||
|
||||
|
||||
AudioPaddingAutoProcessor.register("local", AudioPaddingLocalProcessor)
|
||||
AudioPaddingAutoProcessor.register("pyav", AudioPaddingPyavProcessor)
|
||||
@@ -3,13 +3,17 @@ from faster_whisper import WhisperModel
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = WhisperModel(
|
||||
"tiny", device="cpu", compute_type="float32", num_workers=12
|
||||
settings.WHISPER_CHUNK_MODEL,
|
||||
device="cpu",
|
||||
compute_type="float32",
|
||||
num_workers=12,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
|
||||
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Pyannote file diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
|
||||
|
||||
class FileDiarizationPyannoteProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Run pyannote diarization on file from URL."""
|
||||
self.logger.info(f"Starting pyannote diarization from {data.audio_url}")
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return FileDiarizationOutput(diarization=result["diarization"])
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor)
|
||||
275
server/reflector/processors/file_transcript_whisper.py
Normal file
275
server/reflector/processors/file_transcript_whisper.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Local file transcription processor using faster-whisper with Silero VAD pipeline.
|
||||
|
||||
Downloads audio from URL, segments it using Silero VAD, transcribes each
|
||||
segment with faster-whisper, and merges results. No HTTP backend needed.
|
||||
|
||||
VAD pipeline ported from gpu/self_hosted/app/services/transcriber.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
class FileTranscriptWhisperProcessor(FileTranscriptProcessor):
|
||||
"""Transcribe complete audio files using local faster-whisper with VAD."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _ensure_model(self):
|
||||
"""Lazy-load the whisper model on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
model_name = settings.WHISPER_FILE_MODEL
|
||||
|
||||
self.logger.info(
|
||||
"Loading whisper model",
|
||||
model=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
self._model = faster_whisper.WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Download file, run VAD segmentation, transcribe each segment."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self._transcribe_file_blocking,
|
||||
str(tmp_path),
|
||||
data.language,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _transcribe_file_blocking(self, file_path: str, language: str) -> Transcript:
|
||||
"""Blocking transcription with VAD pipeline."""
|
||||
self._ensure_model()
|
||||
|
||||
audio_array = _load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
# VAD segmentation → batch merging
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
|
||||
for seg_start, seg_end in _vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
# If no speech detected, try transcribing the whole file
|
||||
if not merged_batches:
|
||||
return self._transcribe_whole_file(file_path, language)
|
||||
|
||||
# Transcribe each batch
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = _pad_audio(segment, SAMPLE_RATE)
|
||||
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
all_words.append(
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time, 2),
|
||||
"end": round(float(w.end) + start_time, 2),
|
||||
}
|
||||
)
|
||||
|
||||
all_words = _enforce_word_timing_constraints(all_words)
|
||||
|
||||
words = [
|
||||
Word(text=w["word"], start=w["start"], end=w["end"]) for w in all_words
|
||||
]
|
||||
words.sort(key=lambda w: w.start)
|
||||
return Transcript(words=words)
|
||||
|
||||
def _transcribe_whole_file(self, file_path: str, language: str) -> Transcript:
|
||||
"""Fallback: transcribe entire file without VAD segmentation."""
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
words = []
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
words.append(
|
||||
Word(
|
||||
text=w.word,
|
||||
start=round(float(w.start), 2),
|
||||
end=round(float(w.end), 2),
|
||||
)
|
||||
)
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# --- VAD helpers (ported from gpu/self_hosted/app/services/transcriber.py) ---
|
||||
# IMPORTANT: This VAD segment logic is duplicated for deployment isolation.
|
||||
# If you modify this, consider updating the GPU service copy as well:
|
||||
# - gpu/self_hosted/app/services/transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
|
||||
|
||||
def _load_audio_via_ffmpeg(
|
||||
input_path: str, sample_rate: int = SAMPLE_RATE
|
||||
) -> np.ndarray:
|
||||
"""Load audio file via ffmpeg, converting to mono float32 at target sample rate."""
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
|
||||
|
||||
def _vad_segments(
|
||||
audio_array: np.ndarray,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
"""Detect speech segments using Silero VAD."""
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)), mode="constant")
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
|
||||
# Handle case where audio ends while speech is still active
|
||||
if start is not None:
|
||||
audio_duration = len(audio_array) / float(sample_rate)
|
||||
yield (start / float(SAMPLE_RATE), audio_duration)
|
||||
|
||||
iterator.reset_states()
|
||||
|
||||
|
||||
def _pad_audio(audio_array: np.ndarray, sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""Pad short audio with silence for VAD compatibility."""
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
def _enforce_word_timing_constraints(words: list[dict]) -> list[dict]:
|
||||
"""Ensure no word end time exceeds the next word's start time."""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
|
||||
FileTranscriptAutoProcessor.register("whisper", FileTranscriptWhisperProcessor)
|
||||
50
server/reflector/processors/transcript_translator_marian.py
Normal file
50
server/reflector/processors/transcript_translator_marian.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
MarianMT transcript translator processor using HuggingFace MarianMT in-process.
|
||||
|
||||
Translates transcript text using HuggingFace MarianMT models
|
||||
locally. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from reflector.processors._marian_translator_service import translator_service
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.types import TranslationLanguages
|
||||
|
||||
|
||||
class TranscriptTranslatorMarianProcessor(TranscriptTranslatorProcessor):
|
||||
"""Translate transcript text using MarianMT models."""
|
||||
|
||||
async def _translate(self, text: str) -> str | None:
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
|
||||
languages = TranslationLanguages()
|
||||
assert languages.is_supported(target_language)
|
||||
|
||||
self.logger.debug(f"MarianMT translate {text=}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
translator_service.translate,
|
||||
text,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
if target_language in result["text"]:
|
||||
translation = result["text"][target_language]
|
||||
else:
|
||||
translation = None
|
||||
|
||||
self.logger.debug(f"Translation result: {text=}, {translation=}")
|
||||
return translation
|
||||
|
||||
|
||||
TranscriptTranslatorAutoProcessor.register(
|
||||
"marian", TranscriptTranslatorMarianProcessor
|
||||
)
|
||||
@@ -40,11 +40,19 @@ class Settings(BaseSettings):
|
||||
# backends: silero, frames
|
||||
AUDIO_CHUNKER_BACKEND: str = "frames"
|
||||
|
||||
# HuggingFace token for gated models (pyannote diarization in --cpu mode)
|
||||
HF_TOKEN: str | None = None
|
||||
|
||||
# Audio Transcription
|
||||
# backends:
|
||||
# - whisper: in-process model loading (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
|
||||
# Whisper model sizes for local transcription
|
||||
# Options: "tiny", "base", "small", "medium", "large-v2"
|
||||
WHISPER_CHUNK_MODEL: str = "tiny"
|
||||
WHISPER_FILE_MODEL: str = "tiny"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
@@ -100,7 +108,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# Diarization
|
||||
# backend: modal — HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
# backends: modal — HTTP API client, pyannote — in-process pyannote.audio
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
@@ -111,9 +119,9 @@ class Settings(BaseSettings):
|
||||
|
||||
# Audio Padding
|
||||
# backends:
|
||||
# - local: in-process PyAV padding (no HTTP, runs in same process)
|
||||
# - pyav: in-process PyAV padding (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
PADDING_BACKEND: str = "local"
|
||||
PADDING_BACKEND: str = "pyav"
|
||||
PADDING_URL: str | None = None
|
||||
PADDING_MODAL_API_KEY: str | None = None
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
import reflector.auth as auth
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
|
||||
@@ -36,16 +35,23 @@ async def transcript_get_audio_mp3(
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.PyJWTError:
|
||||
raise unauthorized_exception
|
||||
token_user = await auth.verify_raw_token(token)
|
||||
except Exception:
|
||||
token_user = None
|
||||
# Fallback: try as internal HS256 token (created by _generate_local_audio_link)
|
||||
if not token_user:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
else:
|
||||
user_id = token_user["sub"]
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
|
||||
450
server/tests/test_processors_cpu.py
Normal file
450
server/tests/test_processors_cpu.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Tests for in-process processor backends (--cpu mode).
|
||||
|
||||
All ML model calls are mocked — no actual model loading needed.
|
||||
Tests verify processor registration, wiring, error handling, and data flow.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummaryWithId,
|
||||
Transcript,
|
||||
Word,
|
||||
)
|
||||
|
||||
# ── Registration Tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_audio_diarization_pyannote_registers():
|
||||
"""Verify AudioDiarizationPyannoteProcessor registers with 'pyannote' backend."""
|
||||
# Importing the module triggers registration
|
||||
import reflector.processors.audio_diarization_pyannote # noqa: F401
|
||||
from reflector.processors.audio_diarization_auto import (
|
||||
AudioDiarizationAutoProcessor,
|
||||
)
|
||||
|
||||
assert "pyannote" in AudioDiarizationAutoProcessor._registry
|
||||
|
||||
|
||||
def test_file_diarization_pyannote_registers():
|
||||
"""Verify FileDiarizationPyannoteProcessor registers with 'pyannote' backend."""
|
||||
import reflector.processors.file_diarization_pyannote # noqa: F401
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
|
||||
assert "pyannote" in FileDiarizationAutoProcessor._registry
|
||||
|
||||
|
||||
def test_transcript_translator_marian_registers():
|
||||
"""Verify TranscriptTranslatorMarianProcessor registers with 'marian' backend."""
|
||||
import reflector.processors.transcript_translator_marian # noqa: F401
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
|
||||
assert "marian" in TranscriptTranslatorAutoProcessor._registry
|
||||
|
||||
|
||||
def test_file_transcript_whisper_registers():
|
||||
"""Verify FileTranscriptWhisperProcessor registers with 'whisper' backend."""
|
||||
import reflector.processors.file_transcript_whisper # noqa: F401
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
|
||||
assert "whisper" in FileTranscriptAutoProcessor._registry
|
||||
|
||||
|
||||
# ── Audio Download Utility Tests ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_audio_to_temp_success():
|
||||
"""Verify download_audio_to_temp downloads to a temp file and returns path."""
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "audio/wav"}
|
||||
mock_response.iter_content.return_value = [b"fake audio data"]
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("reflector.processors._audio_download.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = await download_audio_to_temp("https://example.com/test.wav")
|
||||
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
assert result.read_bytes() == b"fake audio data"
|
||||
assert result.suffix == ".wav"
|
||||
|
||||
# Cleanup
|
||||
os.unlink(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_audio_to_temp_cleanup_on_error():
|
||||
"""Verify temp file is cleaned up when download fails mid-write."""
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"content-type": "audio/wav"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
def fail_iter(*args, **kwargs):
|
||||
raise ConnectionError("Download interrupted")
|
||||
|
||||
mock_response.iter_content = fail_iter
|
||||
|
||||
with patch("reflector.processors._audio_download.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(ConnectionError, match="Download interrupted"):
|
||||
await download_audio_to_temp("https://example.com/test.wav")
|
||||
|
||||
|
||||
def test_detect_extension_from_url():
|
||||
"""Verify extension detection from URL path."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert _detect_extension("https://example.com/test.wav", "") == ".wav"
|
||||
assert _detect_extension("https://example.com/test.mp3?signed=1", "") == ".mp3"
|
||||
assert _detect_extension("https://example.com/test.webm", "") == ".webm"
|
||||
|
||||
|
||||
def test_detect_extension_from_content_type():
|
||||
"""Verify extension detection from content-type header."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/mpeg") == ".mp3"
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/wav") == ".wav"
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/webm") == ".webm"
|
||||
|
||||
|
||||
def test_detect_extension_fallback():
|
||||
"""Verify fallback extension when neither URL nor content-type is recognized."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert (
|
||||
_detect_extension("https://s3.aws/uuid", "application/octet-stream") == ".audio"
|
||||
)
|
||||
|
||||
|
||||
# ── Audio Diarization Pyannote Processor Tests ──────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_diarization_pyannote_diarize():
|
||||
"""Verify pyannote audio diarization downloads, diarizes, and cleans up."""
|
||||
from reflector.processors.audio_diarization_pyannote import (
|
||||
AudioDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
mock_diarization_result = {
|
||||
"diarization": [
|
||||
{"start": 0.0, "end": 2.5, "speaker": 0},
|
||||
{"start": 2.5, "end": 5.0, "speaker": 1},
|
||||
]
|
||||
}
|
||||
|
||||
# Create a temp file to simulate download
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = AudioDiarizationPyannoteProcessor()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.audio_diarization_pyannote.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.audio_diarization_pyannote.diarization_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.diarize_file.return_value = mock_diarization_result
|
||||
|
||||
data = AudioDiarizationInput(
|
||||
audio_url="https://example.com/test.wav",
|
||||
topics=[
|
||||
TitleSummaryWithId(
|
||||
id="topic-1",
|
||||
title="Test Topic",
|
||||
summary="A test topic",
|
||||
timestamp=0.0,
|
||||
duration=5.0,
|
||||
transcript=Transcript(
|
||||
words=[Word(text="hello", start=0.0, end=1.0)]
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
result = await processor._diarize(data)
|
||||
|
||||
assert result == mock_diarization_result["diarization"]
|
||||
mock_svc.diarize_file.assert_called_once()
|
||||
|
||||
|
||||
# ── File Diarization Pyannote Processor Tests ───────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_pyannote_diarize():
|
||||
"""Verify pyannote file diarization returns FileDiarizationOutput."""
|
||||
from reflector.processors.file_diarization_pyannote import (
|
||||
FileDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
mock_diarization_result = {
|
||||
"diarization": [
|
||||
{"start": 0.0, "end": 3.0, "speaker": 0},
|
||||
{"start": 3.0, "end": 6.0, "speaker": 1},
|
||||
]
|
||||
}
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = FileDiarizationPyannoteProcessor()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.diarization_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.diarize_file.return_value = mock_diarization_result
|
||||
|
||||
data = FileDiarizationInput(audio_url="https://example.com/test.wav")
|
||||
result = await processor._diarize(data)
|
||||
|
||||
assert isinstance(result, FileDiarizationOutput)
|
||||
assert len(result.diarization) == 2
|
||||
assert result.diarization[0]["start"] == 0.0
|
||||
assert result.diarization[1]["speaker"] == 1
|
||||
|
||||
|
||||
# ── Transcript Translator Marian Processor Tests ───────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_translator_marian_translate():
|
||||
"""Verify MarianMT translator calls service and extracts translation."""
|
||||
from reflector.processors.transcript_translator_marian import (
|
||||
TranscriptTranslatorMarianProcessor,
|
||||
)
|
||||
|
||||
mock_result = {"text": {"en": "Hello world", "fr": "Bonjour le monde"}}
|
||||
|
||||
processor = TranscriptTranslatorMarianProcessor()
|
||||
|
||||
def fake_get_pref(key, default=None):
|
||||
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
|
||||
return prefs.get(key, default)
|
||||
|
||||
with (
|
||||
patch.object(processor, "get_pref", side_effect=fake_get_pref),
|
||||
patch(
|
||||
"reflector.processors.transcript_translator_marian.translator_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.translate.return_value = mock_result
|
||||
|
||||
result = await processor._translate("Hello world")
|
||||
|
||||
assert result == "Bonjour le monde"
|
||||
mock_svc.translate.assert_called_once_with("Hello world", "en", "fr")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_translator_marian_no_translation():
|
||||
"""Verify translator returns None when target language not in result."""
|
||||
from reflector.processors.transcript_translator_marian import (
|
||||
TranscriptTranslatorMarianProcessor,
|
||||
)
|
||||
|
||||
mock_result = {"text": {"en": "Hello world"}}
|
||||
|
||||
processor = TranscriptTranslatorMarianProcessor()
|
||||
|
||||
def fake_get_pref(key, default=None):
|
||||
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
|
||||
return prefs.get(key, default)
|
||||
|
||||
with (
|
||||
patch.object(processor, "get_pref", side_effect=fake_get_pref),
|
||||
patch(
|
||||
"reflector.processors.transcript_translator_marian.translator_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.translate.return_value = mock_result
|
||||
|
||||
result = await processor._translate("Hello world")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── File Transcript Whisper Processor Tests ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_whisper_transcript():
|
||||
"""Verify whisper file processor downloads, transcribes, and returns Transcript."""
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
FileTranscriptWhisperProcessor,
|
||||
)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = FileTranscriptWhisperProcessor()
|
||||
|
||||
# Mock the blocking transcription method
|
||||
mock_transcript = Transcript(
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5),
|
||||
Word(text=" world", start=0.5, end=1.0),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_transcript_whisper.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch.object(
|
||||
processor,
|
||||
"_transcribe_file_blocking",
|
||||
return_value=mock_transcript,
|
||||
),
|
||||
):
|
||||
data = FileTranscriptInput(
|
||||
audio_url="https://example.com/test.wav", language="en"
|
||||
)
|
||||
result = await processor._transcript(data)
|
||||
|
||||
assert isinstance(result, Transcript)
|
||||
assert len(result.words) == 2
|
||||
assert result.words[0].text == "Hello"
|
||||
|
||||
|
||||
# ── VAD Helper Tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_enforce_word_timing_constraints():
|
||||
"""Verify word timing enforcement prevents overlapping times."""
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
_enforce_word_timing_constraints,
|
||||
)
|
||||
|
||||
words = [
|
||||
{"word": "hello", "start": 0.0, "end": 1.5},
|
||||
{"word": "world", "start": 1.0, "end": 2.0}, # overlaps with previous
|
||||
{"word": "test", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
|
||||
result = _enforce_word_timing_constraints(words)
|
||||
|
||||
assert result[0]["end"] == 1.0 # Clamped to next word's start
|
||||
assert result[1]["end"] == 2.0 # Clamped to next word's start
|
||||
assert result[2]["end"] == 3.0 # Last word unchanged
|
||||
|
||||
|
||||
def test_enforce_word_timing_constraints_empty():
|
||||
"""Verify timing enforcement handles empty and single-word lists."""
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
_enforce_word_timing_constraints,
|
||||
)
|
||||
|
||||
assert _enforce_word_timing_constraints([]) == []
|
||||
assert _enforce_word_timing_constraints([{"word": "a", "start": 0, "end": 1}]) == [
|
||||
{"word": "a", "start": 0, "end": 1}
|
||||
]
|
||||
|
||||
|
||||
def test_pad_audio_short():
|
||||
"""Verify short audio gets padded with silence."""
|
||||
import numpy as np
|
||||
|
||||
from reflector.processors.file_transcript_whisper import _pad_audio
|
||||
|
||||
short_audio = np.zeros(100, dtype=np.float32) # Very short
|
||||
result = _pad_audio(short_audio, sample_rate=16000)
|
||||
|
||||
# Should be padded to at least silence_padding duration
|
||||
assert len(result) > len(short_audio)
|
||||
|
||||
|
||||
def test_pad_audio_long():
|
||||
"""Verify long audio is not padded."""
|
||||
import numpy as np
|
||||
|
||||
from reflector.processors.file_transcript_whisper import _pad_audio
|
||||
|
||||
long_audio = np.zeros(32000, dtype=np.float32) # 2 seconds
|
||||
result = _pad_audio(long_audio, sample_rate=16000)
|
||||
|
||||
assert len(result) == len(long_audio)
|
||||
|
||||
|
||||
# ── Translator Service Tests ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_translator_service_resolve_model():
|
||||
"""Verify model resolution for known and unknown language pairs."""
|
||||
from reflector.processors._marian_translator_service import MarianTranslatorService
|
||||
|
||||
svc = MarianTranslatorService()
|
||||
|
||||
assert svc._resolve_model_name("en", "fr") == "Helsinki-NLP/opus-mt-en-fr"
|
||||
assert svc._resolve_model_name("es", "en") == "Helsinki-NLP/opus-mt-es-en"
|
||||
assert svc._resolve_model_name("en", "de") == "Helsinki-NLP/opus-mt-en-de"
|
||||
# Unknown pair falls back to en->fr
|
||||
assert svc._resolve_model_name("ja", "ko") == "Helsinki-NLP/opus-mt-en-fr"
|
||||
|
||||
|
||||
# ── Diarization Service Tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_diarization_service_singleton():
|
||||
"""Verify diarization_service is a module-level singleton."""
|
||||
from reflector.processors._pyannote_diarization_service import (
|
||||
PyannoteDiarizationService,
|
||||
diarization_service,
|
||||
)
|
||||
|
||||
assert isinstance(diarization_service, PyannoteDiarizationService)
|
||||
assert diarization_service._pipeline is None # Not loaded until first use
|
||||
|
||||
|
||||
def test_translator_service_singleton():
|
||||
"""Verify translator_service is a module-level singleton."""
|
||||
from reflector.processors._marian_translator_service import (
|
||||
MarianTranslatorService,
|
||||
translator_service,
|
||||
)
|
||||
|
||||
assert isinstance(translator_service, MarianTranslatorService)
|
||||
assert translator_service._pipeline is None # Not loaded until first use
|
||||
327
server/tests/test_transcripts_audio_token_auth.py
Normal file
327
server/tests/test_transcripts_audio_token_auth.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Tests for audio mp3 endpoint token query-param authentication.
|
||||
|
||||
Covers both password (HS256) and JWT/Authentik (RS256) auth backends,
|
||||
verifying that private transcripts can be accessed via ?token= query param.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
OWNER_USER_ID = "test-owner-user-id"
|
||||
|
||||
|
||||
def _create_hs256_token(user_id: str, secret: str, expired: bool = False) -> str:
|
||||
"""Create an HS256 JWT like the password auth backend does."""
|
||||
delta = timedelta(minutes=-5) if expired else timedelta(hours=24)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": "test@example.com",
|
||||
"exp": datetime.now(timezone.utc) + delta,
|
||||
}
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
|
||||
def _generate_rsa_keypair():
|
||||
"""Generate a fresh RSA keypair for tests."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return private_key, public_pem.decode()
|
||||
|
||||
|
||||
def _create_rs256_token(
|
||||
authentik_uid: str,
|
||||
private_key,
|
||||
audience: str,
|
||||
expired: bool = False,
|
||||
) -> str:
|
||||
"""Create an RS256 JWT like Authentik would issue."""
|
||||
delta = timedelta(minutes=-5) if expired else timedelta(hours=1)
|
||||
payload = {
|
||||
"sub": authentik_uid,
|
||||
"email": "authentik-user@example.com",
|
||||
"aud": audience,
|
||||
"exp": datetime.now(timezone.utc) + delta,
|
||||
}
|
||||
return jwt.encode(payload, private_key, algorithm="RS256")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def private_transcript(tmpdir):
|
||||
"""Create a private transcript owned by OWNER_USER_ID with an mp3 file.
|
||||
|
||||
Created directly via the controller (not HTTP) so no auth override
|
||||
leaks into the test scope.
|
||||
"""
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
"Private audio test",
|
||||
source_kind=SourceKind.FILE,
|
||||
user_id=OWNER_USER_ID,
|
||||
share_mode="private",
|
||||
)
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
|
||||
# Copy a real mp3 to the expected location
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(mp3_source, audio_filename)
|
||||
|
||||
yield transcript
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core access control tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_private_no_auth_returns_403(private_transcript, client):
|
||||
"""Without auth, accessing a private transcript's audio returns 403."""
|
||||
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_with_bearer_header(private_transcript, client):
|
||||
"""Owner accessing audio via Authorization header works."""
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user_optional
|
||||
|
||||
# Temporarily override to simulate the owner being authenticated
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": OWNER_USER_ID,
|
||||
"email": "test@example.com",
|
||||
}
|
||||
try:
|
||||
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
|
||||
finally:
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_public_transcript_no_auth_ok(tmpdir, client):
|
||||
"""Public transcripts are accessible without any auth."""
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
"Public audio test",
|
||||
source_kind=SourceKind.FILE,
|
||||
user_id=OWNER_USER_ID,
|
||||
share_mode="public",
|
||||
)
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(mp3_source, audio_filename)
|
||||
|
||||
response = await client.get(f"/transcripts/{transcript.id}/audio/mp3")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password auth backend tests (?token= with HS256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_token_query_param(private_transcript, client):
|
||||
"""Password backend: valid HS256 ?token= grants access to private audio."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
from reflector.settings import settings
|
||||
|
||||
token = _create_hs256_token(OWNER_USER_ID, settings.SECRET_KEY)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(sub=OWNER_USER_ID, email="test@example.com")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_expired_token_returns_401(private_transcript, client):
|
||||
"""Password backend: expired HS256 ?token= returns 401."""
|
||||
from reflector.settings import settings
|
||||
|
||||
expired_token = _create_hs256_token(
|
||||
OWNER_USER_ID, settings.SECRET_KEY, expired=True
|
||||
)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_wrong_user_returns_403(private_transcript, client):
|
||||
"""Password backend: valid token for a different user returns 403."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
from reflector.settings import settings
|
||||
|
||||
token = _create_hs256_token("other-user-id", settings.SECRET_KEY)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub="other-user-id", email="other@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_invalid_token_returns_401(private_transcript, client):
|
||||
"""Garbage token string returns 401."""
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = None
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" "?token=not-a-real-token"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT/Authentik auth backend tests (?token= with RS256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_token_query_param(private_transcript, client):
|
||||
"""Authentik backend: valid RS256 ?token= grants access to private audio."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
token = _create_rs256_token("authentik-abc123", private_key, "test-audience")
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
# Authentik flow maps authentik_uid -> internal user id
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub=OWNER_USER_ID, email="authentik-user@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_expired_token_returns_401(
|
||||
private_transcript, client
|
||||
):
|
||||
"""Authentik backend: expired RS256 ?token= returns 401."""
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
expired_token = _create_rs256_token(
|
||||
"authentik-abc123", private_key, "test-audience", expired=True
|
||||
)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_wrong_user_returns_403(private_transcript, client):
|
||||
"""Authentik backend: valid RS256 token for different user returns 403."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
token = _create_rs256_token("authentik-other", private_key, "test-audience")
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub="different-user-id", email="other@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_local_audio_link produces HS256 tokens — must be verifiable
|
||||
# by any auth backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_audio_link_token_works_with_authentik_backend(
|
||||
private_transcript, client
|
||||
):
|
||||
"""_generate_local_audio_link creates an HS256 token via create_access_token.
|
||||
|
||||
When the Authentik (RS256) auth backend is active, verify_raw_token uses
|
||||
JWTAuth which expects RS256 + public key. The HS256 token created by
|
||||
_generate_local_audio_link will fail verification, returning 401.
|
||||
|
||||
This test documents the bug: the internal audio URL generated for the
|
||||
diarization pipeline is unusable under the JWT auth backend.
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
# Generate the internal audio link (uses create_access_token → HS256)
|
||||
url = private_transcript._generate_local_audio_link()
|
||||
parsed = urlparse(url)
|
||||
token = parse_qs(parsed.query)["token"][0]
|
||||
|
||||
# Simulate what happens when the JWT/Authentik backend tries to verify
|
||||
# this HS256 token: JWTAuth.verify_token expects RS256, so it raises.
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.exceptions.InvalidAlgorithmError(
|
||||
"the specified alg value is not allowed"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
# BUG: this should be 200 (the token was created by our own server),
|
||||
# but the Authentik backend rejects it because it's HS256, not RS256.
|
||||
assert response.status_code == 200
|
||||
840
server/uv.lock
generated
840
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -78,7 +78,9 @@ const useMp3 = (transcriptId: string, waiting?: boolean): Mp3Response => {
|
||||
|
||||
// Audio is not deleted, proceed to load it
|
||||
audioElement = document.createElement("audio");
|
||||
const audioUrl = `${API_URL}/v1/transcripts/${transcriptId}/audio/mp3`;
|
||||
const audioUrl = accessTokenInfo
|
||||
? `${API_URL}/v1/transcripts/${transcriptId}/audio/mp3?token=${encodeURIComponent(accessTokenInfo)}`
|
||||
: `${API_URL}/v1/transcripts/${transcriptId}/audio/mp3`;
|
||||
audioElement.src = audioUrl;
|
||||
audioElement.crossOrigin = "anonymous";
|
||||
audioElement.preload = "auto";
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
let authToken = null;
|
||||
|
||||
self.addEventListener("install", () => {
|
||||
self.skipWaiting();
|
||||
});
|
||||
|
||||
self.addEventListener("activate", (event) => {
|
||||
event.waitUntil(self.clients.claim());
|
||||
});
|
||||
|
||||
self.addEventListener("message", (event) => {
|
||||
if (event.data && event.data.type === "SET_AUTH_TOKEN") {
|
||||
authToken = event.data.token;
|
||||
@@ -7,8 +15,8 @@ self.addEventListener("message", (event) => {
|
||||
});
|
||||
|
||||
self.addEventListener("fetch", function (event) {
|
||||
// Check if the request is for a media file
|
||||
if (/\/v1\/transcripts\/.*\/audio\/mp3$/.test(event.request.url)) {
|
||||
// Check if the request is for a media file (allow optional query params)
|
||||
if (/\/v1\/transcripts\/.*\/audio\/mp3(\?|$)/.test(event.request.url)) {
|
||||
// Modify the request to add the Authorization header
|
||||
const modifiedHeaders = new Headers(event.request.headers);
|
||||
if (authToken) {
|
||||
|
||||
Reference in New Issue
Block a user