mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-22 07:06:47 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb1beae90d | ||
|
|
1e396ca0ca | ||
|
|
9a2f973a2e | ||
|
|
a9200d35bf | ||
|
|
5646319e96 | ||
|
|
d0472ebf5f | ||
|
|
628a6d735c | ||
|
|
37a1f01850 | ||
|
|
72dca7cacc | ||
|
|
4ae56b730a | ||
|
|
cf6e867cf1 | ||
|
|
183601a121 | ||
|
|
b53c8da398 | ||
|
|
22a50bb94d | ||
|
|
504ca74184 | ||
|
|
a455b8090a | ||
|
|
6b0292d5f0 | ||
|
|
304315daaf | ||
|
|
7845f679c3 | ||
|
|
c155f66982 | ||
|
|
a682846645 | ||
|
|
4235ab4293 | ||
|
|
f5ec2d28cf | ||
|
|
ac46c60a7c | ||
|
|
1d1a520be9 | ||
|
|
9e64d52461 | ||
|
|
0931095f49 | ||
|
|
4d915e2a9f | ||
|
|
045eae8ff2 | ||
|
|
f6cc03286b |
139
.github/workflows/integration_tests.yml
vendored
Normal file
139
.github/workflows/integration_tests.yml
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
name: Integration Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llm_model:
|
||||
description: "LLM model name (overrides LLM_MODEL secret)"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
integration:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Start infrastructure services
|
||||
working-directory: server/tests
|
||||
env:
|
||||
LLM_URL: ${{ secrets.LLM_URL }}
|
||||
LLM_MODEL: ${{ inputs.llm_model || secrets.LLM_MODEL }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
docker compose -f docker-compose.integration.yml up -d --build postgres redis garage hatchet mock-daily
|
||||
|
||||
- name: Set up Garage bucket and keys
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
GARAGE="docker compose -f docker-compose.integration.yml exec -T garage /garage"
|
||||
GARAGE_KEY_ID="GK0123456789abcdef01234567" # gitleaks:allow
|
||||
GARAGE_KEY_SECRET="0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" # gitleaks:allow
|
||||
|
||||
echo "Waiting for Garage to be healthy..."
|
||||
for i in $(seq 1 60); do
|
||||
if $GARAGE stats &>/dev/null; then break; fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "Setting up Garage..."
|
||||
NODE_ID=$($GARAGE node id -q 2>&1 | tr -d '[:space:]')
|
||||
LAYOUT_STATUS=$($GARAGE layout show 2>&1 || true)
|
||||
if echo "$LAYOUT_STATUS" | grep -q "No nodes"; then
|
||||
$GARAGE layout assign "$NODE_ID" -c 1G -z dc1
|
||||
$GARAGE layout apply --version 1
|
||||
fi
|
||||
|
||||
$GARAGE bucket info reflector-media &>/dev/null || $GARAGE bucket create reflector-media
|
||||
if ! $GARAGE key info reflector-test &>/dev/null; then
|
||||
$GARAGE key import --yes "$GARAGE_KEY_ID" "$GARAGE_KEY_SECRET"
|
||||
$GARAGE key rename "$GARAGE_KEY_ID" reflector-test
|
||||
fi
|
||||
$GARAGE bucket allow reflector-media --read --write --key reflector-test
|
||||
|
||||
- name: Wait for Hatchet and generate API token
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
echo "Waiting for Hatchet to be healthy..."
|
||||
for i in $(seq 1 90); do
|
||||
if docker compose -f docker-compose.integration.yml exec -T hatchet curl -sf http://localhost:8888/api/live &>/dev/null; then
|
||||
echo "Hatchet is ready."
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "Generating Hatchet API token..."
|
||||
HATCHET_OUTPUT=$(docker compose -f docker-compose.integration.yml exec -T hatchet \
|
||||
/hatchet-admin token create --config /config --name integration-test 2>&1)
|
||||
HATCHET_TOKEN=$(echo "$HATCHET_OUTPUT" | grep -o 'eyJ[A-Za-z0-9_.\-]*')
|
||||
if [ -z "$HATCHET_TOKEN" ]; then
|
||||
echo "ERROR: Failed to extract Hatchet JWT token"
|
||||
exit 1
|
||||
fi
|
||||
echo "HATCHET_CLIENT_TOKEN=${HATCHET_TOKEN}" >> $GITHUB_ENV
|
||||
|
||||
- name: Start backend services
|
||||
working-directory: server/tests
|
||||
env:
|
||||
LLM_URL: ${{ secrets.LLM_URL }}
|
||||
LLM_MODEL: ${{ inputs.llm_model || secrets.LLM_MODEL }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
# Export garage and hatchet credentials for backend services
|
||||
export GARAGE_KEY_ID="${{ env.GARAGE_KEY_ID }}"
|
||||
export GARAGE_KEY_SECRET="${{ env.GARAGE_KEY_SECRET }}"
|
||||
export HATCHET_CLIENT_TOKEN="${{ env.HATCHET_CLIENT_TOKEN }}"
|
||||
|
||||
docker compose -f docker-compose.integration.yml up -d \
|
||||
server worker hatchet-worker-cpu hatchet-worker-llm test-runner
|
||||
|
||||
- name: Wait for server health check
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
echo "Waiting for server to be healthy..."
|
||||
for i in $(seq 1 60); do
|
||||
if docker compose -f docker-compose.integration.yml exec -T test-runner \
|
||||
curl -sf http://server:1250/health &>/dev/null; then
|
||||
echo "Server is ready."
|
||||
break
|
||||
fi
|
||||
sleep 3
|
||||
done
|
||||
|
||||
- name: Run DB migrations
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
docker compose -f docker-compose.integration.yml exec -T server \
|
||||
uv run alembic upgrade head
|
||||
|
||||
- name: Run integration tests
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
docker compose -f docker-compose.integration.yml exec -T test-runner \
|
||||
uv run pytest tests/integration/ -v -x
|
||||
|
||||
- name: Collect logs on failure
|
||||
if: failure()
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
docker compose -f docker-compose.integration.yml logs --tail=500 > integration-logs.txt 2>&1
|
||||
|
||||
- name: Upload logs artifact
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: integration-logs
|
||||
path: server/tests/integration-logs.txt
|
||||
retention-days: 7
|
||||
|
||||
- name: Teardown
|
||||
if: always()
|
||||
working-directory: server/tests
|
||||
run: |
|
||||
docker compose -f docker-compose.integration.yml down -v --remove-orphans
|
||||
36
.github/workflows/selfhost-script.yml
vendored
Normal file
36
.github/workflows/selfhost-script.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
# Validates the self-hosted setup script: runs with --cpu and --garage,
|
||||
# brings up services, runs health checks, then tears down.
|
||||
name: Selfhost script (CPU + Garage)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request: {}
|
||||
|
||||
jobs:
|
||||
selfhost-cpu-garage:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 25
|
||||
concurrency:
|
||||
group: selfhost-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Run setup-selfhosted.sh (CPU + Garage)
|
||||
run: |
|
||||
./scripts/setup-selfhosted.sh --cpu --garage
|
||||
|
||||
- name: Quick health checks
|
||||
run: |
|
||||
curl -sf http://localhost:1250/health && echo " Server OK"
|
||||
curl -sf http://localhost:3000 > /dev/null && echo " Frontend OK"
|
||||
curl -sf http://localhost:3903/metrics > /dev/null && echo " Garage admin OK"
|
||||
|
||||
- name: Teardown
|
||||
if: always()
|
||||
run: |
|
||||
docker compose -f docker-compose.selfhosted.yml --profile cpu --profile garage down -v --remove-orphans 2>/dev/null || true
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,6 +3,7 @@ server/.env
|
||||
server/.env.production
|
||||
.env
|
||||
Caddyfile
|
||||
.env.hatchet
|
||||
server/exportdanswer
|
||||
.vercel
|
||||
.env*.local
|
||||
@@ -20,8 +21,8 @@ CLAUDE.local.md
|
||||
www/.env.development
|
||||
www/.env.production
|
||||
.playwright-mcp
|
||||
docs/pnpm-lock.yaml
|
||||
.secrets
|
||||
opencode.json
|
||||
|
||||
vibedocs/
|
||||
server/tests/integration/logs/
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
exclude: '(^uv\.lock$|pnpm-lock\.yaml$)'
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
55
CHANGELOG.md
55
CHANGELOG.md
@@ -1,5 +1,60 @@
|
||||
# Changelog
|
||||
|
||||
## [0.39.0](https://github.com/GreyhavenHQ/reflector/compare/v0.38.2...v0.39.0) (2026-03-18)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* migrate file and live post-processing pipelines from Celery to Hatchet workflow engine ([#911](https://github.com/GreyhavenHQ/reflector/issues/911)) ([37a1f01](https://github.com/GreyhavenHQ/reflector/commit/37a1f0185057dd43b68df2b12bb08d3b18e28d34))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* integration tests runner in CI ([#919](https://github.com/GreyhavenHQ/reflector/issues/919)) ([1e396ca](https://github.com/GreyhavenHQ/reflector/commit/1e396ca0ca91bc9d2645ddfc63a1576469491faa))
|
||||
* latest vulns ([#915](https://github.com/GreyhavenHQ/reflector/issues/915)) ([a9200d3](https://github.com/GreyhavenHQ/reflector/commit/a9200d35bf856f65f24a4f34931ebe0d75ad0382))
|
||||
|
||||
## [0.38.2](https://github.com/GreyhavenHQ/reflector/compare/v0.38.1...v0.38.2) (2026-03-12)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* add auth guards to prevent anonymous access to write endpoints in non-public mode ([#907](https://github.com/GreyhavenHQ/reflector/issues/907)) ([cf6e867](https://github.com/GreyhavenHQ/reflector/commit/cf6e867cf12c42411e5a7412f6ec44eee8351665))
|
||||
* add tests that check some of the issues are already fixed ([#905](https://github.com/GreyhavenHQ/reflector/issues/905)) ([b53c8da](https://github.com/GreyhavenHQ/reflector/commit/b53c8da3981c394bdab08504b45d25f62c35495a))
|
||||
|
||||
## [0.38.1](https://github.com/GreyhavenHQ/reflector/compare/v0.38.0...v0.38.1) (2026-03-06)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* pin hatchet sdk version ([#903](https://github.com/GreyhavenHQ/reflector/issues/903)) ([504ca74](https://github.com/GreyhavenHQ/reflector/commit/504ca74184211eda9020d0b38ba7bd2b55d09991))
|
||||
|
||||
## [0.38.0](https://github.com/GreyhavenHQ/reflector/compare/v0.37.0...v0.38.0) (2026-03-06)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* 3-mode selfhosted refactoring (--gpu, --cpu, --hosted) + audio token auth fallback ([#896](https://github.com/GreyhavenHQ/reflector/issues/896)) ([a682846](https://github.com/GreyhavenHQ/reflector/commit/a6828466456407c808302e9eb8dc4b4f0614dd6f))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* improve hatchet workflow reliability ([#900](https://github.com/GreyhavenHQ/reflector/issues/900)) ([c155f66](https://github.com/GreyhavenHQ/reflector/commit/c155f669825e8e2a6e929821a1ef0bd94237dc11))
|
||||
|
||||
## [0.37.0](https://github.com/GreyhavenHQ/reflector/compare/v0.36.0...v0.37.0) (2026-03-03)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* enable daily co in selfhosted + only schedule tasks when necessary ([#883](https://github.com/GreyhavenHQ/reflector/issues/883)) ([045eae8](https://github.com/GreyhavenHQ/reflector/commit/045eae8ff2014a7b83061045e3c8cb25cce9d60a))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* aws storage construction ([#895](https://github.com/GreyhavenHQ/reflector/issues/895)) ([f5ec2d2](https://github.com/GreyhavenHQ/reflector/commit/f5ec2d28cfa2de9b2b4aeec81966737b740689c2))
|
||||
* remaining dependabot security issues ([#890](https://github.com/GreyhavenHQ/reflector/issues/890)) ([0931095](https://github.com/GreyhavenHQ/reflector/commit/0931095f49e61216e651025ce92be460e6a9df9e))
|
||||
* test selfhosted script ([#892](https://github.com/GreyhavenHQ/reflector/issues/892)) ([4d915e2](https://github.com/GreyhavenHQ/reflector/commit/4d915e2a9fe9f05f31cbd0018d9c2580daf7854f))
|
||||
* upgrade to nextjs 16 ([#888](https://github.com/GreyhavenHQ/reflector/issues/888)) ([f6cc032](https://github.com/GreyhavenHQ/reflector/commit/f6cc03286baf3e3a115afd3b22ae993ad7a4b7e3))
|
||||
|
||||
## [0.35.1](https://github.com/GreyhavenHQ/reflector/compare/v0.35.0...v0.35.1) (2026-02-25)
|
||||
|
||||
|
||||
|
||||
17
CLAUDE.md
17
CLAUDE.md
@@ -6,7 +6,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
Reflector is an AI-powered audio transcription and meeting analysis platform with real-time processing capabilities. The system consists of:
|
||||
|
||||
- **Frontend**: Next.js 14 React application (`www/`) with Chakra UI, real-time WebSocket integration
|
||||
- **Frontend**: Next.js 16 React application (`www/`) with Chakra UI, real-time WebSocket integration
|
||||
- **Backend**: Python FastAPI server (`server/`) with async database operations and background processing
|
||||
- **Processing**: GPU-accelerated ML pipeline for transcription, diarization, summarization via Modal.com
|
||||
- **Infrastructure**: Redis, PostgreSQL/SQLite, Celery workers, WebRTC streaming
|
||||
@@ -160,6 +160,21 @@ All endpoints prefixed `/v1/`:
|
||||
- **Frontend**: No current test suite - opportunities for Jest/React Testing Library
|
||||
- **Coverage**: Backend maintains test coverage reports in `htmlcov/`
|
||||
|
||||
### Integration Tests (DO NOT run unless explicitly asked)
|
||||
|
||||
There are end-to-end integration tests in `server/tests/integration/` that spin up the full stack (PostgreSQL, Redis, Hatchet, Garage, mock-daily, server, workers) via Docker Compose and exercise real processing pipelines. These tests are:
|
||||
|
||||
- `test_file_pipeline.py` — File upload → FilePipeline
|
||||
- `test_live_pipeline.py` — WebRTC stream → LivePostPipeline
|
||||
- `test_multitrack_pipeline.py` — Multitrack → DailyMultitrackPipeline
|
||||
|
||||
**Important:**
|
||||
- These tests are **excluded** from normal `uv run pytest` runs via `--ignore=tests/integration` in pyproject.toml.
|
||||
- Do **NOT** run them as part of verification, code review, or general testing unless the user explicitly asks.
|
||||
- They require Docker, external LLM credentials, and HuggingFace token — they cannot run in a regular test environment.
|
||||
- To run locally: `./scripts/run-integration-tests.sh` (requires env vars: `LLM_URL`, `LLM_API_KEY`, `HF_TOKEN`).
|
||||
- In CI: triggered manually via the "Integration Tests" GitHub Actions workflow (`workflow_dispatch`).
|
||||
|
||||
## GPU Processing
|
||||
|
||||
Modal.com integration for scalable ML processing:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
# Self-hosted production Docker Compose — single file for everything.
|
||||
#
|
||||
# Usage: ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
# or: docker compose -f docker-compose.selfhosted.yml --profile gpu [--profile ollama-gpu] [--profile garage] [--profile caddy] up -d
|
||||
# Usage: ./scripts/setup-selfhosted.sh <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--garage] [--caddy]
|
||||
# or: docker compose -f docker-compose.selfhosted.yml [--profile gpu] [--profile ollama-gpu] [--profile garage] [--profile caddy] up -d
|
||||
#
|
||||
# Specialized models (pick ONE — required):
|
||||
# --profile gpu NVIDIA GPU for transcription/diarization/translation
|
||||
# --profile cpu CPU-only for transcription/diarization/translation
|
||||
# ML processing modes (pick ONE — required):
|
||||
# --gpu NVIDIA GPU container for transcription/diarization/translation (profile: gpu)
|
||||
# --cpu In-process CPU processing on server/worker (no ML container needed)
|
||||
# --hosted Remote GPU service URL (no ML container needed)
|
||||
#
|
||||
# Local LLM (optional — for summarization/topics):
|
||||
# --profile ollama-gpu Local Ollama with NVIDIA GPU
|
||||
# --profile ollama-cpu Local Ollama on CPU only
|
||||
#
|
||||
# Daily.co multitrack processing (auto-detected from server/.env):
|
||||
# --profile dailyco Hatchet workflow engine + CPU/LLM workers
|
||||
#
|
||||
# Other optional services:
|
||||
# --profile garage Local S3-compatible storage (Garage)
|
||||
# --profile caddy Reverse proxy with auto-SSL
|
||||
@@ -32,7 +36,7 @@ services:
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "127.0.0.1:1250:1250"
|
||||
- "50000-50100:50000-50100/udp"
|
||||
- "51000-51100:51000-51100/udp"
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -42,18 +46,14 @@ services:
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
HATCHET_CLIENT_SERVER_URL: ""
|
||||
HATCHET_CLIENT_HOST_PORT: ""
|
||||
# Specialized models via gpu/cpu container (aliased as "transcription")
|
||||
TRANSCRIPT_BACKEND: modal
|
||||
TRANSCRIPT_URL: http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY: selfhosted
|
||||
DIARIZATION_BACKEND: modal
|
||||
DIARIZATION_URL: http://transcription:8000
|
||||
TRANSLATION_BACKEND: modal
|
||||
TRANSLATE_URL: http://transcription:8000
|
||||
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
||||
# HF_TOKEN needed for in-process pyannote diarization (--cpu mode)
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
# WebRTC: fixed UDP port range for ICE candidates (mapped above)
|
||||
WEBRTC_PORT_RANGE: "50000-50100"
|
||||
WEBRTC_PORT_RANGE: "51000-51100"
|
||||
# Hatchet workflow engine (always-on for processing pipelines)
|
||||
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
|
||||
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -76,15 +76,11 @@ services:
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
HATCHET_CLIENT_SERVER_URL: ""
|
||||
HATCHET_CLIENT_HOST_PORT: ""
|
||||
TRANSCRIPT_BACKEND: modal
|
||||
TRANSCRIPT_URL: http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY: selfhosted
|
||||
DIARIZATION_BACKEND: modal
|
||||
DIARIZATION_URL: http://transcription:8000
|
||||
TRANSLATION_BACKEND: modal
|
||||
TRANSLATE_URL: http://transcription:8000
|
||||
# ML backend config comes from env_file (server/.env), set per-mode by setup script
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
# Hatchet workflow engine (always-on for processing pipelines)
|
||||
HATCHET_CLIENT_SERVER_URL: ${HATCHET_CLIENT_SERVER_URL:-http://hatchet:8888}
|
||||
HATCHET_CLIENT_HOST_PORT: ${HATCHET_CLIENT_HOST_PORT:-hatchet:7077}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -136,6 +132,8 @@ services:
|
||||
redis:
|
||||
image: redis:7.2-alpine
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 30s
|
||||
@@ -147,12 +145,14 @@ services:
|
||||
postgres:
|
||||
image: postgres:17-alpine
|
||||
restart: unless-stopped
|
||||
command: ["postgres", "-c", "max_connections=200"]
|
||||
environment:
|
||||
POSTGRES_USER: reflector
|
||||
POSTGRES_PASSWORD: reflector
|
||||
POSTGRES_DB: reflector
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./server/docker/init-hatchet-db.sql:/docker-entrypoint-initdb.d/init-hatchet-db.sql:ro
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U reflector"]
|
||||
interval: 30s
|
||||
@@ -161,7 +161,10 @@ services:
|
||||
|
||||
# ===========================================================
|
||||
# Specialized model containers (transcription, diarization, translation)
|
||||
# Both gpu and cpu get alias "transcription" so server config never changes.
|
||||
# Only the gpu profile is activated by the setup script (--gpu mode).
|
||||
# The cpu service definition is kept for manual/standalone use but is
|
||||
# NOT activated by --cpu mode (which uses in-process local backends).
|
||||
# Both services get alias "transcription" so server config never changes.
|
||||
# ===========================================================
|
||||
|
||||
gpu:
|
||||
@@ -305,6 +308,86 @@ services:
|
||||
- web
|
||||
- server
|
||||
|
||||
# ===========================================================
|
||||
# Hatchet workflow engine + workers
|
||||
# Required for all processing pipelines (file, live, Daily.co multitrack).
|
||||
# Always-on — every selfhosted deployment needs Hatchet.
|
||||
# ===========================================================
|
||||
|
||||
hatchet:
|
||||
image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest
|
||||
restart: on-failure
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- "127.0.0.1:8888:8888"
|
||||
- "127.0.0.1:7078:7077"
|
||||
env_file:
|
||||
- ./.env.hatchet
|
||||
environment:
|
||||
DATABASE_URL: "postgresql://reflector:reflector@postgres:5432/hatchet?sslmode=disable&connect_timeout=30"
|
||||
SERVER_AUTH_COOKIE_INSECURE: "t"
|
||||
SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
|
||||
SERVER_GRPC_INSECURE: "t"
|
||||
SERVER_GRPC_BROADCAST_ADDRESS: hatchet:7077
|
||||
SERVER_GRPC_PORT: "7077"
|
||||
SERVER_AUTH_SET_EMAIL_VERIFIED: "t"
|
||||
SERVER_INTERNAL_CLIENT_INTERNAL_GRPC_BROADCAST_ADDRESS: hatchet:7077
|
||||
volumes:
|
||||
- hatchet_config:/config
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/live"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
|
||||
hatchet-worker-cpu:
|
||||
build:
|
||||
context: ./server
|
||||
dockerfile: Dockerfile
|
||||
image: monadicalsas/reflector-backend:latest
|
||||
profiles: [dailyco]
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
ENTRYPOINT: hatchet-worker-cpu
|
||||
DATABASE_URL: postgresql+asyncpg://reflector:reflector@postgres:5432/reflector
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
HATCHET_CLIENT_SERVER_URL: http://hatchet:8888
|
||||
HATCHET_CLIENT_HOST_PORT: hatchet:7077
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
hatchet-worker-llm:
|
||||
build:
|
||||
context: ./server
|
||||
dockerfile: Dockerfile
|
||||
image: monadicalsas/reflector-backend:latest
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
ENTRYPOINT: hatchet-worker-llm
|
||||
DATABASE_URL: postgresql+asyncpg://reflector:reflector@postgres:5432/reflector
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
HATCHET_CLIENT_SERVER_URL: http://hatchet:8888
|
||||
HATCHET_CLIENT_HOST_PORT: hatchet:7077
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
@@ -315,6 +398,7 @@ volumes:
|
||||
ollama_data:
|
||||
caddy_data:
|
||||
caddy_config:
|
||||
hatchet_config:
|
||||
|
||||
networks:
|
||||
default:
|
||||
|
||||
@@ -93,6 +93,7 @@ services:
|
||||
environment:
|
||||
NODE_ENV: development
|
||||
SERVER_API_URL: http://host.docker.internal:1250
|
||||
KV_URL: redis://redis:6379
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
depends_on:
|
||||
|
||||
7
docs/.dockerignore
Normal file
7
docs/.dockerignore
Normal file
@@ -0,0 +1,7 @@
|
||||
node_modules
|
||||
build
|
||||
.git
|
||||
.gitignore
|
||||
*.log
|
||||
.DS_Store
|
||||
.env*
|
||||
@@ -1,14 +1,17 @@
|
||||
FROM node:18-alpine AS builder
|
||||
FROM node:20-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install curl for fetching OpenAPI spec
|
||||
RUN apk add --no-cache curl
|
||||
|
||||
# Copy package files
|
||||
COPY package*.json ./
|
||||
# Enable pnpm
|
||||
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||
|
||||
# Copy package files and lockfile
|
||||
COPY package.json pnpm-lock.yaml* ./
|
||||
|
||||
# Install dependencies
|
||||
RUN npm ci
|
||||
RUN pnpm install --frozen-lockfile
|
||||
|
||||
# Copy source
|
||||
COPY . .
|
||||
@@ -21,7 +24,7 @@ RUN mkdir -p ./static && curl -sf "${OPENAPI_URL}" -o ./static/openapi.json || e
|
||||
RUN sed -i "s/onBrokenLinks: 'throw'/onBrokenLinks: 'warn'/g" docusaurus.config.ts
|
||||
|
||||
# Build static site (skip prebuild hook by calling docusaurus directly)
|
||||
RUN npx docusaurus build
|
||||
RUN pnpm exec docusaurus build
|
||||
|
||||
# Production image
|
||||
FROM nginx:alpine
|
||||
|
||||
@@ -5,13 +5,13 @@ This website is built using [Docusaurus](https://docusaurus.io/), a modern stati
|
||||
### Installation
|
||||
|
||||
```
|
||||
$ yarn
|
||||
$ pnpm install
|
||||
```
|
||||
|
||||
### Local Development
|
||||
|
||||
```
|
||||
$ yarn start
|
||||
$ pnpm start
|
||||
```
|
||||
|
||||
This command starts a local development server and opens up a browser window. Most changes are reflected live without having to restart the server.
|
||||
@@ -19,7 +19,7 @@ This command starts a local development server and opens up a browser window. Mo
|
||||
### Build
|
||||
|
||||
```
|
||||
$ yarn build
|
||||
$ pnpm build
|
||||
```
|
||||
|
||||
This command generates static content into the `build` directory and can be served using any static contents hosting service.
|
||||
@@ -29,13 +29,13 @@ This command generates static content into the `build` directory and can be serv
|
||||
Using SSH:
|
||||
|
||||
```
|
||||
$ USE_SSH=true yarn deploy
|
||||
$ USE_SSH=true pnpm deploy
|
||||
```
|
||||
|
||||
Not using SSH:
|
||||
|
||||
```
|
||||
$ GIT_USER=<Your GitHub username> yarn deploy
|
||||
$ GIT_USER=<Your GitHub username> pnpm deploy
|
||||
```
|
||||
|
||||
If you are using GitHub pages for hosting, this command is a convenient way to build the website and push to the `gh-pages` branch.
|
||||
|
||||
@@ -254,15 +254,15 @@ Reflector can run completely offline:
|
||||
Control where each step happens:
|
||||
|
||||
```yaml
|
||||
# All local processing
|
||||
TRANSCRIPT_BACKEND=local
|
||||
DIARIZATION_BACKEND=local
|
||||
TRANSLATION_BACKEND=local
|
||||
# All in-process processing
|
||||
TRANSCRIPT_BACKEND=whisper
|
||||
DIARIZATION_BACKEND=pyannote
|
||||
TRANSLATION_BACKEND=marian
|
||||
|
||||
# Hybrid approach
|
||||
TRANSCRIPT_BACKEND=modal # Fast GPU processing
|
||||
DIARIZATION_BACKEND=local # Sensitive speaker data
|
||||
TRANSLATION_BACKEND=modal # Non-sensitive translation
|
||||
TRANSCRIPT_BACKEND=modal # Fast GPU processing
|
||||
DIARIZATION_BACKEND=pyannote # Sensitive speaker data
|
||||
TRANSLATION_BACKEND=modal # Non-sensitive translation
|
||||
```
|
||||
|
||||
### Storage Options
|
||||
|
||||
@@ -11,7 +11,7 @@ Reflector is built as a modern, scalable, microservices-based application design
|
||||
|
||||
### Frontend Application
|
||||
|
||||
The user interface is built with **Next.js 15** using the App Router pattern, providing:
|
||||
The user interface is built with **Next.js 16** using the App Router pattern, providing:
|
||||
|
||||
- Server-side rendering for optimal performance
|
||||
- Real-time WebSocket connections for live transcription
|
||||
|
||||
@@ -36,14 +36,15 @@ This creates `docs/static/openapi.json` (should be ~70KB) which will be copied d
|
||||
The Dockerfile is already in `docs/Dockerfile`:
|
||||
|
||||
```dockerfile
|
||||
FROM node:18-alpine AS builder
|
||||
FROM node:20-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Copy package files
|
||||
COPY package*.json ./
|
||||
# Enable pnpm and copy package files + lockfile
|
||||
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||
COPY package.json pnpm-lock.yaml* ./
|
||||
|
||||
# Inshall dependencies
|
||||
RUN npm ci
|
||||
# Install dependencies
|
||||
RUN pnpm install --frozen-lockfile
|
||||
|
||||
# Copy source (includes static/openapi.json if pre-fetched)
|
||||
COPY . .
|
||||
@@ -52,7 +53,7 @@ COPY . .
|
||||
RUN sed -i "s/onBrokenLinks: 'throw'/onBrokenLinks: 'warn'/g" docusaurus.config.ts
|
||||
|
||||
# Build static site
|
||||
RUN npx docusaurus build
|
||||
RUN pnpm exec docusaurus build
|
||||
|
||||
FROM nginx:alpine
|
||||
COPY --from=builder /app/build /usr/share/nginx/html
|
||||
|
||||
@@ -46,7 +46,7 @@ Reflector consists of three main components:
|
||||
|
||||
Ready to deploy Reflector? Head over to our [Installation Guide](./installation/overview) to set up your own instance.
|
||||
|
||||
For a quick overview of how Reflector processes audio, check out our [Pipeline Documentation](./pipelines/overview).
|
||||
For a quick overview of how Reflector processes audio, check out our [Pipeline Documentation](./concepts/pipeline).
|
||||
|
||||
## Open Source
|
||||
|
||||
|
||||
@@ -124,11 +124,11 @@ const config: Config = {
|
||||
items: [
|
||||
{
|
||||
label: 'Architecture',
|
||||
to: '/docs/reference/architecture/overview',
|
||||
to: '/docs/concepts/overview',
|
||||
},
|
||||
{
|
||||
label: 'Pipelines',
|
||||
to: '/docs/pipelines/overview',
|
||||
to: '/docs/concepts/pipeline',
|
||||
},
|
||||
{
|
||||
label: 'Roadmap',
|
||||
|
||||
23526
docs/package-lock.json
generated
23526
docs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -14,26 +14,26 @@
|
||||
"write-heading-ids": "docusaurus write-heading-ids",
|
||||
"typecheck": "tsc",
|
||||
"fetch-openapi": "./scripts/fetch-openapi.sh",
|
||||
"gen-api-docs": "npm run fetch-openapi && docusaurus gen-api-docs reflector",
|
||||
"prebuild": "npm run fetch-openapi"
|
||||
"gen-api-docs": "pnpm run fetch-openapi && docusaurus gen-api-docs reflector",
|
||||
"prebuild": "pnpm run fetch-openapi"
|
||||
},
|
||||
"dependencies": {
|
||||
"@docusaurus/core": "3.6.3",
|
||||
"@docusaurus/preset-classic": "3.6.3",
|
||||
"@mdx-js/react": "^3.0.0",
|
||||
"clsx": "^2.0.0",
|
||||
"docusaurus-plugin-openapi-docs": "^4.5.1",
|
||||
"docusaurus-theme-openapi-docs": "^4.5.1",
|
||||
"@docusaurus/theme-mermaid": "3.6.3",
|
||||
"prism-react-renderer": "^2.3.0",
|
||||
"react": "^18.0.0",
|
||||
"react-dom": "^18.0.0"
|
||||
"@docusaurus/core": "3.9.2",
|
||||
"@docusaurus/preset-classic": "3.9.2",
|
||||
"@docusaurus/theme-mermaid": "3.9.2",
|
||||
"@mdx-js/react": "^3.1.1",
|
||||
"clsx": "^2.1.1",
|
||||
"docusaurus-plugin-openapi-docs": "^4.7.1",
|
||||
"docusaurus-theme-openapi-docs": "^4.7.1",
|
||||
"prism-react-renderer": "^2.4.1",
|
||||
"react": "^19.2.4",
|
||||
"react-dom": "^19.2.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@docusaurus/module-type-aliases": "3.6.3",
|
||||
"@docusaurus/tsconfig": "3.6.3",
|
||||
"@docusaurus/types": "3.6.3",
|
||||
"typescript": "~5.6.2"
|
||||
"@docusaurus/module-type-aliases": "3.9.2",
|
||||
"@docusaurus/tsconfig": "3.9.2",
|
||||
"@docusaurus/types": "3.9.2",
|
||||
"typescript": "~5.9.3"
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
@@ -49,5 +49,16 @@
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0"
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"minimatch@<3.1.4": "3.1.5",
|
||||
"minimatch@>=5.0.0 <5.1.8": "5.1.8",
|
||||
"minimatch@>=9.0.0 <9.0.7": "9.0.7",
|
||||
"lodash@<4.17.23": "4.17.23",
|
||||
"js-yaml@<4.1.1": "4.1.1",
|
||||
"gray-matter": "github:jonschlinkert/gray-matter#234163e",
|
||||
"serialize-javascript": "7.0.4"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
13976
docs/pnpm-lock.yaml
generated
Normal file
13976
docs/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
4151
docs/static/openapi.json
vendored
4151
docs/static/openapi.json
vendored
File diff suppressed because it is too large
Load Diff
@@ -53,9 +53,12 @@ cd reflector
|
||||
# Same but without a domain (self-signed cert, access via IP):
|
||||
./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
|
||||
# CPU-only (same, but slower):
|
||||
# CPU-only (in-process ML, no GPU container):
|
||||
./scripts/setup-selfhosted.sh --cpu --ollama-cpu --garage --caddy
|
||||
|
||||
# Remote GPU service (your own hosted GPU, no local ML container):
|
||||
./scripts/setup-selfhosted.sh --hosted --garage --caddy
|
||||
|
||||
# With password authentication (single admin user):
|
||||
./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy --password mysecretpass
|
||||
|
||||
@@ -65,14 +68,15 @@ cd reflector
|
||||
|
||||
That's it. The script generates env files, secrets, starts all containers, waits for health checks, and prints the URL.
|
||||
|
||||
## Specialized Models (Required)
|
||||
## ML Processing Modes (Required)
|
||||
|
||||
Pick `--gpu` or `--cpu`. This determines how **transcription, diarization, and translation** run:
|
||||
Pick `--gpu`, `--cpu`, or `--hosted`. This determines how **transcription, diarization, translation, and audio padding** run:
|
||||
|
||||
| Flag | What it does | Requires |
|
||||
|------|-------------|----------|
|
||||
| `--gpu` | NVIDIA GPU acceleration for ML models | NVIDIA GPU + drivers + `nvidia-container-toolkit` |
|
||||
| `--cpu` | CPU-only (slower but works without GPU) | 8+ cores, 32GB+ RAM recommended |
|
||||
| `--gpu` | NVIDIA GPU container for ML models | NVIDIA GPU + drivers + `nvidia-container-toolkit` |
|
||||
| `--cpu` | In-process CPU processing on server/worker (no ML container) | 8+ cores, 16GB+ RAM (32GB recommended for large files) |
|
||||
| `--hosted` | Remote GPU service URL (no local ML container) | A running GPU service instance (e.g. `gpu/self_hosted/`) |
|
||||
|
||||
## Local LLM (Optional)
|
||||
|
||||
@@ -130,9 +134,11 @@ Browse all available models at https://ollama.com/library.
|
||||
|
||||
- **`--gpu --ollama-gpu`**: Best for servers with NVIDIA GPU. Fully self-contained, no external API keys needed.
|
||||
- **`--cpu --ollama-cpu`**: No GPU available but want everything self-contained. Slower but works.
|
||||
- **`--hosted --ollama-cpu`**: Remote GPU for ML, local CPU for LLM. Great when you have a separate GPU server.
|
||||
- **`--gpu --ollama-cpu`**: GPU for transcription, CPU for LLM. Saves GPU VRAM for ML models.
|
||||
- **`--gpu`**: Have NVIDIA GPU but prefer a cloud LLM (faster/better summaries with GPT-4, Claude, etc.).
|
||||
- **`--cpu`**: No GPU, prefer cloud LLM. Slowest transcription but best summary quality.
|
||||
- **`--hosted`**: Remote GPU, cloud LLM. No local ML at all.
|
||||
|
||||
## Other Optional Flags
|
||||
|
||||
@@ -160,8 +166,9 @@ Without `--caddy` or `--domain`, no ports are exposed. Point your own reverse pr
|
||||
4. **Generate `www/.env`** — Auto-detects server IP, sets URLs
|
||||
5. **Storage setup** — Either initializes Garage (bucket, keys, permissions) or prompts for external S3 credentials
|
||||
6. **Caddyfile** — Generates domain-specific (Let's Encrypt) or IP-specific (self-signed) configuration
|
||||
7. **Build & start** — Always builds GPU/CPU model image from source. With `--build`, also builds backend and frontend from source; otherwise pulls prebuilt images from the registry
|
||||
8. **Health checks** — Waits for each service, pulls Ollama model if needed, warns about missing LLM config
|
||||
7. **Build & start** — For `--gpu`, builds the GPU model image from source. For `--cpu` and `--hosted`, no ML container is built. With `--build`, also builds backend and frontend from source; otherwise pulls prebuilt images from the registry
|
||||
8. **Auto-detects video platforms** — If `DAILY_API_KEY` is found in `server/.env`, generates `.env.hatchet` (dashboard URL/cookie config), starts Hatchet workflow engine, and generates an API token. If any video platform is configured, enables the Rooms feature
|
||||
9. **Health checks** — Waits for each service, pulls Ollama model if needed, warns about missing LLM config
|
||||
|
||||
> For a deeper dive into each step, see [How the Self-Hosted Setup Works](selfhosted-architecture.md).
|
||||
|
||||
@@ -180,12 +187,23 @@ Without `--caddy` or `--domain`, no ports are exposed. Point your own reverse pr
|
||||
| `ADMIN_PASSWORD_HASH` | PBKDF2 hash for password auth | *(unset)* |
|
||||
| `WEBRTC_HOST` | IP advertised in WebRTC ICE candidates | Auto-detected (server IP) |
|
||||
| `TRANSCRIPT_URL` | Specialized model endpoint | `http://transcription:8000` |
|
||||
| `PADDING_BACKEND` | Audio padding backend (`pyav` or `modal`) | `modal` (selfhosted), `pyav` (default) |
|
||||
| `PADDING_URL` | Audio padding endpoint (when `PADDING_BACKEND=modal`) | `http://transcription:8000` |
|
||||
| `LLM_URL` | OpenAI-compatible LLM endpoint | Auto-set for Ollama modes |
|
||||
| `LLM_API_KEY` | LLM API key | `not-needed` for Ollama |
|
||||
| `LLM_MODEL` | LLM model name | `qwen2.5:14b` for Ollama (override with `--llm-model`) |
|
||||
| `CELERY_BEAT_POLL_INTERVAL` | Override all worker polling intervals (seconds). `0` = use individual defaults | `300` (selfhosted), `0` (other) |
|
||||
| `TRANSCRIPT_STORAGE_BACKEND` | Storage backend | `aws` |
|
||||
| `TRANSCRIPT_STORAGE_AWS_*` | S3 credentials | Auto-set for Garage |
|
||||
| `DAILY_API_KEY` | Daily.co API key (enables live rooms) | *(unset)* |
|
||||
| `DAILY_SUBDOMAIN` | Daily.co subdomain | *(unset)* |
|
||||
| `DAILYCO_STORAGE_AWS_ACCESS_KEY_ID` | AWS access key for reading Daily's recording bucket | *(unset)* |
|
||||
| `DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY` | AWS secret key for reading Daily's recording bucket | *(unset)* |
|
||||
| `HATCHET_CLIENT_TOKEN` | Hatchet API token (auto-generated) | *(unset)* |
|
||||
| `HATCHET_CLIENT_SERVER_URL` | Hatchet server URL | Auto-set when Daily.co configured |
|
||||
| `HATCHET_CLIENT_HOST_PORT` | Hatchet gRPC address | Auto-set when Daily.co configured |
|
||||
| `TRANSCRIPT_FILE_TIMEOUT` | HTTP timeout (seconds) for file transcription requests | `600` (`3600` in CPU mode) |
|
||||
| `DIARIZATION_FILE_TIMEOUT` | HTTP timeout (seconds) for file diarization requests | `600` (`3600` in CPU mode) |
|
||||
|
||||
### Frontend Environment (`www/.env`)
|
||||
|
||||
@@ -197,6 +215,7 @@ Without `--caddy` or `--domain`, no ports are exposed. Point your own reverse pr
|
||||
| `NEXTAUTH_SECRET` | Auth secret | Auto-generated |
|
||||
| `FEATURE_REQUIRE_LOGIN` | Require authentication | `false` |
|
||||
| `AUTH_PROVIDER` | Auth provider (`authentik` or `credentials`) | *(unset)* |
|
||||
| `FEATURE_ROOMS` | Enable meeting rooms UI | Auto-set when video platform configured |
|
||||
|
||||
## Storage Options
|
||||
|
||||
@@ -353,6 +372,87 @@ By default, authentication is disabled (`AUTH_BACKEND=none`, `FEATURE_REQUIRE_LO
|
||||
```
|
||||
5. Restart: `docker compose -f docker-compose.selfhosted.yml down && ./scripts/setup-selfhosted.sh <same-flags>`
|
||||
|
||||
## Enabling Daily.co Live Rooms
|
||||
|
||||
Daily.co enables real-time meeting rooms with automatic recording and per-participant
|
||||
audio tracks for improved diarization. When configured, the setup script automatically
|
||||
starts the Hatchet workflow engine for multitrack recording processing.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **Daily.co account** — Sign up at https://www.daily.co/
|
||||
- **API key** — From Daily.co Dashboard → Developers → API Keys
|
||||
- **Subdomain** — The `yourname` part of `yourname.daily.co`
|
||||
- **AWS S3 bucket** — For Daily.co to store recordings. See [Daily.co recording storage docs](https://docs.daily.co/guides/products/live-streaming-recording/storing-recordings-in-a-custom-s3-bucket)
|
||||
- **IAM role ARN** — An AWS IAM role that Daily.co assumes to write recordings to your bucket
|
||||
|
||||
### Setup
|
||||
|
||||
1. Configure Daily.co env vars in `server/.env` **before** running the setup script:
|
||||
|
||||
```env
|
||||
DAILY_API_KEY=your-daily-api-key
|
||||
DAILY_SUBDOMAIN=your-subdomain
|
||||
DEFAULT_VIDEO_PLATFORM=daily
|
||||
DAILYCO_STORAGE_AWS_BUCKET_NAME=your-recordings-bucket
|
||||
DAILYCO_STORAGE_AWS_REGION=us-east-1
|
||||
DAILYCO_STORAGE_AWS_ROLE_ARN=arn:aws:iam::123456789:role/DailyCoAccess
|
||||
# Worker credentials for reading/deleting recordings from Daily's S3 bucket.
|
||||
# Required when transcript storage is separate from Daily's bucket
|
||||
# (e.g., selfhosted with Garage or a different S3 account).
|
||||
DAILYCO_STORAGE_AWS_ACCESS_KEY_ID=your-aws-access-key
|
||||
DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY=your-aws-secret-key
|
||||
```
|
||||
|
||||
> **Important:** The `DAILYCO_STORAGE_AWS_ACCESS_KEY_ID` and `SECRET_ACCESS_KEY` are AWS IAM
|
||||
> credentials that allow the Hatchet workers to **read and delete** recording files from Daily's
|
||||
> S3 bucket. These are separate from the `ROLE_ARN` (which Daily's API uses to *write* recordings).
|
||||
> Without these keys, multitrack processing will fail with 404 errors when transcript storage
|
||||
> (e.g., Garage) uses different credentials than the Daily recording bucket.
|
||||
|
||||
2. Run the setup script as normal:
|
||||
|
||||
```bash
|
||||
./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
```
|
||||
|
||||
The script detects `DAILY_API_KEY` and automatically:
|
||||
- Starts the Hatchet workflow engine (`hatchet` container)
|
||||
- Starts Hatchet CPU and LLM workers (`hatchet-worker-cpu`, `hatchet-worker-llm`)
|
||||
- Generates a `HATCHET_CLIENT_TOKEN` and saves it to `server/.env`
|
||||
- Sets `HATCHET_CLIENT_SERVER_URL` and `HATCHET_CLIENT_HOST_PORT`
|
||||
- Enables `FEATURE_ROOMS=true` in `www/.env`
|
||||
- Registers Daily.co beat tasks (recording polling, presence reconciliation)
|
||||
|
||||
3. (Optional) For faster recording discovery, configure a Daily.co webhook:
|
||||
- In the Daily.co dashboard, add a webhook pointing to `https://your-domain/v1/daily/webhook`
|
||||
- Set `DAILY_WEBHOOK_SECRET` in `server/.env` (the signing secret from Daily.co)
|
||||
- Without webhooks, the system polls the Daily.co API every 15 seconds
|
||||
|
||||
### What Gets Started
|
||||
|
||||
| Service | Purpose |
|
||||
|---------|---------|
|
||||
| `hatchet` | Workflow orchestration engine (manages multitrack processing pipelines) |
|
||||
| `hatchet-worker-cpu` | CPU-heavy audio tasks (track mixdown, waveform generation) |
|
||||
| `hatchet-worker-llm` | Transcription, LLM inference (summaries, topics, titles), orchestration |
|
||||
|
||||
### Hatchet Dashboard
|
||||
|
||||
The Hatchet workflow engine includes a web dashboard for monitoring workflow runs and debugging. The setup script auto-generates `.env.hatchet` at the project root with the dashboard URL and cookie domain configuration. This file is git-ignored.
|
||||
|
||||
- **With Caddy**: Accessible at `https://your-domain:8888` (TLS via Caddy)
|
||||
- **Without Caddy**: Accessible at `http://your-ip:8888` (direct port mapping)
|
||||
|
||||
### Conditional Beat Tasks
|
||||
|
||||
Beat tasks are registered based on which services are configured:
|
||||
|
||||
- **Whereby tasks** (only if `WHEREBY_API_KEY` or `AWS_PROCESS_RECORDING_QUEUE_URL`): `process_messages`, `reprocess_failed_recordings`
|
||||
- **Daily.co tasks** (only if `DAILY_API_KEY`): `poll_daily_recordings`, `trigger_daily_reconciliation`, `reprocess_failed_daily_recordings`
|
||||
- **Platform tasks** (if any video platform configured): `process_meetings`, `sync_all_ics_calendars`, `create_upcoming_meetings`
|
||||
- **Always registered**: `cleanup_old_public_data` (if `PUBLIC_MODE`), `healthcheck_ping` (if `HEALTHCHECK_URL`)
|
||||
|
||||
## Enabling Real Domain with Let's Encrypt
|
||||
|
||||
By default, Caddy uses self-signed certificates. For a real domain:
|
||||
@@ -446,6 +546,15 @@ docker compose -f docker-compose.selfhosted.yml logs server --tail 50
|
||||
For self-signed certs, your browser will warn. Click Advanced > Proceed.
|
||||
For Let's Encrypt, ensure ports 80/443 are open and DNS is pointed correctly.
|
||||
|
||||
### File processing timeout on CPU
|
||||
CPU transcription and diarization are significantly slower than GPU. A 20-minute audio file can take 20-40 minutes to process on CPU. The setup script automatically sets `TRANSCRIPT_FILE_TIMEOUT=3600` and `DIARIZATION_FILE_TIMEOUT=3600` (1 hour) for `--cpu` mode. If you still hit timeouts with very long files, increase these values in `server/.env`:
|
||||
```bash
|
||||
# Increase to 2 hours for files over 1 hour
|
||||
TRANSCRIPT_FILE_TIMEOUT=7200
|
||||
DIARIZATION_FILE_TIMEOUT=7200
|
||||
```
|
||||
Then restart the worker: `docker compose -f docker-compose.selfhosted.yml restart worker`
|
||||
|
||||
### Summaries/topics not generating
|
||||
Check LLM configuration:
|
||||
```bash
|
||||
@@ -501,22 +610,29 @@ The setup script is idempotent — it won't overwrite existing secrets or env va
|
||||
│ │ │
|
||||
v v v
|
||||
┌───────────┐ ┌─────────┐ ┌─────────┐
|
||||
│transcription│ │postgres │ │ redis │
|
||||
│(gpu/cpu) │ │ :5432 │ │ :6379 │
|
||||
│ :8000 │ └─────────┘ └─────────┘
|
||||
└───────────┘
|
||||
│ ML models │ │postgres │ │ redis │
|
||||
│ (varies) │ │ :5432 │ │ :6379 │
|
||||
└───────────┘ └─────────┘ └─────────┘
|
||||
│
|
||||
┌─────┴─────┐ ┌─────────┐
|
||||
│ ollama │ │ garage │
|
||||
│ (optional)│ │(optional│
|
||||
│ :11435 │ │ S3) │
|
||||
└───────────┘ └─────────┘
|
||||
|
||||
┌───────────────────────────────────┐
|
||||
│ Hatchet (optional — Daily.co) │
|
||||
│ ┌─────────┐ ┌───────────────┐ │
|
||||
│ │ hatchet │ │ hatchet-worker│ │
|
||||
│ │ :8888 │──│ -cpu / -llm │ │
|
||||
│ └─────────┘ └───────────────┘ │
|
||||
└───────────────────────────────────┘
|
||||
|
||||
ML models box varies by mode:
|
||||
--gpu: Local GPU container (transcription:8000)
|
||||
--cpu: In-process on server/worker (no container)
|
||||
--hosted: Remote GPU service (user URL)
|
||||
```
|
||||
|
||||
All services communicate over Docker's internal network. Only Caddy (if enabled) exposes ports to the internet.
|
||||
All services communicate over Docker's internal network. Only Caddy (if enabled) exposes ports to the internet. Hatchet services are only started when `DAILY_API_KEY` is configured.
|
||||
|
||||
## Future Plans for the Self-Hosted Script
|
||||
|
||||
The following features are supported by Reflector but are **not yet integrated into the self-hosted setup script** and require manual configuration:
|
||||
|
||||
- **Daily.co live rooms with multitrack processing**: Daily.co enables real-time meeting rooms with automatic recording and per-participant audio tracks for improved diarization. Requires a Daily.co account, API key, and an AWS S3 bucket for recording storage. Currently not automated in the script because the worker orchestration (hatchet) is not yet supported in the selfhosted compose setup.
|
||||
|
||||
@@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.padding import router as padding_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
from .services.transcriber import WhisperService
|
||||
@@ -27,4 +28,5 @@ def create_app() -> FastAPI:
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
app.include_router(padding_router)
|
||||
return app
|
||||
|
||||
199
gpu/self_hosted/app/routers/padding.py
Normal file
199
gpu/self_hosted/app/routers/padding.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Audio padding endpoint for selfhosted GPU service.
|
||||
|
||||
CPU-intensive audio padding service for adding silence to audio tracks.
|
||||
Uses PyAV filter graph (adelay) for precise track synchronization.
|
||||
|
||||
IMPORTANT: This padding logic is duplicated from server/reflector/utils/audio_padding.py
|
||||
for deployment isolation (self_hosted can't import from server/reflector/). If you modify
|
||||
the PyAV filter graph or padding algorithm, you MUST update both:
|
||||
- gpu/self_hosted/app/routers/padding.py (this file)
|
||||
- server/reflector/utils/audio_padding.py
|
||||
|
||||
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
import requests
|
||||
from av.audio.resampler import AudioResampler
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["padding"])
|
||||
|
||||
# ref B0F71CE8-FC59-4AA5-8414-DAFB836DB711
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
OPUS_DEFAULT_BIT_RATE = 128000
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
class PaddingRequest(BaseModel):
|
||||
track_url: str
|
||||
output_url: str
|
||||
start_time_seconds: float
|
||||
track_index: int
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
@router.post("/pad", dependencies=[Depends(apikey_auth)], response_model=PaddingResponse)
|
||||
def pad_track(req: PaddingRequest):
|
||||
"""Pad audio track with silence using PyAV adelay filter graph."""
|
||||
if not req.track_url:
|
||||
raise HTTPException(status_code=400, detail="track_url cannot be empty")
|
||||
if not req.output_url:
|
||||
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
||||
if req.start_time_seconds <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"start_time_seconds must be positive, got {req.start_time_seconds}",
|
||||
)
|
||||
if req.start_time_seconds > 18000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="start_time_seconds exceeds maximum 18000s (5 hours)",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Padding request: track %d, delay=%.3fs", req.track_index, req.start_time_seconds
|
||||
)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
input_path = None
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
# Download source audio
|
||||
logger.info("Downloading track for padding")
|
||||
response = requests.get(req.track_url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
input_path = os.path.join(temp_dir, "track.webm")
|
||||
total_bytes = 0
|
||||
with open(input_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
logger.info("Track downloaded: %d bytes", total_bytes)
|
||||
|
||||
# Apply padding using PyAV
|
||||
output_path = os.path.join(temp_dir, "padded.webm")
|
||||
delay_ms = math.floor(req.start_time_seconds * 1000)
|
||||
logger.info("Padding track %d with %dms delay using PyAV", req.track_index, delay_ms)
|
||||
|
||||
in_container = av.open(input_path)
|
||||
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
|
||||
if in_stream is None:
|
||||
in_container.close()
|
||||
raise HTTPException(status_code=400, detail="No audio stream in input")
|
||||
|
||||
with av.open(output_path, "w", format="webm") as out_container:
|
||||
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
|
||||
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
|
||||
graph = av.filter.Graph()
|
||||
|
||||
abuf_args = (
|
||||
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_fmt=s16:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
src = graph.add("abuffer", args=abuf_args, name="src")
|
||||
aresample_f = graph.add("aresample", args="async=1", name="ares")
|
||||
delays_arg = f"{delay_ms}|{delay_ms}"
|
||||
adelay_f = graph.add(
|
||||
"adelay", args=f"delays={delays_arg}:all=1", name="delay"
|
||||
)
|
||||
sink = graph.add("abuffersink", name="sink")
|
||||
|
||||
src.link_to(aresample_f)
|
||||
aresample_f.link_to(adelay_f)
|
||||
adelay_f.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
resampler = AudioResampler(
|
||||
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
|
||||
)
|
||||
|
||||
for frame in in_container.decode(in_stream):
|
||||
out_frames = resampler.resample(frame) or []
|
||||
for rframe in out_frames:
|
||||
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
src.push(rframe)
|
||||
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush filter graph
|
||||
src.push(None)
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
|
||||
in_container.close()
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info("Padding complete: %d bytes", file_size)
|
||||
|
||||
# Upload padded track
|
||||
logger.info("Uploading padded track to S3")
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
||||
upload_response.raise_for_status()
|
||||
logger.info("Upload complete: %d bytes", file_size)
|
||||
|
||||
return PaddingResponse(size=file_size)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Padding failed for track %d: %s", req.track_index, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Padding failed: {e}") from e
|
||||
finally:
|
||||
if input_path and os.path.exists(input_path):
|
||||
try:
|
||||
os.unlink(input_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup input file: %s", e)
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup output file: %s", e)
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup temp directory: %s", e)
|
||||
@@ -11,9 +11,11 @@ dependencies = [
|
||||
"faster-whisper>=1.1.0",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
"silero-vad==5.1.2",
|
||||
"transformers>=4.35.0",
|
||||
"sentencepiece",
|
||||
"pyannote.audio==3.1.0",
|
||||
"pyannote.audio==3.4.0",
|
||||
"pytorch-lightning<2.6",
|
||||
"torchaudio>=2.3.0",
|
||||
"av>=13.1.0",
|
||||
]
|
||||
|
||||
23
gpu/self_hosted/uv.lock
generated
23
gpu/self_hosted/uv.lock
generated
@@ -726,7 +726,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/16/035dcfcc48715ccd345f3a93183267167cdd162ad123cd93067d86f27ce4/greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968", size = 655185, upload-time = "2025-08-07T13:45:27.624Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/da/0386695eef69ffae1ad726881571dfe28b41970173947e7c558d9998de0f/greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9", size = 649926, upload-time = "2025-08-07T13:53:15.251Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/88/69bf19fd4dc19981928ceacbc5fd4bb6bc2215d53199e367832e98d1d8fe/greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6", size = 651839, upload-time = "2025-08-07T13:18:30.281Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" },
|
||||
@@ -737,7 +736,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/0b/bc13f787394920b23073ca3b6c4a7a21396301ed75a655bcb47196b50e6e/greenlet-3.2.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:710638eb93b1fa52823aa91bf75326f9ecdfd5e0466f00789246a5280f4ba0fc", size = 655191, upload-time = "2025-08-07T13:45:29.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/d6/6adde57d1345a8d0f14d31e4ab9c23cfe8e2cd39c3baf7674b4b0338d266/greenlet-3.2.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c5111ccdc9c88f423426df3fd1811bfc40ed66264d35aa373420a34377efc98a", size = 649516, upload-time = "2025-08-07T13:53:16.314Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/3b/3a3328a788d4a473889a2d403199932be55b1b0060f4ddd96ee7cdfcad10/greenlet-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d76383238584e9711e20ebe14db6c88ddcedc1829a9ad31a584389463b5aa504", size = 652169, upload-time = "2025-08-07T13:18:32.861Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" },
|
||||
@@ -748,7 +746,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/aa/687d6b12ffb505a4447567d1f3abea23bd20e73a5bed63871178e0831b7a/greenlet-3.2.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c17b6b34111ea72fc5a4e4beec9711d2226285f0386ea83477cbb97c30a3f3a5", size = 699218, upload-time = "2025-08-07T13:45:30.969Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" },
|
||||
@@ -1745,7 +1742,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyannote-audio"
|
||||
version = "3.1.0"
|
||||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "asteroid-filterbanks" },
|
||||
@@ -1768,9 +1765,9 @@ dependencies = [
|
||||
{ name = "torchaudio" },
|
||||
{ name = "torchmetrics" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ad/55/7253267c35e2aa9188b1d86cba121eb5bdd91ed12d3194488625a008cae7/pyannote.audio-3.1.0.tar.gz", hash = "sha256:da04705443d3b74607e034d3ca88f8b572c7e9672dd9a4199cab65a0dbc33fad", size = 14812058, upload-time = "2023-11-16T12:26:38.939Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/1e/efe9619c38f1281ddf21640654d8ea9e3f67c459b76f78657b26d8557bbe/pyannote_audio-3.4.0.tar.gz", hash = "sha256:d523d883cb8d37cb6daf99f3ba83f9138bb193646ad71e6eae7deb89d8ddd642", size = 804850, upload-time = "2025-09-09T07:04:51.17Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/37/158859ce4c45b5ba2dca40b53b0c10d36f935b7f6d4e737298397167c8b1/pyannote.audio-3.1.0-py2.py3-none-any.whl", hash = "sha256:66ab485728c6e141760e80555cb7a083e7be824cd528cc79b9e6f7d6421a91ae", size = 208592, upload-time = "2023-11-16T12:26:36.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/13/620c6f711b723653092fd063bfee82a6af5ea3a4d3c42efc53ce623a7f4d/pyannote_audio-3.4.0-py2.py3-none-any.whl", hash = "sha256:36e38f058059f46da3478dda581cda53d9d85a21173a3e70bbdbc3ba93b5e1b7", size = 897789, upload-time = "2025-09-09T07:04:49.464Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2072,11 +2069,13 @@ name = "reflector-gpu"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "av" },
|
||||
{ name = "fastapi", extra = ["standard"] },
|
||||
{ name = "faster-whisper" },
|
||||
{ name = "librosa" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pyannote-audio" },
|
||||
{ name = "pytorch-lightning" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "silero-vad" },
|
||||
{ name = "torch" },
|
||||
@@ -2087,13 +2086,15 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "av", specifier = ">=13.1.0" },
|
||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.116.1" },
|
||||
{ name = "faster-whisper", specifier = ">=1.1.0" },
|
||||
{ name = "librosa", specifier = "==0.10.1" },
|
||||
{ name = "numpy", specifier = "<2" },
|
||||
{ name = "pyannote-audio", specifier = "==3.1.0" },
|
||||
{ name = "pyannote-audio", specifier = "==3.4.0" },
|
||||
{ name = "pytorch-lightning", specifier = "<2.6" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "silero-vad", specifier = "==5.1.0" },
|
||||
{ name = "silero-vad", specifier = "==5.1.2" },
|
||||
{ name = "torch", specifier = ">=2.3.0" },
|
||||
{ name = "torchaudio", specifier = ">=2.3.0" },
|
||||
{ name = "transformers", specifier = ">=4.35.0" },
|
||||
@@ -2473,16 +2474,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "silero-vad"
|
||||
version = "5.1"
|
||||
version = "5.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "onnxruntime" },
|
||||
{ name = "torch" },
|
||||
{ name = "torchaudio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7c/5d/b912e45d21b8b61859a552554893222d2cdebfd0f9afa7e8ba69c7a3441a/silero_vad-5.1.tar.gz", hash = "sha256:c644275ba5df06cee596cc050ba0bd1e0f5237d1abfa44d58dd4618f6e77434d", size = 3996829, upload-time = "2024-07-09T13:19:24.181Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/b4/d0311b2e6220a11f8f4699f4a278cb088131573286cdfe804c87c7eb5123/silero_vad-5.1.2.tar.gz", hash = "sha256:c442971160026d2d7aa0ad83f0c7ee86c89797a65289fe625c8ea59fc6fb828d", size = 5098526, upload-time = "2024-10-09T09:50:47.019Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/be/0fdbc72030b93d6f55107490d5d2185ddf0dbabdc921f589649d3e92ccd5/silero_vad-5.1-py3-none-any.whl", hash = "sha256:ecb50b484f538f7a962ce5cd3c07120d9db7b9d5a0c5861ccafe459856f22c8f", size = 3939986, upload-time = "2024-07-09T13:19:21.383Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/f7/5ae11d13fbb733cd3bfd7ff1c3a3902e6f55437df4b72307c1f168146268/silero_vad-5.1.2-py3-none-any.whl", hash = "sha256:93b41953d7774b165407fda6b533c119c5803864e367d5034dc626c82cfdf661", size = 5026737, upload-time = "2024-10-09T09:50:44.355Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
156
scripts/run-integration-tests.sh
Executable file
156
scripts/run-integration-tests.sh
Executable file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Run integration tests locally.
|
||||
#
|
||||
# Spins up the full stack via Docker Compose, runs the three integration tests,
|
||||
# and tears everything down afterward.
|
||||
#
|
||||
# Required environment variables:
|
||||
# LLM_URL — OpenAI-compatible LLM endpoint (e.g. https://api.openai.com/v1)
|
||||
# LLM_API_KEY — API key for the LLM endpoint
|
||||
# HF_TOKEN — HuggingFace token for pyannote gated models
|
||||
#
|
||||
# Optional:
|
||||
# LLM_MODEL — Model name (default: qwen2.5:14b)
|
||||
#
|
||||
# Usage:
|
||||
# export LLM_URL="https://api.openai.com/v1"
|
||||
# export LLM_API_KEY="sk-..."
|
||||
# export HF_TOKEN="hf_..."
|
||||
# ./scripts/run-integration-tests.sh
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
COMPOSE_DIR="$REPO_ROOT/server/tests"
|
||||
COMPOSE_FILE="$COMPOSE_DIR/docker-compose.integration.yml"
|
||||
COMPOSE="docker compose -f $COMPOSE_FILE"
|
||||
|
||||
# ── Validate required env vars ──────────────────────────────────────────────
|
||||
for var in LLM_URL LLM_API_KEY HF_TOKEN; do
|
||||
if [[ -z "${!var:-}" ]]; then
|
||||
echo "ERROR: $var is not set. See script header for required env vars."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
export LLM_MODEL="${LLM_MODEL:-qwen2.5:14b}"
|
||||
|
||||
# ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
info() { echo -e "\n\033[1;34m▸ $*\033[0m"; }
|
||||
ok() { echo -e "\033[1;32m ✓ $*\033[0m"; }
|
||||
fail() { echo -e "\033[1;31m ✗ $*\033[0m"; }
|
||||
|
||||
wait_for() {
|
||||
local desc="$1" cmd="$2" max="${3:-60}"
|
||||
info "Waiting for $desc (up to ${max}s)..."
|
||||
for i in $(seq 1 "$max"); do
|
||||
if eval "$cmd" &>/dev/null; then
|
||||
ok "$desc is ready"
|
||||
return 0
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
fail "$desc did not become ready within ${max}s"
|
||||
return 1
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
info "Tearing down..."
|
||||
$COMPOSE down -v --remove-orphans 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Always tear down on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# ── Step 1: Build and start infrastructure ──────────────────────────────────
|
||||
info "Building and starting infrastructure services..."
|
||||
$COMPOSE up -d --build postgres redis garage hatchet mock-daily
|
||||
|
||||
# ── Step 2: Set up Garage (S3 bucket + keys) ───────────────────────────────
|
||||
wait_for "Garage" "$COMPOSE exec -T garage /garage stats" 60
|
||||
|
||||
info "Setting up Garage bucket and keys..."
|
||||
GARAGE="$COMPOSE exec -T garage /garage"
|
||||
|
||||
# Hardcoded test credentials — ephemeral containers, destroyed after tests
|
||||
export GARAGE_KEY_ID="GK0123456789abcdef01234567" # gitleaks:allow
|
||||
export GARAGE_KEY_SECRET="0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" # gitleaks:allow
|
||||
|
||||
# Layout
|
||||
NODE_ID=$($GARAGE node id -q 2>&1 | tr -d '[:space:]')
|
||||
LAYOUT_STATUS=$($GARAGE layout show 2>&1 || true)
|
||||
if echo "$LAYOUT_STATUS" | grep -q "No nodes"; then
|
||||
$GARAGE layout assign "$NODE_ID" -c 1G -z dc1
|
||||
$GARAGE layout apply --version 1
|
||||
fi
|
||||
|
||||
# Bucket
|
||||
$GARAGE bucket info reflector-media >/dev/null 2>&1 || $GARAGE bucket create reflector-media
|
||||
|
||||
# Import key with known credentials
|
||||
if ! $GARAGE key info reflector-test >/dev/null 2>&1; then
|
||||
$GARAGE key import --yes "$GARAGE_KEY_ID" "$GARAGE_KEY_SECRET"
|
||||
$GARAGE key rename "$GARAGE_KEY_ID" reflector-test
|
||||
fi
|
||||
|
||||
# Permissions
|
||||
$GARAGE bucket allow reflector-media --read --write --key reflector-test
|
||||
|
||||
ok "Garage ready with hardcoded test credentials"
|
||||
|
||||
# ── Step 3: Generate Hatchet API token ──────────────────────────────────────
|
||||
wait_for "Hatchet" "$COMPOSE exec -T hatchet curl -sf http://localhost:8888/api/live" 90
|
||||
|
||||
info "Generating Hatchet API token..."
|
||||
HATCHET_TOKEN_OUTPUT=$($COMPOSE exec -T hatchet /hatchet-admin token create --config /config --name local-test 2>&1)
|
||||
export HATCHET_CLIENT_TOKEN=$(echo "$HATCHET_TOKEN_OUTPUT" | grep -o 'eyJ[A-Za-z0-9_.\-]*')
|
||||
|
||||
if [[ -z "$HATCHET_CLIENT_TOKEN" ]]; then
|
||||
fail "Failed to extract Hatchet token (JWT not found in output)"
|
||||
echo " Output was: $HATCHET_TOKEN_OUTPUT"
|
||||
exit 1
|
||||
fi
|
||||
ok "Hatchet token generated"
|
||||
|
||||
# ── Step 4: Start backend services ──────────────────────────────────────────
|
||||
info "Starting backend services..."
|
||||
$COMPOSE up -d server worker hatchet-worker-cpu hatchet-worker-llm test-runner
|
||||
|
||||
# ── Step 5: Wait for server + run migrations ────────────────────────────────
|
||||
wait_for "Server" "$COMPOSE exec -T test-runner curl -sf http://server:1250/health" 60
|
||||
|
||||
info "Running database migrations..."
|
||||
$COMPOSE exec -T server uv run alembic upgrade head
|
||||
ok "Migrations applied"
|
||||
|
||||
# ── Step 6: Run integration tests ───────────────────────────────────────────
|
||||
info "Running integration tests..."
|
||||
echo ""
|
||||
|
||||
LOGS_DIR="$COMPOSE_DIR/integration/logs"
|
||||
mkdir -p "$LOGS_DIR"
|
||||
RUN_TIMESTAMP=$(date +%Y%m%d-%H%M%S)
|
||||
TEST_LOG="$LOGS_DIR/$RUN_TIMESTAMP.txt"
|
||||
|
||||
if $COMPOSE exec -T test-runner uv run pytest tests/integration/ -v -x 2>&1 | tee "$TEST_LOG.pytest"; then
|
||||
echo ""
|
||||
ok "All integration tests passed!"
|
||||
EXIT_CODE=0
|
||||
else
|
||||
echo ""
|
||||
fail "Integration tests failed!"
|
||||
EXIT_CODE=1
|
||||
fi
|
||||
|
||||
# Always collect service logs + test output into a single file
|
||||
info "Collecting logs..."
|
||||
$COMPOSE logs --tail=500 > "$TEST_LOG" 2>&1
|
||||
echo -e "\n\n=== PYTEST OUTPUT ===\n" >> "$TEST_LOG"
|
||||
cat "$TEST_LOG.pytest" >> "$TEST_LOG" 2>/dev/null
|
||||
rm -f "$TEST_LOG.pytest"
|
||||
echo " Logs saved to: server/tests/integration/logs/$RUN_TIMESTAMP.txt"
|
||||
|
||||
# cleanup runs via trap
|
||||
exit $EXIT_CODE
|
||||
@@ -4,11 +4,12 @@
|
||||
# Single script to configure and launch everything on one server.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/setup-selfhosted.sh <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASSWORD] [--build]
|
||||
# ./scripts/setup-selfhosted.sh <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASSWORD] [--build]
|
||||
#
|
||||
# Specialized models (pick ONE — required):
|
||||
# --gpu NVIDIA GPU for transcription/diarization/translation
|
||||
# --cpu CPU-only for transcription/diarization/translation (slower)
|
||||
# ML processing modes (pick ONE — required):
|
||||
# --gpu NVIDIA GPU container for transcription/diarization/translation
|
||||
# --cpu In-process CPU processing (no ML container, slower)
|
||||
# --hosted Remote GPU service URL (no ML container)
|
||||
#
|
||||
# Local LLM (optional — for summarization & topic detection):
|
||||
# --ollama-gpu Local Ollama with NVIDIA GPU acceleration
|
||||
@@ -29,11 +30,16 @@
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --garage --caddy --domain reflector.example.com
|
||||
# ./scripts/setup-selfhosted.sh --cpu --ollama-cpu --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --hosted --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --ollama-gpu --llm-model mistral --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --gpu --garage --caddy --password mysecretpass
|
||||
# ./scripts/setup-selfhosted.sh --gpu --garage --caddy
|
||||
# ./scripts/setup-selfhosted.sh --cpu
|
||||
#
|
||||
# The script auto-detects Daily.co (DAILY_API_KEY) and Whereby (WHEREBY_API_KEY)
|
||||
# from server/.env. If Daily.co is configured, Hatchet workflow services are
|
||||
# started automatically for multitrack recording processing.
|
||||
#
|
||||
# Idempotent — safe to re-run at any time.
|
||||
#
|
||||
set -euo pipefail
|
||||
@@ -179,11 +185,14 @@ for i in "${!ARGS[@]}"; do
|
||||
arg="${ARGS[$i]}"
|
||||
case "$arg" in
|
||||
--gpu)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu and --cpu. Pick one."; exit 1; }
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="gpu" ;;
|
||||
--cpu)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu and --cpu. Pick one."; exit 1; }
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="cpu" ;;
|
||||
--hosted)
|
||||
[[ -n "$MODEL_MODE" ]] && { err "Cannot combine --gpu, --cpu, and --hosted. Pick one."; exit 1; }
|
||||
MODEL_MODE="hosted" ;;
|
||||
--ollama-gpu)
|
||||
[[ -n "$OLLAMA_MODE" ]] && { err "Cannot combine --ollama-gpu and --ollama-cpu. Pick one."; exit 1; }
|
||||
OLLAMA_MODE="ollama-gpu" ;;
|
||||
@@ -220,20 +229,21 @@ for i in "${!ARGS[@]}"; do
|
||||
SKIP_NEXT=true ;;
|
||||
*)
|
||||
err "Unknown argument: $arg"
|
||||
err "Usage: $0 <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err "Usage: $0 <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$MODEL_MODE" ]]; then
|
||||
err "No model mode specified. You must choose --gpu or --cpu."
|
||||
err "No model mode specified. You must choose --gpu, --cpu, or --hosted."
|
||||
err ""
|
||||
err "Usage: $0 <--gpu|--cpu> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err "Usage: $0 <--gpu|--cpu|--hosted> [--ollama-gpu|--ollama-cpu] [--llm-model MODEL] [--garage] [--caddy] [--domain DOMAIN] [--password PASS] [--build]"
|
||||
err ""
|
||||
err "Specialized models (required):"
|
||||
err " --gpu NVIDIA GPU for transcription/diarization/translation"
|
||||
err " --cpu CPU-only (slower but works without GPU)"
|
||||
err "ML processing modes (required):"
|
||||
err " --gpu NVIDIA GPU container for transcription/diarization/translation"
|
||||
err " --cpu In-process CPU processing (no ML container, slower)"
|
||||
err " --hosted Remote GPU service URL (no ML container)"
|
||||
err ""
|
||||
err "Local LLM (optional):"
|
||||
err " --ollama-gpu Local Ollama with GPU (for summarization/topics)"
|
||||
@@ -251,7 +261,11 @@ if [[ -z "$MODEL_MODE" ]]; then
|
||||
fi
|
||||
|
||||
# Build profiles list — one profile per feature
|
||||
COMPOSE_PROFILES=("$MODEL_MODE")
|
||||
# Hatchet + hatchet-worker-llm are always-on (no profile needed).
|
||||
# gpu/cpu profiles only control the ML container (transcription service).
|
||||
COMPOSE_PROFILES=()
|
||||
[[ "$MODEL_MODE" == "gpu" ]] && COMPOSE_PROFILES+=("gpu")
|
||||
[[ "$MODEL_MODE" == "cpu" ]] && COMPOSE_PROFILES+=("cpu")
|
||||
[[ -n "$OLLAMA_MODE" ]] && COMPOSE_PROFILES+=("$OLLAMA_MODE")
|
||||
[[ "$USE_GARAGE" == "true" ]] && COMPOSE_PROFILES+=("garage")
|
||||
[[ "$USE_CADDY" == "true" ]] && COMPOSE_PROFILES+=("caddy")
|
||||
@@ -418,39 +432,102 @@ step_server_env() {
|
||||
env_set "$SERVER_ENV" "WEBRTC_HOST" "$PRIMARY_IP"
|
||||
fi
|
||||
|
||||
# Specialized models (always via gpu/cpu container aliased as "transcription")
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "selfhosted"
|
||||
# Specialized models — backend configuration per mode
|
||||
env_set "$SERVER_ENV" "DIARIZATION_ENABLED" "true"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "http://transcription:8000"
|
||||
case "$MODEL_MODE" in
|
||||
gpu)
|
||||
# GPU container aliased as "transcription" on docker network
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "selfhosted"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "http://transcription:8000"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "PADDING_URL" "http://transcription:8000"
|
||||
ok "ML backends: GPU container (modal)"
|
||||
;;
|
||||
cpu)
|
||||
# In-process backends — no ML service container needed
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "whisper"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "pyannote"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "marian"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "pyav"
|
||||
ok "ML backends: in-process CPU (whisper/pyannote/marian/pyav)"
|
||||
;;
|
||||
hosted)
|
||||
# Remote GPU service — user provides URL
|
||||
local gpu_url=""
|
||||
if env_has_key "$SERVER_ENV" "TRANSCRIPT_URL"; then
|
||||
gpu_url=$(env_get "$SERVER_ENV" "TRANSCRIPT_URL")
|
||||
fi
|
||||
if [[ -z "$gpu_url" ]] && [[ -t 0 ]]; then
|
||||
echo ""
|
||||
info "Enter the URL of your remote GPU service (e.g. https://gpu.example.com)"
|
||||
read -rp " GPU service URL: " gpu_url
|
||||
fi
|
||||
if [[ -z "$gpu_url" ]]; then
|
||||
err "GPU service URL required for --hosted mode."
|
||||
err "Set TRANSCRIPT_URL in server/.env or provide it interactively."
|
||||
exit 1
|
||||
fi
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "TRANSLATION_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "TRANSLATE_URL" "$gpu_url"
|
||||
env_set "$SERVER_ENV" "PADDING_BACKEND" "modal"
|
||||
env_set "$SERVER_ENV" "PADDING_URL" "$gpu_url"
|
||||
# API key for remote service
|
||||
local gpu_api_key=""
|
||||
if env_has_key "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY"; then
|
||||
gpu_api_key=$(env_get "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY")
|
||||
fi
|
||||
if [[ -z "$gpu_api_key" ]] && [[ -t 0 ]]; then
|
||||
read -rp " GPU service API key (or Enter to skip): " gpu_api_key
|
||||
fi
|
||||
if [[ -n "$gpu_api_key" ]]; then
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_MODAL_API_KEY" "$gpu_api_key"
|
||||
fi
|
||||
ok "ML backends: remote hosted ($gpu_url)"
|
||||
;;
|
||||
esac
|
||||
|
||||
# HuggingFace token for gated models (pyannote diarization)
|
||||
# Written to root .env so docker compose picks it up for gpu/cpu containers
|
||||
local root_env="$ROOT_DIR/.env"
|
||||
local current_hf_token="${HF_TOKEN:-}"
|
||||
if [[ -f "$root_env" ]] && env_has_key "$root_env" "HF_TOKEN"; then
|
||||
current_hf_token=$(env_get "$root_env" "HF_TOKEN")
|
||||
fi
|
||||
if [[ -z "$current_hf_token" ]]; then
|
||||
echo ""
|
||||
warn "HF_TOKEN not set. Diarization will use a public model fallback."
|
||||
warn "For best results, get a token at https://huggingface.co/settings/tokens"
|
||||
warn "and accept pyannote licenses at https://huggingface.co/pyannote/speaker-diarization-3.1"
|
||||
read -rp " HuggingFace token (or press Enter to skip): " current_hf_token
|
||||
fi
|
||||
if [[ -n "$current_hf_token" ]]; then
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" "$current_hf_token"
|
||||
export HF_TOKEN="$current_hf_token"
|
||||
ok "HF_TOKEN configured"
|
||||
else
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" ""
|
||||
ok "HF_TOKEN skipped (using public model fallback)"
|
||||
# --gpu: written to root .env (docker compose passes to GPU container)
|
||||
# --cpu: written to both root .env and server/.env (in-process pyannote needs it)
|
||||
# --hosted: not needed (remote service handles its own auth)
|
||||
if [[ "$MODEL_MODE" != "hosted" ]]; then
|
||||
local root_env="$ROOT_DIR/.env"
|
||||
local current_hf_token="${HF_TOKEN:-}"
|
||||
if [[ -f "$root_env" ]] && env_has_key "$root_env" "HF_TOKEN"; then
|
||||
current_hf_token=$(env_get "$root_env" "HF_TOKEN")
|
||||
fi
|
||||
if [[ -z "$current_hf_token" ]]; then
|
||||
echo ""
|
||||
warn "HF_TOKEN not set. Diarization will use a public model fallback."
|
||||
warn "For best results, get a token at https://huggingface.co/settings/tokens"
|
||||
warn "and accept pyannote licenses at https://huggingface.co/pyannote/speaker-diarization-3.1"
|
||||
if [[ -t 0 ]]; then
|
||||
read -rp " HuggingFace token (or press Enter to skip): " current_hf_token
|
||||
fi
|
||||
fi
|
||||
if [[ -n "$current_hf_token" ]]; then
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" "$current_hf_token"
|
||||
export HF_TOKEN="$current_hf_token"
|
||||
# In CPU mode, server process needs HF_TOKEN directly
|
||||
if [[ "$MODEL_MODE" == "cpu" ]]; then
|
||||
env_set "$SERVER_ENV" "HF_TOKEN" "$current_hf_token"
|
||||
fi
|
||||
ok "HF_TOKEN configured"
|
||||
else
|
||||
touch "$root_env"
|
||||
env_set "$root_env" "HF_TOKEN" ""
|
||||
ok "HF_TOKEN skipped (using public model fallback)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# LLM configuration
|
||||
@@ -466,7 +543,7 @@ step_server_env() {
|
||||
if env_has_key "$SERVER_ENV" "LLM_URL"; then
|
||||
current_llm_url=$(env_get "$SERVER_ENV" "LLM_URL")
|
||||
fi
|
||||
if [[ -z "$current_llm_url" ]] || [[ "$current_llm_url" == "http://host.docker.internal"* ]]; then
|
||||
if [[ -z "$current_llm_url" ]]; then
|
||||
warn "LLM not configured. Summarization and topic detection will NOT work."
|
||||
warn "Edit server/.env and set LLM_URL, LLM_API_KEY, LLM_MODEL"
|
||||
warn "Example: LLM_URL=https://api.openai.com/v1 LLM_MODEL=gpt-4o-mini"
|
||||
@@ -475,6 +552,18 @@ step_server_env() {
|
||||
fi
|
||||
fi
|
||||
|
||||
# CPU mode: increase file processing timeouts (default 600s is too short for long audio on CPU)
|
||||
if [[ "$MODEL_MODE" == "cpu" ]]; then
|
||||
env_set "$SERVER_ENV" "TRANSCRIPT_FILE_TIMEOUT" "3600"
|
||||
env_set "$SERVER_ENV" "DIARIZATION_FILE_TIMEOUT" "3600"
|
||||
ok "CPU mode — file processing timeouts set to 3600s (1 hour)"
|
||||
fi
|
||||
|
||||
# Hatchet is always required (file, live, and multitrack pipelines all use it)
|
||||
env_set "$SERVER_ENV" "HATCHET_CLIENT_SERVER_URL" "http://hatchet:8888"
|
||||
env_set "$SERVER_ENV" "HATCHET_CLIENT_HOST_PORT" "hatchet:7077"
|
||||
ok "Hatchet connectivity configured (workflow engine for processing pipelines)"
|
||||
|
||||
ok "server/.env ready"
|
||||
}
|
||||
|
||||
@@ -535,6 +624,19 @@ step_www_env() {
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable rooms if any video platform is configured in server/.env
|
||||
local _daily_key="" _whereby_key=""
|
||||
if env_has_key "$SERVER_ENV" "DAILY_API_KEY"; then
|
||||
_daily_key=$(env_get "$SERVER_ENV" "DAILY_API_KEY")
|
||||
fi
|
||||
if env_has_key "$SERVER_ENV" "WHEREBY_API_KEY"; then
|
||||
_whereby_key=$(env_get "$SERVER_ENV" "WHEREBY_API_KEY")
|
||||
fi
|
||||
if [[ -n "$_daily_key" ]] || [[ -n "$_whereby_key" ]]; then
|
||||
env_set "$WWW_ENV" "FEATURE_ROOMS" "true"
|
||||
ok "Rooms feature enabled (video platform configured)"
|
||||
fi
|
||||
|
||||
ok "www/.env ready (URL=$base_url)"
|
||||
}
|
||||
|
||||
@@ -739,6 +841,23 @@ CADDYEOF
|
||||
else
|
||||
ok "Caddyfile already exists"
|
||||
fi
|
||||
|
||||
# Add Hatchet dashboard route if Daily.co is detected
|
||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
||||
if ! grep -q "hatchet" "$caddyfile" 2>/dev/null; then
|
||||
cat >> "$caddyfile" << CADDYEOF
|
||||
|
||||
# Hatchet workflow dashboard (Daily.co multitrack processing)
|
||||
:8888 {
|
||||
tls internal
|
||||
reverse_proxy hatchet:8888
|
||||
}
|
||||
CADDYEOF
|
||||
ok "Added Hatchet dashboard route to Caddyfile (port 8888)"
|
||||
else
|
||||
ok "Hatchet dashboard route already in Caddyfile"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# =========================================================
|
||||
@@ -747,11 +866,12 @@ CADDYEOF
|
||||
step_services() {
|
||||
info "Step 6: Starting Docker services"
|
||||
|
||||
# Build GPU/CPU image from source (always needed — no prebuilt image)
|
||||
local build_svc="$MODEL_MODE"
|
||||
info "Building $build_svc image (first build downloads ML models, may take a while)..."
|
||||
compose_cmd build "$build_svc"
|
||||
ok "$build_svc image built"
|
||||
# Build GPU image from source (only for --gpu mode)
|
||||
if [[ "$MODEL_MODE" == "gpu" ]]; then
|
||||
info "Building gpu image (first build downloads ML models, may take a while)..."
|
||||
compose_cmd build gpu
|
||||
ok "gpu image built"
|
||||
fi
|
||||
|
||||
# Build or pull backend and frontend images
|
||||
if [[ "$BUILD_IMAGES" == "true" ]]; then
|
||||
@@ -766,6 +886,44 @@ step_services() {
|
||||
compose_cmd pull server web || warn "Pull failed — using cached images"
|
||||
fi
|
||||
|
||||
# Hatchet is always needed (all processing pipelines use it)
|
||||
local NEEDS_HATCHET=true
|
||||
|
||||
# Build hatchet workers if Hatchet is needed (same backend image)
|
||||
if [[ "$NEEDS_HATCHET" == "true" ]] && [[ "$BUILD_IMAGES" == "true" ]]; then
|
||||
info "Building Hatchet worker images..."
|
||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
||||
compose_cmd build hatchet-worker-cpu hatchet-worker-llm
|
||||
else
|
||||
compose_cmd build hatchet-worker-llm
|
||||
fi
|
||||
ok "Hatchet worker images built"
|
||||
fi
|
||||
|
||||
# Ensure hatchet database exists before starting hatchet (init-hatchet-db.sql only runs on fresh postgres volumes)
|
||||
if [[ "$NEEDS_HATCHET" == "true" ]]; then
|
||||
info "Ensuring postgres is running for Hatchet database setup..."
|
||||
compose_cmd up -d postgres
|
||||
local pg_ready=false
|
||||
for i in $(seq 1 30); do
|
||||
if compose_cmd exec -T postgres pg_isready -U reflector > /dev/null 2>&1; then
|
||||
pg_ready=true
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
if [[ "$pg_ready" == "true" ]]; then
|
||||
compose_cmd exec -T postgres psql -U reflector -tc \
|
||||
"SELECT 1 FROM pg_database WHERE datname = 'hatchet'" 2>/dev/null \
|
||||
| grep -q 1 \
|
||||
|| compose_cmd exec -T postgres psql -U reflector -c "CREATE DATABASE hatchet" 2>/dev/null \
|
||||
|| true
|
||||
ok "Hatchet database ready"
|
||||
else
|
||||
warn "Postgres not ready — hatchet database may need to be created manually"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Start all services
|
||||
compose_cmd up -d
|
||||
ok "Containers started"
|
||||
@@ -788,25 +946,29 @@ step_services() {
|
||||
step_health() {
|
||||
info "Step 7: Health checks"
|
||||
|
||||
# Specialized model service (gpu or cpu)
|
||||
local model_svc="$MODEL_MODE"
|
||||
|
||||
info "Waiting for $model_svc service (first start downloads ~1GB of models)..."
|
||||
local model_ok=false
|
||||
for i in $(seq 1 120); do
|
||||
if curl -sf http://localhost:8000/docs > /dev/null 2>&1; then
|
||||
model_ok=true
|
||||
break
|
||||
# Specialized model service (only for --gpu mode)
|
||||
if [[ "$MODEL_MODE" == "gpu" ]]; then
|
||||
info "Waiting for gpu service (first start downloads ~1GB of models)..."
|
||||
local model_ok=false
|
||||
for i in $(seq 1 120); do
|
||||
if curl -sf http://localhost:8000/docs > /dev/null 2>&1; then
|
||||
model_ok=true
|
||||
break
|
||||
fi
|
||||
echo -ne "\r Waiting for gpu service... ($i/120)"
|
||||
sleep 5
|
||||
done
|
||||
echo ""
|
||||
if [[ "$model_ok" == "true" ]]; then
|
||||
ok "gpu service healthy (transcription + diarization)"
|
||||
else
|
||||
warn "gpu service not ready yet — it will keep loading in the background"
|
||||
warn "Check with: docker compose -f docker-compose.selfhosted.yml logs gpu"
|
||||
fi
|
||||
echo -ne "\r Waiting for $model_svc service... ($i/120)"
|
||||
sleep 5
|
||||
done
|
||||
echo ""
|
||||
if [[ "$model_ok" == "true" ]]; then
|
||||
ok "$model_svc service healthy (transcription + diarization)"
|
||||
else
|
||||
warn "$model_svc service not ready yet — it will keep loading in the background"
|
||||
warn "Check with: docker compose -f docker-compose.selfhosted.yml logs $model_svc"
|
||||
elif [[ "$MODEL_MODE" == "cpu" ]]; then
|
||||
ok "CPU mode — ML processing runs in-process on server/worker (no separate service)"
|
||||
elif [[ "$MODEL_MODE" == "hosted" ]]; then
|
||||
ok "Hosted mode — ML processing via remote GPU service (no local health check)"
|
||||
fi
|
||||
|
||||
# Ollama (if applicable)
|
||||
@@ -894,6 +1056,24 @@ step_health() {
|
||||
fi
|
||||
fi
|
||||
|
||||
# Hatchet (always-on)
|
||||
info "Waiting for Hatchet workflow engine..."
|
||||
local hatchet_ok=false
|
||||
for i in $(seq 1 60); do
|
||||
if curl -sf http://localhost:8888/api/live > /dev/null 2>&1; then
|
||||
hatchet_ok=true
|
||||
break
|
||||
fi
|
||||
echo -ne "\r Waiting for Hatchet... ($i/60)"
|
||||
sleep 3
|
||||
done
|
||||
echo ""
|
||||
if [[ "$hatchet_ok" == "true" ]]; then
|
||||
ok "Hatchet workflow engine healthy"
|
||||
else
|
||||
warn "Hatchet not ready yet. Check: docker compose logs hatchet"
|
||||
fi
|
||||
|
||||
# LLM warning for non-Ollama modes
|
||||
if [[ "$USES_OLLAMA" == "false" ]]; then
|
||||
local llm_url=""
|
||||
@@ -911,6 +1091,71 @@ step_health() {
|
||||
fi
|
||||
}
|
||||
|
||||
# =========================================================
|
||||
# Step 8: Hatchet token generation (gpu/cpu/Daily.co)
|
||||
# =========================================================
|
||||
step_hatchet_token() {
|
||||
# Hatchet is always required — no gating needed
|
||||
|
||||
# Skip if token already set
|
||||
if env_has_key "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" && [[ -n "$(env_get "$SERVER_ENV" "HATCHET_CLIENT_TOKEN")" ]]; then
|
||||
ok "HATCHET_CLIENT_TOKEN already set — skipping generation"
|
||||
return
|
||||
fi
|
||||
|
||||
info "Step 8: Generating Hatchet API token"
|
||||
|
||||
# Wait for hatchet to be healthy
|
||||
local hatchet_ok=false
|
||||
for i in $(seq 1 60); do
|
||||
if curl -sf http://localhost:8888/api/live > /dev/null 2>&1; then
|
||||
hatchet_ok=true
|
||||
break
|
||||
fi
|
||||
echo -ne "\r Waiting for Hatchet API... ($i/60)"
|
||||
sleep 3
|
||||
done
|
||||
echo ""
|
||||
|
||||
if [[ "$hatchet_ok" != "true" ]]; then
|
||||
err "Hatchet not responding — cannot generate token"
|
||||
err "Check: docker compose logs hatchet"
|
||||
return
|
||||
fi
|
||||
|
||||
# Get tenant ID from hatchet database
|
||||
local tenant_id
|
||||
tenant_id=$(compose_cmd exec -T postgres psql -U reflector -d hatchet -t -c \
|
||||
"SELECT id FROM \"Tenant\" WHERE slug = 'default';" 2>/dev/null | tr -d ' \n')
|
||||
|
||||
if [[ -z "$tenant_id" ]]; then
|
||||
err "Could not find default tenant in Hatchet database"
|
||||
err "Hatchet may still be initializing. Try re-running the script."
|
||||
return
|
||||
fi
|
||||
|
||||
# Generate token via hatchet-admin
|
||||
local token
|
||||
token=$(compose_cmd exec -T hatchet /hatchet-admin token create \
|
||||
--config /config --tenant-id "$tenant_id" 2>/dev/null | tr -d '\n')
|
||||
|
||||
if [[ -z "$token" ]]; then
|
||||
err "Failed to generate Hatchet token"
|
||||
err "Try generating manually: see server/README.md"
|
||||
return
|
||||
fi
|
||||
|
||||
env_set "$SERVER_ENV" "HATCHET_CLIENT_TOKEN" "$token"
|
||||
ok "HATCHET_CLIENT_TOKEN generated and saved to server/.env"
|
||||
|
||||
# Restart services that need the token
|
||||
info "Restarting services with new Hatchet token..."
|
||||
local restart_services="server worker hatchet-worker-llm"
|
||||
[[ "$DAILY_DETECTED" == "true" ]] && restart_services="$restart_services hatchet-worker-cpu"
|
||||
compose_cmd restart $restart_services
|
||||
ok "Services restarted with Hatchet token"
|
||||
}
|
||||
|
||||
# =========================================================
|
||||
# Main
|
||||
# =========================================================
|
||||
@@ -957,6 +1202,43 @@ main() {
|
||||
echo ""
|
||||
step_server_env
|
||||
echo ""
|
||||
|
||||
# Auto-detect video platforms from server/.env (after step_server_env so file exists)
|
||||
DAILY_DETECTED=false
|
||||
WHEREBY_DETECTED=false
|
||||
if env_has_key "$SERVER_ENV" "DAILY_API_KEY" && [[ -n "$(env_get "$SERVER_ENV" "DAILY_API_KEY")" ]]; then
|
||||
DAILY_DETECTED=true
|
||||
fi
|
||||
if env_has_key "$SERVER_ENV" "WHEREBY_API_KEY" && [[ -n "$(env_get "$SERVER_ENV" "WHEREBY_API_KEY")" ]]; then
|
||||
WHEREBY_DETECTED=true
|
||||
fi
|
||||
ANY_PLATFORM_DETECTED=false
|
||||
[[ "$DAILY_DETECTED" == "true" || "$WHEREBY_DETECTED" == "true" ]] && ANY_PLATFORM_DETECTED=true
|
||||
|
||||
# Conditional profile activation for Daily.co
|
||||
if [[ "$DAILY_DETECTED" == "true" ]]; then
|
||||
COMPOSE_PROFILES+=("dailyco")
|
||||
ok "Daily.co detected — enabling Hatchet workflow services"
|
||||
fi
|
||||
|
||||
# Generate .env.hatchet for hatchet dashboard config (always needed)
|
||||
local hatchet_server_url hatchet_cookie_domain
|
||||
if [[ -n "$CUSTOM_DOMAIN" ]]; then
|
||||
hatchet_server_url="https://${CUSTOM_DOMAIN}:8888"
|
||||
hatchet_cookie_domain="$CUSTOM_DOMAIN"
|
||||
elif [[ -n "$PRIMARY_IP" ]]; then
|
||||
hatchet_server_url="http://${PRIMARY_IP}:8888"
|
||||
hatchet_cookie_domain="$PRIMARY_IP"
|
||||
else
|
||||
hatchet_server_url="http://localhost:8888"
|
||||
hatchet_cookie_domain="localhost"
|
||||
fi
|
||||
cat > "$ROOT_DIR/.env.hatchet" << EOF
|
||||
SERVER_URL=$hatchet_server_url
|
||||
SERVER_AUTH_COOKIE_DOMAIN=$hatchet_cookie_domain
|
||||
EOF
|
||||
ok "Generated .env.hatchet (dashboard URL=$hatchet_server_url)"
|
||||
|
||||
step_www_env
|
||||
echo ""
|
||||
step_storage
|
||||
@@ -966,6 +1248,8 @@ main() {
|
||||
step_services
|
||||
echo ""
|
||||
step_health
|
||||
echo ""
|
||||
step_hatchet_token
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
@@ -995,6 +1279,9 @@ main() {
|
||||
[[ "$USE_GARAGE" != "true" ]] && echo " Storage: External S3"
|
||||
[[ "$USES_OLLAMA" == "true" ]] && echo " LLM: Ollama ($OLLAMA_MODEL) for summarization/topics"
|
||||
[[ "$USES_OLLAMA" != "true" ]] && echo " LLM: External (configure in server/.env)"
|
||||
[[ "$DAILY_DETECTED" == "true" ]] && echo " Video: Daily.co (live rooms + multitrack processing via Hatchet)"
|
||||
[[ "$WHEREBY_DETECTED" == "true" ]] && echo " Video: Whereby (live rooms)"
|
||||
[[ "$ANY_PLATFORM_DETECTED" != "true" ]] && echo " Video: None (rooms disabled)"
|
||||
echo ""
|
||||
echo " To stop: docker compose -f docker-compose.selfhosted.yml down"
|
||||
echo " To re-run: ./scripts/setup-selfhosted.sh $*"
|
||||
|
||||
@@ -86,11 +86,23 @@ LLM_API_KEY=not-needed
|
||||
## Context size for summary generation (tokens)
|
||||
LLM_CONTEXT_WINDOW=16000
|
||||
|
||||
## =======================================================
|
||||
## Audio Padding
|
||||
##
|
||||
## backends: pyav (in-process PyAV), modal (HTTP API client)
|
||||
## Default is "pyav" — no external service needed.
|
||||
## Set to "modal" when using Modal.com or self-hosted gpu/self_hosted/ container.
|
||||
## =======================================================
|
||||
#PADDING_BACKEND=pyav
|
||||
#PADDING_BACKEND=modal
|
||||
#PADDING_URL=https://xxxxx--reflector-padding-web.modal.run
|
||||
#PADDING_MODAL_API_KEY=xxxxx
|
||||
|
||||
## =======================================================
|
||||
## Diarization
|
||||
##
|
||||
## Only available on modal
|
||||
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
|
||||
## backends: modal (HTTP API), pyannote (in-process pyannote.audio)
|
||||
## To allow diarization, you need to expose expose the files to be downloaded by the pipeline
|
||||
## =======================================================
|
||||
DIARIZATION_ENABLED=false
|
||||
DIARIZATION_BACKEND=modal
|
||||
@@ -137,6 +149,10 @@ TRANSCRIPT_STORAGE_AWS_REGION=us-east-1
|
||||
#DAILYCO_STORAGE_AWS_ROLE_ARN=... # IAM role ARN for Daily.co S3 access
|
||||
#DAILYCO_STORAGE_AWS_BUCKET_NAME=reflector-dailyco
|
||||
#DAILYCO_STORAGE_AWS_REGION=us-west-2
|
||||
# Worker credentials for reading/deleting from Daily's recording bucket
|
||||
# Required when transcript storage is separate from Daily's bucket (e.g., selfhosted with Garage)
|
||||
#DAILYCO_STORAGE_AWS_ACCESS_KEY_ID=your-aws-access-key
|
||||
#DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY=your-aws-secret-key
|
||||
|
||||
## Whereby (optional separate bucket)
|
||||
#WHEREBY_STORAGE_AWS_BUCKET_NAME=reflector-whereby
|
||||
|
||||
@@ -32,23 +32,46 @@ AUTH_BACKEND=none
|
||||
|
||||
# =======================================================
|
||||
# Specialized Models (Transcription, Diarization, Translation)
|
||||
# These run in the gpu/cpu container — NOT an LLM.
|
||||
# The "modal" backend means "HTTP API client" — it talks to
|
||||
# the self-hosted container, not Modal.com cloud.
|
||||
# These do NOT use an LLM. Configured per mode by the setup script:
|
||||
#
|
||||
# --gpu mode: modal backends → GPU container (http://transcription:8000)
|
||||
# --cpu mode: whisper/pyannote/marian/pyav → in-process ML on server/worker
|
||||
# --hosted mode: modal backends → user-provided remote GPU service URL
|
||||
# =======================================================
|
||||
|
||||
# --- --gpu mode (default) ---
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=http://transcription:8000
|
||||
TRANSCRIPT_MODAL_API_KEY=selfhosted
|
||||
|
||||
DIARIZATION_ENABLED=true
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=http://transcription:8000
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATE_URL=http://transcription:8000
|
||||
PADDING_BACKEND=modal
|
||||
PADDING_URL=http://transcription:8000
|
||||
|
||||
# HuggingFace token — optional, for gated models (e.g. pyannote).
|
||||
# Falls back to public S3 model bundle if not set.
|
||||
# --- --cpu mode (set by setup script) ---
|
||||
# TRANSCRIPT_BACKEND=whisper
|
||||
# DIARIZATION_BACKEND=pyannote
|
||||
# TRANSLATION_BACKEND=marian
|
||||
# PADDING_BACKEND=pyav
|
||||
|
||||
# --- --hosted mode (set by setup script) ---
|
||||
# TRANSCRIPT_BACKEND=modal
|
||||
# TRANSCRIPT_URL=https://your-gpu-service.example.com
|
||||
# DIARIZATION_BACKEND=modal
|
||||
# DIARIZATION_URL=https://your-gpu-service.example.com
|
||||
# ... (all URLs point to one remote service)
|
||||
|
||||
# Whisper model sizes for local transcription (--cpu mode)
|
||||
# Options: "tiny", "base", "small", "medium", "large-v2"
|
||||
# WHISPER_CHUNK_MODEL=tiny
|
||||
# WHISPER_FILE_MODEL=tiny
|
||||
|
||||
# HuggingFace token — for gated models (e.g. pyannote diarization).
|
||||
# Required for --gpu and --cpu modes; falls back to public S3 bundle if not set.
|
||||
# Not needed for --hosted mode (remote service handles its own auth).
|
||||
# HF_TOKEN=hf_xxxxx
|
||||
|
||||
# =======================================================
|
||||
@@ -93,15 +116,42 @@ TRANSCRIPT_STORAGE_AWS_REGION=us-east-1
|
||||
# =======================================================
|
||||
# Daily.co Live Rooms (Optional)
|
||||
# Enable real-time meeting rooms with Daily.co integration.
|
||||
# Requires a Daily.co account: https://www.daily.co/
|
||||
# Configure these BEFORE running setup-selfhosted.sh and the
|
||||
# script will auto-detect and start Hatchet workflow services.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Daily.co account: https://www.daily.co/
|
||||
# 2. API key: Dashboard → Developers → API Keys
|
||||
# 3. S3 bucket for recordings: https://docs.daily.co/guides/products/live-streaming-recording/storing-recordings-in-a-custom-s3-bucket
|
||||
# 4. IAM role ARN for Daily.co to write recordings to your bucket
|
||||
#
|
||||
# After configuring, run: ./scripts/setup-selfhosted.sh <your-flags>
|
||||
# The script will detect DAILY_API_KEY and automatically:
|
||||
# - Start Hatchet workflow engine + CPU/LLM workers
|
||||
# - Generate a Hatchet API token
|
||||
# - Enable FEATURE_ROOMS in the frontend
|
||||
# =======================================================
|
||||
# DEFAULT_VIDEO_PLATFORM=daily
|
||||
# DAILY_API_KEY=your-daily-api-key
|
||||
# DAILY_SUBDOMAIN=your-subdomain
|
||||
# DAILY_WEBHOOK_SECRET=your-daily-webhook-secret
|
||||
# DEFAULT_VIDEO_PLATFORM=daily
|
||||
# DAILYCO_STORAGE_AWS_BUCKET_NAME=reflector-dailyco
|
||||
# DAILYCO_STORAGE_AWS_REGION=us-east-1
|
||||
# DAILYCO_STORAGE_AWS_ROLE_ARN=arn:aws:iam::role/DailyCoAccess
|
||||
# Worker credentials for reading/deleting from Daily's recording bucket
|
||||
# Required when transcript storage is separate from Daily's bucket (e.g., selfhosted with Garage)
|
||||
# DAILYCO_STORAGE_AWS_ACCESS_KEY_ID=your-aws-access-key
|
||||
# DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY=your-aws-secret-key
|
||||
# DAILY_WEBHOOK_SECRET=your-daily-webhook-secret # optional, for faster recording discovery
|
||||
|
||||
# =======================================================
|
||||
# Hatchet Workflow Engine (Auto-configured for Daily.co)
|
||||
# Required for Daily.co multitrack recording processing.
|
||||
# The setup script generates HATCHET_CLIENT_TOKEN automatically.
|
||||
# Do not set these manually unless you know what you're doing.
|
||||
# =======================================================
|
||||
# HATCHET_CLIENT_TOKEN=<auto-generated-by-script>
|
||||
# HATCHET_CLIENT_SERVER_URL=http://hatchet:8888
|
||||
# HATCHET_CLIENT_HOST_PORT=hatchet:7077
|
||||
|
||||
# =======================================================
|
||||
# Feature Flags
|
||||
|
||||
@@ -6,7 +6,7 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
RUN apt-get update && apt-get install -y curl && apt-get clean
|
||||
RUN apt-get update && apt-get install -y curl ffmpeg && apt-get clean
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
|
||||
@@ -27,7 +27,7 @@ dependencies = [
|
||||
"protobuf>=4.24.3",
|
||||
"celery>=5.3.4",
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"pyjwt[crypto]>=2.8.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"transformers>=4.36.2",
|
||||
"jsonschema>=4.23.0",
|
||||
@@ -38,7 +38,7 @@ dependencies = [
|
||||
"pytest-env>=1.1.5",
|
||||
"webvtt-py>=0.5.0",
|
||||
"icalendar>=6.0.0",
|
||||
"hatchet-sdk>=0.47.0",
|
||||
"hatchet-sdk==1.22.16",
|
||||
"pydantic>=2.12.5",
|
||||
]
|
||||
|
||||
@@ -71,9 +71,12 @@ local = [
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"silero-vad==5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
"pyannote.audio==3.4.0",
|
||||
"pytorch-lightning<2.6",
|
||||
"librosa==0.10.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
@@ -113,9 +116,10 @@ source = ["reflector"]
|
||||
ENVIRONMENT = "pytest"
|
||||
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||
AUTH_BACKEND = "jwt"
|
||||
HATCHET_CLIENT_TOKEN = "test-dummy-token"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v --ignore=tests/integration"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
|
||||
13
server/reflector/_warnings_filter.py
Normal file
13
server/reflector/_warnings_filter.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Suppress known dependency warnings. Import this before any reflector/hatchet_sdk
|
||||
imports that pull in pydantic (e.g. llama_index) to hide UnsupportedFieldAttributeWarning
|
||||
about validate_default.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*validate_default.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
@@ -12,8 +12,10 @@ AccessTokenInfo = auth_module.AccessTokenInfo
|
||||
authenticated = auth_module.authenticated
|
||||
current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
current_user_optional_if_public_mode = auth_module.current_user_optional_if_public_mode
|
||||
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
||||
current_user_ws_optional = auth_module.current_user_ws_optional
|
||||
verify_raw_token = auth_module.verify_raw_token
|
||||
|
||||
# Optional router (e.g. for /auth/login in password backend)
|
||||
router = getattr(auth_module, "router", None)
|
||||
|
||||
@@ -4,8 +4,8 @@ from fastapi import Depends, HTTPException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
import jwt
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.db.user_api_keys import user_api_keys_controller
|
||||
@@ -54,7 +54,7 @@ class JWTAuth:
|
||||
audience=jwt_audience,
|
||||
)
|
||||
return payload
|
||||
except JWTError as e:
|
||||
except jwt.PyJWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise
|
||||
|
||||
@@ -94,7 +94,7 @@ async def _authenticate_user(
|
||||
)
|
||||
|
||||
user_infos.append(UserInfo(sub=user.id, email=email))
|
||||
except JWTError as e:
|
||||
except jwt.PyJWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
@@ -129,6 +129,17 @@ async def current_user_optional(
|
||||
return await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||
|
||||
|
||||
async def current_user_optional_if_public_mode(
|
||||
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||
jwtauth: JWTAuth = Depends(),
|
||||
) -> Optional[UserInfo]:
|
||||
user = await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||
if user is None and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return user
|
||||
|
||||
|
||||
def parse_ws_bearer_token(
|
||||
websocket: "WebSocket",
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
@@ -144,3 +155,8 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
if not token:
|
||||
return None
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
@@ -21,9 +21,19 @@ def current_user_optional():
|
||||
return None
|
||||
|
||||
|
||||
def current_user_optional_if_public_mode():
|
||||
# auth_none means no authentication at all — always public
|
||||
return None
|
||||
|
||||
|
||||
def parse_ws_bearer_token(websocket):
|
||||
return None, None
|
||||
|
||||
|
||||
async def current_user_ws_optional(websocket):
|
||||
return None
|
||||
|
||||
|
||||
async def verify_raw_token(token):
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return None
|
||||
|
||||
@@ -9,9 +9,9 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.auth.password_utils import verify_password
|
||||
@@ -110,7 +110,7 @@ async def _authenticate_user(
|
||||
user_id = payload["sub"]
|
||||
email = payload.get("email")
|
||||
user_infos.append(UserInfo(sub=user_id, email=email))
|
||||
except JWTError as e:
|
||||
except jwt.PyJWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
@@ -150,6 +150,16 @@ async def current_user_optional(
|
||||
return await _authenticate_user(jwt_token, api_key)
|
||||
|
||||
|
||||
async def current_user_optional_if_public_mode(
|
||||
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||
) -> Optional[UserInfo]:
|
||||
user = await _authenticate_user(jwt_token, api_key)
|
||||
if user is None and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return user
|
||||
|
||||
|
||||
# --- WebSocket auth (same pattern as auth_jwt.py) ---
|
||||
def parse_ws_bearer_token(
|
||||
websocket: "WebSocket",
|
||||
@@ -168,6 +178,11 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
# --- Login router ---
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@@ -697,6 +697,18 @@ class TranscriptController:
|
||||
return False
|
||||
return user_id and transcript.user_id == user_id
|
||||
|
||||
@staticmethod
|
||||
def check_can_mutate(transcript: Transcript, user_id: str | None) -> None:
|
||||
"""
|
||||
Raises HTTP 403 if the user cannot mutate the transcript.
|
||||
|
||||
Policy:
|
||||
- Anonymous transcripts (user_id is None) are editable by anyone
|
||||
- Owned transcripts can only be mutated by their owner
|
||||
"""
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,21 @@ class TaskName(StrEnum):
|
||||
DETECT_CHUNK_TOPIC = "detect_chunk_topic"
|
||||
GENERATE_DETAILED_SUMMARY = "generate_detailed_summary"
|
||||
|
||||
# File pipeline tasks
|
||||
EXTRACT_AUDIO = "extract_audio"
|
||||
UPLOAD_AUDIO = "upload_audio"
|
||||
TRANSCRIBE = "transcribe"
|
||||
DIARIZE = "diarize"
|
||||
ASSEMBLE_TRANSCRIPT = "assemble_transcript"
|
||||
GENERATE_SUMMARIES = "generate_summaries"
|
||||
|
||||
# Live post-processing pipeline tasks
|
||||
WAVEFORM = "waveform"
|
||||
CONVERT_MP3 = "convert_mp3"
|
||||
UPLOAD_MP3 = "upload_mp3"
|
||||
REMOVE_UPLOAD = "remove_upload"
|
||||
FINAL_SUMMARIES = "final_summaries"
|
||||
|
||||
|
||||
# Rate limit key for LLM API calls (shared across all LLM-calling tasks)
|
||||
LLM_RATE_LIMIT_KEY = "llm"
|
||||
@@ -39,5 +54,12 @@ TIMEOUT_MEDIUM = (
|
||||
300 # Single LLM calls, waveform generation (5m for slow LLM responses)
|
||||
)
|
||||
TIMEOUT_LONG = 180 # Action items (larger context LLM)
|
||||
TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown
|
||||
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks
|
||||
TIMEOUT_TITLE = 300 # generate_title (single LLM call; doc: reduce from 600s)
|
||||
TIMEOUT_AUDIO = 720 # Audio processing: padding, mixdown (Hatchet execution_timeout)
|
||||
TIMEOUT_AUDIO_HTTP = (
|
||||
660 # httpx timeout for pad_track — below 720 so Hatchet doesn't race
|
||||
)
|
||||
TIMEOUT_HEAVY = 600 # Transcription, fan-out LLM tasks (Hatchet execution_timeout)
|
||||
TIMEOUT_HEAVY_HTTP = (
|
||||
540 # httpx timeout for transcribe_track — below 600 so Hatchet doesn't race
|
||||
)
|
||||
|
||||
74
server/reflector/hatchet/error_classification.py
Normal file
74
server/reflector/hatchet/error_classification.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Classify exceptions as non-retryable for Hatchet workflows.
|
||||
|
||||
When a task raises NonRetryableException (or an exception classified as
|
||||
non-retryable and re-raised as such), Hatchet stops immediately — no further
|
||||
retries. Used by with_error_handling to avoid wasting retries on config errors,
|
||||
auth failures, corrupt data, etc.
|
||||
"""
|
||||
|
||||
# Optional dependencies: only classify if the exception type is available.
|
||||
# This avoids hard dependency on openai/av/botocore for code paths that don't use them.
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
openai = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import av
|
||||
except ImportError:
|
||||
av = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from botocore.exceptions import ClientError as BotoClientError
|
||||
except ImportError:
|
||||
BotoClientError = None # type: ignore[misc, assignment]
|
||||
|
||||
from hatchet_sdk import NonRetryableException
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from reflector.llm import LLMParseError
|
||||
|
||||
# HTTP status codes that won't change on retry (auth, not found, payment, payload)
|
||||
NON_RETRYABLE_HTTP_STATUSES = {401, 402, 403, 404, 413}
|
||||
NON_RETRYABLE_S3_CODES = {"AccessDenied", "NoSuchBucket", "NoSuchKey"}
|
||||
|
||||
|
||||
def is_non_retryable(e: BaseException) -> bool:
|
||||
"""Return True if the exception should stop Hatchet retries immediately.
|
||||
|
||||
Hard failures (config, auth, missing resource, corrupt data) return True.
|
||||
Transient errors (timeouts, 5xx, 429, connection) return False.
|
||||
"""
|
||||
if isinstance(e, NonRetryableException):
|
||||
return True
|
||||
|
||||
# Config/input errors
|
||||
if isinstance(e, (ValueError, TypeError)):
|
||||
return True
|
||||
|
||||
# HTTP status codes that won't change on retry
|
||||
if isinstance(e, HTTPStatusError):
|
||||
return e.response.status_code in NON_RETRYABLE_HTTP_STATUSES
|
||||
|
||||
# OpenAI auth errors
|
||||
if openai is not None and isinstance(e, openai.AuthenticationError):
|
||||
return True
|
||||
|
||||
# LLM parse failures (already retried internally)
|
||||
if isinstance(e, LLMParseError):
|
||||
return True
|
||||
|
||||
# S3 permission/existence errors
|
||||
if BotoClientError is not None and isinstance(e, BotoClientError):
|
||||
code = e.response.get("Error", {}).get("Code", "")
|
||||
return code in NON_RETRYABLE_S3_CODES
|
||||
|
||||
# Corrupt audio (PyAV) — AVError in some versions; fallback to InvalidDataError
|
||||
if av is not None:
|
||||
av_error = getattr(av, "AVError", None) or getattr(
|
||||
getattr(av, "error", None), "InvalidDataError", None
|
||||
)
|
||||
if av_error is not None and isinstance(e, av_error):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -7,6 +7,7 @@ Configuration:
|
||||
- Worker affinity: pool=cpu-heavy
|
||||
"""
|
||||
|
||||
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
daily_multitrack_pipeline,
|
||||
|
||||
@@ -5,10 +5,13 @@ Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestr
|
||||
|
||||
import asyncio
|
||||
|
||||
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
daily_multitrack_pipeline,
|
||||
)
|
||||
from reflector.hatchet.workflows.file_pipeline import file_pipeline
|
||||
from reflector.hatchet.workflows.live_post_pipeline import live_post_pipeline
|
||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||
@@ -46,6 +49,8 @@ def main():
|
||||
},
|
||||
workflows=[
|
||||
daily_multitrack_pipeline,
|
||||
file_pipeline,
|
||||
live_post_pipeline,
|
||||
topic_chunk_workflow,
|
||||
subject_workflow,
|
||||
track_workflow,
|
||||
|
||||
@@ -27,6 +27,7 @@ from hatchet_sdk import (
|
||||
ConcurrencyExpression,
|
||||
ConcurrencyLimitStrategy,
|
||||
Context,
|
||||
NonRetryableException,
|
||||
)
|
||||
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||
from pydantic import BaseModel
|
||||
@@ -43,8 +44,10 @@ from reflector.hatchet.constants import (
|
||||
TIMEOUT_LONG,
|
||||
TIMEOUT_MEDIUM,
|
||||
TIMEOUT_SHORT,
|
||||
TIMEOUT_TITLE,
|
||||
TaskName,
|
||||
)
|
||||
from reflector.hatchet.error_classification import is_non_retryable
|
||||
from reflector.hatchet.workflows.models import (
|
||||
ActionItemsResult,
|
||||
ConsentResult,
|
||||
@@ -90,7 +93,6 @@ from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||
from reflector.processors.types import TitleSummary, Word
|
||||
from reflector.processors.types import Transcript as TranscriptType
|
||||
from reflector.settings import settings
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
from reflector.utils.audio_constants import (
|
||||
PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
WAVEFORM_SEGMENTS,
|
||||
@@ -117,6 +119,7 @@ class PipelineInput(BaseModel):
|
||||
bucket_name: NonEmptyString
|
||||
transcript_id: NonEmptyString
|
||||
room_id: NonEmptyString | None = None
|
||||
source_platform: str = "daily"
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
@@ -170,15 +173,10 @@ async def set_workflow_error_status(transcript_id: NonEmptyString) -> bool:
|
||||
|
||||
|
||||
def _spawn_storage():
|
||||
"""Create fresh storage instance."""
|
||||
# TODO: replace direct AwsStorage construction with get_transcripts_storage() factory
|
||||
return AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
aws_endpoint_url=settings.TRANSCRIPT_STORAGE_AWS_ENDPOINT_URL,
|
||||
)
|
||||
"""Create fresh storage instance for writing to our transcript bucket."""
|
||||
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||
|
||||
return get_transcripts_storage()
|
||||
|
||||
|
||||
class Loggable(Protocol):
|
||||
@@ -221,6 +219,13 @@ def make_audio_progress_logger(
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _successful_run_results(
|
||||
results: list[dict[str, Any] | BaseException],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return only successful (non-exception) results from aio_run_many(return_exceptions=True)."""
|
||||
return [r for r in results if not isinstance(r, BaseException)]
|
||||
|
||||
|
||||
def with_error_handling(
|
||||
step_name: TaskName, set_error_status: bool = True
|
||||
) -> Callable[
|
||||
@@ -248,8 +253,12 @@ def with_error_handling(
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
if is_non_retryable(e):
|
||||
# Hard fail: stop retries, set error status, fail workflow
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
raise NonRetryableException(str(e)) from e
|
||||
# Transient: do not set error status — Hatchet will retry
|
||||
raise
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
@@ -258,7 +267,10 @@ def with_error_handling(
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.GET_RECORDING)
|
||||
async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
@@ -295,7 +307,9 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
ctx.log(
|
||||
f"get_recording: calling Daily.co API for recording_id={input.recording_id}..."
|
||||
)
|
||||
async with DailyApiClient(api_key=settings.DAILY_API_KEY) as client:
|
||||
async with DailyApiClient(
|
||||
api_key=settings.DAILY_API_KEY, base_url=settings.DAILY_API_URL
|
||||
) as client:
|
||||
recording = await client.get_recording(input.recording_id)
|
||||
ctx.log(f"get_recording: Daily.co API returned successfully")
|
||||
|
||||
@@ -314,6 +328,8 @@ async def get_recording(input: PipelineInput, ctx: Context) -> RecordingResult:
|
||||
parents=[get_recording],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.GET_PARTICIPANTS)
|
||||
async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsResult:
|
||||
@@ -360,7 +376,9 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
settings.DAILY_API_KEY, "DAILY_API_KEY is required"
|
||||
)
|
||||
|
||||
async with DailyApiClient(api_key=daily_api_key) as client:
|
||||
async with DailyApiClient(
|
||||
api_key=daily_api_key, base_url=settings.DAILY_API_URL
|
||||
) as client:
|
||||
participants = await client.get_meeting_participants(mtg_session_id)
|
||||
|
||||
id_to_name = {}
|
||||
@@ -417,6 +435,8 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
|
||||
parents=[get_participants],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.PROCESS_TRACKS)
|
||||
async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksResult:
|
||||
@@ -434,12 +454,13 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
bucket_name=input.bucket_name,
|
||||
transcript_id=input.transcript_id,
|
||||
language=source_language,
|
||||
source_platform=input.source_platform,
|
||||
)
|
||||
)
|
||||
for i, track in enumerate(input.tracks)
|
||||
]
|
||||
|
||||
results = await track_workflow.aio_run_many(bulk_runs)
|
||||
results = await track_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||
|
||||
target_language = participants_result.target_language
|
||||
|
||||
@@ -447,7 +468,18 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
padded_tracks = []
|
||||
created_padded_files = set()
|
||||
|
||||
for result in results:
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(
|
||||
"[Hatchet] process_tracks: track workflow failed, failing step",
|
||||
transcript_id=input.transcript_id,
|
||||
track_index=i,
|
||||
error=str(result),
|
||||
)
|
||||
ctx.log(f"process_tracks: track {i} failed ({result}), failing step")
|
||||
raise ValueError(
|
||||
f"Track {i} workflow failed after retries: {result!s}"
|
||||
) from result
|
||||
transcribe_result = TranscribeTrackResult(**result[TaskName.TRANSCRIBE_TRACK])
|
||||
track_words.append(transcribe_result.words)
|
||||
|
||||
@@ -485,7 +517,9 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[process_tracks],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||
retries=3,
|
||||
retries=2,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
desired_worker_labels={
|
||||
"pool": DesiredWorkerLabel(
|
||||
value="cpu-heavy",
|
||||
@@ -597,6 +631,8 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
parents=[mixdown_tracks],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_WAVEFORM)
|
||||
async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResult:
|
||||
@@ -665,6 +701,8 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
|
||||
parents=[process_tracks],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.DETECT_TOPICS)
|
||||
async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
@@ -726,11 +764,22 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
results = await topic_chunk_workflow.aio_run_many(bulk_runs)
|
||||
results = await topic_chunk_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||
|
||||
topic_chunks = [
|
||||
TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results
|
||||
]
|
||||
topic_chunks: list[TopicChunkResult] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(
|
||||
"[Hatchet] detect_topics: chunk workflow failed, failing step",
|
||||
transcript_id=input.transcript_id,
|
||||
chunk_index=i,
|
||||
error=str(result),
|
||||
)
|
||||
ctx.log(f"detect_topics: chunk {i} failed ({result}), failing step")
|
||||
raise ValueError(
|
||||
f"Topic chunk {i} workflow failed after retries: {result!s}"
|
||||
) from result
|
||||
topic_chunks.append(TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]))
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
@@ -768,8 +817,10 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[detect_topics],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||
async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
@@ -834,7 +885,9 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[detect_topics],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.EXTRACT_SUBJECTS)
|
||||
async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult:
|
||||
@@ -913,6 +966,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
|
||||
parents=[extract_subjects],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.PROCESS_SUBJECTS)
|
||||
async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubjectsResult:
|
||||
@@ -939,12 +994,24 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
|
||||
for i, subject in enumerate(subjects)
|
||||
]
|
||||
|
||||
results = await subject_workflow.aio_run_many(bulk_runs)
|
||||
results = await subject_workflow.aio_run_many(bulk_runs, return_exceptions=True)
|
||||
|
||||
subject_summaries = [
|
||||
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY])
|
||||
for result in results
|
||||
]
|
||||
subject_summaries: list[SubjectSummaryResult] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(
|
||||
"[Hatchet] process_subjects: subject workflow failed, failing step",
|
||||
transcript_id=input.transcript_id,
|
||||
subject_index=i,
|
||||
error=str(result),
|
||||
)
|
||||
ctx.log(f"process_subjects: subject {i} failed ({result}), failing step")
|
||||
raise ValueError(
|
||||
f"Subject {i} workflow failed after retries: {result!s}"
|
||||
) from result
|
||||
subject_summaries.append(
|
||||
SubjectSummaryResult(**result[TaskName.GENERATE_DETAILED_SUMMARY])
|
||||
)
|
||||
|
||||
ctx.log(f"process_subjects complete: {len(subject_summaries)} summaries")
|
||||
|
||||
@@ -955,6 +1022,8 @@ async def process_subjects(input: PipelineInput, ctx: Context) -> ProcessSubject
|
||||
parents=[process_subjects],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_RECAP)
|
||||
async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
|
||||
@@ -1044,6 +1113,8 @@ async def generate_recap(input: PipelineInput, ctx: Context) -> RecapResult:
|
||||
parents=[extract_subjects],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_LONG),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.IDENTIFY_ACTION_ITEMS)
|
||||
async def identify_action_items(
|
||||
@@ -1112,6 +1183,8 @@ async def identify_action_items(
|
||||
parents=[process_tracks, generate_title, generate_recap, identify_action_items],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=5,
|
||||
)
|
||||
@with_error_handling(TaskName.FINALIZE)
|
||||
async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
@@ -1181,7 +1254,11 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.task(
|
||||
parents=[finalize], execution_timeout=timedelta(seconds=TIMEOUT_SHORT), retries=3
|
||||
parents=[finalize],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||
async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
@@ -1195,7 +1272,10 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
)
|
||||
from reflector.db.recordings import recordings_controller # noqa: PLC0415
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||
from reflector.storage import ( # noqa: PLC0415
|
||||
get_source_storage,
|
||||
get_transcripts_storage,
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
@@ -1245,7 +1325,7 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
deletion_errors = []
|
||||
|
||||
if input_track_keys and input.bucket_name:
|
||||
master_storage = get_transcripts_storage()
|
||||
master_storage = get_source_storage(input.source_platform)
|
||||
for key in input_track_keys:
|
||||
try:
|
||||
await master_storage.delete_file(key, bucket=input.bucket_name)
|
||||
@@ -1284,6 +1364,8 @@ async def cleanup_consent(input: PipelineInput, ctx: Context) -> ConsentResult:
|
||||
parents=[cleanup_consent],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||
async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||
@@ -1311,6 +1393,8 @@ async def post_zulip(input: PipelineInput, ctx: Context) -> ZulipResult:
|
||||
parents=[cleanup_consent],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||
async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||
@@ -1379,3 +1463,32 @@ async def send_webhook(input: PipelineInput, ctx: Context) -> WebhookResult:
|
||||
except Exception as e:
|
||||
ctx.log(f"send_webhook unexpected error, continuing anyway: {e}")
|
||||
return WebhookResult(webhook_sent=False)
|
||||
|
||||
|
||||
async def on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
|
||||
"""Run when the workflow is truly dead (all retries exhausted).
|
||||
|
||||
Sets transcript status to 'error' only if it is not already 'ended'.
|
||||
Post-finalize tasks (cleanup_consent, post_zulip, send_webhook) use
|
||||
set_error_status=False; if one of them fails, we must not overwrite
|
||||
the 'ended' status that finalize already set.
|
||||
"""
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript and transcript.status == "ended":
|
||||
logger.info(
|
||||
"[Hatchet] on_workflow_failure: transcript already ended, skipping error status (failure was post-finalize)",
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
ctx.log(
|
||||
"on_workflow_failure: transcript already ended, skipping error status"
|
||||
)
|
||||
return
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
|
||||
|
||||
@daily_multitrack_pipeline.on_failure_task()
|
||||
async def _register_on_workflow_failure(input: PipelineInput, ctx: Context) -> None:
|
||||
await on_workflow_failure(input, ctx)
|
||||
|
||||
885
server/reflector/hatchet/workflows/file_pipeline.py
Normal file
885
server/reflector/hatchet/workflows/file_pipeline.py
Normal file
@@ -0,0 +1,885 @@
|
||||
"""
|
||||
Hatchet workflow: FilePipeline
|
||||
|
||||
Processing pipeline for file uploads and Whereby recordings.
|
||||
Orchestrates: extract audio → upload → transcribe/diarize/waveform (parallel)
|
||||
→ assemble → detect topics → title/summaries (parallel) → finalize
|
||||
→ cleanup consent → post zulip / send webhook.
|
||||
|
||||
Note: This file uses deferred imports (inside functions/tasks) intentionally.
|
||||
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
|
||||
are not shared across forks, avoiding connection pooling issues.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.broadcast import (
|
||||
append_event_and_broadcast,
|
||||
set_status_and_broadcast,
|
||||
)
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import (
|
||||
TIMEOUT_HEAVY,
|
||||
TIMEOUT_MEDIUM,
|
||||
TIMEOUT_SHORT,
|
||||
TIMEOUT_TITLE,
|
||||
TaskName,
|
||||
)
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
fresh_db_connection,
|
||||
set_workflow_error_status,
|
||||
with_error_handling,
|
||||
)
|
||||
from reflector.hatchet.workflows.models import (
|
||||
ConsentResult,
|
||||
TitleResult,
|
||||
TopicsResult,
|
||||
WaveformResult,
|
||||
WebhookResult,
|
||||
ZulipResult,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines import topic_processing
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_constants import WAVEFORM_SEGMENTS
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
|
||||
class FilePipelineInput(BaseModel):
|
||||
transcript_id: str
|
||||
room_id: str | None = None
|
||||
|
||||
|
||||
# --- Result models specific to file pipeline ---
|
||||
|
||||
|
||||
class ExtractAudioResult(BaseModel):
|
||||
audio_path: str
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
class UploadAudioResult(BaseModel):
|
||||
audio_url: str
|
||||
audio_path: str
|
||||
|
||||
|
||||
class TranscribeResult(BaseModel):
|
||||
words: list[dict]
|
||||
translation: str | None = None
|
||||
|
||||
|
||||
class DiarizeResult(BaseModel):
|
||||
diarization: list[dict] | None = None
|
||||
|
||||
|
||||
class AssembleTranscriptResult(BaseModel):
|
||||
assembled: bool
|
||||
|
||||
|
||||
class SummariesResult(BaseModel):
|
||||
generated: bool
|
||||
|
||||
|
||||
class FinalizeResult(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
file_pipeline = hatchet.workflow(name="FilePipeline", input_validator=FilePipelineInput)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.EXTRACT_AUDIO)
|
||||
async def extract_audio(input: FilePipelineInput, ctx: Context) -> ExtractAudioResult:
|
||||
"""Extract audio from upload file, convert to MP3."""
|
||||
ctx.log(f"extract_audio: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
await set_status_and_broadcast(input.transcript_id, "processing", logger=logger)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# Find upload file
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
if not audio_file:
|
||||
raise ValueError("No audio file found to process")
|
||||
|
||||
ctx.log(f"extract_audio: processing {audio_file}")
|
||||
|
||||
# Extract audio and write as MP3
|
||||
import av # noqa: PLC0415
|
||||
|
||||
from reflector.processors import AudioFileWriterProcessor # noqa: PLC0415
|
||||
|
||||
duration_ms_container = [0.0]
|
||||
|
||||
async def capture_duration(d):
|
||||
duration_ms_container[0] = d
|
||||
|
||||
mp3_writer = AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=capture_duration,
|
||||
)
|
||||
input_container = av.open(str(audio_file))
|
||||
for frame in input_container.decode(audio=0):
|
||||
await mp3_writer.push(frame)
|
||||
await mp3_writer.flush()
|
||||
input_container.close()
|
||||
|
||||
duration_ms = duration_ms_container[0]
|
||||
audio_path = str(transcript.audio_mp3_filename)
|
||||
|
||||
# Persist duration to database and broadcast to websocket clients
|
||||
from reflector.db.transcripts import TranscriptDuration # noqa: PLC0415
|
||||
from reflector.db.transcripts import transcripts_controller as tc
|
||||
|
||||
await tc.update(transcript, {"duration": duration_ms})
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"DURATION",
|
||||
TranscriptDuration(duration=duration_ms),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
ctx.log(f"extract_audio complete: {audio_path}, duration={duration_ms}ms")
|
||||
return ExtractAudioResult(audio_path=audio_path, duration_ms=duration_ms)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[extract_audio],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.UPLOAD_AUDIO)
|
||||
async def upload_audio(input: FilePipelineInput, ctx: Context) -> UploadAudioResult:
|
||||
"""Upload audio to S3/storage, return audio_url."""
|
||||
ctx.log(f"upload_audio: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
extract_result = ctx.task_output(extract_audio)
|
||||
audio_path = extract_result.audio_path
|
||||
|
||||
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
if not storage:
|
||||
raise ValueError(
|
||||
"Storage backend required for file processing. "
|
||||
"Configure TRANSCRIPT_STORAGE_* settings."
|
||||
)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
storage_path = f"file_pipeline/{input.transcript_id}/audio.mp3"
|
||||
await storage.put_file(storage_path, audio_data)
|
||||
audio_url = await storage.get_file_url(storage_path)
|
||||
|
||||
ctx.log(f"upload_audio complete: {audio_url}")
|
||||
return UploadAudioResult(audio_url=audio_url, audio_path=audio_path)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[upload_audio],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.TRANSCRIBE)
|
||||
async def transcribe(input: FilePipelineInput, ctx: Context) -> TranscribeResult:
|
||||
"""Transcribe the audio file using the configured backend."""
|
||||
ctx.log(f"transcribe: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
upload_result = ctx.task_output(upload_audio)
|
||||
audio_url = upload_result.audio_url
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
source_language = transcript.source_language
|
||||
|
||||
from reflector.pipelines.transcription_helpers import ( # noqa: PLC0415
|
||||
transcribe_file_with_processor,
|
||||
)
|
||||
|
||||
result = await transcribe_file_with_processor(audio_url, source_language)
|
||||
|
||||
ctx.log(f"transcribe complete: {len(result.words)} words")
|
||||
return TranscribeResult(
|
||||
words=[w.model_dump() for w in result.words],
|
||||
translation=result.translation,
|
||||
)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[upload_audio],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.DIARIZE)
|
||||
async def diarize(input: FilePipelineInput, ctx: Context) -> DiarizeResult:
|
||||
"""Diarize the audio file (speaker identification)."""
|
||||
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
if not settings.DIARIZATION_BACKEND:
|
||||
ctx.log("diarize: diarization disabled, skipping")
|
||||
return DiarizeResult(diarization=None)
|
||||
|
||||
upload_result = ctx.task_output(upload_audio)
|
||||
audio_url = upload_result.audio_url
|
||||
|
||||
from reflector.processors.file_diarization import ( # noqa: PLC0415
|
||||
FileDiarizationInput,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import ( # noqa: PLC0415
|
||||
FileDiarizationAutoProcessor,
|
||||
)
|
||||
|
||||
processor = FileDiarizationAutoProcessor()
|
||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||
|
||||
result = None
|
||||
|
||||
async def capture_result(diarization_output):
|
||||
nonlocal result
|
||||
result = diarization_output.diarization
|
||||
|
||||
try:
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
return DiarizeResult(diarization=None)
|
||||
|
||||
ctx.log(f"diarize complete: {len(result) if result else 0} segments")
|
||||
return DiarizeResult(diarization=list(result) if result else None)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[upload_audio],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_WAVEFORM)
|
||||
async def generate_waveform(input: FilePipelineInput, ctx: Context) -> WaveformResult:
|
||||
"""Generate audio waveform visualization."""
|
||||
ctx.log(f"generate_waveform: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
upload_result = ctx.task_output(upload_audio)
|
||||
audio_path = upload_result.audio_path
|
||||
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptWaveform,
|
||||
transcripts_controller,
|
||||
)
|
||||
|
||||
waveform = get_audio_waveform(
|
||||
path=Path(audio_path), segments_count=WAVEFORM_SEGMENTS
|
||||
)
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
transcript.data_path.mkdir(parents=True, exist_ok=True)
|
||||
with open(transcript.audio_waveform_filename, "w") as f:
|
||||
json.dump(waveform, f)
|
||||
|
||||
waveform_data = TranscriptWaveform(waveform=waveform)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"WAVEFORM",
|
||||
waveform_data,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
ctx.log("generate_waveform complete")
|
||||
return WaveformResult(waveform_generated=True)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[transcribe, diarize, generate_waveform],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.ASSEMBLE_TRANSCRIPT)
|
||||
async def assemble_transcript(
|
||||
input: FilePipelineInput, ctx: Context
|
||||
) -> AssembleTranscriptResult:
|
||||
"""Merge transcription + diarization results."""
|
||||
ctx.log(f"assemble_transcript: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
transcribe_result = ctx.task_output(transcribe)
|
||||
diarize_result = ctx.task_output(diarize)
|
||||
|
||||
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import ( # noqa: PLC0415
|
||||
DiarizationSegment,
|
||||
Word,
|
||||
)
|
||||
from reflector.processors.types import ( # noqa: PLC0415
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
|
||||
words = [Word(**w) for w in transcribe_result.words]
|
||||
transcript_data = TranscriptType(
|
||||
words=words, translation=transcribe_result.translation
|
||||
)
|
||||
|
||||
diarization = None
|
||||
if diarize_result.diarization:
|
||||
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
|
||||
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
assembler_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_data, diarization=diarization or []
|
||||
)
|
||||
|
||||
diarized_transcript = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(assembler_input)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript captured")
|
||||
|
||||
# Save the assembled transcript events to the database
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptText,
|
||||
transcripts_controller,
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
assembled_text = diarized_transcript.text if diarized_transcript else ""
|
||||
assembled_translation = (
|
||||
diarized_transcript.translation if diarized_transcript else None
|
||||
)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"TRANSCRIPT",
|
||||
TranscriptText(text=assembled_text, translation=assembled_translation),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
ctx.log("assemble_transcript complete")
|
||||
return AssembleTranscriptResult(assembled=True)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[assemble_transcript],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.DETECT_TOPICS)
|
||||
async def detect_topics(input: FilePipelineInput, ctx: Context) -> TopicsResult:
|
||||
"""Detect topics from the assembled transcript."""
|
||||
ctx.log(f"detect_topics: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
# Re-read the transcript to get the diarized words
|
||||
transcribe_result = ctx.task_output(transcribe)
|
||||
diarize_result = ctx.task_output(diarize)
|
||||
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.transcript_diarization_assembler import ( # noqa: PLC0415
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import ( # noqa: PLC0415
|
||||
DiarizationSegment,
|
||||
Word,
|
||||
)
|
||||
from reflector.processors.types import ( # noqa: PLC0415
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
|
||||
words = [Word(**w) for w in transcribe_result.words]
|
||||
transcript_data = TranscriptType(
|
||||
words=words, translation=transcribe_result.translation
|
||||
)
|
||||
|
||||
diarization = None
|
||||
if diarize_result.diarization:
|
||||
diarization = [DiarizationSegment(**s) for s in diarize_result.diarization]
|
||||
|
||||
# Re-assemble to get the diarized transcript for topic detection
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
assembler_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_data, diarization=diarization or []
|
||||
)
|
||||
|
||||
diarized_transcript = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(assembler_input)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript for topic detection")
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
target_language = transcript.target_language
|
||||
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
|
||||
async def on_topic_callback(data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
transcript=data.transcript.text
|
||||
if hasattr(data.transcript, "text")
|
||||
else "",
|
||||
words=data.transcript.words
|
||||
if hasattr(data.transcript, "words")
|
||||
else [],
|
||||
)
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id, transcript, "TOPIC", topic, logger=logger
|
||||
)
|
||||
|
||||
topics = await topic_processing.detect_topics(
|
||||
diarized_transcript,
|
||||
target_language,
|
||||
on_topic_callback=on_topic_callback,
|
||||
empty_pipeline=empty_pipeline,
|
||||
)
|
||||
|
||||
ctx.log(f"detect_topics complete: {len(topics)} topics")
|
||||
return TopicsResult(topics=topics)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[detect_topics],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||
async def generate_title(input: FilePipelineInput, ctx: Context) -> TitleResult:
|
||||
"""Generate meeting title using LLM."""
|
||||
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
topics_result = ctx.task_output(detect_topics)
|
||||
topics = topics_result.topics
|
||||
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptFinalTitle,
|
||||
transcripts_controller,
|
||||
)
|
||||
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
title_result = None
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
|
||||
async def on_title_callback(data):
|
||||
nonlocal title_result
|
||||
title_result = data.title
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
transcript, {"title": final_title.title}
|
||||
)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"FINAL_TITLE",
|
||||
final_title,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
await topic_processing.generate_title(
|
||||
topics,
|
||||
on_title_callback=on_title_callback,
|
||||
empty_pipeline=empty_pipeline,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
ctx.log(f"generate_title complete: '{title_result}'")
|
||||
return TitleResult(title=title_result)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[detect_topics],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_SUMMARIES)
|
||||
async def generate_summaries(input: FilePipelineInput, ctx: Context) -> SummariesResult:
|
||||
"""Generate long/short summaries and action items."""
|
||||
ctx.log(f"generate_summaries: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
topics_result = ctx.task_output(detect_topics)
|
||||
topics = topics_result.topics
|
||||
|
||||
from reflector.db.transcripts import ( # noqa: PLC0415
|
||||
TranscriptActionItems,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
transcripts_controller,
|
||||
)
|
||||
|
||||
empty_pipeline = topic_processing.EmptyPipeline(logger=logger)
|
||||
|
||||
async with fresh_db_connection():
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
|
||||
async def on_long_summary_callback(data):
|
||||
final_long = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"long_summary": final_long.long_summary}
|
||||
)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"FINAL_LONG_SUMMARY",
|
||||
final_long,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def on_short_summary_callback(data):
|
||||
final_short = TranscriptFinalShortSummary(short_summary=data.short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"short_summary": final_short.short_summary}
|
||||
)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"FINAL_SHORT_SUMMARY",
|
||||
final_short,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def on_action_items_callback(data):
|
||||
action_items = TranscriptActionItems(action_items=data.action_items)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"action_items": action_items.action_items}
|
||||
)
|
||||
await append_event_and_broadcast(
|
||||
input.transcript_id,
|
||||
transcript,
|
||||
"ACTION_ITEMS",
|
||||
action_items,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
await topic_processing.generate_summaries(
|
||||
topics,
|
||||
transcript,
|
||||
on_long_summary_callback=on_long_summary_callback,
|
||||
on_short_summary_callback=on_short_summary_callback,
|
||||
on_action_items_callback=on_action_items_callback,
|
||||
empty_pipeline=empty_pipeline,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
ctx.log("generate_summaries complete")
|
||||
return SummariesResult(generated=True)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[generate_title, generate_summaries],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=5,
|
||||
)
|
||||
@with_error_handling(TaskName.FINALIZE)
|
||||
async def finalize(input: FilePipelineInput, ctx: Context) -> FinalizeResult:
|
||||
"""Set transcript status to 'ended' and broadcast."""
|
||||
ctx.log("finalize: setting status to 'ended'")
|
||||
|
||||
async with fresh_db_connection():
|
||||
await set_status_and_broadcast(input.transcript_id, "ended", logger=logger)
|
||||
|
||||
ctx.log("finalize complete")
|
||||
return FinalizeResult(status="COMPLETED")
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[finalize],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||
async def cleanup_consent(input: FilePipelineInput, ctx: Context) -> ConsentResult:
|
||||
"""Check consent and delete audio files if any participant denied."""
|
||||
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.meetings import ( # noqa: PLC0415
|
||||
meeting_consent_controller,
|
||||
meetings_controller,
|
||||
)
|
||||
from reflector.db.recordings import recordings_controller # noqa: PLC0415
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
from reflector.storage import get_transcripts_storage # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
ctx.log("cleanup_consent: transcript not found")
|
||||
return ConsentResult()
|
||||
|
||||
consent_denied = False
|
||||
recording = None
|
||||
if transcript.recording_id:
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
if meeting:
|
||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||
meeting.id
|
||||
)
|
||||
|
||||
if not consent_denied:
|
||||
ctx.log("cleanup_consent: consent approved, keeping all files")
|
||||
return ConsentResult()
|
||||
|
||||
ctx.log("cleanup_consent: consent denied, deleting audio files")
|
||||
|
||||
deletion_errors = []
|
||||
if recording and recording.bucket_name:
|
||||
keys_to_delete = []
|
||||
if recording.track_keys:
|
||||
keys_to_delete = recording.track_keys
|
||||
elif recording.object_key:
|
||||
keys_to_delete = [recording.object_key]
|
||||
|
||||
master_storage = get_transcripts_storage()
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await master_storage.delete_file(key, bucket=recording.bucket_name)
|
||||
ctx.log(f"Deleted recording file: {recording.bucket_name}/{key}")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete {key}: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
deletion_errors.append(error_msg)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
storage = get_transcripts_storage()
|
||||
try:
|
||||
await storage.delete_file(transcript.storage_audio_path)
|
||||
ctx.log(f"Deleted processed audio: {transcript.storage_audio_path}")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete processed audio: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
deletion_errors.append(error_msg)
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(transcript, "audio_mp3_filename")
|
||||
and transcript.audio_mp3_filename
|
||||
):
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
if (
|
||||
hasattr(transcript, "audio_wav_filename")
|
||||
and transcript.audio_wav_filename
|
||||
):
|
||||
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete local audio files: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
deletion_errors.append(error_msg)
|
||||
|
||||
if deletion_errors:
|
||||
logger.warning(
|
||||
"[Hatchet] cleanup_consent completed with errors",
|
||||
transcript_id=input.transcript_id,
|
||||
error_count=len(deletion_errors),
|
||||
)
|
||||
else:
|
||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||
ctx.log("cleanup_consent: all audio deleted successfully")
|
||||
|
||||
return ConsentResult()
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[cleanup_consent],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||
async def post_zulip(input: FilePipelineInput, ctx: Context) -> ZulipResult:
|
||||
"""Post notification to Zulip."""
|
||||
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
|
||||
|
||||
if not settings.ZULIP_REALM:
|
||||
ctx.log("post_zulip skipped (Zulip not configured)")
|
||||
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
from reflector.zulip import post_transcript_notification # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
message_id = await post_transcript_notification(transcript)
|
||||
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
return ZulipResult(zulip_message_id=message_id)
|
||||
|
||||
|
||||
@file_pipeline.task(
|
||||
parents=[cleanup_consent],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||
async def send_webhook(input: FilePipelineInput, ctx: Context) -> WebhookResult:
|
||||
"""Send completion webhook to external service."""
|
||||
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
|
||||
|
||||
if not input.room_id:
|
||||
ctx.log("send_webhook skipped (no room_id)")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.rooms import rooms_controller # noqa: PLC0415
|
||||
from reflector.utils.webhook import ( # noqa: PLC0415
|
||||
fetch_transcript_webhook_payload,
|
||||
send_webhook_request,
|
||||
)
|
||||
|
||||
room = await rooms_controller.get_by_id(input.room_id)
|
||||
if not room or not room.webhook_url:
|
||||
ctx.log("send_webhook skipped (no webhook_url configured)")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
payload = await fetch_transcript_webhook_payload(
|
||||
transcript_id=input.transcript_id,
|
||||
room_id=input.room_id,
|
||||
)
|
||||
|
||||
if isinstance(payload, str):
|
||||
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
try:
|
||||
response = await send_webhook_request(
|
||||
url=room.webhook_url,
|
||||
payload=payload,
|
||||
event_type="transcript.completed",
|
||||
webhook_secret=room.webhook_secret,
|
||||
timeout=30.0,
|
||||
)
|
||||
ctx.log(f"send_webhook complete: status_code={response.status_code}")
|
||||
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||
except httpx.HTTPStatusError as e:
|
||||
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
|
||||
return WebhookResult(
|
||||
webhook_sent=False, response_code=e.response.status_code
|
||||
)
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
ctx.log(f"send_webhook failed ({e}), continuing")
|
||||
return WebhookResult(webhook_sent=False)
|
||||
except Exception as e:
|
||||
ctx.log(f"send_webhook unexpected error: {e}")
|
||||
return WebhookResult(webhook_sent=False)
|
||||
|
||||
|
||||
# --- On failure handler ---
|
||||
|
||||
|
||||
async def on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
|
||||
"""Set transcript status to 'error' only if not already 'ended'."""
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript and transcript.status == "ended":
|
||||
logger.info(
|
||||
"[Hatchet] FilePipeline on_workflow_failure: transcript already ended, skipping error status",
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
ctx.log(
|
||||
"on_workflow_failure: transcript already ended, skipping error status"
|
||||
)
|
||||
return
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
|
||||
|
||||
@file_pipeline.on_failure_task()
|
||||
async def _register_on_workflow_failure(input: FilePipelineInput, ctx: Context) -> None:
|
||||
await on_workflow_failure(input, ctx)
|
||||
389
server/reflector/hatchet/workflows/live_post_pipeline.py
Normal file
389
server/reflector/hatchet/workflows/live_post_pipeline.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Hatchet workflow: LivePostProcessingPipeline
|
||||
|
||||
Post-processing pipeline for live WebRTC meetings.
|
||||
Triggered after a live meeting ends. Orchestrates:
|
||||
Left branch: waveform → convert_mp3 → upload_mp3 → remove_upload → diarize → cleanup_consent
|
||||
Right branch: generate_title (parallel with left branch)
|
||||
Fan-in: final_summaries → post_zulip → send_webhook
|
||||
|
||||
Note: This file uses deferred imports (inside functions/tasks) intentionally.
|
||||
Hatchet workers run in forked processes; fresh imports per task ensure DB connections
|
||||
are not shared across forks, avoiding connection pooling issues.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from hatchet_sdk import Context
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import (
|
||||
TIMEOUT_HEAVY,
|
||||
TIMEOUT_MEDIUM,
|
||||
TIMEOUT_SHORT,
|
||||
TIMEOUT_TITLE,
|
||||
TaskName,
|
||||
)
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
fresh_db_connection,
|
||||
set_workflow_error_status,
|
||||
with_error_handling,
|
||||
)
|
||||
from reflector.hatchet.workflows.models import (
|
||||
ConsentResult,
|
||||
TitleResult,
|
||||
WaveformResult,
|
||||
WebhookResult,
|
||||
ZulipResult,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class LivePostPipelineInput(BaseModel):
|
||||
transcript_id: str
|
||||
room_id: str | None = None
|
||||
|
||||
|
||||
# --- Result models specific to live post pipeline ---
|
||||
|
||||
|
||||
class ConvertMp3Result(BaseModel):
|
||||
converted: bool
|
||||
|
||||
|
||||
class UploadMp3Result(BaseModel):
|
||||
uploaded: bool
|
||||
|
||||
|
||||
class RemoveUploadResult(BaseModel):
|
||||
removed: bool
|
||||
|
||||
|
||||
class DiarizeResult(BaseModel):
|
||||
diarized: bool
|
||||
|
||||
|
||||
class FinalSummariesResult(BaseModel):
|
||||
generated: bool
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
live_post_pipeline = hatchet.workflow(
|
||||
name="LivePostProcessingPipeline", input_validator=LivePostPipelineInput
|
||||
)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.WAVEFORM)
|
||||
async def waveform(input: LivePostPipelineInput, ctx: Context) -> WaveformResult:
|
||||
"""Generate waveform visualization from recorded audio."""
|
||||
ctx.log(f"waveform: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
PipelineMainWaveform,
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript {input.transcript_id} not found")
|
||||
|
||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
|
||||
ctx.log("waveform complete")
|
||||
return WaveformResult(waveform_generated=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_TITLE),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.GENERATE_TITLE)
|
||||
async def generate_title(input: LivePostPipelineInput, ctx: Context) -> TitleResult:
|
||||
"""Generate meeting title from topics (runs in parallel with audio chain)."""
|
||||
ctx.log(f"generate_title: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
PipelineMainTitle,
|
||||
)
|
||||
|
||||
runner = PipelineMainTitle(transcript_id=input.transcript_id)
|
||||
await runner.run()
|
||||
|
||||
ctx.log("generate_title complete")
|
||||
return TitleResult(title=None)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[waveform],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.CONVERT_MP3)
|
||||
async def convert_mp3(input: LivePostPipelineInput, ctx: Context) -> ConvertMp3Result:
|
||||
"""Convert WAV recording to MP3."""
|
||||
ctx.log(f"convert_mp3: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
pipeline_convert_to_mp3,
|
||||
)
|
||||
|
||||
await pipeline_convert_to_mp3(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("convert_mp3 complete")
|
||||
return ConvertMp3Result(converted=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[convert_mp3],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.UPLOAD_MP3)
|
||||
async def upload_mp3(input: LivePostPipelineInput, ctx: Context) -> UploadMp3Result:
|
||||
"""Upload MP3 to external storage."""
|
||||
ctx.log(f"upload_mp3: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
pipeline_upload_mp3,
|
||||
)
|
||||
|
||||
await pipeline_upload_mp3(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("upload_mp3 complete")
|
||||
return UploadMp3Result(uploaded=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[upload_mp3],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=5,
|
||||
)
|
||||
@with_error_handling(TaskName.REMOVE_UPLOAD)
|
||||
async def remove_upload(
|
||||
input: LivePostPipelineInput, ctx: Context
|
||||
) -> RemoveUploadResult:
|
||||
"""Remove the original upload file."""
|
||||
ctx.log(f"remove_upload: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
pipeline_remove_upload,
|
||||
)
|
||||
|
||||
await pipeline_remove_upload(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("remove_upload complete")
|
||||
return RemoveUploadResult(removed=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[remove_upload],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.DIARIZE)
|
||||
async def diarize(input: LivePostPipelineInput, ctx: Context) -> DiarizeResult:
|
||||
"""Run diarization on the recorded audio."""
|
||||
ctx.log(f"diarize: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
pipeline_diarization,
|
||||
)
|
||||
|
||||
await pipeline_diarization(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("diarize complete")
|
||||
return DiarizeResult(diarized=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[diarize],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=10,
|
||||
)
|
||||
@with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)
|
||||
async def cleanup_consent(input: LivePostPipelineInput, ctx: Context) -> ConsentResult:
|
||||
"""Check consent and delete audio files if any participant denied."""
|
||||
ctx.log(f"cleanup_consent: transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
cleanup_consent as _cleanup_consent,
|
||||
)
|
||||
|
||||
await _cleanup_consent(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("cleanup_consent complete")
|
||||
return ConsentResult()
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[cleanup_consent, generate_title],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
@with_error_handling(TaskName.FINAL_SUMMARIES)
|
||||
async def final_summaries(
|
||||
input: LivePostPipelineInput, ctx: Context
|
||||
) -> FinalSummariesResult:
|
||||
"""Generate final summaries (fan-in after audio chain + title)."""
|
||||
ctx.log(f"final_summaries: starting for transcript_id={input.transcript_id}")
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.pipelines.main_live_pipeline import ( # noqa: PLC0415
|
||||
pipeline_summaries,
|
||||
)
|
||||
|
||||
await pipeline_summaries(transcript_id=input.transcript_id)
|
||||
|
||||
ctx.log("final_summaries complete")
|
||||
return FinalSummariesResult(generated=True)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[final_summaries],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_SHORT),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.POST_ZULIP, set_error_status=False)
|
||||
async def post_zulip(input: LivePostPipelineInput, ctx: Context) -> ZulipResult:
|
||||
"""Post notification to Zulip."""
|
||||
ctx.log(f"post_zulip: transcript_id={input.transcript_id}")
|
||||
|
||||
if not settings.ZULIP_REALM:
|
||||
ctx.log("post_zulip skipped (Zulip not configured)")
|
||||
return ZulipResult(zulip_message_id=None, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
from reflector.zulip import post_transcript_notification # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript:
|
||||
message_id = await post_transcript_notification(transcript)
|
||||
ctx.log(f"post_zulip complete: zulip_message_id={message_id}")
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
return ZulipResult(zulip_message_id=message_id)
|
||||
|
||||
|
||||
@live_post_pipeline.task(
|
||||
parents=[final_summaries],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=15,
|
||||
)
|
||||
@with_error_handling(TaskName.SEND_WEBHOOK, set_error_status=False)
|
||||
async def send_webhook(input: LivePostPipelineInput, ctx: Context) -> WebhookResult:
|
||||
"""Send completion webhook to external service."""
|
||||
ctx.log(f"send_webhook: transcript_id={input.transcript_id}")
|
||||
|
||||
if not input.room_id:
|
||||
ctx.log("send_webhook skipped (no room_id)")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.rooms import rooms_controller # noqa: PLC0415
|
||||
from reflector.utils.webhook import ( # noqa: PLC0415
|
||||
fetch_transcript_webhook_payload,
|
||||
send_webhook_request,
|
||||
)
|
||||
|
||||
room = await rooms_controller.get_by_id(input.room_id)
|
||||
if not room or not room.webhook_url:
|
||||
ctx.log("send_webhook skipped (no webhook_url configured)")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
payload = await fetch_transcript_webhook_payload(
|
||||
transcript_id=input.transcript_id,
|
||||
room_id=input.room_id,
|
||||
)
|
||||
|
||||
if isinstance(payload, str):
|
||||
ctx.log(f"send_webhook skipped (could not build payload): {payload}")
|
||||
return WebhookResult(webhook_sent=False, skipped=True)
|
||||
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
try:
|
||||
response = await send_webhook_request(
|
||||
url=room.webhook_url,
|
||||
payload=payload,
|
||||
event_type="transcript.completed",
|
||||
webhook_secret=room.webhook_secret,
|
||||
timeout=30.0,
|
||||
)
|
||||
ctx.log(f"send_webhook complete: status_code={response.status_code}")
|
||||
return WebhookResult(webhook_sent=True, response_code=response.status_code)
|
||||
except httpx.HTTPStatusError as e:
|
||||
ctx.log(f"send_webhook failed (HTTP {e.response.status_code}), continuing")
|
||||
return WebhookResult(
|
||||
webhook_sent=False, response_code=e.response.status_code
|
||||
)
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
ctx.log(f"send_webhook failed ({e}), continuing")
|
||||
return WebhookResult(webhook_sent=False)
|
||||
except Exception as e:
|
||||
ctx.log(f"send_webhook unexpected error: {e}")
|
||||
return WebhookResult(webhook_sent=False)
|
||||
|
||||
|
||||
# --- On failure handler ---
|
||||
|
||||
|
||||
async def on_workflow_failure(input: LivePostPipelineInput, ctx: Context) -> None:
|
||||
"""Set transcript status to 'error' only if not already 'ended'."""
|
||||
async with fresh_db_connection():
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(input.transcript_id)
|
||||
if transcript and transcript.status == "ended":
|
||||
logger.info(
|
||||
"[Hatchet] LivePostProcessingPipeline on_workflow_failure: transcript already ended",
|
||||
transcript_id=input.transcript_id,
|
||||
)
|
||||
ctx.log(
|
||||
"on_workflow_failure: transcript already ended, skipping error status"
|
||||
)
|
||||
return
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
|
||||
|
||||
@live_post_pipeline.on_failure_task()
|
||||
async def _register_on_workflow_failure(
|
||||
input: LivePostPipelineInput, ctx: Context
|
||||
) -> None:
|
||||
await on_workflow_failure(input, ctx)
|
||||
@@ -24,6 +24,7 @@ class PaddingInput(BaseModel):
|
||||
s3_key: str
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
source_platform: str = "daily"
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
@@ -33,7 +34,12 @@ padding_workflow = hatchet.workflow(
|
||||
)
|
||||
|
||||
|
||||
@padding_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
||||
@padding_workflow.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad audio track with silence based on WebM container start_time."""
|
||||
ctx.log(f"pad_track: track {input.track_index}, s3_key={input.s3_key}")
|
||||
@@ -45,20 +51,14 @@ async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
# TODO: replace direct AwsStorage construction with get_transcripts_storage() factory
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
aws_endpoint_url=settings.TRANSCRIPT_STORAGE_AWS_ENDPOINT_URL,
|
||||
from reflector.storage import ( # noqa: PLC0415
|
||||
get_source_storage,
|
||||
get_transcripts_storage,
|
||||
)
|
||||
|
||||
source_url = await storage.get_file_url(
|
||||
# Source reads: use platform-specific credentials
|
||||
source_storage = get_source_storage(input.source_platform)
|
||||
source_url = await source_storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
@@ -96,52 +96,28 @@ async def pad_track(input: PaddingInput, ctx: Context) -> PadTrackResult:
|
||||
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||
|
||||
# Presign PUT URL for output (Modal will upload directly)
|
||||
output_url = await storage.get_file_url(
|
||||
# Output writes: use transcript storage (our own bucket)
|
||||
output_storage = get_transcripts_storage()
|
||||
output_url = await output_storage.get_file_url(
|
||||
storage_path,
|
||||
operation="put_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
from reflector.processors.audio_padding_modal import ( # noqa: PLC0415
|
||||
AudioPaddingModalProcessor,
|
||||
from reflector.processors.audio_padding_auto import ( # noqa: PLC0415
|
||||
AudioPaddingAutoProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
processor = AudioPaddingModalProcessor()
|
||||
result = await processor.pad_track(
|
||||
track_url=source_url,
|
||||
output_url=output_url,
|
||||
start_time_seconds=start_time_seconds,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
file_size = result.size
|
||||
processor = AudioPaddingAutoProcessor()
|
||||
result = await processor.pad_track(
|
||||
track_url=source_url,
|
||||
output_url=output_url,
|
||||
start_time_seconds=start_time_seconds,
|
||||
track_index=input.track_index,
|
||||
)
|
||||
file_size = result.size
|
||||
|
||||
ctx.log(f"pad_track: Modal returned size={file_size}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = e.response.text if hasattr(e.response, "text") else str(e)
|
||||
logger.error(
|
||||
"[Hatchet] Modal padding HTTP error",
|
||||
transcript_id=input.transcript_id,
|
||||
track_index=input.track_index,
|
||||
status_code=e.response.status_code if hasattr(e, "response") else None,
|
||||
error=error_detail,
|
||||
exc_info=True,
|
||||
)
|
||||
raise Exception(
|
||||
f"Modal padding failed: HTTP {e.response.status_code}"
|
||||
) from e
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(
|
||||
"[Hatchet] Modal padding timeout",
|
||||
transcript_id=input.transcript_id,
|
||||
track_index=input.track_index,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise Exception("Modal padding timeout") from e
|
||||
ctx.log(f"pad_track: padding returned size={file_size}")
|
||||
|
||||
logger.info(
|
||||
"[Hatchet] pad_track complete",
|
||||
|
||||
@@ -13,7 +13,7 @@ from hatchet_sdk.rate_limit import RateLimit
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_MEDIUM
|
||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, TIMEOUT_HEAVY
|
||||
from reflector.hatchet.workflows.models import SubjectSummaryResult
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.summary.prompts import (
|
||||
@@ -41,8 +41,10 @@ subject_workflow = hatchet.workflow(
|
||||
|
||||
|
||||
@subject_workflow.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=60,
|
||||
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=2)],
|
||||
)
|
||||
async def generate_detailed_summary(
|
||||
|
||||
@@ -50,7 +50,9 @@ topic_chunk_workflow = hatchet.workflow(
|
||||
|
||||
@topic_chunk_workflow.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_MEDIUM),
|
||||
retries=3,
|
||||
retries=5,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=60,
|
||||
rate_limits=[RateLimit(static_key=LLM_RATE_LIMIT_KEY, units=1)],
|
||||
)
|
||||
async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunkResult:
|
||||
|
||||
@@ -36,6 +36,7 @@ class TrackInput(BaseModel):
|
||||
bucket_name: str
|
||||
transcript_id: str
|
||||
language: str = "en"
|
||||
source_platform: str = "daily"
|
||||
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
@@ -43,7 +44,12 @@ hatchet = HatchetClientManager.get_client()
|
||||
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
|
||||
|
||||
|
||||
@track_workflow.task(execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3)
|
||||
@track_workflow.task(
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
"""Pad single audio track with silence for alignment.
|
||||
|
||||
@@ -59,20 +65,14 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
)
|
||||
|
||||
try:
|
||||
# Create fresh storage instance to avoid aioboto3 fork issues
|
||||
# TODO: replace direct AwsStorage construction with get_transcripts_storage() factory
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
aws_endpoint_url=settings.TRANSCRIPT_STORAGE_AWS_ENDPOINT_URL,
|
||||
from reflector.storage import ( # noqa: PLC0415
|
||||
get_source_storage,
|
||||
get_transcripts_storage,
|
||||
)
|
||||
|
||||
source_url = await storage.get_file_url(
|
||||
# Source reads: use platform-specific credentials
|
||||
source_storage = get_source_storage(input.source_platform)
|
||||
source_url = await source_storage.get_file_url(
|
||||
input.s3_key,
|
||||
operation="get_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
@@ -99,18 +99,19 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
|
||||
storage_path = f"file_pipeline_hatchet/{input.transcript_id}/tracks/padded_{input.track_index}.webm"
|
||||
|
||||
# Presign PUT URL for output (Modal uploads directly)
|
||||
output_url = await storage.get_file_url(
|
||||
# Output writes: use transcript storage (our own bucket)
|
||||
output_storage = get_transcripts_storage()
|
||||
output_url = await output_storage.get_file_url(
|
||||
storage_path,
|
||||
operation="put_object",
|
||||
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
from reflector.processors.audio_padding_modal import ( # noqa: PLC0415
|
||||
AudioPaddingModalProcessor,
|
||||
from reflector.processors.audio_padding_auto import ( # noqa: PLC0415
|
||||
AudioPaddingAutoProcessor,
|
||||
)
|
||||
|
||||
processor = AudioPaddingModalProcessor()
|
||||
processor = AudioPaddingAutoProcessor()
|
||||
result = await processor.pad_track(
|
||||
track_url=source_url,
|
||||
output_url=output_url,
|
||||
@@ -141,7 +142,11 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
|
||||
|
||||
|
||||
@track_workflow.task(
|
||||
parents=[pad_track], execution_timeout=timedelta(seconds=TIMEOUT_HEAVY), retries=3
|
||||
parents=[pad_track],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_HEAVY),
|
||||
retries=3,
|
||||
backoff_factor=2.0,
|
||||
backoff_max_seconds=30,
|
||||
)
|
||||
async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackResult:
|
||||
"""Transcribe audio track using GPU (Modal.com) or local Whisper."""
|
||||
@@ -161,18 +166,18 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
|
||||
raise ValueError("Missing padded_key from pad_track")
|
||||
|
||||
# Presign URL on demand (avoids stale URLs on workflow replay)
|
||||
# TODO: replace direct AwsStorage construction with get_transcripts_storage() factory
|
||||
from reflector.settings import settings # noqa: PLC0415
|
||||
from reflector.storage.storage_aws import AwsStorage # noqa: PLC0415
|
||||
|
||||
storage = AwsStorage(
|
||||
aws_bucket_name=settings.TRANSCRIPT_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
aws_endpoint_url=settings.TRANSCRIPT_STORAGE_AWS_ENDPOINT_URL,
|
||||
from reflector.storage import ( # noqa: PLC0415
|
||||
get_source_storage,
|
||||
get_transcripts_storage,
|
||||
)
|
||||
|
||||
# If bucket_name is set, file is still in the platform's source bucket (no padding applied).
|
||||
# If bucket_name is None, padded file was written to our transcript storage.
|
||||
if bucket_name:
|
||||
storage = get_source_storage(input.source_platform)
|
||||
else:
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
audio_url = await storage.get_file_url(
|
||||
padded_key,
|
||||
operation="get_object",
|
||||
|
||||
@@ -65,10 +65,25 @@ class LLM:
|
||||
async def get_response(
|
||||
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||
) -> str:
|
||||
"""Get a text response using TreeSummarize for non-function-calling models"""
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||
return str(response).strip()
|
||||
"""Get a text response using TreeSummarize for non-function-calling models.
|
||||
|
||||
Uses the same retry() wrapper as get_structured_response for transient
|
||||
network errors (connection, timeout, OSError) with exponential backoff.
|
||||
"""
|
||||
|
||||
async def _call():
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
response = await summarizer.aget_response(
|
||||
prompt, texts, tone_name=tone_name
|
||||
)
|
||||
return str(response).strip()
|
||||
|
||||
return await retry(_call)(
|
||||
retry_attempts=3,
|
||||
retry_backoff_interval=1.0,
|
||||
retry_backoff_max=30.0,
|
||||
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
|
||||
)
|
||||
|
||||
async def get_structured_response(
|
||||
self,
|
||||
|
||||
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
||||
from typing import Generic
|
||||
|
||||
import av
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from celery import current_task, shared_task
|
||||
from pydantic import BaseModel
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
@@ -397,7 +397,9 @@ class PipelineMainLive(PipelineMainBase):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(self.transcript_id)
|
||||
room_id = transcript.room_id if transcript else None
|
||||
await pipeline_post(transcript_id=self.transcript_id, room_id=room_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
@@ -792,29 +794,20 @@ async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
async def pipeline_post(*, transcript_id: str, room_id: str | None = None):
|
||||
"""
|
||||
Run the post pipeline
|
||||
Run the post pipeline via Hatchet.
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
| task_cleanup_consent.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
|
||||
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||
transcript_id=transcript_id
|
||||
)
|
||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||
|
||||
chain = chord(
|
||||
group(chain_mp3_and_diarize, chain_title_preview),
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
|
||||
return chain.delay()
|
||||
await HatchetClientManager.start_workflow(
|
||||
"LivePostProcessingPipeline",
|
||||
{
|
||||
"transcript_id": str(transcript_id),
|
||||
"room_id": str(room_id) if room_id else None,
|
||||
},
|
||||
additional_metadata={"transcript_id": str(transcript_id)},
|
||||
)
|
||||
|
||||
|
||||
@get_transcript
|
||||
|
||||
@@ -4,6 +4,8 @@ from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_padding import AudioPaddingProcessor # noqa: F401
|
||||
from .audio_padding_auto import AudioPaddingAutoProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
|
||||
86
server/reflector/processors/_audio_download.py
Normal file
86
server/reflector/processors/_audio_download.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Shared audio download utility for local processors.
|
||||
|
||||
Downloads audio from a URL to a temporary file for in-process ML inference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
async def download_audio_to_temp(url: str) -> Path:
|
||||
"""Download audio from URL to a temporary file.
|
||||
|
||||
The caller is responsible for deleting the temp file after use.
|
||||
|
||||
Args:
|
||||
url: Presigned URL or public URL to download audio from.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded temporary file.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _download_blocking, url)
|
||||
|
||||
|
||||
def _download_blocking(url: str) -> Path:
|
||||
"""Blocking download implementation."""
|
||||
log = logger.bind(url=url[:80])
|
||||
log.info("Downloading audio to temp file")
|
||||
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Determine extension from content-type or URL
|
||||
ext = _detect_extension(url, response.headers.get("content-type", ""))
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext)
|
||||
try:
|
||||
total_bytes = 0
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
log.info("Audio downloaded", bytes=total_bytes, path=tmp_path)
|
||||
return Path(tmp_path)
|
||||
except Exception:
|
||||
# Clean up on failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _detect_extension(url: str, content_type: str) -> str:
|
||||
"""Detect audio file extension from URL or content-type."""
|
||||
# Try URL path first
|
||||
path = url.split("?")[0] # Strip query params
|
||||
for ext in (".wav", ".mp3", ".mp4", ".m4a", ".webm", ".ogg", ".flac"):
|
||||
if path.lower().endswith(ext):
|
||||
return ext
|
||||
|
||||
# Try content-type
|
||||
ct_map = {
|
||||
"audio/wav": ".wav",
|
||||
"audio/x-wav": ".wav",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/mp4": ".m4a",
|
||||
"audio/webm": ".webm",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/flac": ".flac",
|
||||
}
|
||||
for ct, ext in ct_map.items():
|
||||
if ct in content_type.lower():
|
||||
return ext
|
||||
|
||||
return ".audio"
|
||||
76
server/reflector/processors/_marian_translator_service.py
Normal file
76
server/reflector/processors/_marian_translator_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
MarianMT translation service.
|
||||
|
||||
Singleton service that loads HuggingFace MarianMT translation models
|
||||
and reuses them across all MarianMT translator processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/translator.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarianTranslatorService:
|
||||
"""MarianMT text translation service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._current_pair = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
"""Load the translation model for a specific language pair."""
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
logger.info(
|
||||
"Loading MarianMT model: %s (%s -> %s)",
|
||||
model_name,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
self._current_pair = (source_language.lower(), target_language.lower())
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
"""Resolve language pair to MarianMT model name."""
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
"""Translate text between languages.
|
||||
|
||||
Args:
|
||||
text: Text to translate.
|
||||
source_language: Source language code (e.g. "en").
|
||||
target_language: Target language code (e.g. "fr").
|
||||
|
||||
Returns:
|
||||
dict with "text" key containing {source_language: original, target_language: translated}.
|
||||
"""
|
||||
pair = (source_language.lower(), target_language.lower())
|
||||
if self._pipeline is None or self._current_pair != pair:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all MarianMT translator processors
|
||||
translator_service = MarianTranslatorService()
|
||||
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Pyannote diarization service using pyannote.audio.
|
||||
|
||||
Singleton service that loads the pyannote speaker diarization model once
|
||||
and reuses it across all pyannote diarization processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/diarizer.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tarfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from urllib.request import urlopen
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import yaml
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
S3_BUNDLE_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz"
|
||||
BUNDLE_CACHE_DIR = Path.home() / ".cache" / "pyannote-bundle"
|
||||
|
||||
|
||||
def _ensure_model(cache_dir: Path) -> str:
|
||||
"""Download and extract S3 model bundle if not cached."""
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
config_path = model_dir / "config.yaml"
|
||||
|
||||
if config_path.exists():
|
||||
logger.info("Using cached model bundle at %s", model_dir)
|
||||
return str(model_dir)
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
tarball_path = cache_dir / "model.tar.gz"
|
||||
|
||||
logger.info("Downloading model bundle from %s", S3_BUNDLE_URL)
|
||||
with urlopen(S3_BUNDLE_URL) as response, open(tarball_path, "wb") as f:
|
||||
while chunk := response.read(8192):
|
||||
f.write(chunk)
|
||||
|
||||
logger.info("Extracting model bundle")
|
||||
with tarfile.open(tarball_path, "r:gz") as tar:
|
||||
tar.extractall(path=cache_dir, filter="data")
|
||||
tarball_path.unlink()
|
||||
|
||||
_patch_config(model_dir, cache_dir)
|
||||
return str(model_dir)
|
||||
|
||||
|
||||
def _patch_config(model_dir: Path, cache_dir: Path) -> None:
|
||||
"""Rewrite config.yaml to reference local pytorch_model.bin paths."""
|
||||
config_path = model_dir / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config["pipeline"]["params"]["segmentation"] = str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
config["pipeline"]["params"]["embedding"] = str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
logger.info("Patched config.yaml with local model paths")
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
"""Pyannote speaker diarization service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
hf_token = settings.HF_TOKEN
|
||||
|
||||
if hf_token:
|
||||
logger.info("Loading pyannote model from HuggingFace (HF_TOKEN set)")
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
else:
|
||||
logger.info("HF_TOKEN not set — loading model from S3 bundle")
|
||||
model_path = _ensure_model(BUNDLE_CACHE_DIR)
|
||||
config_path = Path(model_path) / "config.yaml"
|
||||
self._pipeline = Pipeline.from_pretrained(str(config_path))
|
||||
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
"""Run speaker diarization on an audio file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
timestamp: Offset to add to all segment timestamps.
|
||||
|
||||
Returns:
|
||||
dict with "diarization" key containing list of
|
||||
{"start": float, "end": float, "speaker": int} segments.
|
||||
"""
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
segments = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": segments}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all pyannote diarization processors
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Pyannote audio diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
"""Run pyannote diarization on audio from URL."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return result["diarization"]
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
23
server/reflector/processors/audio_padding.py
Normal file
23
server/reflector/processors/audio_padding.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Base class for audio padding processors.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingProcessor:
|
||||
"""Base class for audio padding processors."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
track_url: str,
|
||||
output_url: str,
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
raise NotImplementedError
|
||||
32
server/reflector/processors/audio_padding_auto.py
Normal file
32
server/reflector/processors/audio_padding_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioPaddingAutoProcessor(AudioPaddingProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.PADDING_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_padding_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `PADDING_XXX_YYY`, push to constructor as `xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "PADDING_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -6,18 +6,14 @@ import asyncio
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO_HTTP
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingModalProcessor:
|
||||
class AudioPaddingModalProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using Modal.com CPU backend via HTTP."""
|
||||
|
||||
def __init__(
|
||||
@@ -64,7 +60,7 @@ class AudioPaddingModalProcessor:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO) as client:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT_AUDIO_HTTP) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -111,3 +107,6 @@ class AudioPaddingModalProcessor:
|
||||
except Exception as e:
|
||||
log.error("Modal padding unexpected error", error=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
AudioPaddingAutoProcessor.register("modal", AudioPaddingModalProcessor)
|
||||
|
||||
133
server/reflector/processors/audio_padding_pyav.py
Normal file
133
server/reflector/processors/audio_padding_pyav.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
PyAV audio padding processor.
|
||||
|
||||
Pads audio tracks with silence directly in-process (no HTTP).
|
||||
Reuses the shared PyAV utilities from reflector.utils.audio_padding.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
from reflector.utils.audio_padding import apply_audio_padding_to_file
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
class AudioPaddingPyavProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using PyAV (no HTTP backend)."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
track_url: str,
|
||||
output_url: str,
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
"""Pad audio track with silence via PyAV.
|
||||
|
||||
Args:
|
||||
track_url: Presigned GET URL for source audio track
|
||||
output_url: Presigned PUT URL for output WebM
|
||||
start_time_seconds: Amount of silence to prepend
|
||||
track_index: Track index for logging
|
||||
"""
|
||||
if not track_url:
|
||||
raise ValueError("track_url cannot be empty")
|
||||
if start_time_seconds <= 0:
|
||||
raise ValueError(
|
||||
f"start_time_seconds must be positive, got {start_time_seconds}"
|
||||
)
|
||||
|
||||
log = logger.bind(track_index=track_index, padding_seconds=start_time_seconds)
|
||||
log.info("Starting local PyAV padding")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
self._pad_track_blocking,
|
||||
track_url,
|
||||
output_url,
|
||||
start_time_seconds,
|
||||
track_index,
|
||||
)
|
||||
|
||||
def _pad_track_blocking(
|
||||
self,
|
||||
track_url: str,
|
||||
output_url: str,
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
"""Blocking padding work: download, pad with PyAV, upload."""
|
||||
import requests
|
||||
|
||||
log = logger.bind(track_index=track_index, padding_seconds=start_time_seconds)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
input_path = None
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
# Download source audio
|
||||
log.info("Downloading track for local padding")
|
||||
response = requests.get(track_url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
input_path = os.path.join(temp_dir, "track.webm")
|
||||
total_bytes = 0
|
||||
with open(input_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
log.info("Track downloaded", bytes=total_bytes)
|
||||
|
||||
# Apply padding using shared PyAV utility
|
||||
output_path = os.path.join(temp_dir, "padded.webm")
|
||||
with av.open(input_path) as in_container:
|
||||
apply_audio_padding_to_file(
|
||||
in_container,
|
||||
output_path,
|
||||
start_time_seconds,
|
||||
track_index,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
log.info("Local padding complete", size=file_size)
|
||||
|
||||
# Upload padded track
|
||||
log.info("Uploading padded track to S3")
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(output_url, data=f, timeout=S3_TIMEOUT)
|
||||
upload_response.raise_for_status()
|
||||
log.info("Upload complete", size=file_size)
|
||||
|
||||
return PaddingResponse(size=file_size)
|
||||
|
||||
except Exception as e:
|
||||
log.error("Local padding failed", error=str(e), exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if input_path and os.path.exists(input_path):
|
||||
try:
|
||||
os.unlink(input_path)
|
||||
except Exception as e:
|
||||
log.warning("Failed to cleanup input file", error=str(e))
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
log.warning("Failed to cleanup output file", error=str(e))
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
log.warning("Failed to cleanup temp directory", error=str(e))
|
||||
|
||||
|
||||
AudioPaddingAutoProcessor.register("pyav", AudioPaddingPyavProcessor)
|
||||
@@ -3,13 +3,17 @@ from faster_whisper import WhisperModel
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = WhisperModel(
|
||||
"tiny", device="cpu", compute_type="float32", num_workers=12
|
||||
settings.WHISPER_CHUNK_MODEL,
|
||||
device="cpu",
|
||||
compute_type="float32",
|
||||
num_workers=12,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
|
||||
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Pyannote file diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
|
||||
|
||||
class FileDiarizationPyannoteProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Run pyannote diarization on file from URL."""
|
||||
self.logger.info(f"Starting pyannote diarization from {data.audio_url}")
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return FileDiarizationOutput(diarization=result["diarization"])
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor)
|
||||
275
server/reflector/processors/file_transcript_whisper.py
Normal file
275
server/reflector/processors/file_transcript_whisper.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Local file transcription processor using faster-whisper with Silero VAD pipeline.
|
||||
|
||||
Downloads audio from URL, segments it using Silero VAD, transcribes each
|
||||
segment with faster-whisper, and merges results. No HTTP backend needed.
|
||||
|
||||
VAD pipeline ported from gpu/self_hosted/app/services/transcriber.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
class FileTranscriptWhisperProcessor(FileTranscriptProcessor):
|
||||
"""Transcribe complete audio files using local faster-whisper with VAD."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _ensure_model(self):
|
||||
"""Lazy-load the whisper model on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
model_name = settings.WHISPER_FILE_MODEL
|
||||
|
||||
self.logger.info(
|
||||
"Loading whisper model",
|
||||
model=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
self._model = faster_whisper.WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Download file, run VAD segmentation, transcribe each segment."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self._transcribe_file_blocking,
|
||||
str(tmp_path),
|
||||
data.language,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _transcribe_file_blocking(self, file_path: str, language: str) -> Transcript:
|
||||
"""Blocking transcription with VAD pipeline."""
|
||||
self._ensure_model()
|
||||
|
||||
audio_array = _load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
# VAD segmentation → batch merging
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
|
||||
for seg_start, seg_end in _vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
# If no speech detected, try transcribing the whole file
|
||||
if not merged_batches:
|
||||
return self._transcribe_whole_file(file_path, language)
|
||||
|
||||
# Transcribe each batch
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = _pad_audio(segment, SAMPLE_RATE)
|
||||
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
all_words.append(
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time, 2),
|
||||
"end": round(float(w.end) + start_time, 2),
|
||||
}
|
||||
)
|
||||
|
||||
all_words = _enforce_word_timing_constraints(all_words)
|
||||
|
||||
words = [
|
||||
Word(text=w["word"], start=w["start"], end=w["end"]) for w in all_words
|
||||
]
|
||||
words.sort(key=lambda w: w.start)
|
||||
return Transcript(words=words)
|
||||
|
||||
def _transcribe_whole_file(self, file_path: str, language: str) -> Transcript:
|
||||
"""Fallback: transcribe entire file without VAD segmentation."""
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
words = []
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
words.append(
|
||||
Word(
|
||||
text=w.word,
|
||||
start=round(float(w.start), 2),
|
||||
end=round(float(w.end), 2),
|
||||
)
|
||||
)
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# --- VAD helpers (ported from gpu/self_hosted/app/services/transcriber.py) ---
|
||||
# IMPORTANT: This VAD segment logic is duplicated for deployment isolation.
|
||||
# If you modify this, consider updating the GPU service copy as well:
|
||||
# - gpu/self_hosted/app/services/transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
|
||||
|
||||
def _load_audio_via_ffmpeg(
|
||||
input_path: str, sample_rate: int = SAMPLE_RATE
|
||||
) -> np.ndarray:
|
||||
"""Load audio file via ffmpeg, converting to mono float32 at target sample rate."""
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
|
||||
|
||||
def _vad_segments(
|
||||
audio_array: np.ndarray,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
"""Detect speech segments using Silero VAD."""
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)), mode="constant")
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
|
||||
# Handle case where audio ends while speech is still active
|
||||
if start is not None:
|
||||
audio_duration = len(audio_array) / float(sample_rate)
|
||||
yield (start / float(SAMPLE_RATE), audio_duration)
|
||||
|
||||
iterator.reset_states()
|
||||
|
||||
|
||||
def _pad_audio(audio_array: np.ndarray, sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""Pad short audio with silence for VAD compatibility."""
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
def _enforce_word_timing_constraints(words: list[dict]) -> list[dict]:
|
||||
"""Ensure no word end time exceeds the next word's start time."""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
|
||||
FileTranscriptAutoProcessor.register("whisper", FileTranscriptWhisperProcessor)
|
||||
50
server/reflector/processors/transcript_translator_marian.py
Normal file
50
server/reflector/processors/transcript_translator_marian.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
MarianMT transcript translator processor using HuggingFace MarianMT in-process.
|
||||
|
||||
Translates transcript text using HuggingFace MarianMT models
|
||||
locally. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from reflector.processors._marian_translator_service import translator_service
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.types import TranslationLanguages
|
||||
|
||||
|
||||
class TranscriptTranslatorMarianProcessor(TranscriptTranslatorProcessor):
|
||||
"""Translate transcript text using MarianMT models."""
|
||||
|
||||
async def _translate(self, text: str) -> str | None:
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
|
||||
languages = TranslationLanguages()
|
||||
assert languages.is_supported(target_language)
|
||||
|
||||
self.logger.debug(f"MarianMT translate {text=}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
translator_service.translate,
|
||||
text,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
if target_language in result["text"]:
|
||||
translation = result["text"][target_language]
|
||||
else:
|
||||
translation = None
|
||||
|
||||
self.logger.debug(f"Translation result: {text=}, {translation=}")
|
||||
return translation
|
||||
|
||||
|
||||
TranscriptTranslatorAutoProcessor.register(
|
||||
"marian", TranscriptTranslatorMarianProcessor
|
||||
)
|
||||
@@ -10,7 +10,6 @@ from dataclasses import dataclass
|
||||
from typing import Literal, Union, assert_never
|
||||
|
||||
import celery
|
||||
from celery.result import AsyncResult
|
||||
from hatchet_sdk.clients.rest.exceptions import ApiException, NotFoundException
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
@@ -18,7 +17,6 @@ from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.transcripts import Transcript, transcripts_controller
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.utils.string import NonEmptyString
|
||||
|
||||
|
||||
@@ -40,6 +38,7 @@ class MultitrackProcessingConfig:
|
||||
track_keys: list[str]
|
||||
recording_id: NonEmptyString | None = None
|
||||
room_id: NonEmptyString | None = None
|
||||
source_platform: str = "daily"
|
||||
mode: Literal["multitrack"] = "multitrack"
|
||||
|
||||
|
||||
@@ -104,11 +103,8 @@ async def validate_transcript_for_processing(
|
||||
):
|
||||
return ValidationNotReady(detail="Recording is not ready for processing")
|
||||
|
||||
# Check Celery tasks
|
||||
# Check Celery tasks (multitrack still uses Celery for some paths)
|
||||
if task_is_scheduled_or_active(
|
||||
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
||||
transcript_id=transcript.id,
|
||||
) or task_is_scheduled_or_active(
|
||||
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
|
||||
transcript_id=transcript.id,
|
||||
):
|
||||
@@ -174,11 +170,8 @@ async def prepare_transcript_processing(validation: ValidationOk) -> PrepareResu
|
||||
|
||||
async def dispatch_transcript_processing(
|
||||
config: ProcessingConfig, force: bool = False
|
||||
) -> AsyncResult | None:
|
||||
"""Dispatch transcript processing to appropriate backend (Hatchet or Celery).
|
||||
|
||||
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
|
||||
"""
|
||||
) -> None:
|
||||
"""Dispatch transcript processing to Hatchet workflow engine."""
|
||||
if isinstance(config, MultitrackProcessingConfig):
|
||||
# Multitrack processing always uses Hatchet (no Celery fallback)
|
||||
# First check if we can replay (outside transaction since it's read-only)
|
||||
@@ -256,6 +249,7 @@ async def dispatch_transcript_processing(
|
||||
"bucket_name": config.bucket_name,
|
||||
"transcript_id": config.transcript_id,
|
||||
"room_id": config.room_id,
|
||||
"source_platform": config.source_platform,
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": config.transcript_id,
|
||||
@@ -273,7 +267,21 @@ async def dispatch_transcript_processing(
|
||||
return None
|
||||
|
||||
elif isinstance(config, FileProcessingConfig):
|
||||
return task_pipeline_file_process.delay(transcript_id=config.transcript_id)
|
||||
# File processing uses Hatchet workflow
|
||||
workflow_id = await HatchetClientManager.start_workflow(
|
||||
workflow_name="FilePipeline",
|
||||
input_data={"transcript_id": config.transcript_id},
|
||||
additional_metadata={"transcript_id": config.transcript_id},
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(config.transcript_id)
|
||||
if transcript:
|
||||
await transcripts_controller.update(
|
||||
transcript, {"workflow_run_id": workflow_id}
|
||||
)
|
||||
|
||||
logger.info("File pipeline dispatched via Hatchet", workflow_id=workflow_id)
|
||||
return None
|
||||
else:
|
||||
assert_never(config)
|
||||
|
||||
|
||||
@@ -40,14 +40,24 @@ class Settings(BaseSettings):
|
||||
# backends: silero, frames
|
||||
AUDIO_CHUNKER_BACKEND: str = "frames"
|
||||
|
||||
# HuggingFace token for gated models (pyannote diarization in --cpu mode)
|
||||
HF_TOKEN: str | None = None
|
||||
|
||||
# Audio Transcription
|
||||
# backends:
|
||||
# - whisper: in-process model loading (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
|
||||
# Whisper model sizes for local transcription
|
||||
# Options: "tiny", "base", "small", "medium", "large-v2"
|
||||
WHISPER_CHUNK_MODEL: str = "tiny"
|
||||
WHISPER_FILE_MODEL: str = "tiny"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = (
|
||||
540 # Below Hatchet TIMEOUT_HEAVY (600) to avoid timeout race
|
||||
)
|
||||
|
||||
# Audio Transcription: modal backend
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
@@ -73,6 +83,9 @@ class Settings(BaseSettings):
|
||||
DAILYCO_STORAGE_AWS_BUCKET_NAME: str | None = None
|
||||
DAILYCO_STORAGE_AWS_REGION: str | None = None
|
||||
DAILYCO_STORAGE_AWS_ROLE_ARN: str | None = None
|
||||
# Worker credentials for reading/deleting from Daily's recording bucket
|
||||
DAILYCO_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Translate into the target language
|
||||
TRANSLATION_BACKEND: str = "passthrough"
|
||||
@@ -97,7 +110,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# Diarization
|
||||
# backend: modal — HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
# backends: modal — HTTP API client, pyannote — in-process pyannote.audio
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
@@ -106,7 +119,11 @@ class Settings(BaseSettings):
|
||||
# Diarization: modal backend
|
||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Audio Padding (Modal.com backend)
|
||||
# Audio Padding
|
||||
# backends:
|
||||
# - pyav: in-process PyAV padding (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
PADDING_BACKEND: str = "pyav"
|
||||
PADDING_URL: str | None = None
|
||||
PADDING_MODAL_API_KEY: str | None = None
|
||||
|
||||
@@ -163,6 +180,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# Daily.co integration
|
||||
DAILY_API_URL: str = "https://api.daily.co/v1"
|
||||
DAILY_API_KEY: str | None = None
|
||||
DAILY_WEBHOOK_SECRET: str | None = None
|
||||
DAILY_SUBDOMAIN: str | None = None
|
||||
|
||||
@@ -17,6 +17,49 @@ def get_transcripts_storage() -> Storage:
|
||||
)
|
||||
|
||||
|
||||
def get_source_storage(platform: str) -> Storage:
|
||||
"""Get storage for reading/deleting source recording files from the platform's bucket.
|
||||
|
||||
Returns an AwsStorage configured with the platform's worker credentials
|
||||
(access keys), or falls back to get_transcripts_storage() when platform-specific
|
||||
credentials aren't configured (e.g., single-bucket setups).
|
||||
|
||||
Args:
|
||||
platform: Recording platform name ("daily", "whereby", or other).
|
||||
"""
|
||||
if platform == "daily":
|
||||
if (
|
||||
settings.DAILYCO_STORAGE_AWS_ACCESS_KEY_ID
|
||||
and settings.DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY
|
||||
and settings.DAILYCO_STORAGE_AWS_BUCKET_NAME
|
||||
):
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
|
||||
return AwsStorage(
|
||||
aws_bucket_name=settings.DAILYCO_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.DAILYCO_STORAGE_AWS_REGION or "us-east-1",
|
||||
aws_access_key_id=settings.DAILYCO_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.DAILYCO_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
elif platform == "whereby":
|
||||
if (
|
||||
settings.WHEREBY_STORAGE_AWS_ACCESS_KEY_ID
|
||||
and settings.WHEREBY_STORAGE_AWS_SECRET_ACCESS_KEY
|
||||
and settings.WHEREBY_STORAGE_AWS_BUCKET_NAME
|
||||
):
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
|
||||
return AwsStorage(
|
||||
aws_bucket_name=settings.WHEREBY_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.WHEREBY_STORAGE_AWS_REGION or "us-east-1",
|
||||
aws_access_key_id=settings.WHEREBY_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.WHEREBY_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
return get_transcripts_storage()
|
||||
|
||||
|
||||
def get_whereby_storage() -> Storage:
|
||||
"""
|
||||
Get storage config for Whereby (for passing to Whereby API).
|
||||
@@ -47,6 +90,9 @@ def get_dailyco_storage() -> Storage:
|
||||
"""
|
||||
Get storage config for Daily.co (for passing to Daily API).
|
||||
|
||||
Uses role_arn only — access keys are excluded because they're for
|
||||
worker reads (get_source_storage), not for the Daily API.
|
||||
|
||||
Usage:
|
||||
daily_storage = get_dailyco_storage()
|
||||
daily_api.create_meeting(
|
||||
@@ -57,13 +103,15 @@ def get_dailyco_storage() -> Storage:
|
||||
|
||||
Do NOT use for our file operations - use get_transcripts_storage() instead.
|
||||
"""
|
||||
# Fail fast if platform-specific config missing
|
||||
if not settings.DAILYCO_STORAGE_AWS_BUCKET_NAME:
|
||||
raise ValueError(
|
||||
"DAILYCO_STORAGE_AWS_BUCKET_NAME required for Daily.co with AWS storage"
|
||||
)
|
||||
|
||||
return Storage.get_instance(
|
||||
name="aws",
|
||||
settings_prefix="DAILYCO_STORAGE_",
|
||||
from reflector.storage.storage_aws import AwsStorage
|
||||
|
||||
return AwsStorage(
|
||||
aws_bucket_name=settings.DAILYCO_STORAGE_AWS_BUCKET_NAME,
|
||||
aws_region=settings.DAILYCO_STORAGE_AWS_REGION or "us-east-1",
|
||||
aws_role_arn=settings.DAILYCO_STORAGE_AWS_ROLE_ARN,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import asyncio
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Tuple
|
||||
from urllib.parse import unquote, urlparse
|
||||
@@ -15,10 +14,8 @@ from urllib.parse import unquote, urlparse
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
|
||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
task_pipeline_file_process as task_pipeline_file_process,
|
||||
)
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
pipeline_process as live_pipeline_process,
|
||||
@@ -237,29 +234,22 @@ async def process_live_pipeline(
|
||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||
assert pre_final_transcript.status != "ended"
|
||||
|
||||
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
||||
result = live_pipeline_post(transcript_id=transcript_id)
|
||||
|
||||
# result.ready() blocks even without await; it mutates result also
|
||||
while not result.ready():
|
||||
print(f"Status: {result.state}")
|
||||
time.sleep(2)
|
||||
# Trigger post-processing via Hatchet (fire-and-forget)
|
||||
await live_pipeline_post(transcript_id=transcript_id)
|
||||
print("Live post-processing pipeline triggered via Hatchet", file=sys.stderr)
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
"""Process audio/video file using the optimized file pipeline via Hatchet"""
|
||||
|
||||
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
||||
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
# Wait for the Celery task to complete
|
||||
while not result.ready():
|
||||
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
||||
time.sleep(2)
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
await HatchetClientManager.start_workflow(
|
||||
"FilePipeline",
|
||||
{"transcript_id": str(transcript_id)},
|
||||
additional_metadata={"transcript_id": str(transcript_id)},
|
||||
)
|
||||
print("File pipeline triggered via Hatchet", file=sys.stderr)
|
||||
|
||||
|
||||
async def process(
|
||||
@@ -293,7 +283,16 @@ async def process(
|
||||
|
||||
await handler(transcript_id)
|
||||
|
||||
await extract_result_from_entry(transcript_id, output_path)
|
||||
if pipeline == "file":
|
||||
# File pipeline is async via Hatchet — results not available immediately.
|
||||
# Use reflector.tools.process_transcript with --sync for polling.
|
||||
print(
|
||||
f"File pipeline dispatched for transcript {transcript_id}. "
|
||||
f"Results will be available once the Hatchet workflow completes.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
await extract_result_from_entry(transcript_id, output_path)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@@ -11,12 +11,11 @@ Usage:
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
import reflector._warnings_filter # noqa: F401 -- side effect: suppress pydantic validate_default warning
|
||||
from reflector.db import get_database
|
||||
from reflector.db.transcripts import Transcript, transcripts_controller
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
@@ -38,7 +37,7 @@ async def process_transcript_inner(
|
||||
on_validation: Callable[[ValidationResult], None],
|
||||
on_preprocess: Callable[[PrepareResult], None],
|
||||
force: bool = False,
|
||||
) -> AsyncResult | None:
|
||||
) -> None:
|
||||
validation = await validate_transcript_for_processing(transcript)
|
||||
on_validation(validation)
|
||||
config = await prepare_transcript_processing(validation)
|
||||
@@ -86,56 +85,39 @@ async def process_transcript(
|
||||
elif isinstance(config, FileProcessingConfig):
|
||||
print(f"Dispatching file pipeline", file=sys.stderr)
|
||||
|
||||
result = await process_transcript_inner(
|
||||
await process_transcript_inner(
|
||||
transcript,
|
||||
on_validation=on_validation,
|
||||
on_preprocess=on_preprocess,
|
||||
force=force,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# Hatchet workflow dispatched
|
||||
if sync:
|
||||
# Re-fetch transcript to get workflow_run_id
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript or not transcript.workflow_run_id:
|
||||
print("Error: workflow_run_id not found", file=sys.stderr)
|
||||
if sync:
|
||||
# Re-fetch transcript to get workflow_run_id
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript or not transcript.workflow_run_id:
|
||||
print("Error: workflow_run_id not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print("Waiting for Hatchet workflow...", file=sys.stderr)
|
||||
while True:
|
||||
status = await HatchetClientManager.get_workflow_run_status(
|
||||
transcript.workflow_run_id
|
||||
)
|
||||
print(f" Status: {status.value}", file=sys.stderr)
|
||||
|
||||
if status == V1TaskStatus.COMPLETED:
|
||||
print("Workflow completed successfully", file=sys.stderr)
|
||||
break
|
||||
elif status in (V1TaskStatus.FAILED, V1TaskStatus.CANCELLED):
|
||||
print(f"Workflow failed: {status}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print("Waiting for Hatchet workflow...", file=sys.stderr)
|
||||
while True:
|
||||
status = await HatchetClientManager.get_workflow_run_status(
|
||||
transcript.workflow_run_id
|
||||
)
|
||||
print(f" Status: {status.value}", file=sys.stderr)
|
||||
|
||||
if status == V1TaskStatus.COMPLETED:
|
||||
print("Workflow completed successfully", file=sys.stderr)
|
||||
break
|
||||
elif status in (V1TaskStatus.FAILED, V1TaskStatus.CANCELLED):
|
||||
print(f"Workflow failed: {status}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
print(
|
||||
"Task dispatched (use --sync to wait for completion)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif sync:
|
||||
print("Waiting for task completion...", file=sys.stderr)
|
||||
while not result.ready():
|
||||
print(f" Status: {result.state}", file=sys.stderr)
|
||||
time.sleep(5)
|
||||
|
||||
if result.successful():
|
||||
print("Task completed successfully", file=sys.stderr)
|
||||
else:
|
||||
print(f"Task failed: {result.result}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
print(
|
||||
"Task dispatched (use --sync to wait for completion)", file=sys.stderr
|
||||
"Task dispatched (use --sync to wait for completion)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -30,6 +30,7 @@ def retry(fn):
|
||||
"retry_httpx_status_stop",
|
||||
(
|
||||
401, # auth issue
|
||||
402, # payment required / no credits — needs human action
|
||||
404, # not found
|
||||
413, # payload too large
|
||||
418, # teapot
|
||||
@@ -58,8 +59,9 @@ def retry(fn):
|
||||
result = await fn(*args, **kwargs)
|
||||
if isinstance(result, Response):
|
||||
result.raise_for_status()
|
||||
if result:
|
||||
return result
|
||||
# Return any result including falsy (e.g. "" from get_response);
|
||||
# only retry on exception, not on empty string.
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
retry_logger.exception(e)
|
||||
status_code = e.response.status_code
|
||||
|
||||
@@ -89,14 +89,16 @@ class StartRecordingRequest(BaseModel):
|
||||
|
||||
@router.post("/meetings/{meeting_id}/recordings/start")
|
||||
async def start_recording(
|
||||
meeting_id: NonEmptyString, body: StartRecordingRequest
|
||||
meeting_id: NonEmptyString,
|
||||
body: StartRecordingRequest,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
) -> dict[str, Any]:
|
||||
"""Start cloud or raw-tracks recording via Daily.co REST API.
|
||||
|
||||
Both cloud and raw-tracks are started via REST API to bypass enable_recording limitation of allowing only 1 recording at a time.
|
||||
Uses different instanceIds for cloud vs raw-tracks (same won't work)
|
||||
|
||||
Note: No authentication required - anonymous users supported. TODO this is a DOS vector
|
||||
"""
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
if not meeting:
|
||||
|
||||
@@ -17,7 +17,6 @@ from reflector.db.rooms import rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.schemas.platform import Platform
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.url import add_query_param
|
||||
from reflector.video_platforms.factory import create_platform_client
|
||||
from reflector.worker.webhook import test_webhook
|
||||
@@ -178,11 +177,10 @@ router = APIRouter()
|
||||
|
||||
@router.get("/rooms", response_model=Page[RoomDetails])
|
||||
async def rooms_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
) -> list[RoomDetails]:
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
paginated = await apaginate(
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Literal, Optional, assert_never
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from jose import jwt
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
BaseModel,
|
||||
@@ -263,16 +263,15 @@ class SearchResponse(BaseModel):
|
||||
|
||||
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
|
||||
async def transcripts_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
source_kind: SourceKind | None = None,
|
||||
room_id: str | None = None,
|
||||
search_term: str | None = None,
|
||||
change_seq_from: int | None = None,
|
||||
sort_by: Literal["created_at", "change_seq"] | None = None,
|
||||
):
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
# Default behavior preserved: sort_by=None → "-created_at"
|
||||
@@ -307,13 +306,10 @@ async def transcripts_search(
|
||||
from_datetime: SearchFromDatetimeParam = None,
|
||||
to_datetime: SearchToDatetimeParam = None,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
] = None,
|
||||
):
|
||||
"""Full-text search across transcript titles and content."""
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
if from_datetime and to_datetime and from_datetime > to_datetime:
|
||||
@@ -346,7 +342,9 @@ async def transcripts_search(
|
||||
@router.post("/transcripts", response_model=GetTranscriptWithParticipants)
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.add(
|
||||
|
||||
@@ -7,13 +7,12 @@ Transcripts audio related endpoints
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from jose import jwt
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
|
||||
@@ -36,16 +35,23 @@ async def transcript_get_audio_mp3(
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.JWTError:
|
||||
raise unauthorized_exception
|
||||
token_user = await auth.verify_raw_token(token)
|
||||
except Exception:
|
||||
token_user = None
|
||||
# Fallback: try as internal HS256 token (created by _generate_local_audio_link)
|
||||
if not token_user:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
else:
|
||||
user_id = token_user["sub"]
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
|
||||
@@ -62,8 +62,7 @@ async def transcript_add_participant(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||
|
||||
# ensure the speaker is unique
|
||||
if participant.speaker is not None and transcript.participants is not None:
|
||||
@@ -109,8 +108,7 @@ async def transcript_update_participant(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||
|
||||
# ensure the speaker is unique
|
||||
for p in transcript.participants:
|
||||
@@ -148,7 +146,6 @@ async def transcript_delete_participant(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
@@ -26,7 +26,9 @@ class ProcessStatus(BaseModel):
|
||||
@router.post("/transcripts/{transcript_id}/process")
|
||||
async def transcript_process(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
) -> ProcessStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
@@ -50,5 +52,5 @@ async def transcript_process(
|
||||
if isinstance(config, ProcessError):
|
||||
raise HTTPException(status_code=500, detail=config.detail)
|
||||
else:
|
||||
await dispatch_transcript_processing(config)
|
||||
await dispatch_transcript_processing(config, force=True)
|
||||
return ProcessStatus(status="ok")
|
||||
|
||||
@@ -41,8 +41,7 @@ async def transcript_assign_speaker(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -121,8 +120,7 @@ async def transcript_merge_speaker(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -21,7 +21,9 @@ async def transcript_record_upload(
|
||||
chunk_number: int,
|
||||
total_chunks: int,
|
||||
chunk: UploadFile,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
@@ -93,7 +95,14 @@ async def transcript_record_upload(
|
||||
transcript, {"status": "uploaded", "source_kind": SourceKind.FILE}
|
||||
)
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
# launch Hatchet workflow to process the file
|
||||
workflow_id = await HatchetClientManager.start_workflow(
|
||||
"FilePipeline",
|
||||
{"transcript_id": str(transcript_id)},
|
||||
additional_metadata={"transcript_id": str(transcript_id)},
|
||||
)
|
||||
|
||||
# Save workflow_run_id for duplicate detection and status polling
|
||||
await transcripts_controller.update(transcript, {"workflow_run_id": workflow_id})
|
||||
|
||||
return UploadStatus(status="ok")
|
||||
|
||||
@@ -15,7 +15,9 @@ async def transcript_record_webrtc(
|
||||
transcript_id: str,
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||
],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
|
||||
@@ -24,6 +24,118 @@ RECONCILIATION_INTERVAL = _override or 30.0
|
||||
ICS_SYNC_INTERVAL = _override or 60.0
|
||||
UPCOMING_MEETINGS_INTERVAL = _override or 30.0
|
||||
|
||||
|
||||
def build_beat_schedule(
|
||||
*,
|
||||
whereby_api_key=None,
|
||||
aws_process_recording_queue_url=None,
|
||||
daily_api_key=None,
|
||||
public_mode=False,
|
||||
public_data_retention_days=None,
|
||||
healthcheck_url=None,
|
||||
):
|
||||
"""Build the Celery beat schedule based on configured services.
|
||||
|
||||
Only registers tasks for services that are actually configured,
|
||||
avoiding unnecessary worker wake-ups in selfhosted deployments.
|
||||
"""
|
||||
beat_schedule = {}
|
||||
|
||||
_whereby_enabled = bool(whereby_api_key) or bool(aws_process_recording_queue_url)
|
||||
if _whereby_enabled:
|
||||
beat_schedule["process_messages"] = {
|
||||
"task": "reflector.worker.process.process_messages",
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
}
|
||||
beat_schedule["reprocess_failed_recordings"] = {
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
}
|
||||
logger.info(
|
||||
"Whereby beat tasks enabled",
|
||||
tasks=["process_messages", "reprocess_failed_recordings"],
|
||||
)
|
||||
else:
|
||||
logger.info("Whereby beat tasks disabled (no WHEREBY_API_KEY or SQS URL)")
|
||||
|
||||
_daily_enabled = bool(daily_api_key)
|
||||
if _daily_enabled:
|
||||
beat_schedule["poll_daily_recordings"] = {
|
||||
"task": "reflector.worker.process.poll_daily_recordings",
|
||||
"schedule": POLL_DAILY_RECORDINGS_INTERVAL_SEC,
|
||||
}
|
||||
beat_schedule["trigger_daily_reconciliation"] = {
|
||||
"task": "reflector.worker.process.trigger_daily_reconciliation",
|
||||
"schedule": RECONCILIATION_INTERVAL,
|
||||
}
|
||||
beat_schedule["reprocess_failed_daily_recordings"] = {
|
||||
"task": "reflector.worker.process.reprocess_failed_daily_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
}
|
||||
logger.info(
|
||||
"Daily.co beat tasks enabled",
|
||||
tasks=[
|
||||
"poll_daily_recordings",
|
||||
"trigger_daily_reconciliation",
|
||||
"reprocess_failed_daily_recordings",
|
||||
],
|
||||
)
|
||||
else:
|
||||
logger.info("Daily.co beat tasks disabled (no DAILY_API_KEY)")
|
||||
|
||||
_any_platform = _whereby_enabled or _daily_enabled
|
||||
if _any_platform:
|
||||
beat_schedule["process_meetings"] = {
|
||||
"task": "reflector.worker.process.process_meetings",
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
}
|
||||
beat_schedule["sync_all_ics_calendars"] = {
|
||||
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
|
||||
"schedule": ICS_SYNC_INTERVAL,
|
||||
}
|
||||
beat_schedule["create_upcoming_meetings"] = {
|
||||
"task": "reflector.worker.ics_sync.create_upcoming_meetings",
|
||||
"schedule": UPCOMING_MEETINGS_INTERVAL,
|
||||
}
|
||||
logger.info(
|
||||
"Platform tasks enabled",
|
||||
tasks=[
|
||||
"process_meetings",
|
||||
"sync_all_ics_calendars",
|
||||
"create_upcoming_meetings",
|
||||
],
|
||||
)
|
||||
else:
|
||||
logger.info("Platform tasks disabled (no video platform configured)")
|
||||
|
||||
if public_mode:
|
||||
beat_schedule["cleanup_old_public_data"] = {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data_task",
|
||||
"schedule": crontab(hour=3, minute=0),
|
||||
}
|
||||
logger.info(
|
||||
"Public mode cleanup enabled",
|
||||
retention_days=public_data_retention_days,
|
||||
)
|
||||
|
||||
if healthcheck_url:
|
||||
beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
"schedule": 60.0 * 10,
|
||||
}
|
||||
logger.info("Healthcheck enabled", url=healthcheck_url)
|
||||
else:
|
||||
logger.warning("Healthcheck disabled, no url configured")
|
||||
|
||||
logger.info(
|
||||
"Beat schedule configured",
|
||||
total_tasks=len(beat_schedule),
|
||||
task_names=sorted(beat_schedule.keys()),
|
||||
)
|
||||
|
||||
return beat_schedule
|
||||
|
||||
|
||||
if celery.current_app.main != "default":
|
||||
logger.info(f"Celery already configured ({celery.current_app})")
|
||||
app = celery.current_app
|
||||
@@ -42,57 +154,11 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
# crontab
|
||||
app.conf.beat_schedule = {
|
||||
"process_messages": {
|
||||
"task": "reflector.worker.process.process_messages",
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
},
|
||||
"process_meetings": {
|
||||
"task": "reflector.worker.process.process_meetings",
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
},
|
||||
"reprocess_failed_recordings": {
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
},
|
||||
"reprocess_failed_daily_recordings": {
|
||||
"task": "reflector.worker.process.reprocess_failed_daily_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
},
|
||||
"poll_daily_recordings": {
|
||||
"task": "reflector.worker.process.poll_daily_recordings",
|
||||
"schedule": POLL_DAILY_RECORDINGS_INTERVAL_SEC,
|
||||
},
|
||||
"trigger_daily_reconciliation": {
|
||||
"task": "reflector.worker.process.trigger_daily_reconciliation",
|
||||
"schedule": RECONCILIATION_INTERVAL,
|
||||
},
|
||||
"sync_all_ics_calendars": {
|
||||
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
|
||||
"schedule": ICS_SYNC_INTERVAL,
|
||||
},
|
||||
"create_upcoming_meetings": {
|
||||
"task": "reflector.worker.ics_sync.create_upcoming_meetings",
|
||||
"schedule": UPCOMING_MEETINGS_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
||||
if settings.PUBLIC_MODE:
|
||||
app.conf.beat_schedule["cleanup_old_public_data"] = {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data_task",
|
||||
"schedule": crontab(hour=3, minute=0),
|
||||
}
|
||||
logger.info(
|
||||
"Public mode cleanup enabled",
|
||||
retention_days=settings.PUBLIC_DATA_RETENTION_DAYS,
|
||||
)
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
"schedule": 60.0 * 10,
|
||||
}
|
||||
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
||||
else:
|
||||
logger.warning("Healthcheck disabled, no url configured")
|
||||
app.conf.beat_schedule = build_beat_schedule(
|
||||
whereby_api_key=settings.WHEREBY_API_KEY,
|
||||
aws_process_recording_queue_url=settings.AWS_PROCESS_RECORDING_QUEUE_URL,
|
||||
daily_api_key=settings.DAILY_API_KEY,
|
||||
public_mode=settings.PUBLIC_MODE,
|
||||
public_data_retention_days=settings.PUBLIC_DATA_RETENTION_DAYS,
|
||||
healthcheck_url=settings.HEALTHCHECK_URL,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ from reflector.db.transcripts import (
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.pipelines.topic_processing import EmptyPipeline
|
||||
from reflector.processors import AudioFileWriterProcessor
|
||||
@@ -132,7 +131,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
target_language="en",
|
||||
user_id=room.user_id,
|
||||
recording_id=recording.id,
|
||||
share_mode="public",
|
||||
share_mode="semi-private",
|
||||
meeting_id=meeting.id,
|
||||
room_id=room.id,
|
||||
)
|
||||
@@ -163,7 +162,14 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
await HatchetClientManager.start_workflow(
|
||||
"FilePipeline",
|
||||
{
|
||||
"transcript_id": str(transcript.id),
|
||||
"room_id": str(room.id) if room else None,
|
||||
},
|
||||
additional_metadata={"transcript_id": str(transcript.id)},
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@@ -343,7 +349,7 @@ async def _process_multitrack_recording_inner(
|
||||
target_language="en",
|
||||
user_id=room.user_id,
|
||||
recording_id=recording.id,
|
||||
share_mode="public",
|
||||
share_mode="semi-private",
|
||||
meeting_id=meeting.id,
|
||||
room_id=room.id,
|
||||
)
|
||||
@@ -357,6 +363,7 @@ async def _process_multitrack_recording_inner(
|
||||
"bucket_name": bucket_name,
|
||||
"transcript_id": transcript.id,
|
||||
"room_id": room.id,
|
||||
"source_platform": "daily",
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": transcript.id,
|
||||
@@ -1068,6 +1075,7 @@ async def reprocess_failed_daily_recordings():
|
||||
"bucket_name": bucket_name,
|
||||
"transcript_id": transcript.id,
|
||||
"room_id": room.id if room else None,
|
||||
"source_platform": "daily",
|
||||
},
|
||||
additional_metadata={
|
||||
"transcript_id": transcript.id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -437,6 +437,8 @@ async def ws_manager_in_memory(monkeypatch):
|
||||
|
||||
try:
|
||||
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
|
||||
# current_user_optional_if_public_mode is NOT overridden here so the real
|
||||
# implementation runs and enforces the PUBLIC_MODE check during tests.
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -491,37 +493,39 @@ async def authenticated_client2():
|
||||
@asynccontextmanager
|
||||
async def authenticated_client_ctx():
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user, current_user_optional
|
||||
from reflector.auth import (
|
||||
current_user,
|
||||
current_user_optional,
|
||||
current_user_optional_if_public_mode,
|
||||
)
|
||||
|
||||
app.dependency_overrides[current_user] = lambda: {
|
||||
"sub": "randomuserid",
|
||||
"email": "test@mail.com",
|
||||
}
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": "randomuserid",
|
||||
"email": "test@mail.com",
|
||||
}
|
||||
_user = lambda: {"sub": "randomuserid", "email": "test@mail.com"}
|
||||
app.dependency_overrides[current_user] = _user
|
||||
app.dependency_overrides[current_user_optional] = _user
|
||||
app.dependency_overrides[current_user_optional_if_public_mode] = _user
|
||||
yield
|
||||
del app.dependency_overrides[current_user]
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
del app.dependency_overrides[current_user_optional_if_public_mode]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def authenticated_client2_ctx():
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user, current_user_optional
|
||||
from reflector.auth import (
|
||||
current_user,
|
||||
current_user_optional,
|
||||
current_user_optional_if_public_mode,
|
||||
)
|
||||
|
||||
app.dependency_overrides[current_user] = lambda: {
|
||||
"sub": "randomuserid2",
|
||||
"email": "test@mail.com",
|
||||
}
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": "randomuserid2",
|
||||
"email": "test@mail.com",
|
||||
}
|
||||
_user = lambda: {"sub": "randomuserid2", "email": "test@mail.com"}
|
||||
app.dependency_overrides[current_user] = _user
|
||||
app.dependency_overrides[current_user_optional] = _user
|
||||
app.dependency_overrides[current_user_optional_if_public_mode] = _user
|
||||
yield
|
||||
del app.dependency_overrides[current_user]
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
del app.dependency_overrides[current_user_optional_if_public_mode]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -534,23 +538,64 @@ def fake_mp3_upload():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_hatchet_client():
|
||||
"""Reset HatchetClientManager singleton before and after each test.
|
||||
def mock_hatchet_client():
|
||||
"""Mock HatchetClientManager for all tests.
|
||||
|
||||
This ensures test isolation - each test starts with a fresh client state.
|
||||
The fixture is autouse=True so it applies to all tests automatically.
|
||||
Prevents tests from connecting to a real Hatchet server. The dummy token
|
||||
in [tool.pytest_env] prevents the import-time ValueError, but the SDK
|
||||
would still try to connect when get_client() is called. This fixture
|
||||
mocks get_client to return a MagicMock and start_workflow to return a
|
||||
dummy workflow ID.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
# Reset before test
|
||||
HatchetClientManager.reset()
|
||||
yield
|
||||
# Reset after test to clean up
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.workflow.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"start_workflow",
|
||||
new_callable=AsyncMock,
|
||||
return_value="mock-workflow-id",
|
||||
),
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"get_workflow_run_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"can_replay",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"cancel_workflow",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch.object(
|
||||
HatchetClientManager,
|
||||
"replay_workflow",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
yield mock_client
|
||||
|
||||
HatchetClientManager.reset()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir, client):
|
||||
async def fake_transcript_with_topics(tmpdir, client, monkeypatch):
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -559,6 +604,9 @@ async def fake_transcript_with_topics(tmpdir, client):
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import transcripts_controller
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "PUBLIC_MODE", True
|
||||
) # public mode: allow anonymous transcript creation for this test
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# create a transcript
|
||||
|
||||
218
server/tests/docker-compose.integration.yml
Normal file
218
server/tests/docker-compose.integration.yml
Normal file
@@ -0,0 +1,218 @@
|
||||
# Integration test stack — full pipeline end-to-end.
|
||||
#
|
||||
# Usage:
|
||||
# docker compose -f server/tests/docker-compose.integration.yml up -d --build
|
||||
#
|
||||
# Requires .env.integration in the repo root (generated by CI workflow).
|
||||
|
||||
x-backend-env: &backend-env
|
||||
DATABASE_URL: postgresql+asyncpg://reflector:reflector@postgres:5432/reflector
|
||||
REDIS_HOST: redis
|
||||
CELERY_BROKER_URL: redis://redis:6379/1
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/1
|
||||
HATCHET_CLIENT_TOKEN: ${HATCHET_CLIENT_TOKEN:-}
|
||||
HATCHET_CLIENT_SERVER_URL: http://hatchet:8888
|
||||
HATCHET_CLIENT_HOST_PORT: hatchet:7077
|
||||
HATCHET_CLIENT_TLS_STRATEGY: none
|
||||
# ML backends — CPU-only, no external services
|
||||
TRANSCRIPT_BACKEND: whisper
|
||||
WHISPER_CHUNK_MODEL: tiny
|
||||
WHISPER_FILE_MODEL: tiny
|
||||
DIARIZATION_BACKEND: pyannote
|
||||
TRANSLATION_BACKEND: passthrough
|
||||
# Storage — local Garage S3
|
||||
TRANSCRIPT_STORAGE_BACKEND: aws
|
||||
TRANSCRIPT_STORAGE_AWS_ENDPOINT_URL: http://garage:3900
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: reflector-media
|
||||
TRANSCRIPT_STORAGE_AWS_REGION: garage
|
||||
# Daily mock
|
||||
DAILY_API_URL: http://mock-daily:8080/v1
|
||||
DAILY_API_KEY: fake-daily-key
|
||||
# Auth
|
||||
PUBLIC_MODE: "true"
|
||||
AUTH_BACKEND: none
|
||||
# LLM (injected from CI)
|
||||
LLM_URL: ${LLM_URL:-}
|
||||
LLM_API_KEY: ${LLM_API_KEY:-}
|
||||
LLM_MODEL: ${LLM_MODEL:-gpt-4o-mini}
|
||||
# HuggingFace (for pyannote gated models)
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
# Garage S3 credentials — hardcoded test keys, containers are ephemeral
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: GK0123456789abcdef01234567 # gitleaks:allow
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" # gitleaks:allow
|
||||
# NOTE: DAILYCO_STORAGE_AWS_* intentionally NOT set — forces fallback to
|
||||
# get_transcripts_storage() which has ENDPOINT_URL pointing at Garage.
|
||||
# Setting them would bypass the endpoint and generate presigned URLs for AWS.
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17-alpine
|
||||
command: ["postgres", "-c", "max_connections=200"]
|
||||
environment:
|
||||
POSTGRES_USER: reflector
|
||||
POSTGRES_PASSWORD: reflector
|
||||
POSTGRES_DB: reflector
|
||||
volumes:
|
||||
- ../../server/docker/init-hatchet-db.sql:/docker-entrypoint-initdb.d/init-hatchet-db.sql:ro
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U reflector"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
|
||||
redis:
|
||||
image: redis:7.2-alpine
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
hatchet:
|
||||
image: ghcr.io/hatchet-dev/hatchet/hatchet-lite:latest
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
DATABASE_URL: "postgresql://reflector:reflector@postgres:5432/hatchet?sslmode=disable&connect_timeout=30"
|
||||
SERVER_AUTH_COOKIE_INSECURE: "t"
|
||||
SERVER_AUTH_COOKIE_DOMAIN: "localhost"
|
||||
SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
|
||||
SERVER_GRPC_INSECURE: "t"
|
||||
SERVER_GRPC_BROADCAST_ADDRESS: hatchet:7077
|
||||
SERVER_GRPC_PORT: "7077"
|
||||
SERVER_AUTH_SET_EMAIL_VERIFIED: "t"
|
||||
SERVER_INTERNAL_CLIENT_INTERNAL_GRPC_BROADCAST_ADDRESS: hatchet:7077
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/live"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 15
|
||||
start_period: 30s
|
||||
|
||||
garage:
|
||||
image: dxflrs/garage:v1.1.0
|
||||
volumes:
|
||||
- ./integration/garage.toml:/etc/garage.toml:ro
|
||||
healthcheck:
|
||||
test: ["CMD", "/garage", "stats"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
start_period: 5s
|
||||
|
||||
mock-daily:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: integration/Dockerfile.mock-daily
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/v1/recordings/test')"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
server:
|
||||
build:
|
||||
context: ../../server
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ENTRYPOINT: server
|
||||
WEBRTC_HOST: server
|
||||
WEBRTC_PORT_RANGE: "52000-52100"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
garage:
|
||||
condition: service_healthy
|
||||
mock-daily:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
worker:
|
||||
build:
|
||||
context: ../../server
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ENTRYPOINT: worker
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
hatchet-worker-cpu:
|
||||
build:
|
||||
context: ../../server
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ENTRYPOINT: hatchet-worker-cpu
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
hatchet-worker-llm:
|
||||
build:
|
||||
context: ../../server
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ENTRYPOINT: hatchet-worker-llm
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
|
||||
test-runner:
|
||||
build:
|
||||
context: ../../server
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
<<: *backend-env
|
||||
# Override DATABASE_URL for sync driver (used by direct DB access in tests)
|
||||
DATABASE_URL_ASYNC: postgresql+asyncpg://reflector:reflector@postgres:5432/reflector
|
||||
DATABASE_URL: postgresql+asyncpg://reflector:reflector@postgres:5432/reflector
|
||||
SERVER_URL: http://server:1250
|
||||
GARAGE_ENDPOINT: http://garage:3900
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_started
|
||||
worker:
|
||||
condition: service_started
|
||||
hatchet-worker-cpu:
|
||||
condition: service_started
|
||||
hatchet-worker-llm:
|
||||
condition: service_started
|
||||
volumes:
|
||||
- server_data:/app/data
|
||||
# Mount test files into the container
|
||||
- ./records:/app/tests/records:ro
|
||||
- ./integration:/app/tests/integration:ro
|
||||
entrypoint: ["sleep", "infinity"]
|
||||
|
||||
volumes:
|
||||
server_data:
|
||||
|
||||
networks:
|
||||
default:
|
||||
attachable: true
|
||||
9
server/tests/integration/Dockerfile.mock-daily
Normal file
9
server/tests/integration/Dockerfile.mock-daily
Normal file
@@ -0,0 +1,9 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
RUN pip install --no-cache-dir fastapi uvicorn[standard]
|
||||
|
||||
WORKDIR /app
|
||||
COPY integration/mock_daily_server.py /app/mock_daily_server.py
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["uvicorn", "mock_daily_server:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
0
server/tests/integration/__init__.py
Normal file
0
server/tests/integration/__init__.py
Normal file
116
server/tests/integration/conftest.py
Normal file
116
server/tests/integration/conftest.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Integration test fixtures — no mocks, real services.
|
||||
|
||||
All services (PostgreSQL, Redis, Hatchet, Garage, server, workers) are
|
||||
expected to be running via docker-compose.integration.yml.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import boto3
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
SERVER_URL = os.environ.get("SERVER_URL", "http://server:1250")
|
||||
GARAGE_ENDPOINT = os.environ.get("GARAGE_ENDPOINT", "http://garage:3900")
|
||||
DATABASE_URL = os.environ.get(
|
||||
"DATABASE_URL_ASYNC",
|
||||
os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql+asyncpg://reflector:reflector@postgres:5432/reflector",
|
||||
),
|
||||
)
|
||||
GARAGE_KEY_ID = os.environ.get("TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID", "")
|
||||
GARAGE_KEY_SECRET = os.environ.get("TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY", "")
|
||||
BUCKET_NAME = "reflector-media"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def api_client():
|
||||
"""HTTP client pointed at the running server."""
|
||||
async with httpx.AsyncClient(
|
||||
base_url=f"{SERVER_URL}/v1",
|
||||
timeout=httpx.Timeout(30.0),
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def s3_client():
|
||||
"""Boto3 S3 client pointed at Garage."""
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=GARAGE_ENDPOINT,
|
||||
aws_access_key_id=GARAGE_KEY_ID,
|
||||
aws_secret_access_key=GARAGE_KEY_SECRET,
|
||||
region_name="garage",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_engine():
|
||||
"""SQLAlchemy async engine for direct DB operations."""
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_records_dir():
|
||||
"""Path to the test audio files directory."""
|
||||
return Path(__file__).parent.parent / "records"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bucket_name():
|
||||
"""S3 bucket name used for integration tests."""
|
||||
return BUCKET_NAME
|
||||
|
||||
|
||||
async def _poll_transcript_status(
|
||||
client: httpx.AsyncClient,
|
||||
transcript_id: str,
|
||||
target: str | tuple[str, ...],
|
||||
error: str = "error",
|
||||
max_wait: int = 300,
|
||||
interval: int = 3,
|
||||
) -> dict:
|
||||
"""
|
||||
Poll GET /transcripts/{id} until status matches target or error.
|
||||
|
||||
target can be a single status string or a tuple of acceptable statuses.
|
||||
Returns the transcript dict on success, raises on timeout or error status.
|
||||
"""
|
||||
targets = (target,) if isinstance(target, str) else target
|
||||
elapsed = 0
|
||||
status = None
|
||||
while elapsed < max_wait:
|
||||
resp = await client.get(f"/transcripts/{transcript_id}")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
|
||||
if status in targets:
|
||||
return data
|
||||
if status == error:
|
||||
raise AssertionError(
|
||||
f"Transcript {transcript_id} reached error status: {data}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
raise TimeoutError(
|
||||
f"Transcript {transcript_id} did not reach status '{target}' "
|
||||
f"within {max_wait}s (last status: {status})"
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def poll_transcript_status():
|
||||
"""Returns the poll_transcript_status async helper function."""
|
||||
return _poll_transcript_status
|
||||
14
server/tests/integration/garage.toml
Normal file
14
server/tests/integration/garage.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
metadata_dir = "/var/lib/garage/meta"
|
||||
data_dir = "/var/lib/garage/data"
|
||||
replication_factor = 1
|
||||
|
||||
rpc_secret = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789" # gitleaks:allow
|
||||
rpc_bind_addr = "[::]:3901"
|
||||
|
||||
[s3_api]
|
||||
api_bind_addr = "[::]:3900"
|
||||
s3_region = "garage"
|
||||
root_domain = ".s3.garage.localhost"
|
||||
|
||||
[admin]
|
||||
api_bind_addr = "[::]:3903"
|
||||
62
server/tests/integration/garage_setup.sh
Executable file
62
server/tests/integration/garage_setup.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Initialize Garage bucket and keys for integration tests.
|
||||
# Run inside the Garage container after it's healthy.
|
||||
#
|
||||
# Outputs KEY_ID and KEY_SECRET to stdout (last two lines).
|
||||
#
|
||||
# Note: uses /bin/sh (not bash) since the Garage container is minimal.
|
||||
#
|
||||
set -eu
|
||||
|
||||
echo "Waiting for Garage to be ready..."
|
||||
i=0
|
||||
while [ "$i" -lt 30 ]; do
|
||||
if /garage stats >/dev/null 2>&1; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
i=$((i + 1))
|
||||
done
|
||||
|
||||
# Layout setup
|
||||
NODE_ID=$(/garage node id -q | tr -d '[:space:]')
|
||||
LAYOUT_STATUS=$(/garage layout show 2>&1 || true)
|
||||
if echo "$LAYOUT_STATUS" | grep -q "No nodes"; then
|
||||
/garage layout assign "$NODE_ID" -c 1G -z dc1
|
||||
/garage layout apply --version 1
|
||||
echo "Layout applied."
|
||||
else
|
||||
echo "Layout already configured."
|
||||
fi
|
||||
|
||||
# Bucket
|
||||
if ! /garage bucket info reflector-media >/dev/null 2>&1; then
|
||||
/garage bucket create reflector-media
|
||||
echo "Bucket 'reflector-media' created."
|
||||
else
|
||||
echo "Bucket 'reflector-media' already exists."
|
||||
fi
|
||||
|
||||
# Key
|
||||
if /garage key info reflector-test >/dev/null 2>&1; then
|
||||
echo "Key 'reflector-test' already exists."
|
||||
KEY_OUTPUT=$(/garage key info reflector-test 2>&1)
|
||||
else
|
||||
KEY_OUTPUT=$(/garage key create reflector-test 2>&1)
|
||||
echo "Key 'reflector-test' created."
|
||||
fi
|
||||
|
||||
# Permissions
|
||||
/garage bucket allow reflector-media --read --write --key reflector-test
|
||||
|
||||
# Extract key ID and secret from output using POSIX-compatible parsing
|
||||
# garage key output format:
|
||||
# Key name: reflector-test
|
||||
# Key ID: GK...
|
||||
# Secret key: ...
|
||||
KEY_ID=$(echo "$KEY_OUTPUT" | grep "Key ID" | sed 's/.*Key ID: *//')
|
||||
KEY_SECRET=$(echo "$KEY_OUTPUT" | grep "Secret key" | sed 's/.*Secret key: *//')
|
||||
|
||||
echo "GARAGE_KEY_ID=${KEY_ID}"
|
||||
echo "GARAGE_KEY_SECRET=${KEY_SECRET}"
|
||||
75
server/tests/integration/mock_daily_server.py
Normal file
75
server/tests/integration/mock_daily_server.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Minimal FastAPI mock for Daily.co API.
|
||||
|
||||
Serves canned responses for:
|
||||
- GET /v1/recordings/{recording_id}
|
||||
- GET /v1/meetings/{meeting_id}/participants
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(title="Mock Daily API")
|
||||
|
||||
|
||||
# Participant UUIDs must be 36-char hex UUIDs to match Daily's filename format
|
||||
PARTICIPANT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
PARTICIPANT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
|
||||
# Daily-format track keys: {recording_start_ts}-{participant_id}-cam-audio-{track_start_ts}
|
||||
TRACK_KEYS = [
|
||||
f"1700000000000-{PARTICIPANT_A_ID}-cam-audio-1700000001000",
|
||||
f"1700000000000-{PARTICIPANT_B_ID}-cam-audio-1700000001000",
|
||||
]
|
||||
|
||||
|
||||
@app.get("/v1/recordings/{recording_id}")
|
||||
async def get_recording(recording_id: str):
|
||||
return {
|
||||
"id": recording_id,
|
||||
"room_name": "integration-test-room",
|
||||
"start_ts": 1700000000,
|
||||
"type": "raw-tracks",
|
||||
"status": "finished",
|
||||
"max_participants": 2,
|
||||
"duration": 5,
|
||||
"share_token": None,
|
||||
"s3": {
|
||||
"bucket_name": "reflector-media",
|
||||
"bucket_region": "garage",
|
||||
"key": None,
|
||||
"endpoint": None,
|
||||
},
|
||||
"s3key": None,
|
||||
"tracks": [
|
||||
{"type": "audio", "s3Key": key, "size": 100000} for key in TRACK_KEYS
|
||||
],
|
||||
"mtgSessionId": "mock-mtg-session-id",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/meetings/{meeting_id}/participants")
|
||||
async def get_meeting_participants(meeting_id: str):
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"user_id": "user-a",
|
||||
"participant_id": PARTICIPANT_A_ID,
|
||||
"user_name": "Speaker A",
|
||||
"join_time": 1700000000,
|
||||
"duration": 300,
|
||||
},
|
||||
{
|
||||
"user_id": "user-b",
|
||||
"participant_id": PARTICIPANT_B_ID,
|
||||
"user_name": "Speaker B",
|
||||
"join_time": 1700000010,
|
||||
"duration": 290,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
61
server/tests/integration/test_file_pipeline.py
Normal file
61
server/tests/integration/test_file_pipeline.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Integration test: File upload → FilePipeline → full processing.
|
||||
|
||||
Exercises: upload endpoint → Hatchet FilePipeline → whisper transcription →
|
||||
pyannote diarization → LLM summarization/topics → status "ended".
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_end_to_end(
|
||||
api_client, test_records_dir, poll_transcript_status
|
||||
):
|
||||
"""Upload a WAV file and verify the full pipeline completes."""
|
||||
# 1. Create transcript
|
||||
resp = await api_client.post(
|
||||
"/transcripts",
|
||||
json={"name": "integration-file-test", "source_kind": "file"},
|
||||
)
|
||||
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
|
||||
transcript = resp.json()
|
||||
transcript_id = transcript["id"]
|
||||
|
||||
# 2. Upload audio file (single chunk)
|
||||
audio_path = test_records_dir / "test_short.wav"
|
||||
assert audio_path.exists(), f"Test audio file not found: {audio_path}"
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
resp = await api_client.post(
|
||||
f"/transcripts/{transcript_id}/record/upload",
|
||||
params={"chunk_number": 0, "total_chunks": 1},
|
||||
files={"chunk": ("test_short.wav", f, "audio/wav")},
|
||||
)
|
||||
assert resp.status_code == 200, f"Upload failed: {resp.text}"
|
||||
|
||||
# 3. Poll until pipeline completes
|
||||
data = await poll_transcript_status(
|
||||
api_client, transcript_id, target="ended", max_wait=300
|
||||
)
|
||||
|
||||
# 4. Assertions
|
||||
assert data["status"] == "ended"
|
||||
assert data.get("title") and len(data["title"]) > 0, "Title should be non-empty"
|
||||
assert (
|
||||
data.get("long_summary") and len(data["long_summary"]) > 0
|
||||
), "Long summary should be non-empty"
|
||||
assert (
|
||||
data.get("short_summary") and len(data["short_summary"]) > 0
|
||||
), "Short summary should be non-empty"
|
||||
|
||||
# Topics are served from a separate endpoint
|
||||
topics_resp = await api_client.get(f"/transcripts/{transcript_id}/topics")
|
||||
assert topics_resp.status_code == 200, f"Failed to get topics: {topics_resp.text}"
|
||||
topics = topics_resp.json()
|
||||
assert len(topics) >= 1, "Should have at least 1 topic"
|
||||
for topic in topics:
|
||||
assert topic.get("title"), "Each topic should have a title"
|
||||
assert topic.get("summary"), "Each topic should have a summary"
|
||||
|
||||
assert data.get("duration", 0) > 0, "Duration should be positive"
|
||||
109
server/tests/integration/test_live_pipeline.py
Normal file
109
server/tests/integration/test_live_pipeline.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Integration test: WebRTC stream → LivePostProcessingPipeline → full processing.
|
||||
|
||||
Exercises: WebRTC SDP exchange → live audio streaming → connection close →
|
||||
Hatchet LivePostPipeline → whisper transcription → LLM summarization/topics → status "ended".
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaPlayer
|
||||
|
||||
SERVER_URL = os.environ.get("SERVER_URL", "http://server:1250")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_pipeline_end_to_end(
|
||||
api_client, test_records_dir, poll_transcript_status
|
||||
):
|
||||
"""Stream audio via WebRTC and verify the full post-processing pipeline completes."""
|
||||
# 1. Create transcript
|
||||
resp = await api_client.post(
|
||||
"/transcripts",
|
||||
json={"name": "integration-live-test"},
|
||||
)
|
||||
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
|
||||
transcript = resp.json()
|
||||
transcript_id = transcript["id"]
|
||||
|
||||
# 2. Set up WebRTC peer connection with audio from test file
|
||||
audio_path = test_records_dir / "test_short.wav"
|
||||
assert audio_path.exists(), f"Test audio file not found: {audio_path}"
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
player = MediaPlayer(audio_path.as_posix())
|
||||
|
||||
# Add audio track
|
||||
audio_track = player.audio
|
||||
pc.addTrack(audio_track)
|
||||
|
||||
# Create data channel (server expects this for STOP command)
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
|
||||
# 3. Generate SDP offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
sdp_payload = {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
"type": pc.localDescription.type,
|
||||
}
|
||||
|
||||
# 4. Send offer to server and get answer
|
||||
webrtc_url = f"{SERVER_URL}/v1/transcripts/{transcript_id}/record/webrtc"
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
||||
resp = await client.post(webrtc_url, json=sdp_payload)
|
||||
assert resp.status_code == 200, f"WebRTC offer failed: {resp.text}"
|
||||
|
||||
answer_data = resp.json()
|
||||
answer = RTCSessionDescription(sdp=answer_data["sdp"], type=answer_data["type"])
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
# 5. Wait for audio playback to finish
|
||||
max_stream_wait = 60
|
||||
elapsed = 0
|
||||
while elapsed < max_stream_wait:
|
||||
if audio_track.readyState == "ended":
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
elapsed += 0.5
|
||||
|
||||
# 6. Send STOP command and close connection
|
||||
try:
|
||||
channel.send(json.dumps({"cmd": "STOP"}))
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
pass # Channel may not be open if track ended quickly
|
||||
|
||||
await pc.close()
|
||||
|
||||
# 7. Poll until post-processing pipeline completes
|
||||
data = await poll_transcript_status(
|
||||
api_client, transcript_id, target="ended", max_wait=300
|
||||
)
|
||||
|
||||
# 8. Assertions
|
||||
assert data["status"] == "ended"
|
||||
assert data.get("title") and len(data["title"]) > 0, "Title should be non-empty"
|
||||
assert (
|
||||
data.get("long_summary") and len(data["long_summary"]) > 0
|
||||
), "Long summary should be non-empty"
|
||||
assert (
|
||||
data.get("short_summary") and len(data["short_summary"]) > 0
|
||||
), "Short summary should be non-empty"
|
||||
|
||||
# Topics are served from a separate endpoint
|
||||
topics_resp = await api_client.get(f"/transcripts/{transcript_id}/topics")
|
||||
assert topics_resp.status_code == 200, f"Failed to get topics: {topics_resp.text}"
|
||||
topics = topics_resp.json()
|
||||
assert len(topics) >= 1, "Should have at least 1 topic"
|
||||
for topic in topics:
|
||||
assert topic.get("title"), "Each topic should have a title"
|
||||
assert topic.get("summary"), "Each topic should have a summary"
|
||||
|
||||
assert data.get("duration", 0) > 0, "Duration should be positive"
|
||||
129
server/tests/integration/test_multitrack_pipeline.py
Normal file
129
server/tests/integration/test_multitrack_pipeline.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Integration test: Multitrack → DailyMultitrackPipeline → full processing.
|
||||
|
||||
Exercises: S3 upload → DB recording setup → process endpoint →
|
||||
Hatchet DiarizationPipeline → mock Daily API → whisper per-track transcription →
|
||||
diarization → mixdown → LLM summarization/topics → status "ended".
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
# Must match Daily's filename format: {recording_start_ts}-{participant_uuid}-cam-audio-{track_start_ts}
|
||||
# These UUIDs must match mock_daily_server.py participant IDs
|
||||
PARTICIPANT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
PARTICIPANT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
TRACK_KEYS = [
|
||||
f"1700000000000-{PARTICIPANT_A_ID}-cam-audio-1700000001000",
|
||||
f"1700000000000-{PARTICIPANT_B_ID}-cam-audio-1700000001000",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multitrack_pipeline_end_to_end(
|
||||
api_client,
|
||||
s3_client,
|
||||
db_engine,
|
||||
test_records_dir,
|
||||
bucket_name,
|
||||
poll_transcript_status,
|
||||
):
|
||||
"""Set up multitrack recording in S3/DB and verify the full pipeline completes."""
|
||||
# 1. Upload test audio as two separate tracks to Garage S3
|
||||
audio_path = test_records_dir / "test_short.wav"
|
||||
assert audio_path.exists(), f"Test audio file not found: {audio_path}"
|
||||
|
||||
for track_key in TRACK_KEYS:
|
||||
s3_client.upload_file(
|
||||
str(audio_path),
|
||||
bucket_name,
|
||||
track_key,
|
||||
)
|
||||
|
||||
# 2. Create transcript via API
|
||||
resp = await api_client.post(
|
||||
"/transcripts",
|
||||
json={"name": "integration-multitrack-test"},
|
||||
)
|
||||
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
|
||||
transcript = resp.json()
|
||||
transcript_id = transcript["id"]
|
||||
|
||||
# 3. Insert Recording row and link to transcript via direct DB access
|
||||
recording_id = f"rec-integration-{transcript_id[:8]}"
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
# Insert recording with track_keys
|
||||
await conn.execute(
|
||||
text("""
|
||||
INSERT INTO recording (id, bucket_name, object_key, recorded_at, status, track_keys)
|
||||
VALUES (:id, :bucket_name, :object_key, :recorded_at, :status, CAST(:track_keys AS json))
|
||||
"""),
|
||||
{
|
||||
"id": recording_id,
|
||||
"bucket_name": bucket_name,
|
||||
"object_key": TRACK_KEYS[0],
|
||||
"recorded_at": now,
|
||||
"status": "completed",
|
||||
"track_keys": json.dumps(TRACK_KEYS),
|
||||
},
|
||||
)
|
||||
|
||||
# Link recording to transcript and set status to uploaded
|
||||
await conn.execute(
|
||||
text("""
|
||||
UPDATE transcript
|
||||
SET recording_id = :recording_id, status = 'uploaded'
|
||||
WHERE id = :transcript_id
|
||||
"""),
|
||||
{
|
||||
"recording_id": recording_id,
|
||||
"transcript_id": transcript_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 4. Trigger processing via process endpoint
|
||||
resp = await api_client.post(f"/transcripts/{transcript_id}/process")
|
||||
assert resp.status_code == 200, f"Process trigger failed: {resp.text}"
|
||||
|
||||
# 5. Poll until pipeline completes
|
||||
# The pipeline will call mock-daily for get_recording and get_participants
|
||||
# Accept "error" too — non-critical steps like action_items may fail due to
|
||||
# LLM parsing flakiness while core results (transcript, summaries) still exist.
|
||||
data = await poll_transcript_status(
|
||||
api_client, transcript_id, target=("ended", "error"), max_wait=300
|
||||
)
|
||||
|
||||
# 6. Assertions — verify core pipeline results regardless of final status
|
||||
assert data.get("title") and len(data["title"]) > 0, "Title should be non-empty"
|
||||
assert (
|
||||
data.get("long_summary") and len(data["long_summary"]) > 0
|
||||
), "Long summary should be non-empty"
|
||||
assert (
|
||||
data.get("short_summary") and len(data["short_summary"]) > 0
|
||||
), "Short summary should be non-empty"
|
||||
|
||||
# Topics are served from a separate endpoint
|
||||
topics_resp = await api_client.get(f"/transcripts/{transcript_id}/topics")
|
||||
assert topics_resp.status_code == 200, f"Failed to get topics: {topics_resp.text}"
|
||||
topics = topics_resp.json()
|
||||
assert len(topics) >= 1, "Should have at least 1 topic"
|
||||
for topic in topics:
|
||||
assert topic.get("title"), "Each topic should have a title"
|
||||
assert topic.get("summary"), "Each topic should have a summary"
|
||||
|
||||
# Participants are served from a separate endpoint
|
||||
participants_resp = await api_client.get(
|
||||
f"/transcripts/{transcript_id}/participants"
|
||||
)
|
||||
assert (
|
||||
participants_resp.status_code == 200
|
||||
), f"Failed to get participants: {participants_resp.text}"
|
||||
participants = participants_resp.json()
|
||||
assert (
|
||||
len(participants) >= 2
|
||||
), f"Expected at least 2 speakers for multitrack, got {len(participants)}"
|
||||
17
server/tests/test_app.py
Normal file
17
server/tests/test_app.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Tests for app-level endpoints (root, not under /v1)."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint_returns_healthy():
|
||||
"""GET /health returns 200 and {"status": "healthy"} for probes and CI."""
|
||||
from httpx import AsyncClient
|
||||
|
||||
from reflector.app import app
|
||||
|
||||
# Health is at app root, not under /v1
|
||||
async with AsyncClient(app=app, base_url="http://test") as root_client:
|
||||
response = await root_client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tests for the password auth backend."""
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from jose import jwt
|
||||
|
||||
from reflector.auth.password_utils import hash_password
|
||||
from reflector.settings import settings
|
||||
|
||||
247
server/tests/test_beat_schedule.py
Normal file
247
server/tests/test_beat_schedule.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Tests for conditional Celery beat schedule registration.
|
||||
|
||||
Verifies that beat tasks are only registered when their corresponding
|
||||
services are configured (WHEREBY_API_KEY, DAILY_API_KEY, etc.).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.worker.app import build_beat_schedule
|
||||
|
||||
|
||||
# Override autouse fixtures from conftest — these tests don't need database or websockets
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_database():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ws_manager_in_memory():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_hatchet_client():
|
||||
yield
|
||||
|
||||
|
||||
# Task name sets for each group
|
||||
WHEREBY_TASKS = {"process_messages", "reprocess_failed_recordings"}
|
||||
DAILY_TASKS = {
|
||||
"poll_daily_recordings",
|
||||
"trigger_daily_reconciliation",
|
||||
"reprocess_failed_daily_recordings",
|
||||
}
|
||||
PLATFORM_TASKS = {
|
||||
"process_meetings",
|
||||
"sync_all_ics_calendars",
|
||||
"create_upcoming_meetings",
|
||||
}
|
||||
|
||||
|
||||
class TestNoPlatformConfigured:
|
||||
"""When no video platform is configured, no platform tasks should be registered."""
|
||||
|
||||
def test_no_platform_tasks(self):
|
||||
schedule = build_beat_schedule()
|
||||
task_names = set(schedule.keys())
|
||||
assert not task_names & WHEREBY_TASKS
|
||||
assert not task_names & DAILY_TASKS
|
||||
assert not task_names & PLATFORM_TASKS
|
||||
|
||||
def test_only_healthcheck_disabled_warning(self):
|
||||
"""With no config at all, schedule should be empty (healthcheck needs URL)."""
|
||||
schedule = build_beat_schedule()
|
||||
assert len(schedule) == 0
|
||||
|
||||
def test_healthcheck_only(self):
|
||||
schedule = build_beat_schedule(healthcheck_url="https://hc.example.com/ping")
|
||||
assert set(schedule.keys()) == {"healthcheck_ping"}
|
||||
|
||||
def test_public_mode_only(self):
|
||||
schedule = build_beat_schedule(public_mode=True)
|
||||
assert set(schedule.keys()) == {"cleanup_old_public_data"}
|
||||
|
||||
|
||||
class TestWherebyOnly:
|
||||
"""When only Whereby is configured."""
|
||||
|
||||
def test_whereby_api_key(self):
|
||||
schedule = build_beat_schedule(whereby_api_key="test-key")
|
||||
task_names = set(schedule.keys())
|
||||
assert WHEREBY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
assert not task_names & DAILY_TASKS
|
||||
|
||||
def test_whereby_sqs_url(self):
|
||||
schedule = build_beat_schedule(
|
||||
aws_process_recording_queue_url="https://sqs.us-east-1.amazonaws.com/123/queue"
|
||||
)
|
||||
task_names = set(schedule.keys())
|
||||
assert WHEREBY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
assert not task_names & DAILY_TASKS
|
||||
|
||||
def test_whereby_task_count(self):
|
||||
schedule = build_beat_schedule(whereby_api_key="test-key")
|
||||
# Whereby (2) + Platform (3) = 5
|
||||
assert len(schedule) == 5
|
||||
|
||||
|
||||
class TestDailyOnly:
|
||||
"""When only Daily.co is configured."""
|
||||
|
||||
def test_daily_api_key(self):
|
||||
schedule = build_beat_schedule(daily_api_key="test-daily-key")
|
||||
task_names = set(schedule.keys())
|
||||
assert DAILY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
assert not task_names & WHEREBY_TASKS
|
||||
|
||||
def test_daily_task_count(self):
|
||||
schedule = build_beat_schedule(daily_api_key="test-daily-key")
|
||||
# Daily (3) + Platform (3) = 6
|
||||
assert len(schedule) == 6
|
||||
|
||||
|
||||
class TestBothPlatforms:
|
||||
"""When both Whereby and Daily.co are configured."""
|
||||
|
||||
def test_all_tasks_registered(self):
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key="test-key",
|
||||
daily_api_key="test-daily-key",
|
||||
)
|
||||
task_names = set(schedule.keys())
|
||||
assert WHEREBY_TASKS <= task_names
|
||||
assert DAILY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
|
||||
def test_combined_task_count(self):
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key="test-key",
|
||||
daily_api_key="test-daily-key",
|
||||
)
|
||||
# Whereby (2) + Daily (3) + Platform (3) = 8
|
||||
assert len(schedule) == 8
|
||||
|
||||
|
||||
class TestConditionalFlags:
|
||||
"""Test PUBLIC_MODE and HEALTHCHECK_URL interact correctly with platform tasks."""
|
||||
|
||||
def test_all_flags_enabled(self):
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key="test-key",
|
||||
daily_api_key="test-daily-key",
|
||||
public_mode=True,
|
||||
healthcheck_url="https://hc.example.com/ping",
|
||||
)
|
||||
task_names = set(schedule.keys())
|
||||
assert "cleanup_old_public_data" in task_names
|
||||
assert "healthcheck_ping" in task_names
|
||||
assert WHEREBY_TASKS <= task_names
|
||||
assert DAILY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
# Whereby (2) + Daily (3) + Platform (3) + cleanup (1) + healthcheck (1) = 10
|
||||
assert len(schedule) == 10
|
||||
|
||||
def test_public_mode_with_whereby(self):
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key="test-key",
|
||||
public_mode=True,
|
||||
)
|
||||
task_names = set(schedule.keys())
|
||||
assert "cleanup_old_public_data" in task_names
|
||||
assert WHEREBY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
|
||||
def test_healthcheck_with_daily(self):
|
||||
schedule = build_beat_schedule(
|
||||
daily_api_key="test-daily-key",
|
||||
healthcheck_url="https://hc.example.com/ping",
|
||||
)
|
||||
task_names = set(schedule.keys())
|
||||
assert "healthcheck_ping" in task_names
|
||||
assert DAILY_TASKS <= task_names
|
||||
assert PLATFORM_TASKS <= task_names
|
||||
|
||||
|
||||
class TestTaskDefinitions:
|
||||
"""Verify task definitions have correct structure."""
|
||||
|
||||
def test_whereby_task_paths(self):
|
||||
schedule = build_beat_schedule(whereby_api_key="test-key")
|
||||
assert (
|
||||
schedule["process_messages"]["task"]
|
||||
== "reflector.worker.process.process_messages"
|
||||
)
|
||||
assert (
|
||||
schedule["reprocess_failed_recordings"]["task"]
|
||||
== "reflector.worker.process.reprocess_failed_recordings"
|
||||
)
|
||||
|
||||
def test_daily_task_paths(self):
|
||||
schedule = build_beat_schedule(daily_api_key="test-daily-key")
|
||||
assert (
|
||||
schedule["poll_daily_recordings"]["task"]
|
||||
== "reflector.worker.process.poll_daily_recordings"
|
||||
)
|
||||
assert (
|
||||
schedule["trigger_daily_reconciliation"]["task"]
|
||||
== "reflector.worker.process.trigger_daily_reconciliation"
|
||||
)
|
||||
assert (
|
||||
schedule["reprocess_failed_daily_recordings"]["task"]
|
||||
== "reflector.worker.process.reprocess_failed_daily_recordings"
|
||||
)
|
||||
|
||||
def test_platform_task_paths(self):
|
||||
schedule = build_beat_schedule(daily_api_key="test-daily-key")
|
||||
assert (
|
||||
schedule["process_meetings"]["task"]
|
||||
== "reflector.worker.process.process_meetings"
|
||||
)
|
||||
assert (
|
||||
schedule["sync_all_ics_calendars"]["task"]
|
||||
== "reflector.worker.ics_sync.sync_all_ics_calendars"
|
||||
)
|
||||
assert (
|
||||
schedule["create_upcoming_meetings"]["task"]
|
||||
== "reflector.worker.ics_sync.create_upcoming_meetings"
|
||||
)
|
||||
|
||||
def test_all_tasks_have_schedule(self):
|
||||
"""Every registered task must have a 'schedule' key."""
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key="test-key",
|
||||
daily_api_key="test-daily-key",
|
||||
public_mode=True,
|
||||
healthcheck_url="https://hc.example.com/ping",
|
||||
)
|
||||
for name, config in schedule.items():
|
||||
assert "schedule" in config, f"Task '{name}' missing 'schedule' key"
|
||||
assert "task" in config, f"Task '{name}' missing 'task' key"
|
||||
|
||||
|
||||
class TestEmptyStringValues:
|
||||
"""Empty strings should be treated as not configured (falsy)."""
|
||||
|
||||
def test_empty_whereby_key(self):
|
||||
schedule = build_beat_schedule(whereby_api_key="")
|
||||
assert not set(schedule.keys()) & WHEREBY_TASKS
|
||||
|
||||
def test_empty_daily_key(self):
|
||||
schedule = build_beat_schedule(daily_api_key="")
|
||||
assert not set(schedule.keys()) & DAILY_TASKS
|
||||
|
||||
def test_empty_sqs_url(self):
|
||||
schedule = build_beat_schedule(aws_process_recording_queue_url="")
|
||||
assert not set(schedule.keys()) & WHEREBY_TASKS
|
||||
|
||||
def test_none_values(self):
|
||||
schedule = build_beat_schedule(
|
||||
whereby_api_key=None,
|
||||
daily_api_key=None,
|
||||
aws_process_recording_queue_url=None,
|
||||
)
|
||||
assert len(schedule) == 0
|
||||
@@ -37,18 +37,3 @@ async def test_hatchet_client_can_replay_handles_exception():
|
||||
|
||||
# Should return False on error (workflow might be gone)
|
||||
assert can_replay is False
|
||||
|
||||
|
||||
def test_hatchet_client_raises_without_token():
|
||||
"""Test that get_client raises ValueError without token.
|
||||
|
||||
Useful: Catches if someone removes the token validation,
|
||||
which would cause cryptic errors later.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
with patch("reflector.hatchet.client.settings") as mock_settings:
|
||||
mock_settings.HATCHET_CLIENT_TOKEN = None
|
||||
|
||||
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
|
||||
HatchetClientManager.get_client()
|
||||
|
||||
303
server/tests/test_hatchet_error_handling.py
Normal file
303
server/tests/test_hatchet_error_handling.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Tests for Hatchet error handling: NonRetryable classification and error status.
|
||||
|
||||
These tests encode the desired behavior from the Hatchet Workflow Analysis doc:
|
||||
- Transient exceptions: do NOT set error status (let Hatchet retry; user stays on "processing").
|
||||
- Hard-fail exceptions: set error status and re-raise as NonRetryableException (stop retries).
|
||||
- on_failure_task: sets error status when workflow is truly dead.
|
||||
|
||||
Run before the fix: some tests fail (reproducing the issues).
|
||||
Run after the fix: all tests pass.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from hatchet_sdk import NonRetryableException
|
||||
|
||||
from reflector.hatchet.error_classification import is_non_retryable
|
||||
from reflector.llm import LLMParseError
|
||||
|
||||
# --- Tests for is_non_retryable() (pass once error_classification exists) ---
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_value_error():
|
||||
"""ValueError (e.g. missing config) should stop retries."""
|
||||
assert is_non_retryable(ValueError("DAILY_API_KEY must be set")) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_type_error():
|
||||
"""TypeError (bad input) should stop retries."""
|
||||
assert is_non_retryable(TypeError("expected str")) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_http_401():
|
||||
"""HTTP 401 auth error should stop retries."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 401
|
||||
err = httpx.HTTPStatusError("Unauthorized", request=MagicMock(), response=resp)
|
||||
assert is_non_retryable(err) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_http_402():
|
||||
"""HTTP 402 (no credits) should stop retries."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 402
|
||||
err = httpx.HTTPStatusError("Payment Required", request=MagicMock(), response=resp)
|
||||
assert is_non_retryable(err) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_http_404():
|
||||
"""HTTP 404 should stop retries."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 404
|
||||
err = httpx.HTTPStatusError("Not Found", request=MagicMock(), response=resp)
|
||||
assert is_non_retryable(err) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_false_for_http_503():
|
||||
"""HTTP 503 is transient; retries are useful."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 503
|
||||
err = httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=MagicMock(), response=resp
|
||||
)
|
||||
assert is_non_retryable(err) is False
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_false_for_timeout():
|
||||
"""Timeout is transient."""
|
||||
assert is_non_retryable(httpx.TimeoutException("timed out")) is False
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_llm_parse_error():
|
||||
"""LLMParseError after internal retries should stop."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class _Dummy(BaseModel):
|
||||
pass
|
||||
|
||||
assert is_non_retryable(LLMParseError(_Dummy, "Failed to parse", 3)) is True
|
||||
|
||||
|
||||
def test_is_non_retryable_returns_true_for_non_retryable_exception():
|
||||
"""Already-wrapped NonRetryableException should stay non-retryable."""
|
||||
assert is_non_retryable(NonRetryableException("custom")) is True
|
||||
|
||||
|
||||
# --- Tests for with_error_handling (need pipeline module with patch) ---
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_module():
|
||||
"""Import daily_multitrack_pipeline with Hatchet client mocked."""
|
||||
with patch("reflector.hatchet.client.settings") as s:
|
||||
s.HATCHET_CLIENT_TOKEN = "test-token"
|
||||
s.HATCHET_DEBUG = False
|
||||
mock_client = MagicMock()
|
||||
mock_client.workflow.return_value = MagicMock()
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
from reflector.hatchet.workflows import daily_multitrack_pipeline
|
||||
|
||||
return daily_multitrack_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_input():
|
||||
"""Minimal PipelineInput for decorator tests."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import PipelineInput
|
||||
|
||||
return PipelineInput(
|
||||
recording_id="rec-1",
|
||||
tracks=[],
|
||||
bucket_name="bucket",
|
||||
transcript_id="ts-123",
|
||||
room_id=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
"""Minimal Context-like object."""
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_error_handling_transient_does_not_set_error_status(
|
||||
pipeline_module, mock_input, mock_ctx
|
||||
):
|
||||
"""Transient exception must NOT set error status (so user stays on 'processing' during retries).
|
||||
|
||||
Before fix: set_workflow_error_status is called on every exception → FAIL.
|
||||
After fix: not called for transient → PASS.
|
||||
"""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise httpx.TimeoutException("timed out")
|
||||
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await wrapped(mock_input, mock_ctx)
|
||||
|
||||
# Desired: do NOT set error status for transient (Hatchet will retry)
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_error_handling_hard_fail_raises_non_retryable_and_sets_status(
|
||||
pipeline_module, mock_input, mock_ctx
|
||||
):
|
||||
"""Hard-fail (e.g. ValueError) must set error status and re-raise NonRetryableException.
|
||||
|
||||
Before fix: raises ValueError, set_workflow_error_status called → test would need to expect ValueError.
|
||||
After fix: raises NonRetryableException, set_workflow_error_status called → PASS.
|
||||
"""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise ValueError("PADDING_URL must be set")
|
||||
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(NonRetryableException) as exc_info:
|
||||
await wrapped(mock_input, mock_ctx)
|
||||
|
||||
assert "PADDING_URL" in str(exc_info.value)
|
||||
mock_set_error.assert_called_once_with("ts-123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_error_handling_set_error_status_false_never_sets_status(
|
||||
pipeline_module, mock_input, mock_ctx
|
||||
):
|
||||
"""When set_error_status=False, we must never set error status (e.g. cleanup_consent)."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise ValueError("something went wrong")
|
||||
|
||||
wrapped = with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)(
|
||||
failing_task
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises((ValueError, NonRetryableException)):
|
||||
await wrapped(mock_input, mock_ctx)
|
||||
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_db_context():
|
||||
"""Async context manager that yields without touching the DB (for unit tests)."""
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_failure_task_sets_error_status(pipeline_module, mock_input, mock_ctx):
|
||||
"""When workflow fails and transcript is not yet 'ended', on_failure sets status to 'error'."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
on_workflow_failure,
|
||||
)
|
||||
|
||||
transcript_processing = MagicMock()
|
||||
transcript_processing.status = "processing"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_processing,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_input, mock_ctx)
|
||||
mock_set_error.assert_called_once_with(mock_input.transcript_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_failure_task_does_not_overwrite_ended(
|
||||
pipeline_module, mock_input, mock_ctx
|
||||
):
|
||||
"""When workflow fails after finalize (e.g. post_zulip), do not overwrite 'ended' with 'error'.
|
||||
|
||||
cleanup_consent, post_zulip, send_webhook use set_error_status=False; if one fails,
|
||||
on_workflow_failure must not set status to 'error' when transcript is already 'ended'.
|
||||
"""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
on_workflow_failure,
|
||||
)
|
||||
|
||||
transcript_ended = MagicMock()
|
||||
transcript_ended.status = "ended"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_ended,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_input, mock_ctx)
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
# --- Tests for fan-out helper (_successful_run_results) ---
|
||||
|
||||
|
||||
def test_successful_run_results_filters_exceptions():
|
||||
"""_successful_run_results returns only non-exception items from aio_run_many(return_exceptions=True)."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
_successful_run_results,
|
||||
)
|
||||
|
||||
results = [
|
||||
{"key": "ok1"},
|
||||
ValueError("child failed"),
|
||||
{"key": "ok2"},
|
||||
RuntimeError("another"),
|
||||
]
|
||||
successful = _successful_run_results(results)
|
||||
assert len(successful) == 2
|
||||
assert successful[0] == {"key": "ok1"}
|
||||
assert successful[1] == {"key": "ok2"}
|
||||
233
server/tests/test_hatchet_file_pipeline.py
Normal file
233
server/tests/test_hatchet_file_pipeline.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Tests for the FilePipeline Hatchet workflow.
|
||||
|
||||
Tests verify:
|
||||
1. with_error_handling behavior for file pipeline input model
|
||||
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
||||
3. Input model validation
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from hatchet_sdk import NonRetryableException
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_db_context():
|
||||
"""Async context manager that yields without touching the DB."""
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def file_pipeline_module():
|
||||
"""Import file_pipeline with Hatchet client mocked."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.workflow.return_value = MagicMock()
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
from reflector.hatchet.workflows import file_pipeline
|
||||
|
||||
return file_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_input():
|
||||
"""Minimal FilePipelineInput for tests."""
|
||||
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||
|
||||
return FilePipelineInput(
|
||||
transcript_id="ts-file-123",
|
||||
room_id="room-456",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
"""Minimal Context-like object."""
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def test_file_pipeline_input_model():
|
||||
"""Test FilePipelineInput validation."""
|
||||
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||
|
||||
# Valid input with room_id
|
||||
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
|
||||
assert input_with_room.transcript_id == "ts-123"
|
||||
assert input_with_room.room_id == "room-456"
|
||||
|
||||
# Valid input without room_id
|
||||
input_no_room = FilePipelineInput(transcript_id="ts-123")
|
||||
assert input_no_room.room_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_error_handling_transient(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""Transient exception must NOT set error status."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise httpx.TimeoutException("timed out")
|
||||
|
||||
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await wrapped(mock_file_input, mock_ctx)
|
||||
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_error_handling_hard_fail(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise ValueError("No audio file found")
|
||||
|
||||
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(NonRetryableException) as exc_info:
|
||||
await wrapped(mock_file_input, mock_ctx)
|
||||
|
||||
assert "No audio file found" in str(exc_info.value)
|
||||
mock_set_error.assert_called_once_with("ts-file-123")
|
||||
|
||||
|
||||
def test_diarize_result_uses_plain_dicts():
|
||||
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
|
||||
|
||||
The diarize task must serialize segments as plain dicts (not call .model_dump()),
|
||||
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
|
||||
This was a real bug: 'dict' object has no attribute 'model_dump'.
|
||||
"""
|
||||
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
# DiarizationSegment is a TypedDict — instances are plain dicts
|
||||
segments = [
|
||||
DiarizationSegment(start=0.0, end=1.5, speaker=0),
|
||||
DiarizationSegment(start=1.5, end=3.0, speaker=1),
|
||||
]
|
||||
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
|
||||
|
||||
# DiarizeResult should accept list[dict] directly (no model_dump needed)
|
||||
result = DiarizeResult(diarization=segments)
|
||||
assert result.diarization is not None
|
||||
assert len(result.diarization) == 2
|
||||
|
||||
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
|
||||
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
|
||||
assert reconstructed[0]["start"] == 0.0
|
||||
assert reconstructed[0]["speaker"] == 0
|
||||
assert reconstructed[1]["end"] == 3.0
|
||||
assert reconstructed[1]["speaker"] == 1
|
||||
|
||||
|
||||
def test_diarize_result_handles_none():
|
||||
"""DiarizeResult with no diarization data (diarization disabled)."""
|
||||
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||
|
||||
result = DiarizeResult(diarization=None)
|
||||
assert result.diarization is None
|
||||
|
||||
result_default = DiarizeResult()
|
||||
assert result_default.diarization is None
|
||||
|
||||
|
||||
def test_transcribe_result_words_are_pydantic():
|
||||
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
|
||||
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
|
||||
from reflector.processors.types import Word
|
||||
|
||||
words = [
|
||||
Word(text="hello", start=0.0, end=0.5),
|
||||
Word(text="world", start=0.5, end=1.0),
|
||||
]
|
||||
# Words are Pydantic models, so model_dump() works
|
||||
word_dicts = [w.model_dump() for w in words]
|
||||
result = TranscribeResult(words=word_dicts)
|
||||
|
||||
# Consumer reconstructs via Word(**w)
|
||||
reconstructed = [Word(**w) for w in result.words]
|
||||
assert reconstructed[0].text == "hello"
|
||||
assert reconstructed[1].start == 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_on_failure_sets_error_status(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure sets error status when transcript is processing."""
|
||||
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||
|
||||
transcript_processing = MagicMock()
|
||||
transcript_processing.status = "processing"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_processing,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
||||
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||
|
||||
transcript_ended = MagicMock()
|
||||
transcript_ended.status = "ended"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_ended,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||
mock_set_error.assert_not_called()
|
||||
218
server/tests/test_hatchet_live_post_pipeline.py
Normal file
218
server/tests/test_hatchet_live_post_pipeline.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Tests for the LivePostProcessingPipeline Hatchet workflow.
|
||||
|
||||
Tests verify:
|
||||
1. with_error_handling behavior for live post pipeline input model
|
||||
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
||||
3. Input model validation
|
||||
4. pipeline_post() now triggers Hatchet instead of Celery chord
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from hatchet_sdk import NonRetryableException
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_db_context():
|
||||
"""Async context manager that yields without touching the DB."""
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def live_pipeline_module():
|
||||
"""Import live_post_pipeline with Hatchet client mocked."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.workflow.return_value = MagicMock()
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
from reflector.hatchet.workflows import live_post_pipeline
|
||||
|
||||
return live_post_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_live_input():
|
||||
"""Minimal LivePostPipelineInput for tests."""
|
||||
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
|
||||
|
||||
return LivePostPipelineInput(
|
||||
transcript_id="ts-live-789",
|
||||
room_id="room-abc",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
"""Minimal Context-like object."""
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def test_live_post_pipeline_input_model():
|
||||
"""Test LivePostPipelineInput validation."""
|
||||
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
|
||||
|
||||
# Valid input with room_id
|
||||
input_with_room = LivePostPipelineInput(transcript_id="ts-123", room_id="room-456")
|
||||
assert input_with_room.transcript_id == "ts-123"
|
||||
assert input_with_room.room_id == "room-456"
|
||||
|
||||
# Valid input without room_id
|
||||
input_no_room = LivePostPipelineInput(transcript_id="ts-123")
|
||||
assert input_no_room.room_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_pipeline_error_handling_transient(
|
||||
live_pipeline_module, mock_live_input, mock_ctx
|
||||
):
|
||||
"""Transient exception must NOT set error status."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise httpx.TimeoutException("timed out")
|
||||
|
||||
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await wrapped(mock_live_input, mock_ctx)
|
||||
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_pipeline_error_handling_hard_fail(
|
||||
live_pipeline_module, mock_live_input, mock_ctx
|
||||
):
|
||||
"""Hard-fail must set error status and raise NonRetryableException."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise ValueError("Transcript not found")
|
||||
|
||||
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(NonRetryableException) as exc_info:
|
||||
await wrapped(mock_live_input, mock_ctx)
|
||||
|
||||
assert "Transcript not found" in str(exc_info.value)
|
||||
mock_set_error.assert_called_once_with("ts-live-789")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_pipeline_on_failure_sets_error_status(
|
||||
live_pipeline_module, mock_live_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure sets error status when transcript is processing."""
|
||||
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
|
||||
|
||||
transcript_processing = MagicMock()
|
||||
transcript_processing.status = "processing"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_processing,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_live_input, mock_ctx)
|
||||
mock_set_error.assert_called_once_with(mock_live_input.transcript_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_pipeline_on_failure_does_not_overwrite_ended(
|
||||
live_pipeline_module, mock_live_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
||||
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
|
||||
|
||||
transcript_ended = MagicMock()
|
||||
transcript_ended.status = "ended"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_ended,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_live_input, mock_ctx)
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_post_triggers_hatchet():
|
||||
"""pipeline_post() should trigger Hatchet LivePostProcessingPipeline workflow."""
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.start_workflow",
|
||||
new_callable=AsyncMock,
|
||||
return_value="workflow-run-id",
|
||||
) as mock_start:
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||
|
||||
await pipeline_post(transcript_id="ts-test-123", room_id="room-test")
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
"LivePostProcessingPipeline",
|
||||
{
|
||||
"transcript_id": "ts-test-123",
|
||||
"room_id": "room-test",
|
||||
},
|
||||
additional_metadata={"transcript_id": "ts-test-123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_post_triggers_hatchet_without_room_id():
|
||||
"""pipeline_post() should handle None room_id."""
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.start_workflow",
|
||||
new_callable=AsyncMock,
|
||||
return_value="workflow-run-id",
|
||||
) as mock_start:
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||
|
||||
await pipeline_post(transcript_id="ts-test-456")
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
"LivePostProcessingPipeline",
|
||||
{
|
||||
"transcript_id": "ts-test-456",
|
||||
"room_id": None,
|
||||
},
|
||||
additional_metadata={"transcript_id": "ts-test-456"},
|
||||
)
|
||||
90
server/tests/test_hatchet_trigger_migration.py
Normal file
90
server/tests/test_hatchet_trigger_migration.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Tests verifying Celery-to-Hatchet trigger migration.
|
||||
|
||||
Ensures that:
|
||||
1. process_recording triggers FilePipeline via Hatchet (not Celery)
|
||||
2. transcript_record_upload triggers FilePipeline via Hatchet (not Celery)
|
||||
3. Old Celery task references are no longer in active call sites
|
||||
"""
|
||||
|
||||
|
||||
def test_process_recording_does_not_import_celery_file_task():
|
||||
"""Verify process.py no longer imports task_pipeline_file_process."""
|
||||
import inspect
|
||||
|
||||
from reflector.worker import process
|
||||
|
||||
source = inspect.getsource(process)
|
||||
# Should not contain the old Celery task import
|
||||
assert "task_pipeline_file_process" not in source
|
||||
|
||||
|
||||
def test_transcripts_upload_does_not_import_celery_file_task():
|
||||
"""Verify transcripts_upload.py no longer imports task_pipeline_file_process."""
|
||||
import inspect
|
||||
|
||||
from reflector.views import transcripts_upload
|
||||
|
||||
source = inspect.getsource(transcripts_upload)
|
||||
# Should not contain the old Celery task import
|
||||
assert "task_pipeline_file_process" not in source
|
||||
|
||||
|
||||
def test_transcripts_upload_imports_hatchet():
|
||||
"""Verify transcripts_upload.py imports HatchetClientManager."""
|
||||
import inspect
|
||||
|
||||
from reflector.views import transcripts_upload
|
||||
|
||||
source = inspect.getsource(transcripts_upload)
|
||||
assert "HatchetClientManager" in source
|
||||
|
||||
|
||||
def test_pipeline_post_is_async():
|
||||
"""Verify pipeline_post is now async (Hatchet trigger)."""
|
||||
import asyncio
|
||||
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post
|
||||
|
||||
assert asyncio.iscoroutinefunction(pipeline_post)
|
||||
|
||||
|
||||
def test_transcript_process_service_does_not_import_celery_file_task():
|
||||
"""Verify transcript_process.py service no longer imports task_pipeline_file_process."""
|
||||
import inspect
|
||||
|
||||
from reflector.services import transcript_process
|
||||
|
||||
source = inspect.getsource(transcript_process)
|
||||
assert "task_pipeline_file_process" not in source
|
||||
|
||||
|
||||
def test_transcript_process_service_dispatch_uses_hatchet():
|
||||
"""Verify dispatch_transcript_processing uses HatchetClientManager for file processing."""
|
||||
import inspect
|
||||
|
||||
from reflector.services import transcript_process
|
||||
|
||||
source = inspect.getsource(transcript_process.dispatch_transcript_processing)
|
||||
assert "HatchetClientManager" in source
|
||||
assert "FilePipeline" in source
|
||||
|
||||
|
||||
def test_new_task_names_exist():
|
||||
"""Verify new TaskName constants were added for file and live pipelines."""
|
||||
from reflector.hatchet.constants import TaskName
|
||||
|
||||
# File pipeline tasks
|
||||
assert TaskName.EXTRACT_AUDIO == "extract_audio"
|
||||
assert TaskName.UPLOAD_AUDIO == "upload_audio"
|
||||
assert TaskName.TRANSCRIBE == "transcribe"
|
||||
assert TaskName.DIARIZE == "diarize"
|
||||
assert TaskName.ASSEMBLE_TRANSCRIPT == "assemble_transcript"
|
||||
assert TaskName.GENERATE_SUMMARIES == "generate_summaries"
|
||||
|
||||
# Live post-processing pipeline tasks
|
||||
assert TaskName.WAVEFORM == "waveform"
|
||||
assert TaskName.CONVERT_MP3 == "convert_mp3"
|
||||
assert TaskName.UPLOAD_MP3 == "upload_mp3"
|
||||
assert TaskName.REMOVE_UPLOAD == "remove_upload"
|
||||
assert TaskName.FINAL_SUMMARIES == "final_summaries"
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
@@ -252,6 +252,63 @@ class TestNetworkErrorRetries:
|
||||
assert mock_settings.llm.astructured_predict.call_count == 3
|
||||
|
||||
|
||||
class TestGetResponseRetries:
|
||||
"""Test that get_response() uses the same retry() wrapper for transient errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_retries_on_connection_error(self, test_settings):
|
||||
"""Test that get_response retries on ConnectionError and returns on success."""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.aget_response = AsyncMock(
|
||||
side_effect=[
|
||||
ConnectionError("Connection refused"),
|
||||
" Summary text ",
|
||||
]
|
||||
)
|
||||
|
||||
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||
result = await llm.get_response("Prompt", ["text"])
|
||||
|
||||
assert result == "Summary text"
|
||||
assert mock_instance.aget_response.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_exhausts_retries(self, test_settings):
|
||||
"""Test that get_response raises RetryException after retry attempts exceeded."""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.aget_response = AsyncMock(
|
||||
side_effect=ConnectionError("Connection refused")
|
||||
)
|
||||
|
||||
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
||||
await llm.get_response("Prompt", ["text"])
|
||||
|
||||
assert mock_instance.aget_response.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_returns_empty_string_without_retry(self, test_settings):
|
||||
"""Empty or whitespace-only LLM response must return '' and not raise RetryException.
|
||||
|
||||
retry() must return falsy results (e.g. '' from get_response) instead of
|
||||
treating them as 'no result' and retrying until RetryException.
|
||||
"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> ""
|
||||
|
||||
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
||||
result = await llm.get_response("Prompt", ["text"])
|
||||
|
||||
assert result == ""
|
||||
assert mock_instance.aget_response.call_count == 1
|
||||
|
||||
|
||||
class TestTextsInclusion:
|
||||
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user