mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
99 Commits
mathieu/ca
...
mathieu/sq
| Author | SHA1 | Date | |
|---|---|---|---|
| 8fcfac80fa | |||
| 9b3da4b2c8 | |||
| d86dc59bf2 | |||
| b7f8e8ef8d | |||
| 27f19ec6ba | |||
| 2aa99fe846 | |||
| df909363f5 | |||
| ad2accb574 | |||
| a07c621bcd | |||
| f51dae8da3 | |||
| b217c7ba41 | |||
| 0b2152ea75 | |||
| e0c71c5548 | |||
| a883df0d63 | |||
| 1c9e8b9cde | |||
| 27b3b9cdee | |||
| 8ad1270229 | |||
| 617a1c8b32 | |||
| 60cc2b16ae | |||
| 606c5f5059 | |||
| 5e036d17b6 | |||
| 04a9c2f2f7 | |||
| fb5bb39716 | |||
| 4f70a7f593 | |||
| 224e40225d | |||
| 24980de4e0 | |||
| 7f178b5f9e | |||
| 0aaa42528a | |||
| 565a62900f | |||
|
|
27016e6051 | ||
| 6ddfee0b4e | |||
|
|
47716f6e5d | ||
| 1520f88e9e | |||
| 9b90aaa57f | |||
| d21b65e4e8 | |||
| 45d1608950 | |||
| 06639d4d8f | |||
| 0abcebfc94 | |||
|
|
2b723da08b | ||
| 6566e04300 | |||
| 870e860517 | |||
| 396a95d5ce | |||
| 6f680b5795 | |||
| ab859d65a6 | |||
| fa049e8d06 | |||
| 2ce7479967 | |||
| b42f7cfc60 | |||
| c546e69739 | |||
|
|
3f1fe8c9bf | ||
| 5f143fe364 | |||
|
|
79f161436e | ||
|
|
5cba5d310d | ||
| 43ea9349f5 | |||
|
|
b3a8e9739d | ||
|
|
369ecdff13 | ||
| fc363bd49b | |||
|
|
962038ee3f | ||
|
|
3b85ff3bdf | ||
|
|
cde99ca271 | ||
|
|
f81fe9948a | ||
|
|
5a5b323382 | ||
| 02a3938822 | |||
|
|
7f5a4c9ddc | ||
|
|
08d88ec349 | ||
|
|
c4d2825c81 | ||
| 0663700a61 | |||
| dc82f8bb3b | |||
| 457823e1c1 | |||
|
|
695d1a957d | ||
| ccffdba75b | |||
| 84a381220b | |||
| 5f2f0e9317 | |||
| 88ed7cfa78 | |||
| 6f0c7c1a5e | |||
| 9dfd76996f | |||
| 55cc8637c6 | |||
| f5331a2107 | |||
|
|
124ce03bf8 | ||
| 7030e0f236 | |||
| 37f0110892 | |||
| cf2896a7f4 | |||
| aabf2c2572 | |||
| 6a7b08f016 | |||
| e2736563d9 | |||
| 0f54b7782d | |||
| 359280dd34 | |||
| 9265d201b5 | |||
| 52f9f533d7 | |||
| 0c3878ac3c | |||
|
|
d70beee51b | ||
| bc5b351d2b | |||
|
|
07981e8090 | ||
| 7e366f6338 | |||
| 7592679a35 | |||
| af16178f86 | |||
| 3ea7f6b7b6 | |||
|
|
009590c080 | ||
|
|
fe5d344cff | ||
|
|
86455ce573 |
5
.github/workflows/db_migrations.yml
vendored
5
.github/workflows/db_migrations.yml
vendored
@@ -2,6 +2,8 @@ name: Test Database Migrations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/migrations/**"
|
||||
- "server/reflector/db/**"
|
||||
@@ -17,6 +19,9 @@ on:
|
||||
jobs:
|
||||
test-migrations:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: db-ubuntu-latest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
|
||||
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,18 +8,30 @@ env:
|
||||
ECR_REPOSITORY: reflector
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
runner: linux-amd64
|
||||
arch: amd64
|
||||
- platform: linux/arm64
|
||||
runner: linux-arm64
|
||||
arch: arm64
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
outputs:
|
||||
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
@@ -27,21 +39,52 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
- name: Build and push ${{ matrix.arch }}
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
provenance: false
|
||||
|
||||
create-manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build]
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Create and push multi-arch manifest
|
||||
run: |
|
||||
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||
|
||||
docker manifest create \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||
|
||||
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||
|
||||
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||
|
||||
45
.github/workflows/test_next_server.yml
vendored
Normal file
45
.github/workflows/test_next_server.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Test Next Server
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "www/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "www/**"
|
||||
|
||||
jobs:
|
||||
test-next-server:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./www
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8
|
||||
|
||||
- name: Setup Node.js cache
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: './www/pnpm-lock.yaml'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test
|
||||
49
.github/workflows/test_server.yml
vendored
49
.github/workflows/test_server.yml
vendored
@@ -5,12 +5,17 @@ on:
|
||||
paths:
|
||||
- "server/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/**"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: pytest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
redis:
|
||||
image: redis:6
|
||||
@@ -19,29 +24,47 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: server
|
||||
|
||||
- name: Tests
|
||||
run: |
|
||||
cd server
|
||||
uv run -m pytest -v tests
|
||||
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
docker-amd64:
|
||||
runs-on: linux-amd64
|
||||
concurrency:
|
||||
group: docker-amd64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build AMD64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64
|
||||
cache-from: type=gha,scope=amd64
|
||||
cache-to: type=gha,mode=max,scope=amd64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
docker-arm64:
|
||||
runs-on: linux-arm64
|
||||
concurrency:
|
||||
group: docker-arm64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build ARM64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/arm64
|
||||
cache-from: type=gha,scope=arm64
|
||||
cache-to: type=gha,mode=max,scope=arm64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -14,4 +14,7 @@ data/
|
||||
www/REFACTOR.md
|
||||
www/reload-frontend
|
||||
server/test.sqlite
|
||||
CLAUDE.local.md
|
||||
CLAUDE.local.md
|
||||
www/.env.development
|
||||
www/.env.production
|
||||
.playwright-mcp
|
||||
|
||||
1
.gitleaksignore
Normal file
1
.gitleaksignore
Normal file
@@ -0,0 +1 @@
|
||||
b9d891d3424f371642cb032ecfd0e2564470a72c:server/tests/test_transcripts_recording_deletion.py:generic-api-key:15
|
||||
@@ -27,3 +27,8 @@ repos:
|
||||
files: ^server/
|
||||
- id: ruff-format
|
||||
files: ^server/
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
153
CHANGELOG.md
153
CHANGELOG.md
@@ -1,5 +1,158 @@
|
||||
# Changelog
|
||||
|
||||
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* TypeError on not all arguments converted during string formatting in logger ([#667](https://github.com/Monadical-SAS/reflector/issues/667)) ([565a629](https://github.com/Monadical-SAS/reflector/commit/565a62900f5a02fc946b68f9269a42190ed70ab6))
|
||||
|
||||
## [0.13.0](https://github.com/Monadical-SAS/reflector/compare/v0.12.1...v0.13.0) (2025-09-19)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* room form edit with enter ([#662](https://github.com/Monadical-SAS/reflector/issues/662)) ([47716f6](https://github.com/Monadical-SAS/reflector/commit/47716f6e5ddee952609d2fa0ffabdfa865286796))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* invalid cleanup call ([#660](https://github.com/Monadical-SAS/reflector/issues/660)) ([0abcebf](https://github.com/Monadical-SAS/reflector/commit/0abcebfc9491f87f605f21faa3e53996fafedd9a))
|
||||
|
||||
## [0.12.1](https://github.com/Monadical-SAS/reflector/compare/v0.12.0...v0.12.1) (2025-09-17)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* production blocked because having existing meeting with room_id null ([#657](https://github.com/Monadical-SAS/reflector/issues/657)) ([870e860](https://github.com/Monadical-SAS/reflector/commit/870e8605171a27155a9cbee215eeccb9a8d6c0a2))
|
||||
|
||||
## [0.12.0](https://github.com/Monadical-SAS/reflector/compare/v0.11.0...v0.12.0) (2025-09-17)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* calendar integration ([#608](https://github.com/Monadical-SAS/reflector/issues/608)) ([6f680b5](https://github.com/Monadical-SAS/reflector/commit/6f680b57954c688882c4ed49f40f161c52a00a24))
|
||||
* self-hosted gpu api ([#636](https://github.com/Monadical-SAS/reflector/issues/636)) ([ab859d6](https://github.com/Monadical-SAS/reflector/commit/ab859d65a6bded904133a163a081a651b3938d42))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* ignore player hotkeys for text inputs ([#646](https://github.com/Monadical-SAS/reflector/issues/646)) ([fa049e8](https://github.com/Monadical-SAS/reflector/commit/fa049e8d068190ce7ea015fd9fcccb8543f54a3f))
|
||||
|
||||
## [0.11.0](https://github.com/Monadical-SAS/reflector/compare/v0.10.0...v0.11.0) (2025-09-16)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* remove profanity filter that was there for conference ([#652](https://github.com/Monadical-SAS/reflector/issues/652)) ([b42f7cf](https://github.com/Monadical-SAS/reflector/commit/b42f7cfc606783afcee792590efcc78b507468ab))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* zulip and consent handler on the file pipeline ([#645](https://github.com/Monadical-SAS/reflector/issues/645)) ([5f143fe](https://github.com/Monadical-SAS/reflector/commit/5f143fe3640875dcb56c26694254a93189281d17))
|
||||
* zulip stream and topic selection in share dialog ([#644](https://github.com/Monadical-SAS/reflector/issues/644)) ([c546e69](https://github.com/Monadical-SAS/reflector/commit/c546e69739e68bb74fbc877eb62609928e5b8de6))
|
||||
|
||||
## [0.10.0](https://github.com/Monadical-SAS/reflector/compare/v0.9.0...v0.10.0) (2025-09-11)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* replace nextjs-config with environment variables ([#632](https://github.com/Monadical-SAS/reflector/issues/632)) ([369ecdf](https://github.com/Monadical-SAS/reflector/commit/369ecdff13f3862d926a9c0b87df52c9d94c4dde))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* anonymous users transcript permissions ([#621](https://github.com/Monadical-SAS/reflector/issues/621)) ([f81fe99](https://github.com/Monadical-SAS/reflector/commit/f81fe9948a9237b3e0001b2d8ca84f54d76878f9))
|
||||
* auth post ([#624](https://github.com/Monadical-SAS/reflector/issues/624)) ([cde99ca](https://github.com/Monadical-SAS/reflector/commit/cde99ca2716f84ba26798f289047732f0448742e))
|
||||
* auth post ([#626](https://github.com/Monadical-SAS/reflector/issues/626)) ([3b85ff3](https://github.com/Monadical-SAS/reflector/commit/3b85ff3bdf4fb053b103070646811bc990c0e70a))
|
||||
* auth post ([#627](https://github.com/Monadical-SAS/reflector/issues/627)) ([962038e](https://github.com/Monadical-SAS/reflector/commit/962038ee3f2a555dc3c03856be0e4409456e0996))
|
||||
* missing follow_redirects=True on modal endpoint ([#630](https://github.com/Monadical-SAS/reflector/issues/630)) ([fc363bd](https://github.com/Monadical-SAS/reflector/commit/fc363bd49b17b075e64f9186e5e0185abc325ea7))
|
||||
* sync backend and frontend token refresh logic ([#614](https://github.com/Monadical-SAS/reflector/issues/614)) ([5a5b323](https://github.com/Monadical-SAS/reflector/commit/5a5b3233820df9536da75e87ce6184a983d4713a))
|
||||
|
||||
## [0.9.0](https://github.com/Monadical-SAS/reflector/compare/v0.8.2...v0.9.0) (2025-09-06)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* frontend openapi react query ([#606](https://github.com/Monadical-SAS/reflector/issues/606)) ([c4d2825](https://github.com/Monadical-SAS/reflector/commit/c4d2825c81f81ad8835629fbf6ea8c7383f8c31b))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* align whisper transcriber api with parakeet ([#602](https://github.com/Monadical-SAS/reflector/issues/602)) ([0663700](https://github.com/Monadical-SAS/reflector/commit/0663700a615a4af69a03c96c410f049e23ec9443))
|
||||
* kv use tls explicit ([#610](https://github.com/Monadical-SAS/reflector/issues/610)) ([08d88ec](https://github.com/Monadical-SAS/reflector/commit/08d88ec349f38b0d13e0fa4cb73486c8dfd31836))
|
||||
* source kind for file processing ([#601](https://github.com/Monadical-SAS/reflector/issues/601)) ([dc82f8b](https://github.com/Monadical-SAS/reflector/commit/dc82f8bb3bdf3ab3d4088e592a30fd63907319e1))
|
||||
* token refresh locking ([#613](https://github.com/Monadical-SAS/reflector/issues/613)) ([7f5a4c9](https://github.com/Monadical-SAS/reflector/commit/7f5a4c9ddc7fd098860c8bdda2ca3b57f63ded2f))
|
||||
|
||||
## [0.8.2](https://github.com/Monadical-SAS/reflector/compare/v0.8.1...v0.8.2) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* search-logspam ([#593](https://github.com/Monadical-SAS/reflector/issues/593)) ([695d1a9](https://github.com/Monadical-SAS/reflector/commit/695d1a957d4cd862753049f9beed88836cabd5ab))
|
||||
|
||||
## [0.8.1](https://github.com/Monadical-SAS/reflector/compare/v0.8.0...v0.8.1) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* make webhook secret/url allowing null ([#590](https://github.com/Monadical-SAS/reflector/issues/590)) ([84a3812](https://github.com/Monadical-SAS/reflector/commit/84a381220bc606231d08d6f71d4babc818fa3c75))
|
||||
|
||||
## [0.8.0](https://github.com/Monadical-SAS/reflector/compare/v0.7.3...v0.8.0) (2025-08-29)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* **cleanup:** add automatic data retention for public instances ([#574](https://github.com/Monadical-SAS/reflector/issues/574)) ([6f0c7c1](https://github.com/Monadical-SAS/reflector/commit/6f0c7c1a5e751713366886c8e764c2009e12ba72))
|
||||
* **rooms:** add webhook for transcript completion ([#578](https://github.com/Monadical-SAS/reflector/issues/578)) ([88ed7cf](https://github.com/Monadical-SAS/reflector/commit/88ed7cfa7804794b9b54cad4c3facc8a98cf85fd))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* file pipeline status reporting and websocket updates ([#589](https://github.com/Monadical-SAS/reflector/issues/589)) ([9dfd769](https://github.com/Monadical-SAS/reflector/commit/9dfd76996f851cc52be54feea078adbc0816dc57))
|
||||
* Igor/evaluation ([#575](https://github.com/Monadical-SAS/reflector/issues/575)) ([124ce03](https://github.com/Monadical-SAS/reflector/commit/124ce03bf86044c18313d27228a25da4bc20c9c5))
|
||||
* optimize parakeet transcription batching algorithm ([#577](https://github.com/Monadical-SAS/reflector/issues/577)) ([7030e0f](https://github.com/Monadical-SAS/reflector/commit/7030e0f23649a8cf6c1eb6d5889684a41ce849ec))
|
||||
|
||||
## [0.7.3](https://github.com/Monadical-SAS/reflector/compare/v0.7.2...v0.7.3) (2025-08-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* cleaned repo, and get git-leaks clean ([359280d](https://github.com/Monadical-SAS/reflector/commit/359280dd340433ba4402ed69034094884c825e67))
|
||||
* restore previous behavior on live pipeline + audio downscaler ([#561](https://github.com/Monadical-SAS/reflector/issues/561)) ([9265d20](https://github.com/Monadical-SAS/reflector/commit/9265d201b590d23c628c5f19251b70f473859043))
|
||||
|
||||
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
|
||||
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
|
||||
|
||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
||||
|
||||
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
||||
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
||||
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
||||
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
||||
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
||||
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
||||
|
||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ pnpm install
|
||||
|
||||
# Copy configuration templates
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
```
|
||||
|
||||
**Development:**
|
||||
|
||||
81
README.md
81
README.md
@@ -1,43 +1,60 @@
|
||||
<div align="center">
|
||||
<img width="100" alt="image" src="https://github.com/user-attachments/assets/66fb367b-2c89-4516-9912-f47ac59c6a7f"/>
|
||||
|
||||
# Reflector
|
||||
|
||||
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
Reflector is an AI-powered audio transcription and meeting analysis platform that provides real-time transcription, speaker diarization, translation and summarization for audio content and live meetings. It works 100% with local models (whisper/parakeet, pyannote, seamless-m4t, and your local llm like phi-4).
|
||||
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/pytests.yml)
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/test_server.yml)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
</div>
|
||||
|
||||
## Screenshots
|
||||
</div>
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3" />
|
||||
<a href="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33" />
|
||||
<a href="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc" />
|
||||
<a href="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4" />
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## What is Reflector?
|
||||
|
||||
Reflector is a web application that utilizes local models to process audio content, providing:
|
||||
|
||||
- **Real-time Transcription**: Convert speech to text using [Whisper](https://github.com/openai/whisper) (multi-language) or [Parakeet](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) models
|
||||
- **Speaker Diarization**: Identify and label different speakers using [Pyannote](https://github.com/pyannote/pyannote-audio) 3.1
|
||||
- **Live Translation**: Translate audio content in real-time to many languages with [Facebook Seamless-M4T](https://github.com/facebookresearch/seamless_communication)
|
||||
- **Topic Detection & Summarization**: Extract key topics and generate concise summaries using LLMs
|
||||
- **Meeting Recording**: Create permanent records of meetings with searchable transcripts
|
||||
|
||||
Currently we provide [modal.com](https://modal.com/) gpu template to deploy.
|
||||
|
||||
## Background
|
||||
|
||||
The project architecture consists of three primary components:
|
||||
|
||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. Most reliable option is Modal deployment
|
||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
|
||||
It also uses authentik for authentication if activated, and Vercel for deployment and configuration of the front-end.
|
||||
It also uses authentik for authentication if activated.
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
@@ -72,6 +89,8 @@ Note: We currently do not have instructions for Windows users.
|
||||
|
||||
## Installation
|
||||
|
||||
*Note: we're working toward better installation, theses instructions are not accurate for now*
|
||||
|
||||
### Frontend
|
||||
|
||||
Start with `cd www`.
|
||||
@@ -80,11 +99,10 @@ Start with `cd www`.
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
Then, fill in the environment variables in `.env` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
|
||||
**Run in development mode**
|
||||
|
||||
@@ -149,3 +167,34 @@ You can manually process an audio file by calling the process tool:
|
||||
```bash
|
||||
uv run python -m reflector.tools.process path/to/audio.wav
|
||||
```
|
||||
|
||||
|
||||
## Feature Flags
|
||||
|
||||
Reflector uses environment variable-based feature flags to control application functionality. These flags allow you to enable or disable features without code changes.
|
||||
|
||||
### Available Feature Flags
|
||||
|
||||
| Feature Flag | Environment Variable |
|
||||
|-------------|---------------------|
|
||||
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
|
||||
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
|
||||
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
|
||||
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
|
||||
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
|
||||
|
||||
### Setting Feature Flags
|
||||
|
||||
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Enable user authentication requirement
|
||||
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
||||
|
||||
# Disable browse functionality
|
||||
NEXT_PUBLIC_FEATURE_BROWSE=false
|
||||
|
||||
# Enable Zulip integration
|
||||
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
||||
```
|
||||
|
||||
@@ -6,6 +6,7 @@ services:
|
||||
- 1250:1250
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -16,6 +17,7 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -26,6 +28,7 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
|
||||
33
gpu/modal_deployments/.gitignore
vendored
Normal file
33
gpu/modal_deployments/.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
.env.*
|
||||
*.env
|
||||
*.secret
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Modal local state (if any)
|
||||
modal_mounts/
|
||||
.modal_cache/
|
||||
171
gpu/modal_deployments/README.md
Normal file
171
gpu/modal_deployments/README.md
Normal file
@@ -0,0 +1,171 @@
|
||||
# Reflector GPU implementation - Transcription and LLM
|
||||
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
Create a modal secret, and name it `reflector-gpu`.
|
||||
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
||||
|
||||
The deployment is done using [Modal.com](https://modal.com) service.
|
||||
|
||||
```
|
||||
$ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_transcriber_parakeet.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
```
|
||||
|
||||
Then in your reflector api configuration `.env`, you can set these keys:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
||||
|
||||
```
|
||||
Authorization: bearer <REFLECTOR_APIKEY>
|
||||
```
|
||||
|
||||
### LLM
|
||||
|
||||
`POST /llm`
|
||||
|
||||
**request**
|
||||
```
|
||||
{
|
||||
"prompt": "xxx"
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx completed"
|
||||
}
|
||||
```
|
||||
|
||||
### Transcription
|
||||
|
||||
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||
|
||||
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||
|
||||
**GPU Configuration:**
|
||||
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||
- Higher concurrency (max_inputs=10)
|
||||
- Optimized for multiple small audio files
|
||||
- Supports batch processing for efficiency
|
||||
|
||||
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||
- Lower concurrency but more powerful processing
|
||||
- Optimized for single large audio files
|
||||
- VAD-based chunking for long-form audio
|
||||
|
||||
##### `/v1/audio/transcriptions` - Small file transcription
|
||||
|
||||
**request** (multipart/form-data)
|
||||
- `file` or `files[]` - audio file(s) to transcribe
|
||||
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||
- `language` - language code (default: `en`)
|
||||
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
For multiple files with batch=true:
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"filename": "audio1.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
},
|
||||
{
|
||||
"filename": "audio2.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||
|
||||
**request** (application/json)
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text from large file",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||
|
||||
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
- `file` - audio file
|
||||
- `language` - language code (e.g. `en`)
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx",
|
||||
"words": [
|
||||
{"text": "xxx", "start": 0.0, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
253
gpu/modal_deployments/reflector_diarizer.py
Normal file
253
gpu/modal_deployments/reflector_diarizer.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Reflector GPU backend - diarizer
|
||||
===================================
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
UPLOADS_PATH = "/uploads"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App(name="reflector-diarizer")
|
||||
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
print(f"Checking audio file at: {audio_file_url}")
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
print(f"Downloading audio file from: {audio_file_url}")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to download audio file: {response.status_code}",
|
||||
)
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
print(f"File saved as: {unique_filename}")
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
def download_pyannote_audio():
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
modal.Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
"onnx",
|
||||
"torchaudio",
|
||||
"onnxruntime-gpu",
|
||||
"torch==2.0.0",
|
||||
"transformers==4.34.0",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"numpy",
|
||||
"huggingface_hub",
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio,
|
||||
secrets=[modal.Secret.from_name("hf_token")],
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A100",
|
||||
timeout=60 * 30,
|
||||
image=diarizer_image,
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
secrets=[
|
||||
modal.Secret.from_name("hf_token"),
|
||||
],
|
||||
)
|
||||
@modal.concurrent(max_inputs=1)
|
||||
class Diarizer:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@modal.method()
|
||||
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||
import torchaudio
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
print(f"Diarizing audio from: {file_path}")
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
image=diarizer_image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = diarizerstub.diarize.spawn(
|
||||
filename=unique_filename, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
608
gpu/modal_deployments/reflector_transcriber.py
Normal file
608
gpu/modal_deployments/reflector_transcriber.py
Normal file
@@ -0,0 +1,608 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
MINUTES = 60 # seconds
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/models"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
model_cache = modal.Volume.from_name("models", create_if_missing=True)
|
||||
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
model_cache.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
|
||||
|
||||
model_cache.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
"fastapi==0.115.12",
|
||||
"requests",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
)
|
||||
.run_function(download_model, volumes={CACHE_PATH: model_cache})
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
|
||||
|
||||
Whisper does not require this strictly, but aligning behavior with Parakeet
|
||||
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberWhisperLive:
|
||||
"""Live transcriber class for small audio segments (A10G).
|
||||
|
||||
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
|
||||
"""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
print(f"Model is on device: {self.device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe a single uploaded audio file by filename."""
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe multiple uploaded audio files and return per-file results."""
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start), 2),
|
||||
"end": round(float(w.end), 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=15 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
class TranscriberWhisperFile:
|
||||
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLERATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments as TimeSegment using Silero VAD."""
|
||||
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield TimeSegment(
|
||||
start / float(SAMPLERATE), end / float(SAMPLERATE)
|
||||
)
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
upload_volume.reload()
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
|
||||
# Batch segments up to ~30s windows by merging contiguous VAD segments
|
||||
merged_batches: list[TimeSegment] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for segment in vad_segments(audio_array):
|
||||
seg_start, seg_end = segment.start, segment.end
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
|
||||
for segment in merged_batches:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
s_idx = int(start_time * SAMPLERATE)
|
||||
e_idx = int(end_time * SAMPLERATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = pad_audio(segment, SAMPLERATE)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: dict) -> str:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
transcriber_live = TranscriberWhisperLive()
|
||||
transcriber_file = TranscriberWhisperFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(dict):
|
||||
pass
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_filenames: list[str] = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
if batch and len(upload_files) > 1:
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
language=language,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
658
gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
658
gpu/modal_deployments/reflector_transcriber_parakeet.py
Normal file
@@ -0,0 +1,658 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/cache"
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
app = modal.App("reflector-transcriber-parakeet")
|
||||
|
||||
# Volume for caching model weights
|
||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry(
|
||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"HF_HOME": "/cache",
|
||||
"DEBIAN_FRONTEND": "noninteractive",
|
||||
"CXX": "g++",
|
||||
"CC": "g++",
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"hf_transfer==0.1.9",
|
||||
"huggingface_hub[hf-xet]==0.31.2",
|
||||
"nemo_toolkit[asr]==2.3.0",
|
||||
"cuda-python==12.8.0",
|
||||
"fastapi==0.115.12",
|
||||
"numpy<2",
|
||||
"librosa==0.10.1",
|
||||
"requests",
|
||||
"silero-vad==5.1.0",
|
||||
"torch",
|
||||
)
|
||||
.entrypoint([]) # silence chatty logs by container on start
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||
|
||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||
have to be of the same length
|
||||
|
||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < 0.5:
|
||||
silence_samples = int(sample_rate * 0.5)
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=600,
|
||||
scaledown_window=300,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberParakeetLive:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||
|
||||
text = output.text.strip()
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
# XXX the space added here is to match the output of whisper
|
||||
# whisper add space to each words, while parakeet don't
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
audio_arrays = []
|
||||
|
||||
# Load all audio files with padding
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
audio_arrays.append(padded_audio)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||
|
||||
# Process results for each file
|
||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||
text = output.text.strip()
|
||||
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# L40S class for file transcription (bigger files)
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=900,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
class TranscriberParakeetFile:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
print("Silero VAD initialized")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
timestamp_offset: float = 0.0,
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def load_and_convert_audio(file_path):
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
return audio_array
|
||||
|
||||
def vad_segment_generator(
|
||||
audio_array,
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments using VAD with start/end sample indices"""
|
||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||
window_size = VAD_CONFIG["window_size"]
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
|
||||
speech_dict = vad_iterator(chunk)
|
||||
if not speech_dict:
|
||||
continue
|
||||
|
||||
if "start" in speech_dict:
|
||||
start = speech_dict["start"]
|
||||
continue
|
||||
|
||||
if "end" in speech_dict and start is not None:
|
||||
end = speech_dict["end"]
|
||||
start_time = start / float(SAMPLERATE)
|
||||
end_time = end / float(SAMPLERATE)
|
||||
|
||||
yield TimeSegment(start_time, end_time)
|
||||
start = None
|
||||
|
||||
vad_iterator.reset_states()
|
||||
|
||||
def batch_speech_segments(
|
||||
segments: Generator[TimeSegment, None, None], max_duration: int
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""
|
||||
Input segments:
|
||||
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
||||
|
||||
↓ (max_duration=10)
|
||||
|
||||
Output batches:
|
||||
[0-8] [10-19] [20-22]
|
||||
|
||||
Note: silences are kept for better transcription, previous implementation was
|
||||
passing segments separatly, but the output was less accurate.
|
||||
"""
|
||||
batch_start_time = None
|
||||
batch_end_time = None
|
||||
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
total_duration = end_time - batch_start_time
|
||||
|
||||
if total_duration <= max_duration:
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
return
|
||||
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
|
||||
def batch_segment_to_audio_segment(
|
||||
segments: Generator[TimeSegment, None, None],
|
||||
audio_array,
|
||||
) -> Generator[AudioSegment, None, None]:
|
||||
"""Extract audio segments and apply padding for Parakeet compatibility.
|
||||
|
||||
Uses pad_audio to ensure segments are at least 0.5s long, preventing
|
||||
Parakeet crashes. This padding may cause slight timing overlaps between
|
||||
segments, which are corrected by enforce_word_timing_constraints.
|
||||
"""
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
start_sample = int(start_time * SAMPLERATE)
|
||||
end_sample = int(end_time * SAMPLERATE)
|
||||
audio_segment = audio_array[start_sample:end_sample]
|
||||
|
||||
padded_segment = pad_audio(audio_segment, SAMPLERATE)
|
||||
|
||||
yield AudioSegment(start_time, end_time, padded_segment)
|
||||
|
||||
def transcribe_batch(model, audio_segments: list) -> list:
|
||||
with NoStdStreams():
|
||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||
return outputs
|
||||
|
||||
def enforce_word_timing_constraints(
|
||||
words: list[WordTiming],
|
||||
) -> list[WordTiming]:
|
||||
"""Enforce that word end times don't exceed the start time of the next word.
|
||||
|
||||
Due to silence padding added in batch_segment_to_audio_segment for better
|
||||
transcription accuracy, word timings from different segments may overlap.
|
||||
This function ensures there are no overlaps by adjusting end times.
|
||||
"""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
|
||||
enforced_words = []
|
||||
for i, word in enumerate(words):
|
||||
enforced_word = word.copy()
|
||||
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if enforced_word["end"] > next_start:
|
||||
enforced_word["end"] = next_start
|
||||
|
||||
enforced_words.append(enforced_word)
|
||||
|
||||
return enforced_words
|
||||
|
||||
def emit_results(
|
||||
results: list,
|
||||
segments_info: list[AudioSegment],
|
||||
) -> Generator[TranscriptResult, None, None]:
|
||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||
for i, (output, segment) in enumerate(zip(results, segments_info)):
|
||||
start_time, end_time = segment.start, segment.end
|
||||
text = output.text.strip()
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(
|
||||
word_info["start"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
end=round(word_info["end"] + start_time + timestamp_offset, 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
yield TranscriptResult(text, words)
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array = load_and_convert_audio(file_path)
|
||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||
|
||||
all_text_parts: list[str] = []
|
||||
all_words: list[WordTiming] = []
|
||||
|
||||
raw_segments = vad_segment_generator(audio_array)
|
||||
speech_segments = batch_speech_segments(
|
||||
raw_segments,
|
||||
VAD_CONFIG["batch_max_duration"],
|
||||
)
|
||||
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
||||
|
||||
for batch in audio_segments:
|
||||
audio_segment = batch.audio
|
||||
results = transcribe_batch(self.model, [audio_segment])
|
||||
|
||||
for result in emit_results(
|
||||
results,
|
||||
[batch],
|
||||
):
|
||||
if not result.text:
|
||||
continue
|
||||
all_text_parts.append(result.text)
|
||||
all_words.extend(result.words)
|
||||
|
||||
all_words = enforce_word_timing_constraints(all_words)
|
||||
|
||||
combined_text = " ".join(all_text_parts)
|
||||
return {"text": combined_text, "words": all_words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
transcriber_live = TranscriberParakeetLive()
|
||||
transcriber_file = TranscriberParakeetFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
# Handle both single file and multiple files
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
# Upload files to volume
|
||||
uploaded_filenames = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||
|
||||
# Generate unique filename
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
# Use A10G live transcriber for per-file transcription
|
||||
if batch and len(upload_files) > 1:
|
||||
# Use batch transcription
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
# Per-file transcription
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {filename}: {e}")
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
2
gpu/self_hosted/.env.example
Normal file
2
gpu/self_hosted/.env.example
Normal file
@@ -0,0 +1,2 @@
|
||||
REFLECTOR_GPU_APIKEY=
|
||||
HF_TOKEN=
|
||||
38
gpu/self_hosted/.gitignore
vendored
Normal file
38
gpu/self_hosted/.gitignore
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
cache/
|
||||
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
*.env
|
||||
*.secret
|
||||
HF_TOKEN
|
||||
REFLECTOR_GPU_APIKEY
|
||||
|
||||
# Virtual env / uv
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
uv/
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
46
gpu/self_hosted/Dockerfile
Normal file
46
gpu/self_hosted/Dockerfile
Normal file
@@ -0,0 +1,46 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
WORKDIR /tmp
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
ffmpeg \
|
||||
curl \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
wget \
|
||||
&& apt-get clean
|
||||
# Add NVIDIA CUDA repo for Debian 12 (bookworm) and install cuDNN 9 for CUDA 12
|
||||
ADD https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb /cuda-keyring.deb
|
||||
RUN dpkg -i /cuda-keyring.deb \
|
||||
&& rm /cuda-keyring.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-12-6 \
|
||||
libcublas-12-6 \
|
||||
libcudnn9-cuda-12 \
|
||||
libcudnn9-dev-cuda-12 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
|
||||
|
||||
COPY ./app /app/app
|
||||
COPY ./main.py /app/
|
||||
COPY ./runserver.sh /app/
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["sh", "/app/runserver.sh"]
|
||||
|
||||
|
||||
73
gpu/self_hosted/README.md
Normal file
73
gpu/self_hosted/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Self-hosted Model API
|
||||
|
||||
Run transcription, translation, and diarization services compatible with Reflector's GPU Model API. Works on CPU or GPU.
|
||||
|
||||
Environment variables
|
||||
|
||||
- REFLECTOR_GPU_APIKEY: Optional Bearer token. If unset, auth is disabled.
|
||||
- HF_TOKEN: Optional. Required for diarization to download pyannote pipelines
|
||||
|
||||
Requirements
|
||||
|
||||
- FFmpeg must be installed and on PATH (used for URL-based and segmented transcription)
|
||||
- Python 3.12+
|
||||
- NVIDIA GPU optional. If available, it will be used automatically
|
||||
|
||||
Local run
|
||||
Set env vars in self_hosted/.env file
|
||||
uv sync
|
||||
|
||||
uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Authentication
|
||||
|
||||
- If REFLECTOR_GPU_APIKEY is set, include header: Authorization: Bearer <key>
|
||||
|
||||
Endpoints
|
||||
|
||||
- POST /v1/audio/transcriptions
|
||||
|
||||
- multipart/form-data
|
||||
- fields: file (single file) OR files[] (multiple files), language, batch (true/false)
|
||||
- response: single { text, words, filename } or { results: [ ... ] }
|
||||
|
||||
- POST /v1/audio/transcriptions-from-url
|
||||
|
||||
- application/json
|
||||
- body: { audio_file_url, language, timestamp_offset }
|
||||
- response: { text, words }
|
||||
|
||||
- POST /translate
|
||||
|
||||
- text: query parameter
|
||||
- body (application/json): { source_language, target_language }
|
||||
- response: { text: { <src>: original, <tgt>: translated } }
|
||||
|
||||
- POST /diarize
|
||||
- query parameters: audio_file_url, timestamp (optional)
|
||||
- requires HF_TOKEN to be set (for pyannote)
|
||||
- response: { diarization: [ { start, end, speaker } ] }
|
||||
|
||||
OpenAPI docs
|
||||
|
||||
- Visit /docs when the server is running
|
||||
|
||||
Docker
|
||||
|
||||
- Not yet provided in this directory. A Dockerfile will be added later. For now, use Local run above
|
||||
|
||||
Conformance tests
|
||||
|
||||
# From this directory
|
||||
|
||||
TRANSCRIPT_URL=http://localhost:8000 \
|
||||
TRANSCRIPT_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_transcript.py
|
||||
|
||||
TRANSLATION_URL=http://localhost:8000 \
|
||||
TRANSLATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_translation.py
|
||||
|
||||
DIARIZATION_URL=http://localhost:8000 \
|
||||
DIARIZATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_diarization.py
|
||||
19
gpu/self_hosted/app/auth.py
Normal file
19
gpu/self_hosted/app/auth.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
required_key = os.environ.get("REFLECTOR_GPU_APIKEY")
|
||||
if not required_key:
|
||||
return
|
||||
if apikey == required_key:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
12
gpu/self_hosted/app/config.py
Normal file
12
gpu/self_hosted/app/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLE_RATE = 16000
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
# App-level paths
|
||||
UPLOADS_PATH = Path("/tmp/whisper-uploads")
|
||||
30
gpu/self_hosted/app/factory.py
Normal file
30
gpu/self_hosted/app/factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
from .services.transcriber import WhisperService
|
||||
from .services.diarizer import PyannoteDiarizationService
|
||||
from .utils import ensure_dirs
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
ensure_dirs()
|
||||
whisper_service = WhisperService()
|
||||
whisper_service.load()
|
||||
app.state.whisper = whisper_service
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
diarization_service.load()
|
||||
app.state.diarizer = diarization_service
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
return app
|
||||
30
gpu/self_hosted/app/routers/diarization.py
Normal file
30
gpu/self_hosted/app/routers/diarization.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.diarizer import PyannoteDiarizationService
|
||||
from ..utils import download_audio_file
|
||||
|
||||
router = APIRouter(tags=["diarization"])
|
||||
|
||||
|
||||
class DiarizationSegment(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
diarization: List[DiarizationSegment]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth)], response_model=DiarizationResponse
|
||||
)
|
||||
def diarize(request: Request, audio_file_url: str, timestamp: float = 0.0):
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
diarizer: PyannoteDiarizationService = request.app.state.diarizer
|
||||
return diarizer.diarize_file(file_path, timestamp=timestamp)
|
||||
109
gpu/self_hosted/app/routers/transcription.py
Normal file
109
gpu/self_hosted/app/routers/transcription.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
from ..auth import apikey_auth
|
||||
from ..config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
from ..services.transcriber import MODEL_NAME
|
||||
from ..utils import cleanup_uploaded_files, download_audio_file
|
||||
|
||||
router = APIRouter(prefix="/v1/audio", tags=["transcription"])
|
||||
|
||||
|
||||
class WordTiming(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
text: str
|
||||
words: list[WordTiming]
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class TranscriptBatchResponse(BaseModel):
|
||||
results: list[TranscriptResult]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=Union[TranscriptResult, TranscriptBatchResponse],
|
||||
)
|
||||
def transcribe(
|
||||
request: Request,
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_paths: list[Path] = []
|
||||
with cleanup_uploaded_files(uploaded_paths):
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1].lower()
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = UPLOADS_PATH / unique_filename
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_paths.append(file_path)
|
||||
|
||||
if batch and len(upload_files) > 1:
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions-from-url",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranscriptResult,
|
||||
)
|
||||
def transcribe_from_url(
|
||||
request: Request,
|
||||
audio_file_url: str = Body(..., description="URL of the audio file to transcribe"),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
result = service.transcribe_vad_url_segment(
|
||||
file_path=file_path, timestamp_offset=timestamp_offset, language=language
|
||||
)
|
||||
return result
|
||||
28
gpu/self_hosted/app/routers/translation.py
Normal file
28
gpu/self_hosted/app/routers/translation.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.translator import TextTranslatorService
|
||||
|
||||
router = APIRouter(tags=["translation"])
|
||||
|
||||
translator = TextTranslatorService()
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
text: Dict[str, str]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/translate",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranslationResponse,
|
||||
)
|
||||
def translate(
|
||||
text: str,
|
||||
source_language: str = Body("en"),
|
||||
target_language: str = Body("fr"),
|
||||
):
|
||||
return translator.translate(text, source_language, target_language)
|
||||
42
gpu/self_hosted/app/services/diarizer.py
Normal file
42
gpu/self_hosted/app/services/diarizer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=os.environ.get("HF_TOKEN"),
|
||||
)
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": words}
|
||||
208
gpu/self_hosted/app/services/transcriber.py
Normal file
208
gpu/self_hosted/app/services/transcriber.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import faster_whisper
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import HTTPException
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from ..config import SAMPLE_RATE, VAD_CONFIG
|
||||
|
||||
# Whisper configuration (service-local defaults)
|
||||
MODEL_NAME = "large-v2"
|
||||
# None delegates compute type to runtime: float16 on CUDA, int8 on CPU
|
||||
MODEL_COMPUTE_TYPE = None
|
||||
MODEL_NUM_WORKERS = 1
|
||||
CACHE_PATH = os.path.join(os.path.expanduser("~"), ".cache", "reflector-whisper")
|
||||
from ..utils import NoStdStreams
|
||||
|
||||
|
||||
class WhisperService:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.device = "cpu"
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = MODEL_COMPUTE_TYPE or (
|
||||
"float16" if self.device == "cuda" else "int8"
|
||||
)
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=compute_type,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
)
|
||||
|
||||
def pad_audio(self, audio_array, sample_rate: int = SAMPLE_RATE):
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
def enforce_word_timing_constraints(self, words: list[dict]) -> list[dict]:
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
def transcribe_file(self, file_path: str, language: str = "en") -> dict:
|
||||
input_for_model: str | "object" = file_path
|
||||
try:
|
||||
audio_array, _sample_rate = librosa.load(
|
||||
file_path, sr=SAMPLE_RATE, mono=True
|
||||
)
|
||||
if len(audio_array) / float(SAMPLE_RATE) < VAD_CONFIG["silence_padding"]:
|
||||
input_for_model = self.pad_audio(audio_array, SAMPLE_RATE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
input_for_model,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
words = self.enforce_word_timing_constraints(words)
|
||||
return {"text": text, "words": words}
|
||||
|
||||
def transcribe_vad_url_segment(
|
||||
self, file_path: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
) -> dict:
|
||||
def load_audio_via_ffmpeg(input_path: str, sample_rate: int) -> np.ndarray:
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"ffmpeg failed: {e}")
|
||||
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
return audio
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
audio_array = load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for seg_start, seg_end in vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = self.pad_audio(segment, SAMPLE_RATE)
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
all_words = self.enforce_word_timing_constraints(all_words)
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
44
gpu/self_hosted/app/services/translator.py
Normal file
44
gpu/self_hosted/app/services/translator.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
|
||||
class TextTranslatorService:
|
||||
"""Simple text-to-text translator using HuggingFace MarianMT models.
|
||||
|
||||
This mirrors the modal translator API shape but uses text translation only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
# Pick a default MarianMT model pair if available; fall back to Helsinki-NLP en->fr
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
# Minimal mapping; extend as needed
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
107
gpu/self_hosted/app/utils.py
Normal file
107
gpu/self_hosted/app/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Mapping
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
UPLOADS_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> str:
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_uploads(audio_file_url: str) -> tuple[Path, str]:
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path: Path = UPLOADS_PATH / unique_filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return file_path, audio_suffix
|
||||
|
||||
|
||||
@contextmanager
|
||||
def download_audio_file(audio_file_url: str):
|
||||
"""Download an audio file to UPLOADS_PATH and remove it after use.
|
||||
|
||||
Yields (file_path: Path, audio_suffix: str).
|
||||
"""
|
||||
file_path, audio_suffix = download_audio_to_uploads(audio_file_url)
|
||||
try:
|
||||
yield file_path, audio_suffix
|
||||
finally:
|
||||
try:
|
||||
file_path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", file_path, e)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cleanup_uploaded_files(file_paths: list[Path]):
|
||||
"""Ensure provided file paths are removed after use.
|
||||
|
||||
The provided list can be populated inside the context; all present entries
|
||||
at exit will be deleted.
|
||||
"""
|
||||
try:
|
||||
yield file_paths
|
||||
finally:
|
||||
for path in list(file_paths):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", path, e)
|
||||
10
gpu/self_hosted/compose.yml
Normal file
10
gpu/self_hosted/compose.yml
Normal file
@@ -0,0 +1,10 @@
|
||||
services:
|
||||
reflector_gpu:
|
||||
build:
|
||||
context: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./cache:/root/.cache
|
||||
3
gpu/self_hosted/main.py
Normal file
3
gpu/self_hosted/main.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.factory import create_app
|
||||
|
||||
app = create_app()
|
||||
19
gpu/self_hosted/pyproject.toml
Normal file
19
gpu/self_hosted/pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[project]
|
||||
name = "reflector-gpu"
|
||||
version = "0.1.0"
|
||||
description = "Self-hosted GPU service for speech transcription, diarization, and translation via FastAPI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"torch>=2.3.0",
|
||||
"faster-whisper>=1.1.0",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
"transformers>=4.35.0",
|
||||
"sentencepiece",
|
||||
"pyannote.audio==3.1.0",
|
||||
"torchaudio>=2.3.0",
|
||||
]
|
||||
17
gpu/self_hosted/runserver.sh
Normal file
17
gpu/self_hosted/runserver.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
export PATH="/root/.local/bin:$PATH"
|
||||
cd /app
|
||||
|
||||
# Install Python dependencies at runtime (first run or when FORCE_SYNC=1)
|
||||
if [ ! -d "/app/.venv" ] || [ "$FORCE_SYNC" = "1" ]; then
|
||||
echo "[startup] Installing Python dependencies with uv..."
|
||||
uv sync --compile-bytecode --locked
|
||||
else
|
||||
echo "[startup] Using existing virtual environment at /app/.venv"
|
||||
fi
|
||||
|
||||
exec uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
|
||||
3013
gpu/self_hosted/uv.lock
generated
Normal file
3013
gpu/self_hosted/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,7 +176,8 @@ artefacts/
|
||||
audio_*.wav
|
||||
|
||||
# ignore local database
|
||||
reflector.sqlite3
|
||||
*.sqlite3
|
||||
*.db
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
@@ -13,8 +14,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
||||
# install application dependencies
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN touch README.md && env uv sync --compile-bytecode --locked
|
||||
COPY pyproject.toml uv.lock README.md /app/
|
||||
RUN uv sync --compile-bytecode --locked
|
||||
|
||||
# pre-download nltk packages
|
||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
@@ -26,4 +27,15 @@ COPY migrations /app/migrations
|
||||
COPY reflector /app/reflector
|
||||
WORKDIR /app
|
||||
|
||||
# Create symlink for libgomp if it doesn't exist (for ARM64 compatibility)
|
||||
RUN if [ "$(uname -m)" = "aarch64" ] && [ ! -f /usr/lib/libgomp.so.1 ]; then \
|
||||
LIBGOMP_PATH=$(find /app/.venv/lib -path "*/torch.libs/libgomp*.so.*" 2>/dev/null | head -n1); \
|
||||
if [ -n "$LIBGOMP_PATH" ]; then \
|
||||
ln -sf "$LIBGOMP_PATH" /usr/lib/libgomp.so.1; \
|
||||
fi \
|
||||
fi
|
||||
|
||||
# Pre-check just to make sure the image will not fail
|
||||
RUN uv run python -c "import silero_vad.model"
|
||||
|
||||
CMD ["./runserver.sh"]
|
||||
|
||||
@@ -40,3 +40,5 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||
```
|
||||
|
||||
.
|
||||
|
||||
118
server/asyncio_loop_analysis.md
Normal file
118
server/asyncio_loop_analysis.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
|
||||
|
||||
## Problem Summary
|
||||
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
|
||||
- `RuntimeError: Task got Future attached to a different loop`
|
||||
- `RuntimeError: Event loop is closed`
|
||||
|
||||
## Root Cause Analysis
|
||||
|
||||
### 1. Multiple Event Loop Creation Points
|
||||
|
||||
The test environment creates event loops at different scopes:
|
||||
|
||||
1. **Session-scoped loop** (conftest.py:27-34):
|
||||
- Created once per test session
|
||||
- Used by session-scoped fixtures
|
||||
- Closed after all tests complete
|
||||
|
||||
2. **Function-scoped loop** (pytest-asyncio default):
|
||||
- Created for each async test function
|
||||
- This is the loop that runs the actual test
|
||||
- Closed immediately after test completes
|
||||
|
||||
3. **AsyncPG internal loop**:
|
||||
- AsyncPG connections store a reference to the loop they were created with
|
||||
- Used for connection lifecycle management
|
||||
|
||||
### 2. Event Loop Lifecycle Mismatch
|
||||
|
||||
The issue occurs because:
|
||||
|
||||
1. **Session fixture creates database connection** on session-scoped loop
|
||||
2. **Test runs** on function-scoped loop (different from session loop)
|
||||
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
|
||||
4. **AsyncPG connection** still references the function-scoped loop which is now closed
|
||||
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
|
||||
|
||||
### 3. Configuration Issues
|
||||
|
||||
Current pytest configuration:
|
||||
- `asyncio_mode = "auto"` in pyproject.toml
|
||||
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
|
||||
- `asyncio_default_test_loop_scope=function` (shown in test output)
|
||||
|
||||
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
|
||||
|
||||
## Solutions
|
||||
|
||||
### Option 1: Align Loop Scopes (Recommended)
|
||||
Change pytest-asyncio configuration to use consistent loop scopes:
|
||||
|
||||
```python
|
||||
# pyproject.toml
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function" # Change from session to function
|
||||
```
|
||||
|
||||
### Option 2: Use Function-Scoped Database Fixture
|
||||
Change the `session` fixture scope from session to function:
|
||||
|
||||
```python
|
||||
@pytest_asyncio.fixture # Remove scope="session"
|
||||
async def session(setup_database):
|
||||
# ... existing code ...
|
||||
```
|
||||
|
||||
### Option 3: Explicit Loop Management
|
||||
Ensure all async operations use the same loop:
|
||||
|
||||
```python
|
||||
@pytest_asyncio.fixture
|
||||
async def session(setup_database, event_loop):
|
||||
# Force using the current event loop
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=NullPool,
|
||||
connect_args={"loop": event_loop} # Pass explicit loop
|
||||
)
|
||||
# ... rest of fixture ...
|
||||
```
|
||||
|
||||
### Option 4: Upgrade pytest-asyncio
|
||||
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
|
||||
|
||||
## Immediate Workaround
|
||||
|
||||
For the test to run cleanly without the teardown error, you can:
|
||||
|
||||
1. Add explicit cleanup in the test:
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_attendee_parsing_bug(session):
|
||||
# ... existing test code ...
|
||||
|
||||
# Explicit cleanup before fixture teardown
|
||||
await session.commit() # or await session.close()
|
||||
```
|
||||
|
||||
2. Or suppress the teardown error (not recommended for production):
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def session(setup_database):
|
||||
# ... existing setup ...
|
||||
try:
|
||||
yield session
|
||||
await session.rollback()
|
||||
except RuntimeError as e:
|
||||
if "Event loop is closed" not in str(e):
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
```
|
||||
|
||||
## Recommendation
|
||||
|
||||
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.
|
||||
95
server/docs/data_retention.md
Normal file
95
server/docs/data_retention.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# Data Retention and Cleanup
|
||||
|
||||
## Overview
|
||||
|
||||
For public instances of Reflector, a data retention policy is automatically enforced to delete anonymous user data after a configurable period (default: 7 days). This ensures compliance with privacy expectations and prevents unbounded storage growth.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `PUBLIC_MODE` (bool): Must be set to `true` to enable automatic cleanup
|
||||
- `PUBLIC_DATA_RETENTION_DAYS` (int): Number of days to retain anonymous data (default: 7)
|
||||
|
||||
### What Gets Deleted
|
||||
|
||||
When data reaches the retention period, the following items are automatically removed:
|
||||
|
||||
1. **Transcripts** from anonymous users (where `user_id` is NULL):
|
||||
- Database records
|
||||
- Local files (audio.wav, audio.mp3, audio.json waveform)
|
||||
- Storage files (cloud storage if configured)
|
||||
|
||||
## Automatic Cleanup
|
||||
|
||||
### Celery Beat Schedule
|
||||
|
||||
When `PUBLIC_MODE=true`, a Celery beat task runs daily at 3 AM to clean up old data:
|
||||
|
||||
```python
|
||||
# Automatically scheduled when PUBLIC_MODE=true
|
||||
"cleanup_old_public_data": {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data",
|
||||
"schedule": crontab(hour=3, minute=0), # Daily at 3 AM
|
||||
}
|
||||
```
|
||||
|
||||
### Running the Worker
|
||||
|
||||
Ensure both Celery worker and beat scheduler are running:
|
||||
|
||||
```bash
|
||||
# Start Celery worker
|
||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
# Start Celery beat scheduler (in another terminal)
|
||||
uv run celery -A reflector.worker.app beat
|
||||
```
|
||||
|
||||
## Manual Cleanup
|
||||
|
||||
For testing or manual intervention, use the cleanup tool:
|
||||
|
||||
```bash
|
||||
# Delete data older than 7 days (default)
|
||||
uv run python -m reflector.tools.cleanup_old_data
|
||||
|
||||
# Delete data older than 30 days
|
||||
uv run python -m reflector.tools.cleanup_old_data --days 30
|
||||
```
|
||||
|
||||
Note: The manual tool uses the same implementation as the Celery worker task to ensure consistency.
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **User Data Deletion**: Only anonymous data (where `user_id` is NULL) is deleted. Authenticated user data is preserved.
|
||||
|
||||
2. **Storage Cleanup**: The system properly cleans up both local files and cloud storage when configured.
|
||||
|
||||
3. **Error Handling**: If individual deletions fail, the cleanup continues and logs errors. Failed deletions are reported in the task output.
|
||||
|
||||
4. **Public Instance Only**: The automatic cleanup task only runs when `PUBLIC_MODE=true` to prevent accidental data loss in private deployments.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the cleanup tests:
|
||||
|
||||
```bash
|
||||
uv run pytest tests/test_cleanup.py -v
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check Celery logs for cleanup task execution:
|
||||
|
||||
```bash
|
||||
# Look for cleanup task logs
|
||||
grep "cleanup_old_public_data" celery.log
|
||||
grep "Starting cleanup of old public data" celery.log
|
||||
```
|
||||
|
||||
Task statistics are logged after each run:
|
||||
- Number of transcripts deleted
|
||||
- Number of meetings deleted
|
||||
- Number of orphaned recordings deleted
|
||||
- Any errors encountered
|
||||
194
server/docs/gpu/api-transcription.md
Normal file
194
server/docs/gpu/api-transcription.md
Normal file
@@ -0,0 +1,194 @@
|
||||
## Reflector GPU Transcription API (Specification)
|
||||
|
||||
This document defines the Reflector GPU transcription API that all implementations must adhere to. Current implementations include NVIDIA Parakeet (NeMo) and Whisper (faster-whisper), both deployed on Modal.com. The API surface and response shapes are OpenAI/Whisper-compatible, so clients can switch implementations by changing only the base URL.
|
||||
|
||||
### Base URL and Authentication
|
||||
|
||||
- Example base URLs (Modal web endpoints):
|
||||
|
||||
- Parakeet: `https://<account>--reflector-transcriber-parakeet-web.modal.run`
|
||||
- Whisper: `https://<account>--reflector-transcriber-web.modal.run`
|
||||
|
||||
- All endpoints are served under `/v1` and require a Bearer token:
|
||||
|
||||
```
|
||||
Authorization: Bearer <REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
Note: To switch implementations, deploy the desired variant and point `TRANSCRIPT_URL` to its base URL. The API is identical.
|
||||
|
||||
### Supported file types
|
||||
|
||||
`mp3, mp4, mpeg, mpga, m4a, wav, webm`
|
||||
|
||||
### Models and languages
|
||||
|
||||
- Parakeet (NVIDIA NeMo): default `nvidia/parakeet-tdt-0.6b-v2`
|
||||
- Language support: only `en`. Other languages return HTTP 400.
|
||||
- Whisper (faster-whisper): default `large-v2` (or deployment-specific)
|
||||
- Language support: multilingual (per Whisper model capabilities).
|
||||
|
||||
Note: The `model` parameter is accepted by all implementations for interface parity. Some backends may treat it as informational.
|
||||
|
||||
### Endpoints
|
||||
|
||||
#### POST /v1/audio/transcriptions
|
||||
|
||||
Transcribe one or more uploaded audio files.
|
||||
|
||||
Request: multipart/form-data
|
||||
|
||||
- `file` (File) — optional. Single file to transcribe.
|
||||
- `files` (File[]) — optional. One or more files to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`.
|
||||
- Parakeet: only `en` is accepted; other values return HTTP 400
|
||||
- Whisper: model-dependent; typically multilingual
|
||||
- `batch` (boolean) — optional, defaults to `false`.
|
||||
|
||||
Notes:
|
||||
|
||||
- Provide either `file` or `files`, not both. If neither is provided, HTTP 400.
|
||||
- `batch` requires `files`; using `batch=true` without `files` returns HTTP 400.
|
||||
- Response shape for multiple files is the same regardless of `batch`.
|
||||
- Files sent to this endpoint are processed in a single pass (no VAD/chunking). This is intended for short clips (roughly ≤ 30s; depends on GPU memory/model). For longer audio, prefer `/v1/audio/transcriptions-from-url` which supports VAD-based chunking.
|
||||
|
||||
Responses
|
||||
|
||||
Single file response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 0.0, "end": 0.5 },
|
||||
{ "word": "world", "start": 0.5, "end": 1.0 }
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
Multiple files response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{"filename": "a1.mp3", "text": "...", "words": [...]},
|
||||
{"filename": "a2.mp3", "text": "...", "words": [...]}]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Word objects always include keys: `word`, `start`, `end`.
|
||||
- Some implementations may include a trailing space in `word` to match Whisper tokenization behavior; clients should trim if needed.
|
||||
|
||||
Example curl (single file):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "file=@/path/to/audio.mp3" \
|
||||
-F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
Example curl (multiple files, batch):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "files=@/path/a1.mp3" -F "files=@/path/a2.mp3" \
|
||||
-F "batch=true" -F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
#### POST /v1/audio/transcriptions-from-url
|
||||
|
||||
Transcribe a single remote audio file by URL.
|
||||
|
||||
Request: application/json
|
||||
|
||||
Body parameters:
|
||||
|
||||
- `audio_file_url` (string) — required. URL of the audio file to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`. Parakeet only accepts `en`.
|
||||
- `timestamp_offset` (number) — optional, defaults to `0.0`. Added to each word's `start`/`end` in the response.
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 10.0, "end": 10.5 },
|
||||
{ "word": "world", "start": 10.5, "end": 11.0 }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `timestamp_offset` is added to each word’s `start`/`end` in the response.
|
||||
- Implementations may perform VAD-based chunking and batching for long-form audio; word timings are adjusted accordingly.
|
||||
|
||||
Example curl:
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0
|
||||
}' \
|
||||
"$BASE_URL/v1/audio/transcriptions-from-url"
|
||||
```
|
||||
|
||||
### Error handling
|
||||
|
||||
- 400 Bad Request
|
||||
- Parakeet: `language` other than `en`
|
||||
- Missing required parameters (`file`/`files` for upload; `audio_file_url` for URL endpoint)
|
||||
- Unsupported file extension
|
||||
- 401 Unauthorized
|
||||
- Missing or invalid Bearer token
|
||||
- 404 Not Found
|
||||
- `audio_file_url` does not exist
|
||||
|
||||
### Implementation details
|
||||
|
||||
- GPUs: A10G for small-file/live, L40S for large-file URL transcription (subject to deployment)
|
||||
- VAD chunking and segment batching; word timings adjusted and overlapping ends constrained
|
||||
- Pads very short segments (< 0.5s) to avoid model crashes on some backends
|
||||
|
||||
### Server configuration (Reflector API)
|
||||
|
||||
Set the Reflector server to use the Modal backend and point `TRANSCRIPT_URL` to your chosen deployment:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://<account>--reflector-transcriber-parakeet-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=<REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
### Conformance tests
|
||||
|
||||
Use the pytest-based conformance tests to validate any new implementation (including self-hosted) against this spec:
|
||||
|
||||
```
|
||||
TRANSCRIPT_URL=https://<your-deployment-base> \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
|
||||
```
|
||||
212
server/docs/webhook.md
Normal file
212
server/docs/webhook.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# Reflector Webhook Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
Reflector supports webhook notifications to notify external systems when transcript processing is completed. Webhooks can be configured per room and are triggered automatically after a transcript is successfully processed.
|
||||
|
||||
## Configuration
|
||||
|
||||
Webhooks are configured at the room level with two fields:
|
||||
- `webhook_url`: The HTTPS endpoint to receive webhook notifications
|
||||
- `webhook_secret`: Optional secret key for HMAC signature verification (auto-generated if not provided)
|
||||
|
||||
## Events
|
||||
|
||||
### `transcript.completed`
|
||||
|
||||
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
|
||||
|
||||
### `test`
|
||||
|
||||
A test event that can be triggered manually to verify webhook configuration.
|
||||
|
||||
## Webhook Request Format
|
||||
|
||||
### Headers
|
||||
|
||||
All webhook requests include the following headers:
|
||||
|
||||
| Header | Description | Example |
|
||||
|--------|-------------|---------|
|
||||
| `Content-Type` | Always `application/json` | `application/json` |
|
||||
| `User-Agent` | Identifies Reflector as the source | `Reflector-Webhook/1.0` |
|
||||
| `X-Webhook-Event` | The event type | `transcript.completed` or `test` |
|
||||
| `X-Webhook-Retry` | Current retry attempt number | `0`, `1`, `2`... |
|
||||
| `X-Webhook-Signature` | HMAC signature (if secret configured) | `t=1735306800,v1=abc123...` |
|
||||
|
||||
### Signature Verification
|
||||
|
||||
If a webhook secret is configured, Reflector includes an HMAC-SHA256 signature in the `X-Webhook-Signature` header to verify the webhook authenticity.
|
||||
|
||||
The signature format is: `t={timestamp},v1={signature}`
|
||||
|
||||
To verify the signature:
|
||||
1. Extract the timestamp and signature from the header
|
||||
2. Create the signed payload: `{timestamp}.{request_body}`
|
||||
3. Compute HMAC-SHA256 of the signed payload using your webhook secret
|
||||
4. Compare the computed signature with the received signature
|
||||
|
||||
Example verification (Python):
|
||||
```python
|
||||
import hmac
|
||||
import hashlib
|
||||
|
||||
def verify_webhook_signature(payload: bytes, signature_header: str, secret: str) -> bool:
|
||||
# Parse header: "t=1735306800,v1=abc123..."
|
||||
parts = dict(part.split("=") for part in signature_header.split(","))
|
||||
timestamp = parts["t"]
|
||||
received_signature = parts["v1"]
|
||||
|
||||
# Create signed payload
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
|
||||
# Compute expected signature
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
return hmac.compare_digest(expected_signature, received_signature)
|
||||
```
|
||||
|
||||
## Event Payloads
|
||||
|
||||
### `transcript.completed` Event
|
||||
|
||||
This event includes a convenient URL for accessing the transcript:
|
||||
- `frontend_url`: Direct link to view the transcript in the web interface
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "transcript.completed",
|
||||
"event_id": "transcript.completed-abc-123-def-456",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"transcript": {
|
||||
"id": "abc-123-def-456",
|
||||
"room_id": "room-789",
|
||||
"created_at": "2025-08-27T12:00:00Z",
|
||||
"duration": 1800.5,
|
||||
"title": "Q3 Product Planning Meeting",
|
||||
"short_summary": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
|
||||
"long_summary": "The product team met to finalize the Q3 roadmap. Key decisions included...",
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone to today's meeting...",
|
||||
"topics": [
|
||||
{
|
||||
"title": "Introduction and Agenda",
|
||||
"summary": "Meeting kickoff with agenda review",
|
||||
"timestamp": 0.0,
|
||||
"duration": 120.0,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone..."
|
||||
},
|
||||
{
|
||||
"title": "Mobile App Features Discussion",
|
||||
"summary": "Team reviewed proposed mobile app features for Q3",
|
||||
"timestamp": 120.0,
|
||||
"duration": 600.0,
|
||||
"webvtt": "WEBVTT\n\n00:02:00.000 --> 00:02:10.000\n<v Speaker 2>Let's talk about the mobile app..."
|
||||
}
|
||||
],
|
||||
"participants": [
|
||||
{
|
||||
"id": "participant-1",
|
||||
"name": "John Doe",
|
||||
"speaker": "Speaker 1"
|
||||
},
|
||||
{
|
||||
"id": "participant-2",
|
||||
"name": "Jane Smith",
|
||||
"speaker": "Speaker 2"
|
||||
}
|
||||
],
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"status": "completed",
|
||||
"frontend_url": "https://app.reflector.com/transcripts/abc-123-def-456"
|
||||
},
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `test` Event
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "test",
|
||||
"event_id": "test.2025-08-27T12:34:56.789012Z",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"message": "This is a test webhook from Reflector",
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Retry Policy
|
||||
|
||||
Webhooks are delivered with automatic retry logic to handle transient failures. When a webhook delivery fails due to server errors or network issues, Reflector will automatically retry the delivery multiple times over an extended period.
|
||||
|
||||
### Retry Mechanism
|
||||
|
||||
Reflector implements an exponential backoff strategy for webhook retries:
|
||||
|
||||
- **Initial retry delay**: 60 seconds after the first failure
|
||||
- **Exponential backoff**: Each subsequent retry waits approximately twice as long as the previous one
|
||||
- **Maximum retry interval**: 1 hour (backoff is capped at this duration)
|
||||
- **Maximum retry attempts**: 30 attempts total
|
||||
- **Total retry duration**: Retries continue for approximately 24 hours
|
||||
|
||||
### How Retries Work
|
||||
|
||||
When a webhook fails, Reflector will:
|
||||
1. Wait 60 seconds, then retry (attempt #1)
|
||||
2. If it fails again, wait ~2 minutes, then retry (attempt #2)
|
||||
3. Continue doubling the wait time up to a maximum of 1 hour between attempts
|
||||
4. Keep retrying at 1-hour intervals until successful or 30 attempts are exhausted
|
||||
|
||||
The `X-Webhook-Retry` header indicates the current retry attempt number (0 for the initial attempt, 1 for first retry, etc.), allowing your endpoint to track retry attempts.
|
||||
|
||||
### Retry Behavior by HTTP Status Code
|
||||
|
||||
| Status Code | Behavior |
|
||||
|-------------|----------|
|
||||
| 2xx (Success) | No retry, webhook marked as delivered |
|
||||
| 4xx (Client Error) | No retry, request is considered permanently failed |
|
||||
| 5xx (Server Error) | Automatic retry with exponential backoff |
|
||||
| Network/Timeout Error | Automatic retry with exponential backoff |
|
||||
|
||||
**Important Notes:**
|
||||
- Webhooks timeout after 30 seconds. If your endpoint takes longer to respond, it will be considered a timeout error and retried.
|
||||
- During the retry period (~24 hours), you may receive the same webhook multiple times if your endpoint experiences intermittent failures.
|
||||
- There is no mechanism to manually retry failed webhooks after the retry period expires.
|
||||
|
||||
## Testing Webhooks
|
||||
|
||||
You can test your webhook configuration before processing transcripts:
|
||||
|
||||
```http
|
||||
POST /v1/rooms/{room_id}/webhook/test
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"status_code": 200,
|
||||
"message": "Webhook test successful",
|
||||
"response_preview": "OK"
|
||||
}
|
||||
```
|
||||
|
||||
Or in case of failure:
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": "Webhook request timed out (10 seconds)"
|
||||
}
|
||||
```
|
||||
@@ -1,86 +0,0 @@
|
||||
# Reflector GPU implementation - Transcription and LLM
|
||||
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
Create a modal secret, and name it `reflector-gpu`.
|
||||
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
||||
|
||||
The deployment is done using [Modal.com](https://modal.com) service.
|
||||
|
||||
```
|
||||
$ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
```
|
||||
|
||||
Then in your reflector api configuration `.env`, you can set these keys:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
||||
|
||||
```
|
||||
Authorization: bearer <REFLECTOR_APIKEY>
|
||||
```
|
||||
|
||||
### LLM
|
||||
|
||||
`POST /llm`
|
||||
|
||||
**request**
|
||||
```
|
||||
{
|
||||
"prompt": "xxx"
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx completed"
|
||||
}
|
||||
```
|
||||
|
||||
### Transcription
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
- `file` - audio file
|
||||
- `language` - language code (e.g. `en`)
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx",
|
||||
"words": [
|
||||
{"text": "xxx", "start": 0.0, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -1,187 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - diarizer
|
||||
===================================
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import modal.gpu
|
||||
from modal import App, Image, Secret, asgi_app, enter, method
|
||||
from pydantic import BaseModel
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
app = App(name="reflector-diarizer")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
def download_pyannote_audio():
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
"onnx",
|
||||
"torchaudio",
|
||||
"onnxruntime-gpu",
|
||||
"torch==2.0.0",
|
||||
"transformers==4.34.0",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"numpy",
|
||||
"huggingface_hub",
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu=modal.gpu.A100(size="40GB"),
|
||||
timeout=60 * 30,
|
||||
scaledown_window=60,
|
||||
allow_concurrent_inputs=1,
|
||||
image=diarizer_image,
|
||||
)
|
||||
class Diarizer:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@method()
|
||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
||||
import tempfile
|
||||
|
||||
import torchaudio
|
||||
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
print("Diarizing audio")
|
||||
waveform, sample_rate = torchaudio.load(fp.name)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(
|
||||
yield_label=True
|
||||
):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
image=diarizer_image,
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def validate_audio_file(audio_file_url: str):
|
||||
# Check if the audio file exists
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="The audio file does not exist.",
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
||||
)
|
||||
def diarize(
|
||||
audio_file_url: str, timestamp: float = 0.0
|
||||
) -> HTTPException | DiarizationResponse:
|
||||
# Currently the uploaded files are in mp3 format
|
||||
audio_suffix = "mp3"
|
||||
|
||||
print("Downloading audio file")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
print("Audio file downloaded successfully")
|
||||
|
||||
func = diarizerstub.diarize.spawn(
|
||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -1,161 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import modal
|
||||
from pydantic import BaseModel
|
||||
|
||||
MODELS_DIR = "/models"
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
|
||||
MINUTES = 60 # seconds
|
||||
|
||||
volume = modal.Volume.from_name("models", create_if_missing=True)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
volume.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=MODELS_DIR)
|
||||
|
||||
volume.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.run_function(download_model, volumes={MODELS_DIR: volume})
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
allow_concurrent_inputs=6,
|
||||
image=image,
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
class Transcriber:
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=MODELS_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
audio_data: str,
|
||||
audio_suffix: str,
|
||||
language: str,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
fp.name,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments)
|
||||
words = [
|
||||
{"word": word.word, "start": word.start, "end": word.end}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=60,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
transcriber = Transcriber()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
supported_file_types = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile,
|
||||
model: str = "whisper-1",
|
||||
language: Annotated[str, Body(...)] = "en",
|
||||
) -> TranscriptResponse:
|
||||
audio_data = file.file.read()
|
||||
audio_suffix = file.filename.split(".")[-1]
|
||||
assert audio_suffix in supported_file_types
|
||||
|
||||
func = transcriber.transcribe_segment.spawn(
|
||||
audio_data=audio_data,
|
||||
audio_suffix=audio_suffix,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
583
server/migration.md
Normal file
583
server/migration.md
Normal file
@@ -0,0 +1,583 @@
|
||||
# Celery to TaskIQ Migration Guide
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document outlines the migration path from Celery to TaskIQ for the Reflector project. TaskIQ is a modern, async-first distributed task queue that provides similar functionality to Celery while being designed specifically for async Python applications.
|
||||
|
||||
## Current Celery Usage Analysis
|
||||
|
||||
### Key Patterns in Use
|
||||
1. **Task Decorators**: `@shared_task`, `@asynctask`, `@with_session` decorators
|
||||
2. **Task Invocation**: `.delay()`, `.si()` for signatures
|
||||
3. **Workflow Patterns**: `chain()`, `group()`, `chord()` for complex pipelines
|
||||
4. **Scheduled Tasks**: Celery Beat with crontab and periodic schedules
|
||||
5. **Session Management**: Custom `@with_session` and `@with_session_and_transcript` decorators
|
||||
6. **Retry Logic**: Auto-retry with exponential backoff
|
||||
7. **Redis Backend**: Using Redis for broker and result backend
|
||||
|
||||
### Critical Files to Migrate
|
||||
- `reflector/worker/app.py` - Celery app configuration and beat schedule
|
||||
- `reflector/worker/session_decorator.py` - Session management decorators
|
||||
- `reflector/pipelines/main_file_pipeline.py` - File processing pipeline
|
||||
- `reflector/pipelines/main_live_pipeline.py` - Live streaming pipeline (10 tasks)
|
||||
- `reflector/worker/process.py` - Background processing tasks
|
||||
- `reflector/worker/ics_sync.py` - Calendar sync tasks
|
||||
- `reflector/worker/cleanup.py` - Cleanup tasks
|
||||
- `reflector/worker/webhook.py` - Webhook notifications
|
||||
|
||||
## TaskIQ Architecture Mapping
|
||||
|
||||
### 1. Installation
|
||||
|
||||
```bash
|
||||
# Remove Celery dependencies
|
||||
uv remove celery flower
|
||||
|
||||
# Install TaskIQ with Redis support
|
||||
uv add taskiq taskiq-redis taskiq-pipelines
|
||||
```
|
||||
|
||||
### 2. Broker Configuration
|
||||
|
||||
#### Current (Celery)
|
||||
```python
|
||||
# reflector/worker/app.py
|
||||
from celery import Celery
|
||||
|
||||
app = Celery(
|
||||
"reflector",
|
||||
broker=settings.CELERY_BROKER_URL,
|
||||
backend=settings.CELERY_RESULT_BACKEND,
|
||||
include=[...],
|
||||
)
|
||||
```
|
||||
|
||||
#### New (TaskIQ)
|
||||
```python
|
||||
# reflector/worker/broker.py
|
||||
from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker
|
||||
from taskiq import PipelineMiddleware, SimpleRetryMiddleware
|
||||
|
||||
result_backend = RedisAsyncResultBackend(
|
||||
redis_url=settings.REDIS_URL,
|
||||
result_ex_time=86400, # 24 hours
|
||||
)
|
||||
|
||||
broker = RedisStreamBroker(
|
||||
url=settings.REDIS_URL,
|
||||
max_connection_pool_size=10,
|
||||
).with_result_backend(result_backend).with_middlewares(
|
||||
PipelineMiddleware(), # For chain/group/chord support
|
||||
SimpleRetryMiddleware(default_retry_count=3),
|
||||
)
|
||||
|
||||
# For testing environment
|
||||
if os.environ.get("ENVIRONMENT") == "pytest":
|
||||
from taskiq import InMemoryBroker
|
||||
broker = InMemoryBroker(await_inplace=True)
|
||||
```
|
||||
|
||||
### 3. Task Definition Migration
|
||||
|
||||
#### Current (Celery)
|
||||
```python
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def task_pipeline_file_process(session: AsyncSession, transcript_id: str):
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process()
|
||||
```
|
||||
|
||||
#### New (TaskIQ)
|
||||
```python
|
||||
from taskiq import TaskiqDepends
|
||||
from reflector.worker.broker import broker
|
||||
from reflector.worker.dependencies import get_db_session
|
||||
|
||||
@broker.task
|
||||
async def task_pipeline_file_process(transcript_id: str):
|
||||
# Use get_session for proper test mocking
|
||||
async for session in get_session():
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process()
|
||||
```
|
||||
|
||||
### 4. Session Management
|
||||
|
||||
#### Current Session Decorators (Keep Using These!)
|
||||
```python
|
||||
# reflector/worker/session_decorator.py
|
||||
def with_session(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with get_session_context() as session:
|
||||
return await func(session, *args, **kwargs)
|
||||
return wrapper
|
||||
```
|
||||
|
||||
#### Session Management Strategy
|
||||
|
||||
**⚠️ CRITICAL**: The key insight is to maintain consistent session management patterns:
|
||||
|
||||
1. **For Worker Tasks**: Continue using `@with_session` decorator pattern
|
||||
2. **For FastAPI endpoints**: Use `get_session` dependency injection
|
||||
3. **Never use `get_session_factory()` directly** in application code
|
||||
|
||||
```python
|
||||
# APPROACH 1: Simple migration keeping decorator pattern
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session
|
||||
async def task_pipeline_file_process(session, *, transcript_id: str):
|
||||
# Session is provided by decorator, just like Celery version
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process()
|
||||
|
||||
# APPROACH 2: For test compatibility without decorator
|
||||
from reflector.db import get_session
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_file_process(transcript_id: str):
|
||||
# Use get_session which is mocked in tests
|
||||
async for session in get_session():
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process()
|
||||
|
||||
# APPROACH 3: Future - TaskIQ dependency injection (after full migration)
|
||||
from taskiq import TaskiqDepends
|
||||
|
||||
async def get_session_context():
|
||||
"""Context manager version of get_session for consistency"""
|
||||
async for session in get_session():
|
||||
yield session
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_file_process(
|
||||
transcript_id: str,
|
||||
session: AsyncSession = TaskiqDepends(get_session_context)
|
||||
):
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process()
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `@with_session` decorator works with TaskIQ tasks (remove `@asynctask`, keep `@with_session`)
|
||||
- For testing: `get_session()` from `reflector.db` is properly mocked
|
||||
- Never call `get_session_factory()` directly - always use the abstractions
|
||||
|
||||
### 5. Task Invocation
|
||||
|
||||
#### Current (Celery)
|
||||
```python
|
||||
# Simple async execution
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
|
||||
# With signature for chaining
|
||||
task_cleanup_consent.si(transcript_id=transcript_id)
|
||||
```
|
||||
|
||||
#### New (TaskIQ)
|
||||
```python
|
||||
# Simple async execution
|
||||
await task_pipeline_file_process.kiq(transcript_id=transcript.id)
|
||||
|
||||
# With kicker for advanced configuration
|
||||
await task_cleanup_consent.kicker().with_labels(
|
||||
priority="high"
|
||||
).kiq(transcript_id=transcript_id)
|
||||
```
|
||||
|
||||
### 6. Workflow Patterns (Chain, Group, Chord)
|
||||
|
||||
#### Current (Celery)
|
||||
```python
|
||||
from celery import chain, group, chord
|
||||
|
||||
# Chain example
|
||||
post_chain = chain(
|
||||
task_cleanup_consent.si(transcript_id=transcript_id),
|
||||
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
|
||||
task_send_webhook_if_needed.si(transcript_id=transcript_id),
|
||||
)
|
||||
|
||||
# Chord example (parallel + callback)
|
||||
chain = chord(
|
||||
group(chain_mp3_and_diarize, chain_title_preview),
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
```
|
||||
|
||||
#### New (TaskIQ with Pipelines)
|
||||
```python
|
||||
from taskiq_pipelines import Pipeline
|
||||
from taskiq import gather
|
||||
|
||||
# Chain example using Pipeline
|
||||
post_pipeline = (
|
||||
Pipeline(broker, task_cleanup_consent)
|
||||
.call_next(task_pipeline_post_to_zulip, transcript_id=transcript_id)
|
||||
.call_next(task_send_webhook_if_needed, transcript_id=transcript_id)
|
||||
)
|
||||
await post_pipeline.kiq(transcript_id=transcript_id)
|
||||
|
||||
# Parallel execution with gather
|
||||
results = await gather([
|
||||
chain_mp3_and_diarize.kiq(transcript_id),
|
||||
chain_title_preview.kiq(transcript_id),
|
||||
])
|
||||
|
||||
# Then execute callback
|
||||
await chain_final_summaries.kiq(transcript_id, results)
|
||||
await task_pipeline_post_to_zulip.kiq(transcript_id)
|
||||
```
|
||||
|
||||
### 7. Scheduled Tasks (Celery Beat → TaskIQ Scheduler)
|
||||
|
||||
#### Current (Celery Beat)
|
||||
```python
|
||||
# reflector/worker/app.py
|
||||
app.conf.beat_schedule = {
|
||||
"process_messages": {
|
||||
"task": "reflector.worker.process.process_messages",
|
||||
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
|
||||
},
|
||||
"reprocess_failed_recordings": {
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
"schedule": crontab(hour=5, minute=0),
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
#### New (TaskIQ Scheduler)
|
||||
```python
|
||||
# reflector/worker/scheduler.py
|
||||
from taskiq import TaskiqScheduler
|
||||
from taskiq_redis import ListRedisScheduleSource
|
||||
|
||||
schedule_source = ListRedisScheduleSource(settings.REDIS_URL)
|
||||
|
||||
# Define scheduled tasks with decorators
|
||||
@broker.task(
|
||||
schedule=[
|
||||
{
|
||||
"cron": f"*/{int(settings.SQS_POLLING_TIMEOUT_SECONDS)} * * * * *"
|
||||
}
|
||||
]
|
||||
)
|
||||
async def process_messages():
|
||||
# Task implementation
|
||||
pass
|
||||
|
||||
@broker.task(
|
||||
schedule=[{"cron": "0 5 * * *"}] # Daily at 5 AM
|
||||
)
|
||||
async def reprocess_failed_recordings():
|
||||
# Task implementation
|
||||
pass
|
||||
|
||||
# Initialize scheduler
|
||||
scheduler = TaskiqScheduler(broker, sources=[schedule_source])
|
||||
|
||||
# Run scheduler (separate process)
|
||||
# taskiq scheduler reflector.worker.scheduler:scheduler
|
||||
```
|
||||
|
||||
### 8. Retry Configuration
|
||||
|
||||
#### Current (Celery)
|
||||
```python
|
||||
@shared_task(
|
||||
bind=True,
|
||||
max_retries=30,
|
||||
default_retry_delay=60,
|
||||
retry_backoff=True,
|
||||
retry_backoff_max=3600,
|
||||
)
|
||||
async def task_send_webhook_if_needed(self, ...):
|
||||
try:
|
||||
# Task logic
|
||||
except Exception as exc:
|
||||
raise self.retry(exc=exc)
|
||||
```
|
||||
|
||||
#### New (TaskIQ)
|
||||
```python
|
||||
from taskiq.middlewares import SimpleRetryMiddleware
|
||||
|
||||
# Global middleware configuration (1:1 with Celery defaults)
|
||||
broker = broker.with_middlewares(
|
||||
SimpleRetryMiddleware(default_retry_count=3),
|
||||
)
|
||||
|
||||
# For specific tasks with custom retry logic:
|
||||
@broker.task(retry_on_error=True, max_retries=30)
|
||||
async def task_send_webhook_if_needed(...):
|
||||
# Task logic - exceptions auto-retry
|
||||
pass
|
||||
```
|
||||
|
||||
## Testing Migration
|
||||
|
||||
### Current Pytest Setup (Celery)
|
||||
```python
|
||||
# tests/conftest.py
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
return {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": "cache+memory://",
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
async def test_task():
|
||||
pass
|
||||
```
|
||||
|
||||
### New Pytest Setup (TaskIQ)
|
||||
```python
|
||||
# tests/conftest.py
|
||||
import pytest
|
||||
from taskiq import InMemoryBroker
|
||||
from reflector.worker.broker import broker
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
async def setup_taskiq_broker():
|
||||
"""Replace broker with InMemoryBroker for testing"""
|
||||
original_broker = broker
|
||||
test_broker = InMemoryBroker(await_inplace=True)
|
||||
|
||||
# Copy task registrations
|
||||
for task_name, task in original_broker._tasks.items():
|
||||
test_broker.register_task(task.original_function, task_name=task_name)
|
||||
|
||||
yield test_broker
|
||||
await test_broker.shutdown()
|
||||
|
||||
@pytest.fixture
|
||||
async def taskiq_with_db_session(db_session):
|
||||
"""Setup TaskIQ with database session"""
|
||||
from reflector.worker.broker import broker
|
||||
broker.add_dependency_context({
|
||||
AsyncSession: db_session
|
||||
})
|
||||
yield
|
||||
broker.custom_dependency_context = {}
|
||||
|
||||
# Test example
|
||||
@pytest.mark.anyio
|
||||
async def test_task(taskiq_with_db_session):
|
||||
result = await task_pipeline_file_process("transcript-id")
|
||||
assert result is not None
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Phase 1: Setup (Week 1)
|
||||
1. **Install TaskIQ packages**
|
||||
```bash
|
||||
uv add taskiq taskiq-redis taskiq-pipelines
|
||||
```
|
||||
|
||||
2. **Create new broker configuration**
|
||||
- Create `reflector/worker/broker.py` with TaskIQ broker setup
|
||||
- Create `reflector/worker/dependencies.py` for dependency injection
|
||||
|
||||
3. **Update settings**
|
||||
- Keep existing Redis configuration
|
||||
- Add TaskIQ-specific settings if needed
|
||||
|
||||
### Phase 2: Parallel Running (Week 2-3)
|
||||
1. **Migrate simple tasks first**
|
||||
- Start with `cleanup.py` (1 task)
|
||||
- Move to `webhook.py` (1 task)
|
||||
- Test thoroughly in isolation
|
||||
|
||||
2. **Setup dual-mode operation**
|
||||
- Keep Celery tasks running
|
||||
- Add TaskIQ versions alongside
|
||||
- Use feature flags to switch between them
|
||||
|
||||
### Phase 3: Complex Tasks (Week 3-4)
|
||||
1. **Migrate pipeline tasks**
|
||||
- Convert `main_file_pipeline.py`
|
||||
- Convert `main_live_pipeline.py` (most complex with 10 tasks)
|
||||
- Ensure chain/group/chord patterns work
|
||||
|
||||
2. **Migrate scheduled tasks**
|
||||
- Setup TaskIQ scheduler
|
||||
- Convert beat schedule to TaskIQ schedules
|
||||
- Test cron patterns
|
||||
|
||||
### Phase 4: Testing & Validation (Week 4-5)
|
||||
1. **Update test suite**
|
||||
- Replace Celery fixtures with TaskIQ fixtures
|
||||
- Update all test files
|
||||
- Ensure coverage remains the same
|
||||
|
||||
2. **Performance testing**
|
||||
- Compare task execution times
|
||||
- Monitor Redis memory usage
|
||||
- Test under load
|
||||
|
||||
### Phase 5: Cutover (Week 5-6)
|
||||
1. **Final migration**
|
||||
- Remove Celery dependencies
|
||||
- Update deployment scripts
|
||||
- Update documentation
|
||||
|
||||
2. **Monitoring**
|
||||
- Setup TaskIQ monitoring (if available)
|
||||
- Create health checks
|
||||
- Document operational procedures
|
||||
|
||||
## Key Differences to Note
|
||||
|
||||
### Advantages of TaskIQ
|
||||
1. **Native async support** - No need for `@asynctask` wrapper
|
||||
2. **Dependency injection** - Cleaner than decorators for session management
|
||||
3. **Type hints** - Better IDE support and autocompletion
|
||||
4. **Modern Python** - Designed for Python 3.7+
|
||||
5. **Simpler testing** - InMemoryBroker makes testing easier
|
||||
|
||||
### Potential Challenges
|
||||
1. **Less mature ecosystem** - Fewer third-party integrations
|
||||
2. **Documentation** - Less comprehensive than Celery
|
||||
3. **Monitoring tools** - No Flower equivalent (may need custom solution)
|
||||
4. **Community support** - Smaller community than Celery
|
||||
|
||||
## Command Line Changes
|
||||
|
||||
### Current (Celery)
|
||||
```bash
|
||||
# Start worker
|
||||
celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
# Start beat scheduler
|
||||
celery -A reflector.worker.app beat
|
||||
```
|
||||
|
||||
### New (TaskIQ)
|
||||
```bash
|
||||
# Start worker
|
||||
taskiq worker reflector.worker.broker:broker
|
||||
|
||||
# Start scheduler
|
||||
taskiq scheduler reflector.worker.scheduler:scheduler
|
||||
|
||||
# With custom settings
|
||||
taskiq worker reflector.worker.broker:broker --workers 4 --log-level INFO
|
||||
```
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If issues arise during migration:
|
||||
|
||||
1. **Keep Celery code in version control** - Tag the last Celery version
|
||||
2. **Maintain dual broker setup** - Can switch back via environment variable
|
||||
3. **Database compatibility** - No schema changes required
|
||||
4. **Redis compatibility** - Both use Redis, easy to switch back
|
||||
|
||||
## Success Criteria
|
||||
|
||||
1. ✅ All tasks migrated and functioning
|
||||
2. ✅ Test coverage maintained at current levels
|
||||
3. ✅ Performance equal or better than Celery
|
||||
4. ✅ Scheduled tasks running reliably
|
||||
5. ✅ Error handling and retries working correctly
|
||||
6. ✅ WebSocket notifications still functioning
|
||||
7. ✅ Pipeline processing maintaining same behavior
|
||||
|
||||
## Monitoring & Operations
|
||||
|
||||
### Health Checks
|
||||
```python
|
||||
# reflector/worker/healthcheck.py
|
||||
@broker.task
|
||||
async def healthcheck_ping():
|
||||
"""TaskIQ health check task"""
|
||||
return {"status": "healthy", "timestamp": datetime.now()}
|
||||
```
|
||||
|
||||
### Metrics Collection
|
||||
- Task execution times
|
||||
- Success/failure rates
|
||||
- Queue depths
|
||||
- Worker utilization
|
||||
|
||||
## Key Implementation Points - MUST READ
|
||||
|
||||
### Critical Changes Required
|
||||
|
||||
1. **Session Management in Tasks**
|
||||
- ✅ **VERIFIED**: Tasks MUST use `get_session()` from `reflector.db` for test compatibility
|
||||
- ❌ Do NOT use `get_session_factory()` directly in tasks - it bypasses test mocks
|
||||
- ✅ The test database session IS properly shared when using `get_session()`
|
||||
|
||||
2. **Task Invocation Changes**
|
||||
- Replace `.delay()` with `await .kiq()`
|
||||
- All task invocations become async/await
|
||||
- No need to commit sessions before task invocation (controllers handle this)
|
||||
|
||||
3. **Broker Configuration**
|
||||
- TaskIQ broker must be initialized in `worker/app.py`
|
||||
- Use `InMemoryBroker(await_inplace=True)` for testing
|
||||
- Use `RedisStreamBroker` for production
|
||||
|
||||
4. **Test Setup Requirements**
|
||||
- Set `os.environ["ENVIRONMENT"] = "pytest"` at top of test files
|
||||
- Add TaskIQ broker fixture to test functions
|
||||
- Keep Celery fixtures for now (dual-mode operation)
|
||||
|
||||
5. **Import Pattern Changes**
|
||||
```python
|
||||
# Each file needs both imports during migration
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
task_pipeline_file_process, # Celery version
|
||||
task_pipeline_file_process_taskiq, # TaskIQ version
|
||||
)
|
||||
```
|
||||
|
||||
6. **Decorator Changes**
|
||||
- Remove `@asynctask` - TaskIQ is async-native
|
||||
- **Keep `@with_session`** - it works with TaskIQ tasks!
|
||||
- Remove `@shared_task` from TaskIQ version
|
||||
- Keep `@shared_task` on Celery version for backward compatibility
|
||||
|
||||
## Verified POC Results
|
||||
|
||||
✅ **Database transactions work correctly** across test and TaskIQ tasks
|
||||
✅ **Tasks execute immediately** in tests with `InMemoryBroker(await_inplace=True)`
|
||||
✅ **Session mocking works** when using `get_session()` properly
|
||||
✅ **"OK" output confirmed** - TaskIQ task executes and accesses test data
|
||||
|
||||
## Conclusion
|
||||
|
||||
The migration from Celery to TaskIQ is feasible and offers several advantages for an async-first codebase like Reflector. The key challenges will be:
|
||||
|
||||
1. Migrating complex pipeline patterns (chain/chord)
|
||||
2. Ensuring scheduled task reliability
|
||||
3. **SOLVED**: Maintaining session management patterns - use `get_session()`
|
||||
4. Updating the test suite
|
||||
|
||||
The phased approach allows for gradual migration with minimal risk. The ability to run both systems in parallel provides a safety net during the transition period.
|
||||
|
||||
## Appendix: Quick Reference
|
||||
|
||||
| Celery | TaskIQ |
|
||||
|--------|--------|
|
||||
| `@shared_task` | `@broker.task` |
|
||||
| `.delay()` | `.kiq()` |
|
||||
| `.apply_async()` | `.kicker().kiq()` |
|
||||
| `chain()` | `Pipeline()` |
|
||||
| `group()` | `gather()` |
|
||||
| `chord()` | `gather() + callback` |
|
||||
| `@task.retry()` | `retry_on_error=True` |
|
||||
| Celery Beat | TaskIQ Scheduler |
|
||||
| `celery worker` | `taskiq worker` |
|
||||
| Flower | Custom monitoring needed |
|
||||
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from reflector.db import metadata
|
||||
from reflector.db.base import metadata
|
||||
from reflector.settings import settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add webhook fields to rooms
|
||||
|
||||
Revision ID: 0194f65cd6d3
|
||||
Revises: 5a8907fd1d78
|
||||
Create Date: 2025-08-27 09:03:19.610995
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0194f65cd6d3"
|
||||
down_revision: Union[str, None] = "5a8907fd1d78"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("webhook_url", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("webhook_secret", sa.String(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_column("webhook_secret")
|
||||
batch_op.drop_column("webhook_url")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,68 @@
|
||||
"""add_long_summary_to_search_vector
|
||||
|
||||
Revision ID: 0ab2d7ffaa16
|
||||
Revises: b1c33bd09963
|
||||
Create Date: 2025-08-15 13:27:52.680211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ab2d7ffaa16"
|
||||
down_revision: Union[str, None] = "b1c33bd09963"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the search vector column with long_summary included
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||
) STORED
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the updated search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the original search vector column without long_summary
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""remove user_id from meeting table
|
||||
|
||||
Revision ID: 0ce521cda2ee
|
||||
Revises: 6dec9fb5b46c
|
||||
Create Date: 2025-09-10 12:40:55.688899
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ce521cda2ee"
|
||||
down_revision: Union[str, None] = "6dec9fb5b46c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_column("user_id")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -21,13 +21,15 @@ def upgrade() -> None:
|
||||
if conn.dialect.name != "postgresql":
|
||||
return
|
||||
|
||||
op.execute("""
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""clean up orphaned room_id references in meeting table
|
||||
|
||||
Revision ID: 2ae3db106d4e
|
||||
Revises: def1b5867d4c
|
||||
Create Date: 2025-09-11 10:35:15.759967
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2ae3db106d4e"
|
||||
down_revision: Union[str, None] = "def1b5867d4c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set room_id to NULL for meetings that reference non-existent rooms
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE meeting
|
||||
SET room_id = NULL
|
||||
WHERE room_id IS NOT NULL
|
||||
AND room_id NOT IN (SELECT id FROM room WHERE id IS NOT NULL)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore orphaned references - no operation needed
|
||||
pass
|
||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
|
||||
@@ -36,9 +36,7 @@ def upgrade() -> None:
|
||||
|
||||
# select only the one with duration = 0
|
||||
results = bind.execute(
|
||||
select([transcript.c.id, transcript.c.duration]).where(
|
||||
transcript.c.duration == 0
|
||||
)
|
||||
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0)
|
||||
)
|
||||
|
||||
data_dir = Path(settings.DATA_DIR)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add cascade delete to meeting consent foreign key
|
||||
|
||||
Revision ID: 5a8907fd1d78
|
||||
Revises: 0ab2d7ffaa16
|
||||
Create Date: 2025-08-26 17:26:50.945491
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5a8907fd1d78"
|
||||
down_revision: Union[str, None] = "0ab2d7ffaa16"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,53 @@
|
||||
"""remove_one_active_meeting_per_room_constraint
|
||||
|
||||
Revision ID: 6025e9b2bef2
|
||||
Revises: 2ae3db106d4e
|
||||
Create Date: 2025-08-18 18:45:44.418392
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6025e9b2bef2"
|
||||
down_revision: Union[str, None] = "2ae3db106d4e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove the unique constraint that prevents multiple active meetings per room
|
||||
# This is needed to support calendar integration with overlapping meetings
|
||||
# Check if index exists before trying to drop it
|
||||
from alembic import context
|
||||
|
||||
if context.get_context().dialect.name == "postgresql":
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_indexes WHERE indexname = 'idx_one_active_meeting_per_room'"
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
else:
|
||||
# For SQLite, just try to drop it
|
||||
try:
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the unique constraint
|
||||
op.create_index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"meeting",
|
||||
["room_id"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
sqlite_where=sa.text("is_active = 1"),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
"""webhook url and secret null by default
|
||||
|
||||
|
||||
Revision ID: 61882a919591
|
||||
Revises: 0194f65cd6d3
|
||||
Create Date: 2025-08-29 11:46:36.738091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "61882a919591"
|
||||
down_revision: Union[str, None] = "0194f65cd6d3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,35 @@
|
||||
"""make meeting room_id required and add foreign key
|
||||
|
||||
Revision ID: 6dec9fb5b46c
|
||||
Revises: 61882a919591
|
||||
Create Date: 2025-09-10 10:47:06.006819
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6dec9fb5b46c"
|
||||
down_revision: Union[str, None] = "61882a919591"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.create_foreign_key(
|
||||
None, "room", ["room_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("meeting_room_id_fkey", type_="foreignkey")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_search_optimization_indexes
|
||||
|
||||
Revision ID: b1c33bd09963
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-14 17:26:02.117408
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b1c33bd09963"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add indexes for actual search filtering patterns used in frontend
|
||||
# Based on /browse page filters: room_id and source_kind
|
||||
|
||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||
op.create_index(
|
||||
"idx_transcript_room_id_created_at",
|
||||
"transcript",
|
||||
["room_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
# Index for source_kind alone (actively used filter in frontend)
|
||||
op.create_index(
|
||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add_grace_period_fields_to_meeting
|
||||
|
||||
Revision ID: d4a1c446458c
|
||||
Revises: 6025e9b2bef2
|
||||
Create Date: 2025-08-18 18:50:37.768052
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d4a1c446458c"
|
||||
down_revision: Union[str, None] = "6025e9b2bef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add fields to track when participants left for grace period logic
|
||||
op.add_column(
|
||||
"meeting", sa.Column("last_participant_left_at", sa.DateTime(timezone=True))
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
@@ -27,7 +27,8 @@ def upgrade() -> None:
|
||||
|
||||
# Populate room_id for existing ROOM-type transcripts
|
||||
# This joins through recording -> meeting -> room to get the room_id
|
||||
op.execute("""
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE transcript AS t
|
||||
SET room_id = r.id
|
||||
FROM recording rec
|
||||
@@ -36,11 +37,13 @@ def upgrade() -> None:
|
||||
WHERE t.recording_id = rec.id
|
||||
AND t.source_kind = 'room'
|
||||
AND t.room_id IS NULL
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
# Fix missing meeting_id for ROOM-type transcripts
|
||||
# The meeting_id field exists but was never populated
|
||||
op.execute("""
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE transcript AS t
|
||||
SET meeting_id = rec.meeting_id
|
||||
FROM recording rec
|
||||
@@ -48,7 +51,8 @@ def upgrade() -> None:
|
||||
AND t.source_kind = 'room'
|
||||
AND t.meeting_id IS NULL
|
||||
AND rec.meeting_id IS NOT NULL
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
129
server/migrations/versions/d8e204bbf615_add_calendar.py
Normal file
129
server/migrations/versions/d8e204bbf615_add_calendar.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""add calendar
|
||||
|
||||
Revision ID: d8e204bbf615
|
||||
Revises: d4a1c446458c
|
||||
Create Date: 2025-09-10 19:56:22.295756
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d8e204bbf615"
|
||||
down_revision: Union[str, None] = "d4a1c446458c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"calendar_event",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("room_id", sa.String(), nullable=False),
|
||||
sa.Column("ics_uid", sa.Text(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("attendees", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("location", sa.Text(), nullable=True),
|
||||
sa.Column("ics_raw_data", sa.Text(), nullable=True),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"is_deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["room_id"],
|
||||
["room.id"],
|
||||
name="fk_calendar_event_room_id",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||
)
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_deleted",
|
||||
["is_deleted"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("NOT is_deleted"),
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_room_start", ["room_id", "start_time"], unique=False
|
||||
)
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("calendar_event_id", sa.String(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"calendar_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_meeting_calendar_event", ["calendar_event_id"], unique=False
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_meeting_calendar_event_id",
|
||||
"calendar_event",
|
||||
["calendar_event_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("ics_url", sa.Text(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_fetch_interval", sa.Integer(), server_default="300", nullable=True
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_enabled",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("ics_last_sync", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
batch_op.add_column(sa.Column("ics_last_etag", sa.Text(), nullable=True))
|
||||
batch_op.create_index("idx_room_ics_enabled", ["ics_enabled"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_room_ics_enabled")
|
||||
batch_op.drop_column("ics_last_etag")
|
||||
batch_op.drop_column("ics_last_sync")
|
||||
batch_op.drop_column("ics_enabled")
|
||||
batch_op.drop_column("ics_fetch_interval")
|
||||
batch_op.drop_column("ics_url")
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_meeting_calendar_event_id", type_="foreignkey")
|
||||
batch_op.drop_index("idx_meeting_calendar_event")
|
||||
batch_op.drop_column("calendar_metadata")
|
||||
batch_op.drop_column("calendar_event_id")
|
||||
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_calendar_event_room_start")
|
||||
batch_op.drop_index(
|
||||
"idx_calendar_event_deleted", postgresql_where=sa.text("NOT is_deleted")
|
||||
)
|
||||
|
||||
op.drop_table("calendar_event")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,43 @@
|
||||
"""remove_grace_period_fields
|
||||
|
||||
Revision ID: dc035ff72fd5
|
||||
Revises: d8e204bbf615
|
||||
Create Date: 2025-09-11 10:36:45.197588
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc035ff72fd5"
|
||||
down_revision: Union[str, None] = "d8e204bbf615"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove grace period columns from meeting table
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back grace period columns to meeting table
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"last_participant_left_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"grace_period_minutes",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("15"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make meeting room_id nullable but keep foreign key
|
||||
|
||||
Revision ID: def1b5867d4c
|
||||
Revises: 0ce521cda2ee
|
||||
Create Date: 2025-09-11 09:42:18.697264
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "def1b5867d4c"
|
||||
down_revision: Union[str, None] = "0ce521cda2ee"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -12,7 +12,6 @@ dependencies = [
|
||||
"requests>=2.31.0",
|
||||
"aiortc>=1.5.0",
|
||||
"sortedcontainers>=2.4.0",
|
||||
"loguru>=0.7.0",
|
||||
"pydantic-settings>=2.0.2",
|
||||
"structlog>=23.1.0",
|
||||
"uvicorn[standard]>=0.23.1",
|
||||
@@ -20,19 +19,16 @@ dependencies = [
|
||||
"sentry-sdk[fastapi]>=1.29.2",
|
||||
"httpx>=0.24.1",
|
||||
"fastapi-pagination>=0.12.6",
|
||||
"databases[aiosqlite, asyncpg]>=0.7.0",
|
||||
"sqlalchemy<1.5",
|
||||
"sqlalchemy>=2.0.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"alembic>=1.11.3",
|
||||
"nltk>=3.8.1",
|
||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||
"sentencepiece>=0.1.99",
|
||||
"protobuf>=4.24.3",
|
||||
"profanityfilter>=2.0.6",
|
||||
"celery>=5.3.4",
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"faster-whisper>=0.10.0",
|
||||
"transformers>=4.36.2",
|
||||
"jsonschema>=4.23.0",
|
||||
"openai>=1.59.7",
|
||||
@@ -41,6 +37,9 @@ dependencies = [
|
||||
"llama-index-llms-openai-like>=0.4.0",
|
||||
"pytest-env>=1.1.5",
|
||||
"webvtt-py>=0.5.0",
|
||||
"icalendar>=6.0.0",
|
||||
"taskiq>=0.11.18",
|
||||
"taskiq-redis>=1.1.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -48,6 +47,7 @@ dev = [
|
||||
"black>=24.1.1",
|
||||
"stamina>=23.1.0",
|
||||
"pyinstrument>=4.6.1",
|
||||
"pytest-async-sqlalchemy>=0.2.0",
|
||||
]
|
||||
tests = [
|
||||
"pytest-cov>=4.1.0",
|
||||
@@ -56,7 +56,7 @@ tests = [
|
||||
"pytest>=7.4.0",
|
||||
"httpx-ws>=0.4.1",
|
||||
"pytest-httpx>=0.23.1",
|
||||
"pytest-celery>=0.0.0",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-docker>=3.2.3",
|
||||
"asgi-lifespan>=2.1.0",
|
||||
]
|
||||
@@ -67,6 +67,15 @@ evaluation = [
|
||||
"tqdm>=4.66.0",
|
||||
"pydantic>=2.1.1",
|
||||
]
|
||||
local = [
|
||||
"pyannote-audio>=3.3.2",
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = [
|
||||
@@ -74,6 +83,21 @@ default-groups = [
|
||||
"tests",
|
||||
"aws",
|
||||
"evaluation",
|
||||
"local",
|
||||
"silero-vad"
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -88,12 +112,18 @@ source = ["reflector"]
|
||||
|
||||
[tool.pytest_env]
|
||||
ENVIRONMENT = "pytest"
|
||||
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_debug = true
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
markers = [
|
||||
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
@@ -104,7 +134,7 @@ select = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||
"gpu/**.py" = ["PLC0415"]
|
||||
"gpu/modal_deployments/**.py" = ["PLC0415"]
|
||||
"reflector/tools/**.py" = ["PLC0415"]
|
||||
"migrations/versions/**.py" = ["PLC0415"]
|
||||
"tests/**.py" = ["PLC0415"]
|
||||
|
||||
@@ -88,8 +88,8 @@ app.include_router(zulip_router, prefix="/v1")
|
||||
app.include_router(whereby_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
from reflector.worker import app as celery_app # noqa
|
||||
# prepare taskiq
|
||||
from reflector.worker import app as taskiq_app # noqa
|
||||
|
||||
|
||||
# simpler openapi id
|
||||
|
||||
@@ -67,7 +67,8 @@ def current_user(
|
||||
try:
|
||||
payload = jwtauth.verify_token(token)
|
||||
sub = payload["sub"]
|
||||
return UserInfo(sub=sub)
|
||||
email = payload["email"]
|
||||
return UserInfo(sub=sub, email=email)
|
||||
except JWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
@@ -1,47 +1,82 @@
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import databases
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from reflector.db.base import Base as Base
|
||||
from reflector.db.base import metadata as metadata
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||
contextvars.ContextVar("database", default=None)
|
||||
)
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def get_database() -> databases.Database:
|
||||
"""Get database instance for current asyncio context"""
|
||||
db = _database_context.get()
|
||||
if db is None:
|
||||
db = databases.Database(settings.DATABASE_URL)
|
||||
_database_context.set(db)
|
||||
return db
|
||||
def get_engine() -> AsyncEngine:
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
# import models
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
global _session_factory
|
||||
if _session_factory is None:
|
||||
_session_factory = async_sessionmaker(
|
||||
get_engine(),
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
# necessary implementation to ease mocking on pytest
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Get a database session, fastapi dependency injection style
|
||||
"""
|
||||
async for session in _get_session():
|
||||
yield session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session_context():
|
||||
"""
|
||||
Get a database session as an async context manager
|
||||
"""
|
||||
async for session in _get_session():
|
||||
yield session
|
||||
|
||||
|
||||
import reflector.db.calendar_events # noqa
|
||||
import reflector.db.meetings # noqa
|
||||
import reflector.db.recordings # noqa
|
||||
import reflector.db.rooms # noqa
|
||||
import reflector.db.transcripts # noqa
|
||||
|
||||
kwargs = {}
|
||||
if "postgres" not in settings.DATABASE_URL:
|
||||
raise Exception("Only postgres database is supported in reflector")
|
||||
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||
|
||||
|
||||
@subscribers_startup.append
|
||||
async def database_connect(_):
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
get_engine()
|
||||
|
||||
|
||||
@subscribers_shutdown.append
|
||||
async def database_disconnect(_):
|
||||
database = get_database()
|
||||
await database.disconnect()
|
||||
global _engine
|
||||
if _engine:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
|
||||
237
server/reflector/db/base.py
Normal file
237
server/reflector/db/base.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TranscriptModel(Base):
|
||||
__tablename__ = "transcript"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
name: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
status: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
||||
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
title: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
reviewed: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
audio_location: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="local"
|
||||
)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
share_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="private"
|
||||
)
|
||||
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
||||
source_kind: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False
|
||||
) # Enum will be handled separately
|
||||
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
||||
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sa.Index("idx_transcript_user_id", "user_id"),
|
||||
sa.Index("idx_transcript_created_at", "created_at"),
|
||||
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sa.Index("idx_transcript_room_id", "room_id"),
|
||||
sa.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
TranscriptModel.search_vector_en = sa.Column(
|
||||
"search_vector_en",
|
||||
TSVECTOR,
|
||||
sa.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RoomModel(Base):
|
||||
__tablename__ = "room"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
|
||||
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
zulip_auto_post: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
is_locked: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
room_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="normal"
|
||||
)
|
||||
recording_type: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="cloud"
|
||||
)
|
||||
recording_trigger: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
||||
)
|
||||
is_shared: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
|
||||
sa.Integer, server_default=sa.text("300")
|
||||
)
|
||||
ics_enabled: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True)
|
||||
)
|
||||
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_room_is_shared", "is_shared"),
|
||||
sa.Index("idx_room_ics_enabled", "ics_enabled"),
|
||||
)
|
||||
|
||||
|
||||
class MeetingModel(Base):
|
||||
__tablename__ = "meeting"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
room_id: Mapped[Optional[str]] = mapped_column(
|
||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
|
||||
)
|
||||
is_locked: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
room_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="normal"
|
||||
)
|
||||
recording_type: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="cloud"
|
||||
)
|
||||
recording_trigger: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
||||
)
|
||||
num_clients: Mapped[int] = mapped_column(
|
||||
sa.Integer, nullable=False, server_default=sa.text("0")
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("true")
|
||||
)
|
||||
calendar_event_id: Mapped[Optional[str]] = mapped_column(
|
||||
sa.String,
|
||||
sa.ForeignKey(
|
||||
"calendar_event.id",
|
||||
ondelete="SET NULL",
|
||||
name="fk_meeting_calendar_event_id",
|
||||
),
|
||||
)
|
||||
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||
)
|
||||
|
||||
|
||||
class MeetingConsentModel(Base):
|
||||
__tablename__ = "meeting_consent"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
meeting_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
|
||||
consent_timestamp: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class RecordingModel(Base):
|
||||
__tablename__ = "recording"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
meeting_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
|
||||
|
||||
|
||||
class CalendarEventModel(Base):
|
||||
__tablename__ = "calendar_event"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
room_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
title: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
description: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
start_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
end_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
|
||||
location: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
last_synced: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
||||
)
|
||||
|
||||
|
||||
metadata = Base.metadata
|
||||
189
server/reflector/db/calendar_events.py
Normal file
189
server/reflector/db/calendar_events.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.base import CalendarEventModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: list[dict[str, Any]] | None = None
|
||||
location: str | None = None
|
||||
ics_raw_data: str | None = None
|
||||
last_synced: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class CalendarEventController:
|
||||
async def get_upcoming_events(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_id: str,
|
||||
current_time: datetime,
|
||||
buffer_minutes: int = 15,
|
||||
) -> list[CalendarEvent]:
|
||||
buffer_time = current_time + timedelta(minutes=buffer_minutes)
|
||||
|
||||
query = (
|
||||
select(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.start_time <= buffer_time,
|
||||
CalendarEventModel.end_time > current_time,
|
||||
)
|
||||
)
|
||||
.order_by(CalendarEventModel.start_time)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, event_id: str
|
||||
) -> CalendarEvent | None:
|
||||
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def get_by_ics_uid(
|
||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
||||
) -> CalendarEvent | None:
|
||||
query = select(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.ics_uid == ics_uid,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def upsert(
|
||||
self, session: AsyncSession, event: CalendarEvent
|
||||
) -> CalendarEvent:
|
||||
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
|
||||
|
||||
if existing:
|
||||
event.updated_at = datetime.now(timezone.utc)
|
||||
query = (
|
||||
update(CalendarEventModel)
|
||||
.where(CalendarEventModel.id == existing.id)
|
||||
.values(**event.model_dump(exclude={"id"}))
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return event
|
||||
else:
|
||||
new_event = CalendarEventModel(**event.model_dump())
|
||||
session.add(new_event)
|
||||
await session.commit()
|
||||
return event
|
||||
|
||||
async def delete_old_events(
|
||||
self, session: AsyncSession, room_id: str, cutoff_date: datetime
|
||||
) -> int:
|
||||
query = delete(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.end_time < cutoff_date,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
async def delete_events_not_in_list(
|
||||
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
|
||||
) -> int:
|
||||
if not keep_ics_uids:
|
||||
query = delete(CalendarEventModel).where(
|
||||
CalendarEventModel.room_id == room_id
|
||||
)
|
||||
else:
|
||||
query = delete(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
async def get_by_room(
|
||||
self, session: AsyncSession, room_id: str, include_deleted: bool = True
|
||||
) -> list[CalendarEvent]:
|
||||
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id)
|
||||
if not include_deleted:
|
||||
query = query.where(CalendarEventModel.is_deleted == False)
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_upcoming(
|
||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
||||
) -> list[CalendarEvent]:
|
||||
now = datetime.now(timezone.utc)
|
||||
buffer_time = now + timedelta(minutes=minutes_ahead)
|
||||
|
||||
query = (
|
||||
select(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.start_time <= buffer_time,
|
||||
CalendarEventModel.end_time > now,
|
||||
CalendarEventModel.is_deleted == False,
|
||||
)
|
||||
)
|
||||
.order_by(CalendarEventModel.start_time)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def soft_delete_missing(
|
||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
||||
) -> int:
|
||||
query = (
|
||||
update(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
(
|
||||
CalendarEventModel.ics_uid.notin_(current_ics_uids)
|
||||
if current_ics_uids
|
||||
else True
|
||||
),
|
||||
CalendarEventModel.end_time > datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
.values(is_deleted=True)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
calendar_events_controller = CalendarEventController()
|
||||
@@ -1,67 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import MeetingConsentModel, MeetingModel
|
||||
from reflector.db.rooms import Room
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
meetings = sa.Table(
|
||||
"meeting",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("room_name", sa.String),
|
||||
sa.Column("room_url", sa.String),
|
||||
sa.Column("host_room_url", sa.String),
|
||||
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("room_id", sa.String),
|
||||
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
|
||||
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
|
||||
sa.Column(
|
||||
"recording_trigger",
|
||||
sa.String,
|
||||
nullable=False,
|
||||
server_default="automatic-2nd-participant",
|
||||
),
|
||||
sa.Column(
|
||||
"num_clients",
|
||||
sa.Integer,
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"room_id",
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
),
|
||||
)
|
||||
|
||||
meeting_consent = sa.Table(
|
||||
"meeting_consent",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
class MeetingConsent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
meeting_id: str
|
||||
user_id: str | None = None
|
||||
@@ -70,14 +22,15 @@ class MeetingConsent(BaseModel):
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
room_name: str
|
||||
room_url: str
|
||||
host_room_url: str
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_id: str | None
|
||||
is_locked: bool = False
|
||||
room_mode: Literal["normal", "group"] = "normal"
|
||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||
@@ -85,23 +38,25 @@ class Meeting(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
num_clients: int = 0
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MeetingController:
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
id: str,
|
||||
room_name: str,
|
||||
room_url: str,
|
||||
host_room_url: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: str,
|
||||
room: Room,
|
||||
calendar_event_id: str | None = None,
|
||||
calendar_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new meeting
|
||||
"""
|
||||
meeting = Meeting(
|
||||
id=id,
|
||||
room_name=room_name,
|
||||
@@ -109,148 +64,206 @@ class MeetingController:
|
||||
host_room_url=host_room_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_id=user_id,
|
||||
room_id=room.id,
|
||||
is_locked=room.is_locked,
|
||||
room_mode=room.room_mode,
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
calendar_event_id=calendar_event_id,
|
||||
calendar_metadata=calendar_metadata,
|
||||
)
|
||||
query = meetings.insert().values(**meeting.model_dump())
|
||||
await get_database().execute(query)
|
||||
new_meeting = MeetingModel(**meeting.model_dump())
|
||||
session.add(new_meeting)
|
||||
await session.commit()
|
||||
return meeting
|
||||
|
||||
async def get_all_active(self) -> list[Meeting]:
|
||||
"""
|
||||
Get active meetings.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.is_active)
|
||||
return await get_database().fetch_all(query)
|
||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
||||
result = await session.execute(query)
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_room_name(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_name: str,
|
||||
) -> Meeting:
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by room name.
|
||||
For backward compatibility, returns the most recent meeting.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = (
|
||||
select(MeetingModel)
|
||||
.where(MeetingModel.room_name == room_name)
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
|
||||
async def get_active(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get latest active meeting for a room.
|
||||
For backward compatibility, returns the most recent active meeting.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
meetings.select()
|
||||
select(MeetingModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
meetings.c.room_id == room.id,
|
||||
meetings.c.end_date > current_time,
|
||||
meetings.c.is_active,
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
.order_by(end_date.desc())
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
return Meeting(**result)
|
||||
async def get_all_active_for_room(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
) -> list[Meeting]:
|
||||
query = (
|
||||
select(MeetingModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||
async def get_active_by_calendar_event(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room: Room,
|
||||
calendar_event_id: str,
|
||||
current_time: datetime,
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by id
|
||||
Get active meeting for a specific calendar event.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(MeetingModel).where(
|
||||
sa.and_(
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.calendar_event_id == calendar_event_id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Meeting:
|
||||
"""
|
||||
Get a meeting by ID for HTTP request.
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
||||
) -> Meeting | None:
|
||||
query = select(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
async def get_by_calendar_event(
|
||||
self, session: AsyncSession, calendar_event_id: str
|
||||
) -> Meeting | None:
|
||||
query = select(MeetingModel).where(
|
||||
MeetingModel.calendar_event_id == calendar_event_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
meeting = Meeting(**result)
|
||||
if result["user_id"] != user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||
await get_database().execute(query)
|
||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
||||
query = (
|
||||
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs)
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class MeetingConsentController:
|
||||
async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id
|
||||
async def get_by_meeting_id(
|
||||
self, session: AsyncSession, meeting_id: str
|
||||
) -> list[MeetingConsent]:
|
||||
query = select(MeetingConsentModel).where(
|
||||
MeetingConsentModel.meeting_id == meeting_id
|
||||
)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [MeetingConsent(**result) for result in results]
|
||||
result = await session.execute(query)
|
||||
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_meeting_and_user(
|
||||
self, meeting_id: str, user_id: str
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
||||
) -> MeetingConsent | None:
|
||||
"""Get existing consent for a specific user and meeting"""
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.user_id == user_id,
|
||||
query = select(MeetingConsentModel).where(
|
||||
sa.and_(
|
||||
MeetingConsentModel.meeting_id == meeting_id,
|
||||
MeetingConsentModel.user_id == user_id,
|
||||
)
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if result is None:
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return MeetingConsent(**result) if result else None
|
||||
return MeetingConsent.model_validate(row)
|
||||
|
||||
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
|
||||
"""Create new consent or update existing one for authenticated users"""
|
||||
async def upsert(
|
||||
self, session: AsyncSession, consent: MeetingConsent
|
||||
) -> MeetingConsent:
|
||||
if consent.user_id:
|
||||
# For authenticated users, check if consent already exists
|
||||
# not transactional but we're ok with that; the consents ain't deleted anyways
|
||||
existing = await self.get_by_meeting_and_user(
|
||||
consent.meeting_id, consent.user_id
|
||||
session, consent.meeting_id, consent.user_id
|
||||
)
|
||||
if existing:
|
||||
query = (
|
||||
meeting_consent.update()
|
||||
.where(meeting_consent.c.id == existing.id)
|
||||
update(MeetingConsentModel)
|
||||
.where(MeetingConsentModel.id == existing.id)
|
||||
.values(
|
||||
consent_given=consent.consent_given,
|
||||
consent_timestamp=consent.consent_timestamp,
|
||||
)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
existing.consent_given = consent.consent_given
|
||||
existing.consent_timestamp = consent.consent_timestamp
|
||||
return existing
|
||||
existing.consent_given = consent.consent_given
|
||||
existing.consent_timestamp = consent.consent_timestamp
|
||||
return existing
|
||||
|
||||
query = meeting_consent.insert().values(**consent.model_dump())
|
||||
await get_database().execute(query)
|
||||
new_consent = MeetingConsentModel(**consent.model_dump())
|
||||
session.add(new_consent)
|
||||
await session.commit()
|
||||
return consent
|
||||
|
||||
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool:
|
||||
"""Check if any participant denied consent for this meeting"""
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.consent_given.is_(False),
|
||||
query = select(MeetingConsentModel).where(
|
||||
sa.and_(
|
||||
MeetingConsentModel.meeting_id == meeting_id,
|
||||
MeetingConsentModel.consent_given.is_(False),
|
||||
)
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
return result is not None
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
meetings_controller = MeetingController()
|
||||
|
||||
@@ -1,61 +1,79 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RecordingModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
recordings = sa.Table(
|
||||
"recording",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("bucket_name", sa.String, nullable=False),
|
||||
sa.Column("object_key", sa.String, nullable=False),
|
||||
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String,
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column("meeting_id", sa.String),
|
||||
sa.Index("idx_recording_meeting_id", "meeting_id"),
|
||||
)
|
||||
|
||||
|
||||
class Recording(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
bucket_name: str
|
||||
meeting_id: str
|
||||
url: str
|
||||
object_key: str
|
||||
recorded_at: datetime
|
||||
status: Literal["pending", "processing", "completed", "failed"] = "pending"
|
||||
meeting_id: str | None = None
|
||||
duration: float | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RecordingController:
|
||||
async def create(self, recording: Recording):
|
||||
query = recordings.insert().values(**recording.model_dump())
|
||||
await get_database().execute(query)
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
meeting_id: str,
|
||||
url: str,
|
||||
object_key: str,
|
||||
duration: float | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
if created_at is None:
|
||||
created_at = datetime.now(timezone.utc)
|
||||
|
||||
recording = Recording(
|
||||
meeting_id=meeting_id,
|
||||
url=url,
|
||||
object_key=object_key,
|
||||
duration=duration,
|
||||
created_at=created_at,
|
||||
)
|
||||
new_recording = RecordingModel(**recording.model_dump())
|
||||
session.add(new_recording)
|
||||
await session.commit()
|
||||
return recording
|
||||
|
||||
async def get_by_id(self, id: str) -> Recording:
|
||||
query = recordings.select().where(recordings.c.id == id)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, recording_id: str
|
||||
) -> Recording | None:
|
||||
"""
|
||||
Get a recording by id
|
||||
"""
|
||||
query = select(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Recording.model_validate(row)
|
||||
|
||||
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||
query = recordings.select().where(
|
||||
recordings.c.bucket_name == bucket_name,
|
||||
recordings.c.object_key == object_key,
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
async def get_by_meeting_id(
|
||||
self, session: AsyncSession, meeting_id: str
|
||||
) -> list[Recording]:
|
||||
"""
|
||||
Get all recordings for a meeting
|
||||
"""
|
||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
return [Recording.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def remove_by_id(self, id: str) -> None:
|
||||
query = recordings.delete().where(recordings.c.id == id)
|
||||
await get_database().execute(query)
|
||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
||||
"""
|
||||
Remove a recording by id
|
||||
"""
|
||||
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
|
||||
recordings_controller = RecordingController()
|
||||
|
||||
@@ -1,50 +1,21 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import Literal
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.sql import false, or_
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import or_
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RoomModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
rooms = sqlalchemy.Table(
|
||||
"room",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||
sqlalchemy.Column(
|
||||
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
|
||||
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"recording_trigger",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="automatic-2nd-participant",
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
|
||||
)
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
name: str
|
||||
user_id: str
|
||||
@@ -59,11 +30,19 @@ class Room(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
is_shared: bool = False
|
||||
webhook_url: str | None = None
|
||||
webhook_secret: str | None = None
|
||||
ics_url: str | None = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
ics_last_sync: datetime | None = None
|
||||
ics_last_etag: str | None = None
|
||||
|
||||
|
||||
class RoomController:
|
||||
async def get_all(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
return_query: bool = False,
|
||||
@@ -77,14 +56,14 @@ class RoomController:
|
||||
Parameters:
|
||||
- `order_by`: field to order by, e.g. "-created_at"
|
||||
"""
|
||||
query = rooms.select()
|
||||
query = select(RoomModel)
|
||||
if user_id is not None:
|
||||
query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
|
||||
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared))
|
||||
else:
|
||||
query = query.where(rooms.c.is_shared)
|
||||
query = query.where(RoomModel.is_shared)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(rooms.c, order_by[1:])
|
||||
field = getattr(RoomModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
@@ -92,11 +71,12 @@ class RoomController:
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
result = await session.execute(query)
|
||||
return [Room.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def add(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
user_id: str,
|
||||
zulip_auto_post: bool,
|
||||
@@ -107,10 +87,18 @@ class RoomController:
|
||||
recording_type: str,
|
||||
recording_trigger: str,
|
||||
is_shared: bool,
|
||||
webhook_url: str = "",
|
||||
webhook_secret: str = "",
|
||||
ics_url: str | None = None,
|
||||
ics_fetch_interval: int = 300,
|
||||
ics_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Add a new room
|
||||
"""
|
||||
if webhook_url and not webhook_secret:
|
||||
webhook_secret = secrets.token_urlsafe(32)
|
||||
|
||||
room = Room(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
@@ -122,21 +110,33 @@ class RoomController:
|
||||
recording_type=recording_type,
|
||||
recording_trigger=recording_trigger,
|
||||
is_shared=is_shared,
|
||||
webhook_url=webhook_url,
|
||||
webhook_secret=webhook_secret,
|
||||
ics_url=ics_url,
|
||||
ics_fetch_interval=ics_fetch_interval,
|
||||
ics_enabled=ics_enabled,
|
||||
)
|
||||
query = rooms.insert().values(**room.model_dump())
|
||||
new_room = RoomModel(**room.model_dump())
|
||||
session.add(new_room)
|
||||
try:
|
||||
await get_database().execute(query)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
return room
|
||||
|
||||
async def update(self, room: Room, values: dict, mutate=True):
|
||||
async def update(
|
||||
self, session: AsyncSession, room: Room, values: dict, mutate=True
|
||||
):
|
||||
"""
|
||||
Update a room fields with key/values in values
|
||||
"""
|
||||
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||
if values.get("webhook_url") and not values.get("webhook_secret"):
|
||||
values["webhook_secret"] = secrets.token_urlsafe(32)
|
||||
|
||||
query = update(RoomModel).where(RoomModel.id == room.id).values(**values)
|
||||
try:
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
|
||||
@@ -144,60 +144,79 @@ class RoomController:
|
||||
for key, value in values.items():
|
||||
setattr(room, key, value)
|
||||
|
||||
async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, room_id: str, **kwargs
|
||||
) -> Room | None:
|
||||
"""
|
||||
Get a room by id
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.id == room_id)
|
||||
query = select(RoomModel).where(RoomModel.id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**result)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
|
||||
async def get_by_name(
|
||||
self, session: AsyncSession, room_name: str, **kwargs
|
||||
) -> Room | None:
|
||||
"""
|
||||
Get a room by name
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.name == room_name)
|
||||
query = select(RoomModel).where(RoomModel.name == room_name)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**result)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
|
||||
async def get_by_id_for_http(
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
||||
) -> Room:
|
||||
"""
|
||||
Get a room by ID for HTTP request.
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(RoomModel).where(RoomModel.id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
room = Room(**result)
|
||||
room = Room.model_validate(row)
|
||||
|
||||
return room
|
||||
|
||||
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]:
|
||||
query = select(RoomModel).where(
|
||||
RoomModel.ics_enabled == True, RoomModel.ics_url != None
|
||||
)
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
return [Room(**row.__dict__) for row in results]
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a room by id
|
||||
"""
|
||||
room = await self.get_by_id(room_id, user_id=user_id)
|
||||
room = await self.get_by_id(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
return
|
||||
if user_id is not None and room.user_id != user_id:
|
||||
return
|
||||
query = rooms.delete().where(rooms.c.id == room_id)
|
||||
await get_database().execute(query)
|
||||
query = delete(RoomModel).where(RoomModel.id == room_id)
|
||||
await session.execute(query)
|
||||
await session.flush()
|
||||
|
||||
|
||||
rooms_controller = RoomController()
|
||||
|
||||
@@ -1,22 +1,36 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict
|
||||
from typing import Annotated, Any, Dict, Iterator
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
from fastapi import HTTPException
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.db.base import RoomModel, TranscriptModel
|
||||
from reflector.db.transcripts import SourceKind, TranscriptStatus
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
@@ -24,6 +38,7 @@ SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
|
||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||
search_query_adapter = TypeAdapter(SearchQuery)
|
||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||
SearchOffset = Annotated[
|
||||
SearchOffsetBase, Field(description="Number of results to skip")
|
||||
@@ -32,15 +47,92 @@ SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
WEBVTT_SPEC_HEADER = "WEBVTT"
|
||||
|
||||
WebVTTContent = Annotated[
|
||||
str,
|
||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||
]
|
||||
|
||||
|
||||
class WebVTTProcessor:
|
||||
"""Stateless processor for WebVTT content operations."""
|
||||
|
||||
@staticmethod
|
||||
def parse(raw_content: str) -> WebVTTContent:
|
||||
"""Parse WebVTT content and return it as a string."""
|
||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||
return raw_content
|
||||
|
||||
@staticmethod
|
||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except webvtt.errors.MalformedFileError as e:
|
||||
logger.warning(f"Malformed WebVTT content: {e}")
|
||||
return ""
|
||||
except (UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
return SnippetGenerator.generate(
|
||||
WebVTTProcessor.extract_text(webvtt_content),
|
||||
query,
|
||||
max_snippets=max_snippets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnippetCandidate:
|
||||
"""Represents a candidate snippet with its position."""
|
||||
|
||||
_text: str
|
||||
start: NonNegativeInt
|
||||
_original_text_length: int
|
||||
|
||||
@property
|
||||
def end(self) -> NonNegativeInt:
|
||||
"""Calculate end position from start and raw text length."""
|
||||
return self.start + len(self._text)
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get display text with ellipses added if needed."""
|
||||
result = self._text.strip()
|
||||
if self.start > 0:
|
||||
result = "..." + result
|
||||
if self.end < self._original_text_length:
|
||||
result = result + "..."
|
||||
return result
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
|
||||
query_text: SearchQuery
|
||||
query_text: SearchQuery | None = None
|
||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
@@ -64,13 +156,18 @@ class SearchResult(BaseModel):
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
status: TranscriptStatus = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -79,153 +176,289 @@ class SearchResult(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
class SnippetGenerator:
|
||||
"""Stateless generator for text snippets and match operations."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||
"""Generate all match positions for a query in text."""
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in find_all_matches")
|
||||
return
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in find_all_matches")
|
||||
return
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
text_lower = text.lower()
|
||||
query_lower = query.lower()
|
||||
start = 0
|
||||
prev_start = start
|
||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||
yield pos
|
||||
start = pos + len(query_lower)
|
||||
if start <= prev_start:
|
||||
raise ValueError("panic! find_all_matches is not incremental")
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
assert query is not None
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||
) -> SnippetCandidate:
|
||||
"""Create a snippet from a match position."""
|
||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||
|
||||
snippet_text = text[snippet_start:snippet_end]
|
||||
|
||||
return SnippetCandidate(
|
||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_non_overlapping(
|
||||
candidates: Iterator[SnippetCandidate],
|
||||
) -> Iterator[str]:
|
||||
"""Filter out overlapping snippets and return only display text."""
|
||||
last_end = 0
|
||||
for candidate in candidates:
|
||||
display_text = candidate.text()
|
||||
# it means that next overlapping snippets simply don't get included
|
||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||
if candidate.start >= last_end and display_text:
|
||||
yield display_text
|
||||
last_end = candidate.end
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
query: SearchQuery,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
"""Generate snippets from text."""
|
||||
assert query is not None
|
||||
if not text:
|
||||
logger.warning("Empty text for generate_snippets")
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
candidates = (
|
||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||
)
|
||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||
snippets = list(itertools.islice(filtered, max_snippets))
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
# Fallback to first word search if no full matches
|
||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||
if not snippets and " " in query:
|
||||
first_word = query.split()[0]
|
||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: NonEmptyString | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: SearchQuery,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
|
||||
Returns (snippets, total_match_count) tuple.
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
|
||||
assert (
|
||||
summary is not None or webvtt is not None
|
||||
), "At least one source must be present"
|
||||
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
if webvtt:
|
||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||
|
||||
if summary:
|
||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||
|
||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||
|
||||
summary_snippets = (
|
||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||
)
|
||||
|
||||
if len(summary_snippets) >= max_total:
|
||||
return summary_snippets[:max_total], total_matches
|
||||
|
||||
remaining = max_total - len(summary_snippets)
|
||||
webvtt_snippets = (
|
||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||
if webvtt
|
||||
else []
|
||||
)
|
||||
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
cls, session: AsyncSession, params: SearchParameters
|
||||
) -> tuple[list[SearchResult], int]:
|
||||
"""
|
||||
Full-text search for transcripts using PostgreSQL tsvector.
|
||||
Returns (results, total_count).
|
||||
"""
|
||||
|
||||
if not is_postgresql():
|
||||
logger.warning(
|
||||
"Full-text search requires PostgreSQL. Returning empty results."
|
||||
base_columns = [
|
||||
TranscriptModel.id,
|
||||
TranscriptModel.title,
|
||||
TranscriptModel.created_at,
|
||||
TranscriptModel.duration,
|
||||
TranscriptModel.status,
|
||||
TranscriptModel.user_id,
|
||||
TranscriptModel.room_id,
|
||||
TranscriptModel.source_kind,
|
||||
TranscriptModel.webvtt,
|
||||
TranscriptModel.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=RoomModel.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
search_query = None
|
||||
if params.query_text is not None:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
return [], 0
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
TranscriptModel.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
else:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = (
|
||||
sqlalchemy.select(*columns)
|
||||
.select_from(TranscriptModel)
|
||||
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
||||
)
|
||||
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
if params.query_text is not None:
|
||||
# because already initialized based on params.query_text presence above
|
||||
assert search_query is not None
|
||||
base_query = base_query.where(
|
||||
TranscriptModel.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
base_query = base_query.where(
|
||||
sqlalchemy.or_(
|
||||
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_query = base_query.where(RoomModel.is_shared)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
TranscriptModel.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await get_database().fetch_all(query)
|
||||
if params.query_text is not None:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(TranscriptModel.created_at)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
result = await session.execute(query)
|
||||
rs = result.mappings().all()
|
||||
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from(
|
||||
base_query.alias("search_results")
|
||||
)
|
||||
total = await get_database().fetch_val(count_query)
|
||||
count_result = await session.execute(count_query)
|
||||
total = count_result.scalar()
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
def _process_result(r: dict) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
webvtt: WebVTTContent | None
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
|
||||
long_summary_r: str | None = r_dict.pop("long_summary", None)
|
||||
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
at_least_one_source = webvtt is not None or long_summary is not None
|
||||
has_query = params.query_text is not None
|
||||
snippets, total_match_count = (
|
||||
SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
if has_query and at_least_one_source
|
||||
else ([], 0)
|
||||
)
|
||||
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
room_name=room_name,
|
||||
search_snippets=snippets,
|
||||
total_match_count=total_match_count,
|
||||
)
|
||||
|
||||
try:
|
||||
results = [_process_result(r) for r in rs]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal search result data consistency error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
webvtt_processor = WebVTTProcessor()
|
||||
snippet_generator = SnippetGenerator()
|
||||
|
||||
@@ -2,22 +2,18 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.sql import false, or_
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import or_
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RoomModel, TranscriptModel
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
@@ -32,93 +28,20 @@ class SourceKind(enum.StrEnum):
|
||||
FILE = enum.auto()
|
||||
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||
sqlalchemy.Column("title", sqlalchemy.String),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("participants", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"audio_location",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="local",
|
||||
),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"share_mode",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"meeting_id",
|
||||
sqlalchemy.String,
|
||||
),
|
||||
sqlalchemy.Column("recording_id", sqlalchemy.String),
|
||||
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
|
||||
sqlalchemy.Column(
|
||||
"source_kind",
|
||||
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
|
||||
nullable=False,
|
||||
),
|
||||
# indicative field: whether associated audio is deleted
|
||||
# the main "audio deleted" is the presence of the audio itself / consents not-given
|
||||
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
|
||||
if is_postgresql():
|
||||
transcripts.append_column(
|
||||
sqlalchemy.Column(
|
||||
"search_vector_en",
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Add GIN index for the search vector
|
||||
transcripts.append_constraint(
|
||||
sqlalchemy.Index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"search_vector_en",
|
||||
postgresql_using="gin",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_transcript_name() -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
TranscriptStatus = Literal[
|
||||
"idle", "uploaded", "recording", "processing", "error", "ended"
|
||||
]
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
@@ -179,10 +102,12 @@ class TranscriptParticipant(BaseModel):
|
||||
class Transcript(BaseModel):
|
||||
"""Full transcript model with all fields."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
status: TranscriptStatus = "idle"
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
title: str | None = None
|
||||
@@ -347,6 +272,7 @@ class Transcript(BaseModel):
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
@@ -371,102 +297,114 @@ class TranscriptController:
|
||||
- `search_term`: filter transcripts by search term
|
||||
"""
|
||||
|
||||
query = transcripts.select().join(
|
||||
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||
query = select(TranscriptModel).join(
|
||||
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
|
||||
)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
|
||||
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
|
||||
)
|
||||
else:
|
||||
query = query.where(rooms.c.is_shared)
|
||||
query = query.where(RoomModel.is_shared)
|
||||
|
||||
if source_kind:
|
||||
query = query.where(transcripts.c.source_kind == source_kind)
|
||||
query = query.where(TranscriptModel.source_kind == source_kind)
|
||||
|
||||
if room_id:
|
||||
query = query.where(transcripts.c.room_id == room_id)
|
||||
query = query.where(TranscriptModel.room_id == room_id)
|
||||
|
||||
if search_term:
|
||||
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
|
||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
||||
|
||||
# Exclude heavy JSON columns from list queries
|
||||
# Get all ORM column attributes except excluded ones
|
||||
transcript_columns = [
|
||||
col for col in transcripts.c if col.name not in exclude_columns
|
||||
getattr(TranscriptModel, col.name)
|
||||
for col in TranscriptModel.__table__.c
|
||||
if col.name not in exclude_columns
|
||||
]
|
||||
|
||||
query = query.with_only_columns(
|
||||
transcript_columns
|
||||
+ [
|
||||
rooms.c.name.label("room_name"),
|
||||
]
|
||||
*transcript_columns,
|
||||
RoomModel.name.label("room_name"),
|
||||
)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
field = getattr(TranscriptModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
query = query.filter(TranscriptModel.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
query = query.filter(TranscriptModel.status != "recording")
|
||||
|
||||
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
result = await session.execute(query)
|
||||
return [dict(row) for row in result.mappings().all()]
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, transcript_id: str, **kwargs
|
||||
) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_recording_id(
|
||||
self, recording_id: str, **kwargs
|
||||
self, session: AsyncSession, recording_id: str, **kwargs
|
||||
) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||
query = select(TranscriptModel).where(
|
||||
TranscriptModel.recording_id == recording_id
|
||||
)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
|
||||
async def get_by_room_id(
|
||||
self, session: AsyncSession, room_id: str, **kwargs
|
||||
) -> list[Transcript]:
|
||||
"""
|
||||
Get transcripts by room_id (direct access without joins)
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.room_id == room_id)
|
||||
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
if "order_by" in kwargs:
|
||||
order_by = kwargs["order_by"]
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
field = getattr(TranscriptModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Transcript(**result) for result in results]
|
||||
results = await session.execute(query)
|
||||
return [
|
||||
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
||||
]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript_id: str,
|
||||
user_id: str | None,
|
||||
) -> Transcript:
|
||||
@@ -479,13 +417,14 @@ class TranscriptController:
|
||||
This method checks the share mode of the transcript and the user_id
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# if the transcript is anonymous, share mode is not checked
|
||||
transcript = Transcript(**result)
|
||||
transcript = Transcript.model_validate(row)
|
||||
if transcript.user_id is None:
|
||||
return transcript
|
||||
|
||||
@@ -508,6 +447,7 @@ class TranscriptController:
|
||||
|
||||
async def add(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
source_kind: SourceKind,
|
||||
source_language: str = "en",
|
||||
@@ -532,14 +472,20 @@ class TranscriptController:
|
||||
meeting_id=meeting_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await get_database().execute(query)
|
||||
query = insert(TranscriptModel).values(**transcript.model_dump())
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return transcript
|
||||
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
# using mutate=True is discouraged
|
||||
async def update(
|
||||
self, transcript: Transcript, values: dict, mutate=False
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
values: dict,
|
||||
commit=True,
|
||||
mutate=False,
|
||||
) -> Transcript:
|
||||
"""
|
||||
Update a transcript fields with key/values in values.
|
||||
@@ -548,11 +494,13 @@ class TranscriptController:
|
||||
values = TranscriptController._handle_topics_update(values)
|
||||
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
update(TranscriptModel)
|
||||
.where(TranscriptModel.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
if commit:
|
||||
await session.commit()
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
@@ -581,13 +529,14 @@ class TranscriptController:
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a transcript by id
|
||||
"""
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
@@ -607,7 +556,7 @@ class TranscriptController:
|
||||
if transcript.recording_id:
|
||||
try:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
transcript.recording_id
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording:
|
||||
try:
|
||||
@@ -618,46 +567,49 @@ class TranscriptController:
|
||||
exc_info=e,
|
||||
recording_id=transcript.recording_id,
|
||||
)
|
||||
await recordings_controller.remove_by_id(transcript.recording_id)
|
||||
await recordings_controller.remove_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete recording row",
|
||||
exc_info=e,
|
||||
recording_id=transcript.recording_id,
|
||||
)
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await get_database().execute(query)
|
||||
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
async def remove_by_recording_id(self, recording_id: str):
|
||||
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
|
||||
"""
|
||||
Remove a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||
await get_database().execute(query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with get_database().transaction(isolation="serializable"):
|
||||
yield
|
||||
query = delete(TranscriptModel).where(
|
||||
TranscriptModel.recording_id == recording_id
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
event: str,
|
||||
data: Any,
|
||||
commit=True,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(transcript, {"events": transcript.events_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"events": transcript.events_dump()}, commit=commit
|
||||
)
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
topic: TranscriptTopic,
|
||||
) -> TranscriptEvent:
|
||||
@@ -665,9 +617,9 @@ class TranscriptController:
|
||||
Upsert topics to a transcript
|
||||
"""
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
await self.update(session, transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
|
||||
"""
|
||||
Move mp3 file to storage
|
||||
"""
|
||||
@@ -691,12 +643,16 @@ class TranscriptController:
|
||||
|
||||
# indicate on the transcript that the audio is now on storage
|
||||
# mutates transcript argument
|
||||
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
|
||||
await self.update(
|
||||
session, transcript, {"audio_location": "storage"}, mutate=True
|
||||
)
|
||||
|
||||
# unlink the local file
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
|
||||
async def download_mp3_from_storage(self, transcript: Transcript):
|
||||
async def download_mp3_from_storage(
|
||||
self, session: AsyncSession, transcript: Transcript
|
||||
):
|
||||
"""
|
||||
Download audio from storage
|
||||
"""
|
||||
@@ -708,6 +664,7 @@ class TranscriptController:
|
||||
|
||||
async def upsert_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
participant: TranscriptParticipant,
|
||||
) -> TranscriptParticipant:
|
||||
@@ -715,11 +672,14 @@ class TranscriptController:
|
||||
Add/update a participant to a transcript
|
||||
"""
|
||||
result = transcript.upsert_participant(participant)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"participants": transcript.participants_dump()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
participant_id: str,
|
||||
):
|
||||
@@ -727,7 +687,38 @@ class TranscriptController:
|
||||
Delete a participant from a transcript
|
||||
"""
|
||||
transcript.delete_participant(participant_id)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"participants": transcript.participants_dump()}
|
||||
)
|
||||
|
||||
async def set_status(
|
||||
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
|
||||
) -> TranscriptEvent | None:
|
||||
"""
|
||||
Update the status of a transcript
|
||||
|
||||
Will add an event STATUS + update the status field of transcript
|
||||
"""
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
if transcript.status == status:
|
||||
return
|
||||
resp = await self.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
commit=False,
|
||||
)
|
||||
await self.update(
|
||||
session,
|
||||
transcript,
|
||||
{"status": status},
|
||||
commit=False,
|
||||
)
|
||||
await session.commit()
|
||||
return resp
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Database utility functions."""
|
||||
|
||||
from reflector.db import get_database
|
||||
|
||||
|
||||
def is_postgresql() -> bool:
|
||||
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||
"postgresql"
|
||||
)
|
||||
445
server/reflector/pipelines/main_file_pipeline.py
Normal file
445
server/reflector/pipelines/main_file_pipeline.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
File-based processing pipeline
|
||||
==============================
|
||||
|
||||
Optimized pipeline for processing complete audio/video files.
|
||||
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
Transcript,
|
||||
TranscriptStatus,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
PipelineMainBase,
|
||||
broadcast_to_sockets,
|
||||
task_cleanup_consent_taskiq,
|
||||
task_pipeline_post_to_zulip_taskiq,
|
||||
)
|
||||
from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.app import taskiq_broker
|
||||
from reflector.worker.session_decorator import catch_exception, with_session
|
||||
from reflector.worker.webhook import send_transcript_webhook_taskiq
|
||||
|
||||
|
||||
class EmptyPipeline:
|
||||
"""Empty pipeline for processors that need a pipeline reference"""
|
||||
|
||||
def __init__(self, logger: structlog.BoundLogger):
|
||||
self.logger = logger
|
||||
|
||||
def get_pref(self, k, d=None):
|
||||
return d
|
||||
|
||||
async def emit(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineMainFile(PipelineMainBase):
|
||||
"""
|
||||
Optimized file processing pipeline.
|
||||
Processes complete audio/video files with parallel execution.
|
||||
"""
|
||||
|
||||
logger: structlog.BoundLogger = None
|
||||
empty_pipeline = None
|
||||
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__(transcript_id=transcript_id)
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
self.logger.error(
|
||||
f"Error in {operation} (task {i}): {result}",
|
||||
transcript_id=self.transcript_id,
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def set_status(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript_id: str,
|
||||
status: TranscriptStatus,
|
||||
):
|
||||
return await transcripts_controller.set_status(session, transcript_id, status)
|
||||
|
||||
async def process(self, session: AsyncSession, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
|
||||
# Upload for processing
|
||||
audio_url = await self.upload_audio(audio_path, transcript)
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
session,
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
transcript.target_language,
|
||||
)
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
) -> Path:
|
||||
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||
self.logger.info(f"Processing audio file: {file_path}")
|
||||
|
||||
# Check if it's already audio-only
|
||||
container = av.open(str(file_path))
|
||||
has_video = len(container.streams.video) > 0
|
||||
container.close()
|
||||
|
||||
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||
mp3_writer = AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
)
|
||||
|
||||
# Process audio frames and write to transcript location
|
||||
input_container = av.open(str(file_path))
|
||||
for frame in input_container.decode(audio=0):
|
||||
await mp3_writer.push(frame)
|
||||
|
||||
await mp3_writer.flush()
|
||||
input_container.close()
|
||||
|
||||
if has_video:
|
||||
self.logger.info(
|
||||
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
|
||||
return transcript.audio_mp3_filename
|
||||
|
||||
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||
"""Upload audio to storage for processing"""
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
if not storage:
|
||||
raise Exception(
|
||||
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||
)
|
||||
|
||||
self.logger.info("Uploading audio to storage")
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||
await storage.put_file(storage_path, audio_data)
|
||||
|
||||
audio_url = await storage.get_file_url(storage_path)
|
||||
|
||||
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||
return audio_url
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
session,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
):
|
||||
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||
self.logger.info(
|
||||
"Starting parallel processing", transcript_id=self.transcript_id
|
||||
)
|
||||
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(session, audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
)
|
||||
|
||||
transcript_result = results[0]
|
||||
diarization_result = results[1]
|
||||
|
||||
# Handle errors - raise any exception that occurred
|
||||
self._handle_gather_exceptions(results, "parallel processing")
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
# Phase 2: Assemble transcript with diarization
|
||||
self.logger.info(
|
||||
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||
)
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
input_data = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_result, diarization=diarization_result or []
|
||||
)
|
||||
|
||||
# Store result for retrieval
|
||||
diarized_transcript: Transcript | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript captured")
|
||||
|
||||
# Phase 3: Generate topics from diarized transcript
|
||||
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||
|
||||
# Phase 4: Generate title and summaries in parallel
|
||||
self.logger.info(
|
||||
"Generating title and summaries", transcript_id=self.transcript_id
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(session, topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
self._handle_gather_exceptions(results, "title and summary generation")
|
||||
|
||||
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||
"""Transcribe complete file"""
|
||||
processor = FileTranscriptAutoProcessor()
|
||||
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||
|
||||
# Store result for retrieval
|
||||
result: TranscriptType | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal result
|
||||
result = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not result:
|
||||
raise ValueError("No transcript captured")
|
||||
|
||||
return result
|
||||
|
||||
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||
"""Get diarization for file"""
|
||||
if not settings.DIARIZATION_BACKEND:
|
||||
self.logger.info("Diarization disabled")
|
||||
return None
|
||||
|
||||
processor = FileDiarizationAutoProcessor()
|
||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||
|
||||
# Store result for retrieval
|
||||
result = None
|
||||
|
||||
async def capture_result(diarization_output):
|
||||
nonlocal result
|
||||
result = diarization_output.diarization
|
||||
|
||||
try:
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def detect_topics(
|
||||
self, transcript: TranscriptType, target_language: str
|
||||
) -> list[TitleSummary]:
|
||||
"""Detect topics from complete transcript"""
|
||||
chunk_size = 300
|
||||
topics: list[TitleSummary] = []
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
topics.append(topic)
|
||||
return await self.on_topic(topic)
|
||||
|
||||
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||
topic_detector.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for i in range(0, len(transcript.words), chunk_size):
|
||||
chunk_words = transcript.words[i : i + chunk_size]
|
||||
if not chunk_words:
|
||||
continue
|
||||
|
||||
chunk_transcript = TranscriptType(
|
||||
words=chunk_words, translation=transcript.translation
|
||||
)
|
||||
|
||||
await topic_detector.push(chunk_transcript)
|
||||
|
||||
await topic_detector.flush()
|
||||
return topics
|
||||
|
||||
async def generate_title(self, topics: list[TitleSummary]):
|
||||
"""Generate title from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for title generation")
|
||||
return
|
||||
|
||||
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
on_short_summary=self.on_short_summary,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session
|
||||
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
await send_transcript_webhook_taskiq.kiq(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@catch_exception
|
||||
@with_session
|
||||
async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
await pipeline.set_status(session, transcript_id, "processing")
|
||||
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
|
||||
if not audio_file:
|
||||
raise Exception("No audio file found to process")
|
||||
|
||||
await pipeline.process(session, audio_file)
|
||||
|
||||
except Exception:
|
||||
logger.error("Error while processing the file", exc_info=True)
|
||||
try:
|
||||
await pipeline.set_status(session, transcript_id, "error")
|
||||
except:
|
||||
logger.error(
|
||||
"Error setting status in task_pipeline_file_process during exception, ignoring it"
|
||||
)
|
||||
raise
|
||||
|
||||
await task_cleanup_consent_taskiq.kiq(transcript_id=transcript_id)
|
||||
await task_pipeline_post_to_zulip_taskiq.kiq(transcript_id=transcript_id)
|
||||
await task_send_webhook_if_needed.kiq(transcript_id=transcript_id)
|
||||
@@ -12,17 +12,16 @@ It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Generic
|
||||
|
||||
import av
|
||||
import boto3
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db import get_session_context
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -32,6 +31,7 @@ from reflector.db.transcripts import (
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptStatus,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
@@ -40,8 +40,9 @@ from reflector.db.transcripts import (
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
@@ -60,6 +61,8 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.app import taskiq_broker
|
||||
from reflector.worker.session_decorator import with_session_and_transcript
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from reflector.zulip import (
|
||||
get_zulip_message,
|
||||
@@ -68,29 +71,6 @@ from reflector.zulip import (
|
||||
)
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
coro = run_with_db()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
@@ -109,59 +89,27 @@ def broadcast_to_sockets(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
|
||||
# Enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
if current_task:
|
||||
tlogger = tlogger.bind(
|
||||
task_id=current_task.request.id,
|
||||
task_name=current_task.name,
|
||||
worker_hostname=current_task.request.hostname,
|
||||
task_retries=current_task.request.retries,
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__()
|
||||
self._lock = asyncio.Lock()
|
||||
self.transcript_id = transcript_id
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
self._ws_manager = None
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
@property
|
||||
def ws_manager(self) -> WebsocketManager:
|
||||
if self._ws_manager is None:
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
|
||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
@@ -184,24 +132,31 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async def lock_transaction(self):
|
||||
# This lock is to prevent multiple processor starting adding
|
||||
# into event array at the same time
|
||||
async with self._lock:
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_session(self):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_context() as session:
|
||||
yield session
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
# if it's the first part, update the status of the transcript
|
||||
# but do not set the ended status yet.
|
||||
if isinstance(self, PipelineMainLive):
|
||||
status_mapping = {
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
"started": "recording",
|
||||
"push": "recording",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainFinalSummaries):
|
||||
status_mapping = {
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
@@ -217,28 +172,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
return
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if status == transcript.status:
|
||||
return
|
||||
resp = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return resp
|
||||
async with self._lock:
|
||||
async with get_session_context() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, self.transcript_id, status
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
@@ -255,10 +200,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
@@ -267,16 +213,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
@@ -285,15 +233,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
@@ -304,15 +254,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
@@ -320,29 +272,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
async with self.locked_session() as session:
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
session, transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
async with self.locked_session() as session:
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
session, transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
@@ -355,15 +308,16 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_context() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(
|
||||
path=transcript.audio_wav_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioChunkerProcessor(),
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
@@ -376,6 +330,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
pipeline.describe()
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -394,7 +349,6 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
@@ -403,7 +357,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_context() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
# diarization works only if the file is uploaded to an external storage
|
||||
if transcript.audio_location == "local":
|
||||
@@ -435,10 +390,9 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
async with get_session_context() as session:
|
||||
self._transcript = transcript = await self.get_transcript(session)
|
||||
|
||||
# create pipeline
|
||||
processors = self.get_processors()
|
||||
@@ -498,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
||||
]
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
|
||||
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
||||
logger.info("Starting remove upload")
|
||||
uploads = transcript.data_path.glob("upload.*")
|
||||
@@ -508,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
logger.info("Remove upload done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting waveform")
|
||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Waveform done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
@@ -543,8 +494,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||
logger.info("No storage backend configured, skipping mp3 upload")
|
||||
return
|
||||
@@ -562,49 +512,49 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_title(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_title(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting title")
|
||||
runner = PipelineMainTitle(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Title done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting consent cleanup")
|
||||
|
||||
consent_denied = False
|
||||
recording = None
|
||||
try:
|
||||
if transcript.recording_id:
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(
|
||||
session, recording.meeting_id
|
||||
)
|
||||
if meeting:
|
||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||
meeting.id
|
||||
session, meeting.id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||
@@ -633,7 +583,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||
|
||||
# non-transactional, files marked for deletion not actually deleted is possible
|
||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
||||
# 2. Delete processed audio from transcript storage S3 bucket
|
||||
if transcript.audio_location == "storage":
|
||||
storage = get_transcripts_storage()
|
||||
@@ -657,15 +607,14 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.info("Consent cleanup done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting post to zulip")
|
||||
|
||||
if not transcript.recording_id:
|
||||
logger.info("Transcript has no recording")
|
||||
return
|
||||
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
@@ -674,12 +623,12 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
if not room:
|
||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||
return
|
||||
@@ -705,7 +654,7 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
room.zulip_stream, room.zulip_topic, message
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"zulip_message_id": response["id"]}
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
|
||||
logger.info("Posted to zulip")
|
||||
@@ -716,92 +665,120 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_remove_upload(*, transcript_id: str):
|
||||
await pipeline_remove_upload(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_remove_upload(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_waveform(*, transcript_id: str):
|
||||
await pipeline_waveform(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_waveform(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_waveform(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_convert_to_mp3(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_upload_mp3(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_diarization(*, transcript_id: str):
|
||||
await pipeline_diarization(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_diarization(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_diarization(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_title(*, transcript_id: str):
|
||||
await pipeline_title(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_title(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_title(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
await pipeline_summaries(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_final_summaries(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_summaries(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_cleanup_consent(*, transcript_id: str):
|
||||
await cleanup_consent(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip(
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
"""
|
||||
Run the post pipeline
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
| task_cleanup_consent.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
|
||||
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||
transcript_id=transcript_id
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent_taskiq(
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip_taskiq(
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
async def pipeline_post(*, transcript_id: str):
|
||||
await task_pipeline_post_sequential.kiq(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_post_sequential(*, transcript_id: str):
|
||||
await task_pipeline_waveform.kiq(transcript_id=transcript_id)
|
||||
await task_pipeline_convert_to_mp3.kiq(transcript_id=transcript_id)
|
||||
await task_pipeline_upload_mp3.kiq(transcript_id=transcript_id)
|
||||
await task_pipeline_remove_upload.kiq(transcript_id=transcript_id)
|
||||
await task_pipeline_diarization.kiq(transcript_id=transcript_id)
|
||||
await task_cleanup_consent.kiq(transcript_id=transcript_id)
|
||||
|
||||
await asyncio.gather(
|
||||
task_pipeline_title.kiq(transcript_id=transcript_id),
|
||||
task_pipeline_final_summaries.kiq(transcript_id=transcript_id),
|
||||
)
|
||||
|
||||
chain = chord(
|
||||
group(chain_mp3_and_diarize, chain_title_preview),
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
|
||||
chain.delay()
|
||||
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
||||
try:
|
||||
if transcript.audio_location == "storage":
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
@@ -839,6 +816,7 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
@@ -849,7 +827,9 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
logger.info("Pipeline ended")
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_process(*, transcript_id: str):
|
||||
return await pipeline_process(transcript_id=transcript_id)
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_process(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
return await pipeline_process(session, transcript=transcript, logger=logger)
|
||||
|
||||
@@ -18,22 +18,14 @@ During its lifecycle, it will emit the following status:
|
||||
import asyncio
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
PipelineMessage = TypeVar("PipelineMessage")
|
||||
|
||||
|
||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
class PipelineRunner(Generic[PipelineMessage]):
|
||||
def __init__(self):
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
@@ -42,6 +34,8 @@ class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
self.status = "idle"
|
||||
self.pipeline: Pipeline | None = None
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_chunker_auto import AudioChunkerAutoProcessor # noqa: F401
|
||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
@@ -11,6 +13,13 @@ from .base import ( # noqa: F401
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||
from .transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||
)
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
|
||||
@@ -1,28 +1,78 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioChunkerProcessor(Processor):
|
||||
"""
|
||||
Assemble audio frames into chunks
|
||||
Base class for assembling audio frames into chunks
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = list[av.AudioFrame]
|
||||
|
||||
def __init__(self, max_frames=256):
|
||||
super().__init__()
|
||||
m_chunk = Histogram(
|
||||
"audio_chunker",
|
||||
"Time spent in AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_call = Counter(
|
||||
"audio_chunker_call",
|
||||
"Number of calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_success = Counter(
|
||||
"audio_chunker_success",
|
||||
"Number of successful calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_failure = Counter(
|
||||
"audio_chunker_failure",
|
||||
"Number of failed calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_chunk = self.m_chunk.labels(name)
|
||||
self.m_chunk_call = self.m_chunk_call.labels(name)
|
||||
self.m_chunk_success = self.m_chunk_success.labels(name)
|
||||
self.m_chunk_failure = self.m_chunk_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.frames: list[av.AudioFrame] = []
|
||||
self.max_frames = max_frames
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
self.frames.append(data)
|
||||
if len(self.frames) >= self.max_frames:
|
||||
await self.flush()
|
||||
"""Process incoming audio frame"""
|
||||
# Validate audio format on first frame
|
||||
if len(self.frames) == 0:
|
||||
if data.sample_rate != 16000 or len(data.layout.channels) != 1:
|
||||
raise ValueError(
|
||||
f"AudioChunkerProcessor expects 16kHz mono audio, got {data.sample_rate}Hz "
|
||||
f"with {len(data.layout.channels)} channel(s). "
|
||||
f"Use AudioDownscaleProcessor before this processor."
|
||||
)
|
||||
|
||||
try:
|
||||
self.m_chunk_call.inc()
|
||||
with self.m_chunk.time():
|
||||
result = await self._chunk(data)
|
||||
self.m_chunk_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_chunk_failure.inc()
|
||||
raise
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
"""
|
||||
Process audio frame and return chunk when ready.
|
||||
Subclasses should implement their chunking logic here.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
await self.emit(frames)
|
||||
"""Flush any remaining frames when processing ends"""
|
||||
raise NotImplementedError
|
||||
|
||||
32
server/reflector/processors/audio_chunker_auto.py
Normal file
32
server/reflector/processors/audio_chunker_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioChunkerAutoProcessor(AudioChunkerProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.AUDIO_CHUNKER_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_chunker_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `AUDIO_CHUNKER_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "AUDIO_CHUNKER_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
34
server/reflector/processors/audio_chunker_frames.py
Normal file
34
server/reflector/processors/audio_chunker_frames.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||
|
||||
|
||||
class AudioChunkerFramesProcessor(AudioChunkerProcessor):
|
||||
"""
|
||||
Simple frame-based audio chunker that emits chunks after a fixed number of frames
|
||||
"""
|
||||
|
||||
def __init__(self, max_frames=256, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_frames = max_frames
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
self.frames.append(data)
|
||||
if len(self.frames) >= self.max_frames:
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
return frames_to_emit
|
||||
|
||||
return None
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
await self.emit(frames)
|
||||
|
||||
|
||||
AudioChunkerAutoProcessor.register("frames", AudioChunkerFramesProcessor)
|
||||
298
server/reflector/processors/audio_chunker_silero.py
Normal file
298
server/reflector/processors/audio_chunker_silero.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||
|
||||
|
||||
class AudioChunkerSileroProcessor(AudioChunkerProcessor):
|
||||
"""
|
||||
Assemble audio frames into chunks with VAD-based speech detection using Silero VAD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_frames=256,
|
||||
max_frames=1024,
|
||||
use_onnx=True,
|
||||
min_frames=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.block_frames = block_frames
|
||||
self.max_frames = max_frames
|
||||
self.min_frames = min_frames
|
||||
|
||||
# Initialize Silero VAD
|
||||
self._init_vad(use_onnx)
|
||||
|
||||
def _init_vad(self, use_onnx=False):
|
||||
"""Initialize Silero VAD model"""
|
||||
try:
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||
self.logger.info("Silero VAD initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||
self.vad_model = None
|
||||
self.vad_iterator = None
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
"""Process audio frame and return chunk when ready"""
|
||||
self.frames.append(data)
|
||||
|
||||
# Check for speech segments every 32 frames (~1 second)
|
||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||
return await self._process_block()
|
||||
|
||||
# Safety fallback - emit if we hit max frames
|
||||
elif len(self.frames) >= self.max_frames:
|
||||
self.logger.warning(
|
||||
f"AudioChunkerSileroProcessor: Reached max frames ({self.max_frames}), "
|
||||
f"emitting first {self.max_frames // 2} frames"
|
||||
)
|
||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||
self.frames = self.frames[self.max_frames // 2 :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _process_block(self) -> Optional[list[av.AudioFrame]]:
|
||||
# Need at least 32 frames for VAD detection (~1 second)
|
||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||
return None
|
||||
|
||||
# Processing block with current buffer size
|
||||
print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||
|
||||
try:
|
||||
# Convert frames to numpy array for VAD
|
||||
audio_array = self._frames_to_numpy(self.frames)
|
||||
|
||||
if audio_array is None:
|
||||
# Fallback: emit all frames if conversion failed
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
return None
|
||||
|
||||
# Find complete speech segments in the buffer
|
||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||
|
||||
if speech_end_frame is None or speech_end_frame <= 0:
|
||||
# No speech found but buffer is getting large
|
||||
if len(self.frames) > 512:
|
||||
# Check if it's all silence and can be discarded
|
||||
# No speech segment found, buffer at {len(self.frames)} frames
|
||||
|
||||
# Could emit silence or discard old frames here
|
||||
# For now, keep first 256 frames and discard older silence
|
||||
if len(self.frames) > 768:
|
||||
self.logger.debug(
|
||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||
)
|
||||
self.frames = self.frames[-256:]
|
||||
return None
|
||||
|
||||
# Calculate segment timing information
|
||||
frames_to_emit = self.frames[:speech_end_frame]
|
||||
|
||||
# Get timing from av.AudioFrame
|
||||
if frames_to_emit:
|
||||
first_frame = frames_to_emit[0]
|
||||
last_frame = frames_to_emit[-1]
|
||||
sample_rate = first_frame.sample_rate
|
||||
|
||||
# Calculate duration
|
||||
total_samples = sum(f.samples for f in frames_to_emit)
|
||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||
|
||||
# Get timestamps if available
|
||||
start_time = (
|
||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||
)
|
||||
end_time = (
|
||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||
)
|
||||
|
||||
# Convert to HH:MM:SS format for logging
|
||||
def format_time(seconds):
|
||||
if not seconds:
|
||||
return "00:00:00"
|
||||
total_seconds = int(float(seconds))
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds % 3600) // 60
|
||||
secs = total_seconds % 60
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
remaining_after = len(self.frames) - speech_end_frame
|
||||
|
||||
# Single structured log line
|
||||
self.logger.info(
|
||||
"Speech segment found",
|
||||
start=start_formatted,
|
||||
end=end_formatted,
|
||||
frames=speech_end_frame,
|
||||
duration=round(duration_seconds, 2),
|
||||
buffer_before=len(self.frames),
|
||||
remaining=remaining_after,
|
||||
)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
self.frames = self.frames[speech_end_frame:]
|
||||
|
||||
# Filter out segments with too few frames
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in VAD processing: {e}")
|
||||
# Fallback to simple chunking
|
||||
if len(self.frames) >= self.block_frames:
|
||||
frames_to_emit = self.frames[: self.block_frames]
|
||||
self.frames = self.frames[self.block_frames :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||
if not frames:
|
||||
return None
|
||||
|
||||
try:
|
||||
audio_data = []
|
||||
for frame in frames:
|
||||
frame_array = frame.to_ndarray()
|
||||
|
||||
if len(frame_array.shape) == 2:
|
||||
frame_array = frame_array.flatten()
|
||||
|
||||
audio_data.append(frame_array)
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
combined_audio = np.concatenate(audio_data)
|
||||
|
||||
# Ensure float32 format
|
||||
if combined_audio.dtype == np.int16:
|
||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||
elif combined_audio.dtype != np.float32:
|
||||
combined_audio = combined_audio.astype(np.float32)
|
||||
|
||||
return combined_audio
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||
"""Find complete speech segments and return frame index at segment end"""
|
||||
if self.vad_iterator is None or len(audio_array) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process audio in 512-sample windows for VAD
|
||||
window_size = 512
|
||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||
|
||||
# Track speech state
|
||||
in_speech = False
|
||||
speech_start = None
|
||||
speech_end = None
|
||||
silence_count = 0
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||
|
||||
# Detect if this window has speech
|
||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||
|
||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||
if speech_dict:
|
||||
if not in_speech:
|
||||
# Speech started
|
||||
speech_start = i
|
||||
in_speech = True
|
||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||
silence_count = 0 # Reset silence counter
|
||||
continue
|
||||
|
||||
if not in_speech:
|
||||
continue
|
||||
|
||||
# We're in speech but found silence
|
||||
silence_count += 1
|
||||
if silence_count < min_silence_windows:
|
||||
continue
|
||||
|
||||
# Found end of speech segment
|
||||
speech_end = i - (min_silence_windows - 1) * window_size
|
||||
# Debug: print(f"Speech END at sample {speech_end}")
|
||||
|
||||
# Convert sample position to frame index
|
||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||
frame_index = speech_end // samples_per_frame
|
||||
|
||||
# Ensure we don't exceed buffer
|
||||
frame_index = min(frame_index, len(self.frames))
|
||||
return frame_index
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding speech segment: {e}")
|
||||
return None
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
if len(frames) >= self.min_frames:
|
||||
await self.emit(frames)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring flush segment with {len(frames)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
|
||||
AudioChunkerAutoProcessor.register("silero", AudioChunkerSileroProcessor)
|
||||
@@ -1,6 +1,7 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
@@ -37,18 +38,21 @@ class AudioDiarizationProcessor(Processor):
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
|
||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
self._diarization_remove_overlap(diarization)
|
||||
self._diarization_remove_segment_without_words(words, diarization)
|
||||
self._diarization_merge_same_speaker(words, diarization)
|
||||
self._diarization_assign_speaker(words, diarization)
|
||||
@classmethod
|
||||
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||
cls._diarization_remove_overlap(diarization)
|
||||
cls._diarization_remove_segment_without_words(words, diarization)
|
||||
cls._diarization_merge_same_speaker(diarization)
|
||||
cls._diarization_assign_speaker(words, diarization)
|
||||
|
||||
def iter_words_from_topics(self, topics: TitleSummary):
|
||||
@staticmethod
|
||||
def iter_words_from_topics(topics: list[TitleSummary]):
|
||||
for topic in topics:
|
||||
for word in topic.transcript.words:
|
||||
yield word
|
||||
|
||||
def is_word_continuation(self, word_prev, word):
|
||||
@staticmethod
|
||||
def is_word_continuation(word_prev, word):
|
||||
"""
|
||||
Return True if the word is a continuation of the previous word
|
||||
by checking if the previous word is ending with a punctuation
|
||||
@@ -61,7 +65,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||
@staticmethod
|
||||
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||
"""
|
||||
Remove overlap in diarization results
|
||||
|
||||
@@ -86,8 +91,9 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_segment_without_words(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
"""
|
||||
Remove diarization segments without words
|
||||
@@ -116,9 +122,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_merge_same_speaker(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
@staticmethod
|
||||
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||
"""
|
||||
Merge diarization contigous segments with the same speaker
|
||||
|
||||
@@ -135,7 +140,10 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
@classmethod
|
||||
def _diarization_assign_speaker(
|
||||
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
"""
|
||||
Assign speaker to words based on diarization
|
||||
|
||||
@@ -143,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
"""
|
||||
|
||||
word_idx = 0
|
||||
last_speaker = None
|
||||
last_speaker = 0
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
@@ -158,7 +166,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
# If it's a continuation, assign with the last speaker
|
||||
is_continuation = False
|
||||
if word_idx > 0 and word_idx < len(words) - 1:
|
||||
is_continuation = self.is_word_continuation(
|
||||
is_continuation = cls.is_word_continuation(
|
||||
*words[word_idx - 1 : word_idx + 1]
|
||||
)
|
||||
if is_continuation:
|
||||
|
||||
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
74
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
"""Local diarization processor using pyannote.audio library"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||
pyannote_auth_token: str | None = None,
|
||||
device: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.model_name = model_name
|
||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||
self.device = device
|
||||
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
self.model_name, use_auth_token=self.auth_token
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||
try:
|
||||
# Load audio file (audio_url is assumed to be a local file path)
|
||||
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||
self.logger.info("Running speaker diarization")
|
||||
diarization = self.diarization_pipeline(audio_input)
|
||||
|
||||
# Convert pyannote diarization output to our format
|
||||
segments = []
|
||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||
speaker_id = 0
|
||||
if speaker.startswith("SPEAKER_"):
|
||||
try:
|
||||
speaker_id = int(speaker.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback to hash-based ID if parsing fails
|
||||
speaker_id = hash(speaker) % 1000
|
||||
|
||||
segments.append(
|
||||
{
|
||||
"start": round(segment.start, 3),
|
||||
"end": round(segment.end, 3),
|
||||
"speaker": speaker_id,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||
return segments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Diarization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
60
server/reflector/processors/audio_downscale.py
Normal file
60
server/reflector/processors/audio_downscale.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
from av.audio.resampler import AudioResampler
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||
frame_copy = frame.from_ndarray(
|
||||
frame.to_ndarray(),
|
||||
format=frame.format.name,
|
||||
layout=frame.layout.name,
|
||||
)
|
||||
frame_copy.sample_rate = frame.sample_rate
|
||||
frame_copy.pts = frame.pts
|
||||
frame_copy.time_base = frame.time_base
|
||||
return frame_copy
|
||||
|
||||
|
||||
class AudioDownscaleProcessor(Processor):
|
||||
"""
|
||||
Downscale audio frames to 16kHz mono format
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, target_rate: int = 16000, target_layout: str = "mono", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.target_rate = target_rate
|
||||
self.target_layout = target_layout
|
||||
self.resampler: Optional[AudioResampler] = None
|
||||
self.needs_resampling: Optional[bool] = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if self.needs_resampling is None:
|
||||
self.needs_resampling = (
|
||||
data.sample_rate != self.target_rate
|
||||
or data.layout.name != self.target_layout
|
||||
)
|
||||
|
||||
if self.needs_resampling:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16", layout=self.target_layout, rate=self.target_rate
|
||||
)
|
||||
|
||||
if not self.needs_resampling or not self.resampler:
|
||||
await self.emit(data)
|
||||
return
|
||||
|
||||
resampled_frames = self.resampler.resample(copy_frame(data))
|
||||
for resampled_frame in resampled_frames:
|
||||
await self.emit(resampled_frame)
|
||||
|
||||
async def _flush(self):
|
||||
if self.needs_resampling and self.resampler:
|
||||
final_frames = self.resampler.resample(None)
|
||||
for frame in final_frames:
|
||||
await self.emit(frame)
|
||||
@@ -16,37 +16,46 @@ class AudioMergeProcessor(Processor):
|
||||
INPUT_TYPE = list[av.AudioFrame]
|
||||
OUTPUT_TYPE = AudioFile
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def _push(self, data: list[av.AudioFrame]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# get audio information from first frame
|
||||
frame = data[0]
|
||||
channels = len(frame.layout.channels)
|
||||
sample_rate = frame.sample_rate
|
||||
sample_width = frame.format.bytes
|
||||
output_channels = len(frame.layout.channels)
|
||||
output_sample_rate = frame.sample_rate
|
||||
output_sample_width = frame.format.bytes
|
||||
|
||||
# create audio file
|
||||
uu = uuid4().hex
|
||||
fd = io.BytesIO()
|
||||
|
||||
# Use PyAV to write frames
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||
out_stream.layout = frame.layout.name
|
||||
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
|
||||
fd.seek(0)
|
||||
|
||||
# emit audio file
|
||||
audiofile = AudioFile(
|
||||
name=f"{monotonic_ns()}-{uu}.wav",
|
||||
fd=fd,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
sample_width=sample_width,
|
||||
sample_rate=output_sample_rate,
|
||||
channels=output_channels,
|
||||
sample_width=output_sample_width,
|
||||
timestamp=data[0].pts * data[0].time_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,11 @@ from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
modal_api_key: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
|
||||
@@ -173,6 +173,7 @@ class Processor(Emitter):
|
||||
except Exception:
|
||||
self.m_processor_failure.inc()
|
||||
self.logger.exception("Error in push")
|
||||
raise
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
@@ -240,33 +241,45 @@ class ThreadedProcessor(Processor):
|
||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.queue = asyncio.Queue()
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
self.queue = asyncio.Queue(maxsize=50)
|
||||
self.task: asyncio.Task | None = None
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
async def loop(self):
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||
|
||||
async def _ensure_task(self):
|
||||
if self.task is None:
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||
# to be running without having a chance to work on the task here.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def _push(self, data):
|
||||
await self._ensure_task()
|
||||
await self.queue.put(data)
|
||||
|
||||
async def _flush(self):
|
||||
|
||||
33
server/reflector/processors/file_diarization.py
Normal file
33
server/reflector/processors/file_diarization.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
|
||||
class FileDiarizationInput(BaseModel):
|
||||
"""Input for file diarization containing audio URL"""
|
||||
|
||||
audio_url: str
|
||||
|
||||
|
||||
class FileDiarizationOutput(BaseModel):
|
||||
"""Output for file diarization containing speaker segments"""
|
||||
|
||||
diarization: list[DiarizationSegment]
|
||||
|
||||
|
||||
class FileDiarizationProcessor(Processor):
|
||||
"""
|
||||
Diarize complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileDiarizationInput
|
||||
OUTPUT_TYPE = FileDiarizationOutput
|
||||
|
||||
async def _push(self, data: FileDiarizationInput):
|
||||
result = await self._diarize(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
raise NotImplementedError
|
||||
33
server/reflector/processors/file_diarization_auto.py
Normal file
33
server/reflector/processors/file_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
58
server/reflector/processors/file_diarization_modal.py
Normal file
58
server/reflector/processors/file_diarization_modal.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
File diarization implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to DIARIZATION_URL:
|
||||
|
||||
```
|
||||
POST /diarize?audio_file_url=...×tamp=0
|
||||
Authorization: Bearer <modal_api_key>
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.DIARIZATION_URL:
|
||||
raise Exception(
|
||||
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||
)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Get speaker diarization for file"""
|
||||
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
diarization_data = response.json()["diarization"]
|
||||
|
||||
return FileDiarizationOutput(diarization=diarization_data)
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||
65
server/reflector/processors/file_transcript.py
Normal file
65
server/reflector/processors/file_transcript.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript
|
||||
|
||||
|
||||
class FileTranscriptInput:
|
||||
"""Input for file transcription containing audio URL and language settings"""
|
||||
|
||||
def __init__(self, audio_url: str, language: str = "en"):
|
||||
self.audio_url = audio_url
|
||||
self.language = language
|
||||
|
||||
|
||||
class FileTranscriptProcessor(Processor):
|
||||
"""
|
||||
Transcript complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileTranscriptInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
m_transcript = Histogram(
|
||||
"file_transcript",
|
||||
"Time spent in FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_call = Counter(
|
||||
"file_transcript_call",
|
||||
"Number of calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_success = Counter(
|
||||
"file_transcript_success",
|
||||
"Number of successful calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_failure = Counter(
|
||||
"file_transcript_failure",
|
||||
"Number of failed calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_transcript = self.m_transcript.labels(name)
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: FileTranscriptInput):
|
||||
try:
|
||||
self.m_transcript_call.inc()
|
||||
with self.m_transcript.time():
|
||||
result = await self._transcript(data)
|
||||
self.m_transcript_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
raise NotImplementedError
|
||||
32
server/reflector/processors/file_transcript_auto.py
Normal file
32
server/reflector/processors/file_transcript_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
78
server/reflector/processors/file_transcript_modal.py
Normal file
78
server/reflector/processors/file_transcript_modal.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
File transcription implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://...",
|
||||
"language": "en",
|
||||
"model": "parakeet-tdt-0.6b-v2",
|
||||
"batch": true
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||
)
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Send full file to Modal for transcription"""
|
||||
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||
|
||||
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={
|
||||
"audio_file_url": data.audio_url,
|
||||
"language": data.language,
|
||||
"batch": True,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
words = [
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
# words come not in order
|
||||
words.sort(key=lambda w: w.start)
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# Register with the auto processor
|
||||
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Processor to assemble transcript with diarization results
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment, Transcript
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerInput:
|
||||
"""Input containing transcript and diarization data"""
|
||||
|
||||
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||
self.transcript = transcript
|
||||
self.diarization = diarization
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||
"""
|
||||
Assemble transcript with diarization results by applying speaker assignments
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||
result = await self._assemble(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||
"""Apply diarization to transcript words"""
|
||||
if not data.diarization:
|
||||
self.logger.info(
|
||||
"No diarization data provided, returning original transcript"
|
||||
)
|
||||
return data.transcript
|
||||
|
||||
# Reuse logic from AudioDiarizationProcessor
|
||||
processor = AudioDiarizationProcessor()
|
||||
words = data.transcript.words
|
||||
processor.assign_speaker(words, data.diarization)
|
||||
|
||||
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||
return data.transcript
|
||||
@@ -2,18 +2,21 @@ import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
class DiarizationSegment(TypedDict):
|
||||
"""Type definition for diarization segment containing speaker information"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
profanity_filter.set_censor("*")
|
||||
|
||||
|
||||
class AudioFile(BaseModel):
|
||||
name: str
|
||||
@@ -115,21 +118,11 @@ def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
|
||||
|
||||
class Transcript(BaseModel):
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
def raw_text(self):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
words: list[Word] = []
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return self._get_censored_text(self.raw_text)
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
@@ -161,12 +154,6 @@ class Transcript(BaseModel):
|
||||
word.start += offset
|
||||
word.end += offset
|
||||
|
||||
def clone(self):
|
||||
words = [
|
||||
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
||||
]
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
return words_to_segments(self.words)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user