mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Compare commits
5 Commits
mathieu/sq
...
mathieu/ca
| Author | SHA1 | Date | |
|---|---|---|---|
| 311d453e41 | |||
| f286f0882c | |||
| ffcafb3bf2 | |||
| 27075d840c | |||
| 30b5cd45e3 |
5
.github/workflows/db_migrations.yml
vendored
5
.github/workflows/db_migrations.yml
vendored
@@ -2,8 +2,6 @@ name: Test Database Migrations
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
paths:
|
||||||
- "server/migrations/**"
|
- "server/migrations/**"
|
||||||
- "server/reflector/db/**"
|
- "server/reflector/db/**"
|
||||||
@@ -19,9 +17,6 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
test-migrations:
|
test-migrations:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
concurrency:
|
|
||||||
group: db-ubuntu-latest-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:17
|
image: postgres:17
|
||||||
|
|||||||
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,30 +8,18 @@ env:
|
|||||||
ECR_REPOSITORY: reflector
|
ECR_REPOSITORY: reflector
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
deploy:
|
||||||
strategy:
|
runs-on: ubuntu-latest
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- platform: linux/amd64
|
|
||||||
runner: linux-amd64
|
|
||||||
arch: amd64
|
|
||||||
- platform: linux/arm64
|
|
||||||
runner: linux-arm64
|
|
||||||
arch: arm64
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.runner }}
|
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
|
deployments: write
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
outputs:
|
|
||||||
registry: ${{ steps.login-ecr.outputs.registry }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@v4
|
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
||||||
with:
|
with:
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
@@ -39,52 +27,21 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Amazon ECR
|
- name: Login to Amazon ECR
|
||||||
id: login-ecr
|
id: login-ecr
|
||||||
uses: aws-actions/amazon-ecr-login@v2
|
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v2
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v2
|
||||||
|
|
||||||
- name: Build and push ${{ matrix.arch }}
|
- name: Build and push
|
||||||
uses: docker/build-push-action@v5
|
id: docker_build
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: ${{ matrix.platform }}
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
||||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
|
||||||
provenance: false
|
|
||||||
|
|
||||||
create-manifest:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [build]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
deployments: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v4
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
|
|
||||||
- name: Login to Amazon ECR
|
|
||||||
uses: aws-actions/amazon-ecr-login@v2
|
|
||||||
|
|
||||||
- name: Create and push multi-arch manifest
|
|
||||||
run: |
|
|
||||||
# Get the registry URL (since we can't easily access job outputs in matrix)
|
|
||||||
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
|
||||||
|
|
||||||
docker manifest create \
|
|
||||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
|
||||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
|
||||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
|
||||||
|
|
||||||
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
|
||||||
|
|
||||||
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
|
||||||
|
|||||||
45
.github/workflows/test_next_server.yml
vendored
45
.github/workflows/test_next_server.yml
vendored
@@ -1,45 +0,0 @@
|
|||||||
name: Test Next Server
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "www/**"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- "www/**"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test-next-server:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
working-directory: ./www
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '20'
|
|
||||||
|
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
version: 8
|
|
||||||
|
|
||||||
- name: Setup Node.js cache
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '20'
|
|
||||||
cache: 'pnpm'
|
|
||||||
cache-dependency-path: './www/pnpm-lock.yaml'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: pnpm install
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: pnpm test
|
|
||||||
49
.github/workflows/test_server.yml
vendored
49
.github/workflows/test_server.yml
vendored
@@ -5,17 +5,12 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- "server/**"
|
- "server/**"
|
||||||
push:
|
push:
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
paths:
|
||||||
- "server/**"
|
- "server/**"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pytest:
|
pytest:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
concurrency:
|
|
||||||
group: pytest-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
services:
|
services:
|
||||||
redis:
|
redis:
|
||||||
image: redis:6
|
image: redis:6
|
||||||
@@ -24,47 +19,29 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v3
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
working-directory: server
|
working-directory: server
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: |
|
run: |
|
||||||
cd server
|
cd server
|
||||||
uv run -m pytest -v tests
|
uv run -m pytest -v tests
|
||||||
|
|
||||||
docker-amd64:
|
docker:
|
||||||
runs-on: linux-amd64
|
runs-on: ubuntu-latest
|
||||||
concurrency:
|
|
||||||
group: docker-amd64-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v2
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Build AMD64
|
- name: Build and push
|
||||||
uses: docker/build-push-action@v6
|
id: docker_build
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: server
|
context: server
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64,linux/arm64
|
||||||
cache-from: type=gha,scope=amd64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=amd64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
|
||||||
|
|
||||||
docker-arm64:
|
|
||||||
runs-on: linux-arm64
|
|
||||||
concurrency:
|
|
||||||
group: docker-arm64-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Build ARM64
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
context: server
|
|
||||||
platforms: linux/arm64
|
|
||||||
cache-from: type=gha,scope=arm64
|
|
||||||
cache-to: type=gha,mode=max,scope=arm64
|
|
||||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,6 +15,3 @@ www/REFACTOR.md
|
|||||||
www/reload-frontend
|
www/reload-frontend
|
||||||
server/test.sqlite
|
server/test.sqlite
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
www/.env.development
|
|
||||||
www/.env.production
|
|
||||||
.playwright-mcp
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
b9d891d3424f371642cb032ecfd0e2564470a72c:server/tests/test_transcripts_recording_deletion.py:generic-api-key:15
|
|
||||||
@@ -27,8 +27,3 @@ repos:
|
|||||||
files: ^server/
|
files: ^server/
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
files: ^server/
|
files: ^server/
|
||||||
|
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
|
||||||
rev: v8.28.0
|
|
||||||
hooks:
|
|
||||||
- id: gitleaks
|
|
||||||
|
|||||||
153
CHANGELOG.md
153
CHANGELOG.md
@@ -1,158 +1,5 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* TypeError on not all arguments converted during string formatting in logger ([#667](https://github.com/Monadical-SAS/reflector/issues/667)) ([565a629](https://github.com/Monadical-SAS/reflector/commit/565a62900f5a02fc946b68f9269a42190ed70ab6))
|
|
||||||
|
|
||||||
## [0.13.0](https://github.com/Monadical-SAS/reflector/compare/v0.12.1...v0.13.0) (2025-09-19)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* room form edit with enter ([#662](https://github.com/Monadical-SAS/reflector/issues/662)) ([47716f6](https://github.com/Monadical-SAS/reflector/commit/47716f6e5ddee952609d2fa0ffabdfa865286796))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* invalid cleanup call ([#660](https://github.com/Monadical-SAS/reflector/issues/660)) ([0abcebf](https://github.com/Monadical-SAS/reflector/commit/0abcebfc9491f87f605f21faa3e53996fafedd9a))
|
|
||||||
|
|
||||||
## [0.12.1](https://github.com/Monadical-SAS/reflector/compare/v0.12.0...v0.12.1) (2025-09-17)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* production blocked because having existing meeting with room_id null ([#657](https://github.com/Monadical-SAS/reflector/issues/657)) ([870e860](https://github.com/Monadical-SAS/reflector/commit/870e8605171a27155a9cbee215eeccb9a8d6c0a2))
|
|
||||||
|
|
||||||
## [0.12.0](https://github.com/Monadical-SAS/reflector/compare/v0.11.0...v0.12.0) (2025-09-17)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* calendar integration ([#608](https://github.com/Monadical-SAS/reflector/issues/608)) ([6f680b5](https://github.com/Monadical-SAS/reflector/commit/6f680b57954c688882c4ed49f40f161c52a00a24))
|
|
||||||
* self-hosted gpu api ([#636](https://github.com/Monadical-SAS/reflector/issues/636)) ([ab859d6](https://github.com/Monadical-SAS/reflector/commit/ab859d65a6bded904133a163a081a651b3938d42))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* ignore player hotkeys for text inputs ([#646](https://github.com/Monadical-SAS/reflector/issues/646)) ([fa049e8](https://github.com/Monadical-SAS/reflector/commit/fa049e8d068190ce7ea015fd9fcccb8543f54a3f))
|
|
||||||
|
|
||||||
## [0.11.0](https://github.com/Monadical-SAS/reflector/compare/v0.10.0...v0.11.0) (2025-09-16)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* remove profanity filter that was there for conference ([#652](https://github.com/Monadical-SAS/reflector/issues/652)) ([b42f7cf](https://github.com/Monadical-SAS/reflector/commit/b42f7cfc606783afcee792590efcc78b507468ab))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* zulip and consent handler on the file pipeline ([#645](https://github.com/Monadical-SAS/reflector/issues/645)) ([5f143fe](https://github.com/Monadical-SAS/reflector/commit/5f143fe3640875dcb56c26694254a93189281d17))
|
|
||||||
* zulip stream and topic selection in share dialog ([#644](https://github.com/Monadical-SAS/reflector/issues/644)) ([c546e69](https://github.com/Monadical-SAS/reflector/commit/c546e69739e68bb74fbc877eb62609928e5b8de6))
|
|
||||||
|
|
||||||
## [0.10.0](https://github.com/Monadical-SAS/reflector/compare/v0.9.0...v0.10.0) (2025-09-11)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* replace nextjs-config with environment variables ([#632](https://github.com/Monadical-SAS/reflector/issues/632)) ([369ecdf](https://github.com/Monadical-SAS/reflector/commit/369ecdff13f3862d926a9c0b87df52c9d94c4dde))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* anonymous users transcript permissions ([#621](https://github.com/Monadical-SAS/reflector/issues/621)) ([f81fe99](https://github.com/Monadical-SAS/reflector/commit/f81fe9948a9237b3e0001b2d8ca84f54d76878f9))
|
|
||||||
* auth post ([#624](https://github.com/Monadical-SAS/reflector/issues/624)) ([cde99ca](https://github.com/Monadical-SAS/reflector/commit/cde99ca2716f84ba26798f289047732f0448742e))
|
|
||||||
* auth post ([#626](https://github.com/Monadical-SAS/reflector/issues/626)) ([3b85ff3](https://github.com/Monadical-SAS/reflector/commit/3b85ff3bdf4fb053b103070646811bc990c0e70a))
|
|
||||||
* auth post ([#627](https://github.com/Monadical-SAS/reflector/issues/627)) ([962038e](https://github.com/Monadical-SAS/reflector/commit/962038ee3f2a555dc3c03856be0e4409456e0996))
|
|
||||||
* missing follow_redirects=True on modal endpoint ([#630](https://github.com/Monadical-SAS/reflector/issues/630)) ([fc363bd](https://github.com/Monadical-SAS/reflector/commit/fc363bd49b17b075e64f9186e5e0185abc325ea7))
|
|
||||||
* sync backend and frontend token refresh logic ([#614](https://github.com/Monadical-SAS/reflector/issues/614)) ([5a5b323](https://github.com/Monadical-SAS/reflector/commit/5a5b3233820df9536da75e87ce6184a983d4713a))
|
|
||||||
|
|
||||||
## [0.9.0](https://github.com/Monadical-SAS/reflector/compare/v0.8.2...v0.9.0) (2025-09-06)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* frontend openapi react query ([#606](https://github.com/Monadical-SAS/reflector/issues/606)) ([c4d2825](https://github.com/Monadical-SAS/reflector/commit/c4d2825c81f81ad8835629fbf6ea8c7383f8c31b))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* align whisper transcriber api with parakeet ([#602](https://github.com/Monadical-SAS/reflector/issues/602)) ([0663700](https://github.com/Monadical-SAS/reflector/commit/0663700a615a4af69a03c96c410f049e23ec9443))
|
|
||||||
* kv use tls explicit ([#610](https://github.com/Monadical-SAS/reflector/issues/610)) ([08d88ec](https://github.com/Monadical-SAS/reflector/commit/08d88ec349f38b0d13e0fa4cb73486c8dfd31836))
|
|
||||||
* source kind for file processing ([#601](https://github.com/Monadical-SAS/reflector/issues/601)) ([dc82f8b](https://github.com/Monadical-SAS/reflector/commit/dc82f8bb3bdf3ab3d4088e592a30fd63907319e1))
|
|
||||||
* token refresh locking ([#613](https://github.com/Monadical-SAS/reflector/issues/613)) ([7f5a4c9](https://github.com/Monadical-SAS/reflector/commit/7f5a4c9ddc7fd098860c8bdda2ca3b57f63ded2f))
|
|
||||||
|
|
||||||
## [0.8.2](https://github.com/Monadical-SAS/reflector/compare/v0.8.1...v0.8.2) (2025-08-29)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* search-logspam ([#593](https://github.com/Monadical-SAS/reflector/issues/593)) ([695d1a9](https://github.com/Monadical-SAS/reflector/commit/695d1a957d4cd862753049f9beed88836cabd5ab))
|
|
||||||
|
|
||||||
## [0.8.1](https://github.com/Monadical-SAS/reflector/compare/v0.8.0...v0.8.1) (2025-08-29)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* make webhook secret/url allowing null ([#590](https://github.com/Monadical-SAS/reflector/issues/590)) ([84a3812](https://github.com/Monadical-SAS/reflector/commit/84a381220bc606231d08d6f71d4babc818fa3c75))
|
|
||||||
|
|
||||||
## [0.8.0](https://github.com/Monadical-SAS/reflector/compare/v0.7.3...v0.8.0) (2025-08-29)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* **cleanup:** add automatic data retention for public instances ([#574](https://github.com/Monadical-SAS/reflector/issues/574)) ([6f0c7c1](https://github.com/Monadical-SAS/reflector/commit/6f0c7c1a5e751713366886c8e764c2009e12ba72))
|
|
||||||
* **rooms:** add webhook for transcript completion ([#578](https://github.com/Monadical-SAS/reflector/issues/578)) ([88ed7cf](https://github.com/Monadical-SAS/reflector/commit/88ed7cfa7804794b9b54cad4c3facc8a98cf85fd))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* file pipeline status reporting and websocket updates ([#589](https://github.com/Monadical-SAS/reflector/issues/589)) ([9dfd769](https://github.com/Monadical-SAS/reflector/commit/9dfd76996f851cc52be54feea078adbc0816dc57))
|
|
||||||
* Igor/evaluation ([#575](https://github.com/Monadical-SAS/reflector/issues/575)) ([124ce03](https://github.com/Monadical-SAS/reflector/commit/124ce03bf86044c18313d27228a25da4bc20c9c5))
|
|
||||||
* optimize parakeet transcription batching algorithm ([#577](https://github.com/Monadical-SAS/reflector/issues/577)) ([7030e0f](https://github.com/Monadical-SAS/reflector/commit/7030e0f23649a8cf6c1eb6d5889684a41ce849ec))
|
|
||||||
|
|
||||||
## [0.7.3](https://github.com/Monadical-SAS/reflector/compare/v0.7.2...v0.7.3) (2025-08-22)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* cleaned repo, and get git-leaks clean ([359280d](https://github.com/Monadical-SAS/reflector/commit/359280dd340433ba4402ed69034094884c825e67))
|
|
||||||
* restore previous behavior on live pipeline + audio downscaler ([#561](https://github.com/Monadical-SAS/reflector/issues/561)) ([9265d20](https://github.com/Monadical-SAS/reflector/commit/9265d201b590d23c628c5f19251b70f473859043))
|
|
||||||
|
|
||||||
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
|
|
||||||
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
|
|
||||||
|
|
||||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
|
||||||
|
|
||||||
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
|
||||||
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
|
||||||
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
|
||||||
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
|
||||||
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
|
||||||
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
|
||||||
|
|
||||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ pnpm install
|
|||||||
|
|
||||||
# Copy configuration templates
|
# Copy configuration templates
|
||||||
cp .env_template .env
|
cp .env_template .env
|
||||||
|
cp config-template.ts config.ts
|
||||||
```
|
```
|
||||||
|
|
||||||
**Development:**
|
**Development:**
|
||||||
|
|||||||
497
ICS_IMPLEMENTATION.md
Normal file
497
ICS_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
# ICS Calendar Integration - Implementation Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This document provides detailed implementation guidance for integrating ICS calendar feeds with Reflector rooms. Unlike CalDAV which requires complex authentication and protocol handling, ICS integration uses simple HTTP(S) fetching of calendar files.
|
||||||
|
|
||||||
|
## Key Differences from CalDAV Approach
|
||||||
|
|
||||||
|
| Aspect | CalDAV | ICS |
|
||||||
|
|--------|--------|-----|
|
||||||
|
| Protocol | WebDAV extension | HTTP/HTTPS GET |
|
||||||
|
| Authentication | Username/password, OAuth | Tokens embedded in URL |
|
||||||
|
| Data Access | Selective event queries | Full calendar download |
|
||||||
|
| Implementation | Complex (caldav library) | Simple (requests + icalendar) |
|
||||||
|
| Real-time Updates | Supported | Polling only |
|
||||||
|
| Write Access | Yes | No (read-only) |
|
||||||
|
|
||||||
|
## Technical Architecture
|
||||||
|
|
||||||
|
### 1. ICS Fetching Service
|
||||||
|
|
||||||
|
```python
|
||||||
|
# reflector/services/ics_sync.py
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from icalendar import Calendar
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
class ICSFetchService:
|
||||||
|
def __init__(self):
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers.update({'User-Agent': 'Reflector/1.0'})
|
||||||
|
|
||||||
|
def fetch_ics(self, url: str) -> str:
|
||||||
|
"""Fetch ICS file from URL (authentication via URL token if needed)."""
|
||||||
|
response = self.session.get(url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
def parse_ics(self, ics_content: str) -> Calendar:
|
||||||
|
"""Parse ICS content into calendar object."""
|
||||||
|
return Calendar.from_ical(ics_content)
|
||||||
|
|
||||||
|
def extract_room_events(self, calendar: Calendar, room_url: str) -> List[dict]:
|
||||||
|
"""Extract events that match the room URL."""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
for component in calendar.walk():
|
||||||
|
if component.name == "VEVENT":
|
||||||
|
# Check if event matches this room
|
||||||
|
if self._event_matches_room(component, room_url):
|
||||||
|
events.append(self._parse_event(component))
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _event_matches_room(self, event, room_url: str) -> bool:
|
||||||
|
"""Check if event location or description contains room URL."""
|
||||||
|
location = str(event.get('LOCATION', ''))
|
||||||
|
description = str(event.get('DESCRIPTION', ''))
|
||||||
|
|
||||||
|
# Support various URL formats
|
||||||
|
patterns = [
|
||||||
|
room_url,
|
||||||
|
room_url.replace('https://', ''),
|
||||||
|
room_url.split('/')[-1], # Just room name
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
if pattern in location or pattern in description:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Database Schema
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Modify room table
|
||||||
|
ALTER TABLE room ADD COLUMN ics_url TEXT; -- encrypted to protect embedded tokens
|
||||||
|
ALTER TABLE room ADD COLUMN ics_fetch_interval INTEGER DEFAULT 300; -- seconds
|
||||||
|
ALTER TABLE room ADD COLUMN ics_enabled BOOLEAN DEFAULT FALSE;
|
||||||
|
ALTER TABLE room ADD COLUMN ics_last_sync TIMESTAMP;
|
||||||
|
ALTER TABLE room ADD COLUMN ics_last_etag TEXT; -- for caching
|
||||||
|
|
||||||
|
-- Calendar events table
|
||||||
|
CREATE TABLE calendar_event (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
room_id UUID REFERENCES room(id) ON DELETE CASCADE,
|
||||||
|
external_id TEXT NOT NULL, -- ICS UID
|
||||||
|
title TEXT,
|
||||||
|
description TEXT,
|
||||||
|
start_time TIMESTAMP NOT NULL,
|
||||||
|
end_time TIMESTAMP NOT NULL,
|
||||||
|
attendees JSONB,
|
||||||
|
location TEXT,
|
||||||
|
ics_raw_data TEXT, -- Store raw VEVENT for reference
|
||||||
|
last_synced TIMESTAMP DEFAULT NOW(),
|
||||||
|
is_deleted BOOLEAN DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMP DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP DEFAULT NOW(),
|
||||||
|
UNIQUE(room_id, external_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Index for efficient queries
|
||||||
|
CREATE INDEX idx_calendar_event_room_start ON calendar_event(room_id, start_time);
|
||||||
|
CREATE INDEX idx_calendar_event_deleted ON calendar_event(is_deleted) WHERE NOT is_deleted;
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Background Tasks
|
||||||
|
|
||||||
|
```python
|
||||||
|
# reflector/worker/tasks/ics_sync.py
|
||||||
|
|
||||||
|
from celery import shared_task
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def sync_ics_calendars():
|
||||||
|
"""Sync all enabled ICS calendars based on their fetch intervals."""
|
||||||
|
rooms = Room.query.filter_by(ics_enabled=True).all()
|
||||||
|
|
||||||
|
for room in rooms:
|
||||||
|
# Check if it's time to sync based on fetch interval
|
||||||
|
if should_sync(room):
|
||||||
|
sync_room_calendar.delay(room.id)
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def sync_room_calendar(room_id: str):
|
||||||
|
"""Sync calendar for a specific room."""
|
||||||
|
room = Room.query.get(room_id)
|
||||||
|
if not room or not room.ics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch ICS file (decrypt URL first)
|
||||||
|
service = ICSFetchService()
|
||||||
|
decrypted_url = decrypt_ics_url(room.ics_url)
|
||||||
|
ics_content = service.fetch_ics(decrypted_url)
|
||||||
|
|
||||||
|
# Check if content changed (using ETag or hash)
|
||||||
|
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||||
|
if room.ics_last_etag == content_hash:
|
||||||
|
logger.info(f"No changes in ICS for room {room_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Parse and extract events
|
||||||
|
calendar = service.parse_ics(ics_content)
|
||||||
|
events = service.extract_room_events(calendar, room.url)
|
||||||
|
|
||||||
|
# Update database
|
||||||
|
sync_events_to_database(room_id, events)
|
||||||
|
|
||||||
|
# Update sync metadata
|
||||||
|
room.ics_last_sync = datetime.utcnow()
|
||||||
|
room.ics_last_etag = content_hash
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sync ICS for room {room_id}: {e}")
|
||||||
|
|
||||||
|
def should_sync(room) -> bool:
|
||||||
|
"""Check if room calendar should be synced."""
|
||||||
|
if not room.ics_last_sync:
|
||||||
|
return True
|
||||||
|
|
||||||
|
time_since_sync = datetime.utcnow() - room.ics_last_sync
|
||||||
|
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Celery Beat Schedule
|
||||||
|
|
||||||
|
```python
|
||||||
|
# reflector/worker/celeryconfig.py
|
||||||
|
|
||||||
|
from celery.schedules import crontab
|
||||||
|
|
||||||
|
beat_schedule = {
|
||||||
|
'sync-ics-calendars': {
|
||||||
|
'task': 'reflector.worker.tasks.ics_sync.sync_ics_calendars',
|
||||||
|
'schedule': 60.0, # Check every minute which calendars need syncing
|
||||||
|
},
|
||||||
|
'pre-create-meetings': {
|
||||||
|
'task': 'reflector.worker.tasks.ics_sync.pre_create_calendar_meetings',
|
||||||
|
'schedule': 60.0, # Check every minute for upcoming meetings
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Room ICS Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# PATCH /v1/rooms/{room_id}
|
||||||
|
{
|
||||||
|
"ics_url": "https://calendar.google.com/calendar/ical/.../private-token/basic.ics",
|
||||||
|
"ics_fetch_interval": 300, # seconds
|
||||||
|
"ics_enabled": true
|
||||||
|
# URL will be encrypted in database to protect embedded tokens
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Sync Trigger
|
||||||
|
|
||||||
|
```python
|
||||||
|
# POST /v1/rooms/{room_name}/ics/sync
|
||||||
|
# Response:
|
||||||
|
{
|
||||||
|
"status": "syncing",
|
||||||
|
"last_sync": "2024-01-15T10:30:00Z",
|
||||||
|
"events_found": 5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### ICS Status
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GET /v1/rooms/{room_name}/ics/status
|
||||||
|
# Response:
|
||||||
|
{
|
||||||
|
"enabled": true,
|
||||||
|
"last_sync": "2024-01-15T10:30:00Z",
|
||||||
|
"next_sync": "2024-01-15T10:35:00Z",
|
||||||
|
"fetch_interval": 300,
|
||||||
|
"events_count": 12,
|
||||||
|
"upcoming_events": 3
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## ICS Parsing Details
|
||||||
|
|
||||||
|
### Event Field Mapping
|
||||||
|
|
||||||
|
| ICS Field | Database Field | Notes |
|
||||||
|
|-----------|---------------|-------|
|
||||||
|
| UID | external_id | Unique identifier |
|
||||||
|
| SUMMARY | title | Event title |
|
||||||
|
| DESCRIPTION | description | Full description |
|
||||||
|
| DTSTART | start_time | Convert to UTC |
|
||||||
|
| DTEND | end_time | Convert to UTC |
|
||||||
|
| LOCATION | location | Check for room URL |
|
||||||
|
| ATTENDEE | attendees | Parse into JSON |
|
||||||
|
| ORGANIZER | attendees | Add as organizer |
|
||||||
|
| STATUS | (internal) | Filter cancelled events |
|
||||||
|
|
||||||
|
### Handling Recurring Events
|
||||||
|
|
||||||
|
```python
|
||||||
|
def expand_recurring_events(event, start_date, end_date):
|
||||||
|
"""Expand recurring events into individual occurrences."""
|
||||||
|
from dateutil.rrule import rrulestr
|
||||||
|
|
||||||
|
if 'RRULE' not in event:
|
||||||
|
return [event]
|
||||||
|
|
||||||
|
# Parse recurrence rule
|
||||||
|
rrule_str = event['RRULE'].to_ical().decode()
|
||||||
|
dtstart = event['DTSTART'].dt
|
||||||
|
|
||||||
|
# Generate occurrences
|
||||||
|
rrule = rrulestr(rrule_str, dtstart=dtstart)
|
||||||
|
occurrences = []
|
||||||
|
|
||||||
|
for dt in rrule.between(start_date, end_date):
|
||||||
|
# Clone event with new date
|
||||||
|
occurrence = event.copy()
|
||||||
|
occurrence['DTSTART'].dt = dt
|
||||||
|
if 'DTEND' in event:
|
||||||
|
duration = event['DTEND'].dt - event['DTSTART'].dt
|
||||||
|
occurrence['DTEND'].dt = dt + duration
|
||||||
|
|
||||||
|
# Unique ID for each occurrence
|
||||||
|
occurrence['UID'] = f"{event['UID']}_{dt.isoformat()}"
|
||||||
|
occurrences.append(occurrence)
|
||||||
|
|
||||||
|
return occurrences
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timezone Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
def normalize_datetime(dt):
|
||||||
|
"""Convert various datetime formats to UTC."""
|
||||||
|
import pytz
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if hasattr(dt, 'dt'): # icalendar property
|
||||||
|
dt = dt.dt
|
||||||
|
|
||||||
|
if isinstance(dt, datetime):
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
# Assume local timezone if naive
|
||||||
|
dt = pytz.timezone('UTC').localize(dt)
|
||||||
|
else:
|
||||||
|
# Convert to UTC
|
||||||
|
dt = dt.astimezone(pytz.UTC)
|
||||||
|
|
||||||
|
return dt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
### 1. URL Validation
|
||||||
|
|
||||||
|
```python
|
||||||
|
def validate_ics_url(url: str) -> bool:
|
||||||
|
"""Validate ICS URL for security."""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Must be HTTPS in production
|
||||||
|
if not settings.DEBUG and parsed.scheme != 'https':
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Prevent local file access
|
||||||
|
if parsed.scheme in ('file', 'ftp'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Prevent internal network access
|
||||||
|
if is_internal_ip(parsed.hostname):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Rate Limiting
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Implement per-room rate limiting
|
||||||
|
RATE_LIMITS = {
|
||||||
|
'min_fetch_interval': 60, # Minimum 1 minute between fetches
|
||||||
|
'max_requests_per_hour': 60, # Max 60 requests per hour per room
|
||||||
|
'max_file_size': 10 * 1024 * 1024, # Max 10MB ICS file
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. ICS URL Encryption
|
||||||
|
|
||||||
|
```python
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
class URLEncryption:
|
||||||
|
def __init__(self):
|
||||||
|
self.cipher = Fernet(settings.ENCRYPTION_KEY)
|
||||||
|
|
||||||
|
def encrypt_url(self, url: str) -> str:
|
||||||
|
"""Encrypt ICS URL to protect embedded tokens."""
|
||||||
|
return self.cipher.encrypt(url.encode()).decode()
|
||||||
|
|
||||||
|
def decrypt_url(self, encrypted: str) -> str:
|
||||||
|
"""Decrypt ICS URL for fetching."""
|
||||||
|
return self.cipher.decrypt(encrypted.encode()).decode()
|
||||||
|
|
||||||
|
def mask_url(self, url: str) -> str:
|
||||||
|
"""Mask sensitive parts of URL for display."""
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
# Keep scheme, host, and path structure but mask tokens
|
||||||
|
if '/private-' in parsed.path:
|
||||||
|
# Google Calendar format
|
||||||
|
parts = parsed.path.split('/private-')
|
||||||
|
masked_path = parts[0] + '/private-***' + parts[1].split('/')[-1]
|
||||||
|
elif 'token=' in url:
|
||||||
|
# Query parameter token
|
||||||
|
masked_path = parsed.path
|
||||||
|
parsed = parsed._replace(query='token=***')
|
||||||
|
else:
|
||||||
|
# Generic masking of path segments that look like tokens
|
||||||
|
import re
|
||||||
|
masked_path = re.sub(r'/[a-zA-Z0-9]{20,}/', '/***/', parsed.path)
|
||||||
|
|
||||||
|
return urlunparse(parsed._replace(path=masked_path))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### 1. Unit Tests
|
||||||
|
|
||||||
|
```python
|
||||||
|
# tests/test_ics_sync.py
|
||||||
|
|
||||||
|
def test_ics_parsing():
|
||||||
|
"""Test ICS file parsing."""
|
||||||
|
ics_content = """BEGIN:VCALENDAR
|
||||||
|
VERSION:2.0
|
||||||
|
BEGIN:VEVENT
|
||||||
|
UID:test-123
|
||||||
|
SUMMARY:Team Meeting
|
||||||
|
LOCATION:https://reflector.monadical.com/engineering
|
||||||
|
DTSTART:20240115T100000Z
|
||||||
|
DTEND:20240115T110000Z
|
||||||
|
END:VEVENT
|
||||||
|
END:VCALENDAR"""
|
||||||
|
|
||||||
|
service = ICSFetchService()
|
||||||
|
calendar = service.parse_ics(ics_content)
|
||||||
|
events = service.extract_room_events(
|
||||||
|
calendar,
|
||||||
|
"https://reflector.monadical.com/engineering"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]['title'] == 'Team Meeting'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Integration Tests
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_full_sync_flow():
|
||||||
|
"""Test complete sync workflow."""
|
||||||
|
# Create room with ICS URL (encrypt URL to protect tokens)
|
||||||
|
encryption = URLEncryption()
|
||||||
|
room = Room(
|
||||||
|
name="test-room",
|
||||||
|
ics_url=encryption.encrypt_url("https://example.com/calendar.ics?token=secret"),
|
||||||
|
ics_enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock ICS fetch
|
||||||
|
with patch('requests.get') as mock_get:
|
||||||
|
mock_get.return_value.text = sample_ics_content
|
||||||
|
|
||||||
|
# Run sync
|
||||||
|
sync_room_calendar(room.id)
|
||||||
|
|
||||||
|
# Verify events created
|
||||||
|
events = CalendarEvent.query.filter_by(room_id=room.id).all()
|
||||||
|
assert len(events) > 0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common ICS Provider Configurations
|
||||||
|
|
||||||
|
### Google Calendar
|
||||||
|
- URL Format: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||||
|
- Authentication via token embedded in URL
|
||||||
|
- Updates every 3-8 hours by default
|
||||||
|
|
||||||
|
### Outlook/Office 365
|
||||||
|
- URL Format: `https://outlook.office365.com/owa/calendar/{id}/calendar.ics`
|
||||||
|
- May include token in URL path or query parameters
|
||||||
|
- Real-time updates
|
||||||
|
|
||||||
|
### Apple iCloud
|
||||||
|
- URL Format: `webcal://p{XX}-caldav.icloud.com/published/2/{token}`
|
||||||
|
- Convert webcal:// to https://
|
||||||
|
- Token embedded in URL path
|
||||||
|
- Public calendars only
|
||||||
|
|
||||||
|
### Nextcloud/ownCloud
|
||||||
|
- URL Format: `https://cloud.example.com/remote.php/dav/public-calendars/{token}`
|
||||||
|
- Token embedded in URL path
|
||||||
|
- Configurable update frequency
|
||||||
|
|
||||||
|
## Migration from CalDAV
|
||||||
|
|
||||||
|
If migrating from an existing CalDAV implementation:
|
||||||
|
|
||||||
|
1. **Database Migration**: Rename fields from `caldav_*` to `ics_*`
|
||||||
|
2. **URL Conversion**: Most CalDAV servers provide ICS export endpoints
|
||||||
|
3. **Authentication**: Convert from username/password to URL-embedded tokens
|
||||||
|
4. **Remove Dependencies**: Uninstall caldav library, add icalendar
|
||||||
|
5. **Update Background Tasks**: Replace CalDAV sync with ICS fetch
|
||||||
|
|
||||||
|
## Performance Optimizations
|
||||||
|
|
||||||
|
1. **Caching**: Use ETag/Last-Modified headers to avoid refetching unchanged calendars
|
||||||
|
2. **Incremental Sync**: Store last sync timestamp, only process new/modified events
|
||||||
|
3. **Batch Processing**: Process multiple room calendars in parallel
|
||||||
|
4. **Connection Pooling**: Reuse HTTP connections for multiple requests
|
||||||
|
5. **Compression**: Support gzip encoding for large ICS files
|
||||||
|
|
||||||
|
## Monitoring and Debugging
|
||||||
|
|
||||||
|
### Metrics to Track
|
||||||
|
- Sync success/failure rate per room
|
||||||
|
- Average sync duration
|
||||||
|
- ICS file sizes
|
||||||
|
- Number of events processed
|
||||||
|
- Failed event matches
|
||||||
|
|
||||||
|
### Debug Logging
|
||||||
|
```python
|
||||||
|
logger.debug(f"Fetching ICS from {room.ics_url}")
|
||||||
|
logger.debug(f"ICS content size: {len(ics_content)} bytes")
|
||||||
|
logger.debug(f"Found {len(events)} matching events")
|
||||||
|
logger.debug(f"Event UIDs: {[e['external_id'] for e in events]}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
1. **SSL Certificate Errors**: Add certificate validation options
|
||||||
|
2. **Timeout Issues**: Increase timeout for large calendars
|
||||||
|
3. **Encoding Problems**: Handle various character encodings
|
||||||
|
4. **Timezone Mismatches**: Always convert to UTC
|
||||||
|
5. **Memory Issues**: Stream large ICS files instead of loading entirely
|
||||||
337
PLAN.md
Normal file
337
PLAN.md
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
# ICS Calendar Integration Plan
|
||||||
|
|
||||||
|
## Core Concept
|
||||||
|
ICS calendar URLs are attached to rooms (not users) to enable automatic meeting tracking and management through periodic fetching of calendar data.
|
||||||
|
|
||||||
|
## Database Schema Updates
|
||||||
|
|
||||||
|
### 1. Add ICS configuration to rooms
|
||||||
|
- Add `ics_url` field to room table (URL to .ics file, may include auth token)
|
||||||
|
- Add `ics_fetch_interval` field to room table (default: 5 minutes, configurable)
|
||||||
|
- Add `ics_enabled` boolean field to room table
|
||||||
|
- Add `ics_last_sync` timestamp field to room table
|
||||||
|
|
||||||
|
### 2. Create calendar_events table
|
||||||
|
- `id` - UUID primary key
|
||||||
|
- `room_id` - Foreign key to room
|
||||||
|
- `external_id` - ICS event UID
|
||||||
|
- `title` - Event title
|
||||||
|
- `description` - Event description
|
||||||
|
- `start_time` - Event start timestamp
|
||||||
|
- `end_time` - Event end timestamp
|
||||||
|
- `attendees` - JSON field with attendee list and status
|
||||||
|
- `location` - Meeting location (should contain room name)
|
||||||
|
- `last_synced` - Last sync timestamp
|
||||||
|
- `is_deleted` - Boolean flag for soft delete (preserve past events)
|
||||||
|
- `ics_raw_data` - TEXT field to store raw VEVENT data for reference
|
||||||
|
|
||||||
|
### 3. Update meeting table
|
||||||
|
- Add `calendar_event_id` - Foreign key to calendar_events
|
||||||
|
- Add `calendar_metadata` - JSON field for additional calendar data
|
||||||
|
- Remove unique constraint on room_id + active status (allow multiple active meetings per room)
|
||||||
|
|
||||||
|
## Backend Implementation
|
||||||
|
|
||||||
|
### 1. ICS Sync Service
|
||||||
|
- Create background task that runs based on room's `ics_fetch_interval` (default: 5 minutes)
|
||||||
|
- For each room with ICS enabled, fetch the .ics file via HTTP/HTTPS
|
||||||
|
- Parse ICS file using icalendar library
|
||||||
|
- Extract VEVENT components and filter events looking for room URL (e.g., "https://reflector.monadical.com/max")
|
||||||
|
- Store matching events in calendar_events table
|
||||||
|
- Mark events as "upcoming" if start_time is within next 30 minutes
|
||||||
|
- Pre-create Whereby meetings 1 minute before start (ensures no delay when users join)
|
||||||
|
- Soft-delete future events that were removed from calendar (set is_deleted=true)
|
||||||
|
- Never delete past events (preserve for historical record)
|
||||||
|
- Support authenticated ICS feeds via tokens embedded in URL
|
||||||
|
|
||||||
|
### 2. Meeting Management Updates
|
||||||
|
- Allow multiple active meetings per room
|
||||||
|
- Pre-create meeting record 1 minute before calendar event starts (ensures meeting is ready)
|
||||||
|
- Link meeting to calendar_event for metadata
|
||||||
|
- Keep meeting active for 15 minutes after last participant leaves (grace period)
|
||||||
|
- Don't auto-close if new participant joins within grace period
|
||||||
|
|
||||||
|
### 3. API Endpoints
|
||||||
|
- `GET /v1/rooms/{room_name}/meetings` - List all active and upcoming meetings for a room
|
||||||
|
- Returns filtered data based on user role (owner vs participant)
|
||||||
|
- `GET /v1/rooms/{room_name}/meetings/upcoming` - List upcoming meetings (next 30 min)
|
||||||
|
- Returns filtered data based on user role
|
||||||
|
- `POST /v1/rooms/{room_name}/meetings/{meeting_id}/join` - Join specific meeting
|
||||||
|
- `PATCH /v1/rooms/{room_id}` - Update room settings (including ICS configuration)
|
||||||
|
- ICS fields only visible/editable by room owner
|
||||||
|
- `POST /v1/rooms/{room_name}/ics/sync` - Trigger manual ICS sync
|
||||||
|
- Only accessible by room owner
|
||||||
|
- `GET /v1/rooms/{room_name}/ics/status` - Get ICS sync status and last fetch time
|
||||||
|
- Only accessible by room owner
|
||||||
|
|
||||||
|
## Frontend Implementation
|
||||||
|
|
||||||
|
### 1. Room Settings Page
|
||||||
|
- Add ICS configuration section
|
||||||
|
- Field for ICS URL (e.g., Google Calendar private URL, Outlook ICS export)
|
||||||
|
- Field for fetch interval (dropdown: 1 min, 5 min, 10 min, 30 min, 1 hour)
|
||||||
|
- Test connection button (validates ICS file can be fetched and parsed)
|
||||||
|
- Manual sync button
|
||||||
|
- Show last sync time and next scheduled sync
|
||||||
|
|
||||||
|
### 2. Meeting Selection Page (New)
|
||||||
|
- Show when accessing `/room/{room_name}`
|
||||||
|
- **Host view** (room owner):
|
||||||
|
- Full calendar event details
|
||||||
|
- Meeting title and description
|
||||||
|
- Complete attendee list with RSVP status
|
||||||
|
- Number of current participants
|
||||||
|
- Duration (how long it's been running)
|
||||||
|
- **Participant view** (non-owners):
|
||||||
|
- Meeting title only
|
||||||
|
- Date and time
|
||||||
|
- Number of current participants
|
||||||
|
- Duration (how long it's been running)
|
||||||
|
- No attendee list or description (privacy)
|
||||||
|
- Display upcoming meetings (visible 30min before):
|
||||||
|
- Show countdown to start
|
||||||
|
- Can click to join early → redirected to waiting page
|
||||||
|
- Waiting page shows countdown until meeting starts
|
||||||
|
- Meeting pre-created by background task (ready when users arrive)
|
||||||
|
- Option to create unscheduled meeting (uses existing flow)
|
||||||
|
|
||||||
|
### 3. Meeting Room Updates
|
||||||
|
- Show calendar metadata in meeting info
|
||||||
|
- Display invited attendees vs actual participants
|
||||||
|
- Show meeting title from calendar event
|
||||||
|
|
||||||
|
## Meeting Lifecycle
|
||||||
|
|
||||||
|
### 1. Meeting Creation
|
||||||
|
- Automatic: Pre-created 1 minute before calendar event starts (ensures Whereby room is ready)
|
||||||
|
- Manual: User creates unscheduled meeting (existing `/rooms/{room_name}/meeting` endpoint)
|
||||||
|
- Background task handles pre-creation to avoid delays when users join
|
||||||
|
|
||||||
|
### 2. Meeting Join Rules
|
||||||
|
- Can join active meetings immediately
|
||||||
|
- Can see upcoming meetings 30 minutes before start
|
||||||
|
- Can click to join upcoming meetings early → sent to waiting page
|
||||||
|
- Waiting page automatically transitions to meeting at scheduled time
|
||||||
|
- Unscheduled meetings always joinable (current behavior)
|
||||||
|
|
||||||
|
### 3. Meeting Closure Rules
|
||||||
|
- All meetings: 15-minute grace period after last participant leaves
|
||||||
|
- If participant rejoins within grace period, keep meeting active
|
||||||
|
- Calendar meetings: Force close 30 minutes after scheduled end time
|
||||||
|
- Unscheduled meetings: Keep active for 8 hours (current behavior)
|
||||||
|
|
||||||
|
## ICS Parsing Logic
|
||||||
|
|
||||||
|
### 1. Event Matching
|
||||||
|
- Parse ICS file using Python icalendar library
|
||||||
|
- Iterate through VEVENT components
|
||||||
|
- Check LOCATION field for full FQDN URL (e.g., "https://reflector.monadical.com/max")
|
||||||
|
- Check DESCRIPTION for room URL or mention
|
||||||
|
- Support multiple formats:
|
||||||
|
- Full URL: "https://reflector.monadical.com/max"
|
||||||
|
- With /room path: "https://reflector.monadical.com/room/max"
|
||||||
|
- Partial paths: "room/max", "/max room"
|
||||||
|
|
||||||
|
### 2. Attendee Extraction
|
||||||
|
- Parse ATTENDEE properties from VEVENT
|
||||||
|
- Extract email (MAILTO), name (CN parameter), and RSVP status (PARTSTAT)
|
||||||
|
- Store as JSON in calendar_events.attendees
|
||||||
|
|
||||||
|
### 3. Sync Strategy
|
||||||
|
- Fetch complete ICS file (contains all events)
|
||||||
|
- Filter events from (now - 1 hour) to (now + 24 hours) for processing
|
||||||
|
- Update existing events if LAST-MODIFIED or SEQUENCE changed
|
||||||
|
- Delete future events that no longer exist in ICS (start_time > now)
|
||||||
|
- Keep past events for historical record (never delete if start_time < now)
|
||||||
|
- Handle recurring events (RRULE) - expand to individual instances
|
||||||
|
- Track deleted calendar events to clean up future meetings
|
||||||
|
- Cache ICS file hash to detect changes and skip unnecessary processing
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
### 1. ICS URL Security
|
||||||
|
- ICS URLs may contain authentication tokens (e.g., Google Calendar private URLs)
|
||||||
|
- Store full ICS URLs encrypted using Fernet to protect embedded tokens
|
||||||
|
- Validate ICS URLs (must be HTTPS for production)
|
||||||
|
- Never expose full ICS URLs in API responses (return masked version)
|
||||||
|
- Rate limit ICS fetching to prevent abuse
|
||||||
|
|
||||||
|
### 2. Room Access
|
||||||
|
- Only room owner can configure ICS URL
|
||||||
|
- ICS URL shown as masked version to room owner (hides embedded tokens)
|
||||||
|
- ICS settings not visible to other users
|
||||||
|
- Meeting list visible to all room participants
|
||||||
|
- ICS fetch logs only visible to room owner
|
||||||
|
|
||||||
|
### 3. Meeting Privacy
|
||||||
|
- Full calendar details visible only to room owner
|
||||||
|
- Participants see limited info: title, date/time only
|
||||||
|
- Attendee list and description hidden from non-owners
|
||||||
|
- Meeting titles visible in room listing to all
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Database and ICS Setup (Week 1) ✅ COMPLETED (2025-08-18)
|
||||||
|
1. ✅ Created database migrations for ICS fields and calendar_events table
|
||||||
|
- Added ics_url, ics_fetch_interval, ics_enabled, ics_last_sync, ics_last_etag to room table
|
||||||
|
- Created calendar_event table with ics_uid (instead of external_id) and proper typing
|
||||||
|
- Added calendar_event_id and calendar_metadata (JSONB) to meeting table
|
||||||
|
- Removed server_default from datetime fields for consistency
|
||||||
|
2. ✅ Installed icalendar Python library for ICS parsing
|
||||||
|
- Added icalendar>=6.0.0 to dependencies
|
||||||
|
- No encryption needed - ICS URLs are read-only
|
||||||
|
3. ✅ Built ICS fetch and sync service
|
||||||
|
- Simple HTTP fetching without unnecessary validation
|
||||||
|
- Proper TypedDict typing for event data structures
|
||||||
|
- Supports any standard ICS format
|
||||||
|
- Event matching on full room URL only
|
||||||
|
4. ✅ API endpoints for ICS configuration
|
||||||
|
- Room model updated to support ICS fields via existing PATCH endpoint
|
||||||
|
- POST /v1/rooms/{room_name}/ics/sync - Trigger manual sync (owner only)
|
||||||
|
- GET /v1/rooms/{room_name}/ics/status - Get sync status (owner only)
|
||||||
|
- GET /v1/rooms/{room_name}/meetings - List meetings with privacy controls
|
||||||
|
- GET /v1/rooms/{room_name}/meetings/upcoming - List upcoming meetings
|
||||||
|
5. ✅ Celery background tasks for periodic sync
|
||||||
|
- sync_room_ics - Sync individual room calendar
|
||||||
|
- sync_all_ics_calendars - Check all rooms and queue sync based on fetch intervals
|
||||||
|
- pre_create_upcoming_meetings - Pre-create Whereby meetings 1 minute before start
|
||||||
|
- Tasks scheduled in beat schedule (every minute for checking, respects individual intervals)
|
||||||
|
6. ✅ Tests written and passing
|
||||||
|
- 6 tests for Room ICS fields
|
||||||
|
- 7 tests for CalendarEvent model
|
||||||
|
- 7 tests for ICS sync service
|
||||||
|
- 11 tests for API endpoints
|
||||||
|
- 6 tests for background tasks
|
||||||
|
- All 31 ICS-related tests passing
|
||||||
|
|
||||||
|
### Phase 2: Meeting Management (Week 2) ✅ COMPLETED (2025-08-19)
|
||||||
|
1. ✅ Updated meeting lifecycle logic with grace period support
|
||||||
|
- 15-minute grace period after last participant leaves
|
||||||
|
- Automatic reactivation when participants rejoin
|
||||||
|
- Force close calendar meetings 30 minutes after scheduled end
|
||||||
|
2. ✅ Support multiple active meetings per room
|
||||||
|
- Removed unique constraint on active meetings
|
||||||
|
- Added get_all_active_for_room() method
|
||||||
|
- Added get_active_by_calendar_event() method
|
||||||
|
3. ✅ Implemented grace period logic
|
||||||
|
- Added last_participant_left_at and grace_period_minutes fields
|
||||||
|
- Process meetings task handles grace period checking
|
||||||
|
- Whereby webhooks clear grace period on participant join
|
||||||
|
4. ✅ Link meetings to calendar events
|
||||||
|
- Pre-created meetings properly linked via calendar_event_id
|
||||||
|
- Calendar metadata stored with meeting
|
||||||
|
- API endpoints for listing and joining specific meetings
|
||||||
|
|
||||||
|
### Phase 3: Frontend Meeting Selection (Week 3)
|
||||||
|
1. Build meeting selection page
|
||||||
|
2. Show active and upcoming meetings
|
||||||
|
3. Implement waiting page for early joiners
|
||||||
|
4. Add automatic transition from waiting to meeting
|
||||||
|
5. Support unscheduled meeting creation
|
||||||
|
|
||||||
|
### Phase 4: Calendar Integration UI (Week 4)
|
||||||
|
1. Add ICS settings to room configuration
|
||||||
|
2. Display calendar metadata in meetings
|
||||||
|
3. Show attendee information
|
||||||
|
4. Add sync status indicators
|
||||||
|
5. Show fetch interval and next sync time
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
- Zero merged meetings from consecutive calendar events
|
||||||
|
- Successful ICS sync from major providers (Google Calendar, Outlook, Apple Calendar, Nextcloud)
|
||||||
|
- Meeting join accuracy: correct meeting 100% of the time
|
||||||
|
- Grace period prevents 90% of accidental meeting closures
|
||||||
|
- Configurable fetch intervals reduce unnecessary API calls
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
1. **ICS attached to room, not user** - Prevents duplicate meetings from multiple calendars
|
||||||
|
2. **Multiple active meetings per room** - Supported with meeting selection page
|
||||||
|
3. **Grace period for rejoining** - 15 minutes after last participant leaves
|
||||||
|
4. **Upcoming meeting visibility** - Show 30 minutes before, join only on time
|
||||||
|
5. **Calendar data storage** - Attached to meeting record for full context
|
||||||
|
6. **No "ad-hoc" meetings** - Use existing meeting creation flow (unscheduled meetings)
|
||||||
|
7. **ICS configuration via room PATCH** - Reuse existing room configuration endpoint
|
||||||
|
8. **Event deletion handling** - Soft-delete future events, preserve past meetings
|
||||||
|
9. **Configurable fetch interval** - Balance between freshness and server load
|
||||||
|
10. **ICS over CalDAV** - Simpler implementation, wider compatibility, no complex auth
|
||||||
|
|
||||||
|
## Phase 2 Implementation Files
|
||||||
|
|
||||||
|
### Database Migrations
|
||||||
|
- `/server/migrations/versions/6025e9b2bef2_remove_one_active_meeting_per_room_.py` - Remove unique constraint
|
||||||
|
- `/server/migrations/versions/d4a1c446458c_add_grace_period_fields_to_meeting.py` - Add grace period fields
|
||||||
|
|
||||||
|
### Updated Models
|
||||||
|
- `/server/reflector/db/meetings.py` - Added grace period fields and new query methods
|
||||||
|
|
||||||
|
### Updated Services
|
||||||
|
- `/server/reflector/worker/process.py` - Enhanced with grace period logic and multiple meeting support
|
||||||
|
|
||||||
|
### Updated API
|
||||||
|
- `/server/reflector/views/rooms.py` - Added endpoints for listing active meetings and joining specific meetings
|
||||||
|
- `/server/reflector/views/whereby.py` - Clear grace period on participant join
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
- `/server/tests/test_multiple_active_meetings.py` - Comprehensive tests for Phase 2 features (5 tests)
|
||||||
|
|
||||||
|
## Phase 1 Implementation Files Created
|
||||||
|
|
||||||
|
### Database Models
|
||||||
|
- `/server/reflector/db/rooms.py` - Updated with ICS fields (url, fetch_interval, enabled, last_sync, etag)
|
||||||
|
- `/server/reflector/db/calendar_events.py` - New CalendarEvent model with ics_uid and proper typing
|
||||||
|
- `/server/reflector/db/meetings.py` - Updated with calendar_event_id and calendar_metadata (JSONB)
|
||||||
|
|
||||||
|
### Services
|
||||||
|
- `/server/reflector/services/ics_sync.py` - ICS fetching and parsing with TypedDict for proper typing
|
||||||
|
|
||||||
|
### API Endpoints
|
||||||
|
- `/server/reflector/views/rooms.py` - Added ICS management endpoints with privacy controls
|
||||||
|
|
||||||
|
### Background Tasks
|
||||||
|
- `/server/reflector/worker/ics_sync.py` - Celery tasks for automatic periodic sync
|
||||||
|
- `/server/reflector/worker/app.py` - Updated beat schedule for ICS tasks
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
- `/server/tests/test_room_ics.py` - Room model ICS fields tests (6 tests)
|
||||||
|
- `/server/tests/test_calendar_event.py` - CalendarEvent model tests (7 tests)
|
||||||
|
- `/server/tests/test_ics_sync.py` - ICS sync service tests (7 tests)
|
||||||
|
- `/server/tests/test_room_ics_api.py` - API endpoint tests (11 tests)
|
||||||
|
- `/server/tests/test_ics_background_tasks.py` - Background task tests (6 tests)
|
||||||
|
|
||||||
|
### Key Design Decisions
|
||||||
|
- No encryption needed - ICS URLs are read-only access
|
||||||
|
- Using ics_uid instead of external_id for clarity
|
||||||
|
- Proper TypedDict typing for event data structures
|
||||||
|
- Removed unnecessary URL validation and webcal handling
|
||||||
|
- calendar_metadata in meetings stores flexible calendar data (organizer, recurrence, etc)
|
||||||
|
- Background tasks query all rooms directly to avoid filtering issues
|
||||||
|
- Sync intervals respected per-room configuration
|
||||||
|
|
||||||
|
## Implementation Approach
|
||||||
|
|
||||||
|
### ICS Fetching vs CalDAV
|
||||||
|
- **ICS Benefits**:
|
||||||
|
- Simpler implementation (HTTP GET vs CalDAV protocol)
|
||||||
|
- Wider compatibility (all calendar apps can export ICS)
|
||||||
|
- No authentication complexity (simple URL with optional token)
|
||||||
|
- Easier debugging (ICS is plain text)
|
||||||
|
- Lower server requirements (no CalDAV library dependencies)
|
||||||
|
|
||||||
|
### Supported Calendar Providers
|
||||||
|
1. **Google Calendar**: Private ICS URL from calendar settings
|
||||||
|
2. **Outlook/Office 365**: ICS export URL from calendar sharing
|
||||||
|
3. **Apple Calendar**: Published calendar ICS URL
|
||||||
|
4. **Nextcloud**: Public/private calendar ICS export
|
||||||
|
5. **Any CalDAV server**: Via ICS export endpoint
|
||||||
|
|
||||||
|
### ICS URL Examples
|
||||||
|
- Google: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||||
|
- Outlook: `https://outlook.live.com/owa/calendar/{id}/calendar.ics`
|
||||||
|
- Custom: `https://example.com/calendars/room-schedule.ics`
|
||||||
|
|
||||||
|
### Fetch Interval Configuration
|
||||||
|
- 1 minute: For critical/high-activity rooms
|
||||||
|
- 5 minutes (default): Balance of freshness and efficiency
|
||||||
|
- 10 minutes: Standard meeting rooms
|
||||||
|
- 30 minutes: Low-activity rooms
|
||||||
|
- 1 hour: Rarely-used rooms or stable schedules
|
||||||
81
README.md
81
README.md
@@ -1,60 +1,43 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<img width="100" alt="image" src="https://github.com/user-attachments/assets/66fb367b-2c89-4516-9912-f47ac59c6a7f"/>
|
|
||||||
|
|
||||||
# Reflector
|
# Reflector
|
||||||
|
|
||||||
Reflector is an AI-powered audio transcription and meeting analysis platform that provides real-time transcription, speaker diarization, translation and summarization for audio content and live meetings. It works 100% with local models (whisper/parakeet, pyannote, seamless-m4t, and your local llm like phi-4).
|
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||||
|
|
||||||
[](https://github.com/monadical-sas/reflector/actions/workflows/test_server.yml)
|
[](https://github.com/monadical-sas/reflector/actions/workflows/pytests.yml)
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
|
||||||
|
## Screenshots
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<td>
|
<td>
|
||||||
<a href="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97">
|
<a href="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3">
|
||||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97" />
|
<img width="700" alt="image" src="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3" />
|
||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
<a href="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c">
|
<a href="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33">
|
||||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c" />
|
<img width="700" alt="image" src="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33" />
|
||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
<a href="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897">
|
<a href="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc">
|
||||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897" />
|
<img width="700" alt="image" src="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc" />
|
||||||
</a>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<a href="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4">
|
|
||||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4" />
|
|
||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## What is Reflector?
|
|
||||||
|
|
||||||
Reflector is a web application that utilizes local models to process audio content, providing:
|
|
||||||
|
|
||||||
- **Real-time Transcription**: Convert speech to text using [Whisper](https://github.com/openai/whisper) (multi-language) or [Parakeet](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) models
|
|
||||||
- **Speaker Diarization**: Identify and label different speakers using [Pyannote](https://github.com/pyannote/pyannote-audio) 3.1
|
|
||||||
- **Live Translation**: Translate audio content in real-time to many languages with [Facebook Seamless-M4T](https://github.com/facebookresearch/seamless_communication)
|
|
||||||
- **Topic Detection & Summarization**: Extract key topics and generate concise summaries using LLMs
|
|
||||||
- **Meeting Recording**: Create permanent records of meetings with searchable transcripts
|
|
||||||
|
|
||||||
Currently we provide [modal.com](https://modal.com/) gpu template to deploy.
|
|
||||||
|
|
||||||
## Background
|
## Background
|
||||||
|
|
||||||
The project architecture consists of three primary components:
|
The project architecture consists of three primary components:
|
||||||
|
|
||||||
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
|
||||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||||
|
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. Most reliable option is Modal deployment
|
||||||
|
|
||||||
It also uses authentik for authentication if activated.
|
It also uses authentik for authentication if activated, and Vercel for deployment and configuration of the front-end.
|
||||||
|
|
||||||
## Contribution Guidelines
|
## Contribution Guidelines
|
||||||
|
|
||||||
@@ -89,8 +72,6 @@ Note: We currently do not have instructions for Windows users.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
*Note: we're working toward better installation, theses instructions are not accurate for now*
|
|
||||||
|
|
||||||
### Frontend
|
### Frontend
|
||||||
|
|
||||||
Start with `cd www`.
|
Start with `cd www`.
|
||||||
@@ -99,10 +80,11 @@ Start with `cd www`.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pnpm install
|
pnpm install
|
||||||
cp .env.example .env
|
cp .env_template .env
|
||||||
|
cp config-template.ts config.ts
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, fill in the environment variables in `.env` as needed. If you are unsure on how to proceed, ask in Zulip.
|
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||||
|
|
||||||
**Run in development mode**
|
**Run in development mode**
|
||||||
|
|
||||||
@@ -167,34 +149,3 @@ You can manually process an audio file by calling the process tool:
|
|||||||
```bash
|
```bash
|
||||||
uv run python -m reflector.tools.process path/to/audio.wav
|
uv run python -m reflector.tools.process path/to/audio.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Feature Flags
|
|
||||||
|
|
||||||
Reflector uses environment variable-based feature flags to control application functionality. These flags allow you to enable or disable features without code changes.
|
|
||||||
|
|
||||||
### Available Feature Flags
|
|
||||||
|
|
||||||
| Feature Flag | Environment Variable |
|
|
||||||
|-------------|---------------------|
|
|
||||||
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
|
|
||||||
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
|
|
||||||
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
|
|
||||||
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
|
|
||||||
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
|
|
||||||
|
|
||||||
### Setting Feature Flags
|
|
||||||
|
|
||||||
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```bash
|
|
||||||
# Enable user authentication requirement
|
|
||||||
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
|
||||||
|
|
||||||
# Disable browse functionality
|
|
||||||
NEXT_PUBLIC_FEATURE_BROWSE=false
|
|
||||||
|
|
||||||
# Enable Zulip integration
|
|
||||||
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ services:
|
|||||||
- 1250:1250
|
- 1250:1250
|
||||||
volumes:
|
volumes:
|
||||||
- ./server/:/app/
|
- ./server/:/app/
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
@@ -17,7 +16,6 @@ services:
|
|||||||
context: server
|
context: server
|
||||||
volumes:
|
volumes:
|
||||||
- ./server/:/app/
|
- ./server/:/app/
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
@@ -28,7 +26,6 @@ services:
|
|||||||
context: server
|
context: server
|
||||||
volumes:
|
volumes:
|
||||||
- ./server/:/app/
|
- ./server/:/app/
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
|
|||||||
33
gpu/modal_deployments/.gitignore
vendored
33
gpu/modal_deployments/.gitignore
vendored
@@ -1,33 +0,0 @@
|
|||||||
# OS / Editor
|
|
||||||
.DS_Store
|
|
||||||
.vscode/
|
|
||||||
.idea/
|
|
||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# Logs
|
|
||||||
*.log
|
|
||||||
|
|
||||||
# Env and secrets
|
|
||||||
.env
|
|
||||||
.env.*
|
|
||||||
*.env
|
|
||||||
*.secret
|
|
||||||
|
|
||||||
# Build / dist
|
|
||||||
build/
|
|
||||||
dist/
|
|
||||||
.eggs/
|
|
||||||
*.egg-info/
|
|
||||||
|
|
||||||
# Coverage / test
|
|
||||||
.pytest_cache/
|
|
||||||
.coverage*
|
|
||||||
htmlcov/
|
|
||||||
|
|
||||||
# Modal local state (if any)
|
|
||||||
modal_mounts/
|
|
||||||
.modal_cache/
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
# Reflector GPU implementation - Transcription and LLM
|
|
||||||
|
|
||||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
|
||||||
and use [Modal.com](https://modal.com)
|
|
||||||
|
|
||||||
- `reflector_diarizer.py` - Diarization API
|
|
||||||
- `reflector_transcriber.py` - Transcription API (Whisper)
|
|
||||||
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
|
||||||
- `reflector_translator.py` - Translation API
|
|
||||||
|
|
||||||
## Modal.com deployment
|
|
||||||
|
|
||||||
Create a modal secret, and name it `reflector-gpu`.
|
|
||||||
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
|
||||||
|
|
||||||
The deployment is done using [Modal.com](https://modal.com) service.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ modal deploy reflector_transcriber.py
|
|
||||||
...
|
|
||||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
|
||||||
|
|
||||||
$ modal deploy reflector_transcriber_parakeet.py
|
|
||||||
...
|
|
||||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
|
||||||
|
|
||||||
$ modal deploy reflector_llm.py
|
|
||||||
...
|
|
||||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
|
||||||
```
|
|
||||||
|
|
||||||
Then in your reflector api configuration `.env`, you can set these keys:
|
|
||||||
|
|
||||||
```
|
|
||||||
TRANSCRIPT_BACKEND=modal
|
|
||||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
|
||||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
|
||||||
|
|
||||||
DIARIZATION_BACKEND=modal
|
|
||||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
|
||||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
|
||||||
|
|
||||||
TRANSLATION_BACKEND=modal
|
|
||||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
|
||||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
|
||||||
```
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
|
||||||
|
|
||||||
```
|
|
||||||
Authorization: bearer <REFLECTOR_APIKEY>
|
|
||||||
```
|
|
||||||
|
|
||||||
### LLM
|
|
||||||
|
|
||||||
`POST /llm`
|
|
||||||
|
|
||||||
**request**
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"prompt": "xxx"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**response**
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"text": "xxx completed"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Transcription
|
|
||||||
|
|
||||||
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
|
||||||
|
|
||||||
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
|
||||||
|
|
||||||
**GPU Configuration:**
|
|
||||||
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
|
||||||
- Higher concurrency (max_inputs=10)
|
|
||||||
- Optimized for multiple small audio files
|
|
||||||
- Supports batch processing for efficiency
|
|
||||||
|
|
||||||
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
|
||||||
- Lower concurrency but more powerful processing
|
|
||||||
- Optimized for single large audio files
|
|
||||||
- VAD-based chunking for long-form audio
|
|
||||||
|
|
||||||
##### `/v1/audio/transcriptions` - Small file transcription
|
|
||||||
|
|
||||||
**request** (multipart/form-data)
|
|
||||||
- `file` or `files[]` - audio file(s) to transcribe
|
|
||||||
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
|
||||||
- `language` - language code (default: `en`)
|
|
||||||
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
|
||||||
|
|
||||||
**response**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"text": "transcribed text",
|
|
||||||
"words": [
|
|
||||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
||||||
{"word": "world", "start": 0.5, "end": 1.0}
|
|
||||||
],
|
|
||||||
"filename": "audio.mp3"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
For multiple files with batch=true:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"results": [
|
|
||||||
{
|
|
||||||
"filename": "audio1.mp3",
|
|
||||||
"text": "transcribed text",
|
|
||||||
"words": [...]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"filename": "audio2.mp3",
|
|
||||||
"text": "transcribed text",
|
|
||||||
"words": [...]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
|
||||||
|
|
||||||
**request** (application/json)
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"audio_file_url": "https://example.com/audio.mp3",
|
|
||||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
|
||||||
"language": "en",
|
|
||||||
"timestamp_offset": 0.0
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**response**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"text": "transcribed text from large file",
|
|
||||||
"words": [
|
|
||||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
|
||||||
{"word": "world", "start": 0.5, "end": 1.0}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
|
||||||
|
|
||||||
#### Whisper Transcriber (`reflector_transcriber.py`)
|
|
||||||
|
|
||||||
`POST /transcribe`
|
|
||||||
|
|
||||||
**request** (multipart/form-data)
|
|
||||||
|
|
||||||
- `file` - audio file
|
|
||||||
- `language` - language code (e.g. `en`)
|
|
||||||
|
|
||||||
**response**
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"text": "xxx",
|
|
||||||
"words": [
|
|
||||||
{"text": "xxx", "start": 0.0, "end": 1.0}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -1,253 +0,0 @@
|
|||||||
"""
|
|
||||||
Reflector GPU backend - diarizer
|
|
||||||
===================================
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Mapping, NewType
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import modal
|
|
||||||
|
|
||||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
|
||||||
MODEL_DIR = "/root/diarization_models"
|
|
||||||
UPLOADS_PATH = "/uploads"
|
|
||||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
||||||
|
|
||||||
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
|
||||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
|
||||||
|
|
||||||
app = modal.App(name="reflector-diarizer")
|
|
||||||
|
|
||||||
# Volume for temporary file uploads
|
|
||||||
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
url_path = parsed_url.path
|
|
||||||
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return AudioFileExtension(ext)
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return AudioFileExtension("mp3")
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return AudioFileExtension("wav")
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return AudioFileExtension("mp4")
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported audio format for URL: {url}. "
|
|
||||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_volume(
|
|
||||||
audio_file_url: str,
|
|
||||||
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
print(f"Checking audio file at: {audio_file_url}")
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
print(f"Downloading audio file from: {audio_file_url}")
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"Download failed with status {response.status_code}: {response.text}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail=f"Failed to download audio file: {response.status_code}",
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
print(f"File saved as: {unique_filename}")
|
|
||||||
return unique_filename, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
def download_pyannote_audio():
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
|
|
||||||
Pipeline.from_pretrained(
|
|
||||||
PYANNOTE_MODEL_NAME,
|
|
||||||
cache_dir=MODEL_DIR,
|
|
||||||
use_auth_token=os.environ["HF_TOKEN"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
diarizer_image = (
|
|
||||||
modal.Image.debian_slim(python_version="3.10.8")
|
|
||||||
.pip_install(
|
|
||||||
"pyannote.audio==3.1.0",
|
|
||||||
"requests",
|
|
||||||
"onnx",
|
|
||||||
"torchaudio",
|
|
||||||
"onnxruntime-gpu",
|
|
||||||
"torch==2.0.0",
|
|
||||||
"transformers==4.34.0",
|
|
||||||
"sentencepiece",
|
|
||||||
"protobuf",
|
|
||||||
"numpy",
|
|
||||||
"huggingface_hub",
|
|
||||||
"hf-transfer",
|
|
||||||
)
|
|
||||||
.run_function(
|
|
||||||
download_pyannote_audio,
|
|
||||||
secrets=[modal.Secret.from_name("hf_token")],
|
|
||||||
)
|
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
.env(
|
|
||||||
{
|
|
||||||
"LD_LIBRARY_PATH": (
|
|
||||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
|
||||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A100",
|
|
||||||
timeout=60 * 30,
|
|
||||||
image=diarizer_image,
|
|
||||||
volumes={UPLOADS_PATH: upload_volume},
|
|
||||||
enable_memory_snapshot=True,
|
|
||||||
experimental_options={"enable_gpu_snapshot": True},
|
|
||||||
secrets=[
|
|
||||||
modal.Secret.from_name("hf_token"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=1)
|
|
||||||
class Diarizer:
|
|
||||||
@modal.enter(snap=True)
|
|
||||||
def enter(self):
|
|
||||||
import torch
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
|
|
||||||
self.use_gpu = torch.cuda.is_available()
|
|
||||||
self.device = "cuda" if self.use_gpu else "cpu"
|
|
||||||
print(f"Using device: {self.device}")
|
|
||||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
|
||||||
PYANNOTE_MODEL_NAME,
|
|
||||||
cache_dir=MODEL_DIR,
|
|
||||||
use_auth_token=os.environ["HF_TOKEN"],
|
|
||||||
)
|
|
||||||
self.diarization_pipeline.to(torch.device(self.device))
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def diarize(self, filename: str, timestamp: float = 0.0):
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
print(f"Diarizing audio from: {file_path}")
|
|
||||||
waveform, sample_rate = torchaudio.load(file_path)
|
|
||||||
diarization = self.diarization_pipeline(
|
|
||||||
{"waveform": waveform, "sample_rate": sample_rate}
|
|
||||||
)
|
|
||||||
|
|
||||||
words = []
|
|
||||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
words.append(
|
|
||||||
{
|
|
||||||
"start": round(timestamp + diarization_segment.start, 3),
|
|
||||||
"end": round(timestamp + diarization_segment.end, 3),
|
|
||||||
"speaker": int(speaker[-2:]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print("Diarization complete")
|
|
||||||
return {"diarization": words}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
# Web API
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
timeout=60 * 10,
|
|
||||||
scaledown_window=60 * 3,
|
|
||||||
secrets=[
|
|
||||||
modal.Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
volumes={UPLOADS_PATH: upload_volume},
|
|
||||||
image=diarizer_image,
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=40)
|
|
||||||
@modal.asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
diarizerstub = Diarizer()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class DiarizationResponse(BaseModel):
|
|
||||||
result: dict
|
|
||||||
|
|
||||||
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
|
||||||
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
|
||||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
|
||||||
|
|
||||||
try:
|
|
||||||
func = diarizerstub.diarize.spawn(
|
|
||||||
filename=unique_filename, timestamp=timestamp
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
print(f"Deleting file: {file_path}")
|
|
||||||
os.remove(file_path)
|
|
||||||
upload_volume.commit()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error cleaning up {unique_filename}: {e}")
|
|
||||||
|
|
||||||
return app
|
|
||||||
@@ -1,608 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import modal
|
|
||||||
|
|
||||||
MODEL_NAME = "large-v2"
|
|
||||||
MODEL_COMPUTE_TYPE: str = "float16"
|
|
||||||
MODEL_NUM_WORKERS: int = 1
|
|
||||||
MINUTES = 60 # seconds
|
|
||||||
SAMPLERATE = 16000
|
|
||||||
UPLOADS_PATH = "/uploads"
|
|
||||||
CACHE_PATH = "/models"
|
|
||||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
||||||
VAD_CONFIG = {
|
|
||||||
"batch_max_duration": 30.0,
|
|
||||||
"silence_padding": 0.5,
|
|
||||||
"window_size": 512,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
|
|
||||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
|
||||||
|
|
||||||
app = modal.App("reflector-transcriber")
|
|
||||||
|
|
||||||
model_cache = modal.Volume.from_name("models", create_if_missing=True)
|
|
||||||
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeSegment(NamedTuple):
|
|
||||||
"""Represents a time segment with start and end times."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
|
|
||||||
|
|
||||||
class AudioSegment(NamedTuple):
|
|
||||||
"""Represents an audio segment with timing and audio data."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
audio: any
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptResult(NamedTuple):
|
|
||||||
"""Represents a transcription result with text and word timings."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
words: list["WordTiming"]
|
|
||||||
|
|
||||||
|
|
||||||
class WordTiming(TypedDict):
|
|
||||||
"""Represents a word with its timing information."""
|
|
||||||
|
|
||||||
word: str
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
|
|
||||||
|
|
||||||
def download_model():
|
|
||||||
from faster_whisper import download_model
|
|
||||||
|
|
||||||
model_cache.reload()
|
|
||||||
|
|
||||||
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
|
|
||||||
|
|
||||||
model_cache.commit()
|
|
||||||
|
|
||||||
|
|
||||||
image = (
|
|
||||||
modal.Image.debian_slim(python_version="3.12")
|
|
||||||
.env(
|
|
||||||
{
|
|
||||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
||||||
"LD_LIBRARY_PATH": (
|
|
||||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
|
||||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.apt_install("ffmpeg")
|
|
||||||
.pip_install(
|
|
||||||
"huggingface_hub==0.27.1",
|
|
||||||
"hf-transfer==0.1.9",
|
|
||||||
"torch==2.5.1",
|
|
||||||
"faster-whisper==1.1.1",
|
|
||||||
"fastapi==0.115.12",
|
|
||||||
"requests",
|
|
||||||
"librosa==0.10.1",
|
|
||||||
"numpy<2",
|
|
||||||
"silero-vad==5.1.0",
|
|
||||||
)
|
|
||||||
.run_function(download_model, volumes={CACHE_PATH: model_cache})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
url_path = parsed_url.path
|
|
||||||
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return AudioFileExtension(ext)
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return AudioFileExtension("mp3")
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return AudioFileExtension("wav")
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return AudioFileExtension("mp4")
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported audio format for URL: {url}. "
|
|
||||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_volume(
|
|
||||||
audio_file_url: str,
|
|
||||||
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
return unique_filename, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
|
||||||
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
|
|
||||||
|
|
||||||
Whisper does not require this strictly, but aligning behavior with Parakeet
|
|
||||||
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
audio_duration = len(audio_array) / sample_rate
|
|
||||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
|
||||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
|
||||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
return np.concatenate([audio_array, silence])
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A10G",
|
|
||||||
timeout=5 * MINUTES,
|
|
||||||
scaledown_window=5 * MINUTES,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=10)
|
|
||||||
class TranscriberWhisperLive:
|
|
||||||
"""Live transcriber class for small audio segments (A10G).
|
|
||||||
|
|
||||||
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@modal.enter()
|
|
||||||
def enter(self):
|
|
||||||
import faster_whisper
|
|
||||||
import torch
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.use_gpu = torch.cuda.is_available()
|
|
||||||
self.device = "cuda" if self.use_gpu else "cpu"
|
|
||||||
self.model = faster_whisper.WhisperModel(
|
|
||||||
MODEL_NAME,
|
|
||||||
device=self.device,
|
|
||||||
compute_type=MODEL_COMPUTE_TYPE,
|
|
||||||
num_workers=MODEL_NUM_WORKERS,
|
|
||||||
download_root=CACHE_PATH,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
print(f"Model is on device: {self.device}")
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
language: str = "en",
|
|
||||||
):
|
|
||||||
"""Transcribe a single uploaded audio file by filename."""
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
segments, _ = self.model.transcribe(
|
|
||||||
file_path,
|
|
||||||
language=language,
|
|
||||||
beam_size=5,
|
|
||||||
word_timestamps=True,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
|
|
||||||
segments = list(segments)
|
|
||||||
text = "".join(segment.text for segment in segments).strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": word.word,
|
|
||||||
"start": round(float(word.start), 2),
|
|
||||||
"end": round(float(word.end), 2),
|
|
||||||
}
|
|
||||||
for segment in segments
|
|
||||||
for word in segment.words
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"text": text, "words": words}
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_batch(
|
|
||||||
self,
|
|
||||||
filenames: list[str],
|
|
||||||
language: str = "en",
|
|
||||||
):
|
|
||||||
"""Transcribe multiple uploaded audio files and return per-file results."""
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for filename in filenames:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
segments, _ = self.model.transcribe(
|
|
||||||
file_path,
|
|
||||||
language=language,
|
|
||||||
beam_size=5,
|
|
||||||
word_timestamps=True,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
|
|
||||||
segments = list(segments)
|
|
||||||
text = "".join(seg.text for seg in segments).strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": w.word,
|
|
||||||
"start": round(float(w.start), 2),
|
|
||||||
"end": round(float(w.end), 2),
|
|
||||||
}
|
|
||||||
for seg in segments
|
|
||||||
for w in seg.words
|
|
||||||
]
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"filename": filename,
|
|
||||||
"text": text,
|
|
||||||
"words": words,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="L40S",
|
|
||||||
timeout=15 * MINUTES,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
)
|
|
||||||
class TranscriberWhisperFile:
|
|
||||||
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
|
|
||||||
|
|
||||||
@modal.enter()
|
|
||||||
def enter(self):
|
|
||||||
import faster_whisper
|
|
||||||
import torch
|
|
||||||
from silero_vad import load_silero_vad
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.use_gpu = torch.cuda.is_available()
|
|
||||||
self.device = "cuda" if self.use_gpu else "cpu"
|
|
||||||
self.model = faster_whisper.WhisperModel(
|
|
||||||
MODEL_NAME,
|
|
||||||
device=self.device,
|
|
||||||
compute_type=MODEL_COMPUTE_TYPE,
|
|
||||||
num_workers=MODEL_NUM_WORKERS,
|
|
||||||
download_root=CACHE_PATH,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
self.vad_model = load_silero_vad(onnx=False)
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
from silero_vad import VADIterator
|
|
||||||
|
|
||||||
def vad_segments(
|
|
||||||
audio_array,
|
|
||||||
sample_rate: int = SAMPLERATE,
|
|
||||||
window_size: int = VAD_CONFIG["window_size"],
|
|
||||||
) -> Generator[TimeSegment, None, None]:
|
|
||||||
"""Generate speech segments as TimeSegment using Silero VAD."""
|
|
||||||
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
|
|
||||||
start = None
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(
|
|
||||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
|
||||||
)
|
|
||||||
speech = iterator(chunk)
|
|
||||||
if not speech:
|
|
||||||
continue
|
|
||||||
if "start" in speech:
|
|
||||||
start = speech["start"]
|
|
||||||
continue
|
|
||||||
if "end" in speech and start is not None:
|
|
||||||
end = speech["end"]
|
|
||||||
yield TimeSegment(
|
|
||||||
start / float(SAMPLERATE), end / float(SAMPLERATE)
|
|
||||||
)
|
|
||||||
start = None
|
|
||||||
iterator.reset_states()
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
|
|
||||||
# Batch segments up to ~30s windows by merging contiguous VAD segments
|
|
||||||
merged_batches: list[TimeSegment] = []
|
|
||||||
batch_start = None
|
|
||||||
batch_end = None
|
|
||||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
|
||||||
for segment in vad_segments(audio_array):
|
|
||||||
seg_start, seg_end = segment.start, segment.end
|
|
||||||
if batch_start is None:
|
|
||||||
batch_start, batch_end = seg_start, seg_end
|
|
||||||
continue
|
|
||||||
if seg_end - batch_start <= max_duration:
|
|
||||||
batch_end = seg_end
|
|
||||||
else:
|
|
||||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
|
||||||
batch_start, batch_end = seg_start, seg_end
|
|
||||||
if batch_start is not None and batch_end is not None:
|
|
||||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
|
||||||
|
|
||||||
all_text = []
|
|
||||||
all_words = []
|
|
||||||
|
|
||||||
for segment in merged_batches:
|
|
||||||
start_time, end_time = segment.start, segment.end
|
|
||||||
s_idx = int(start_time * SAMPLERATE)
|
|
||||||
e_idx = int(end_time * SAMPLERATE)
|
|
||||||
segment = audio_array[s_idx:e_idx]
|
|
||||||
segment = pad_audio(segment, SAMPLERATE)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
segments, _ = self.model.transcribe(
|
|
||||||
segment,
|
|
||||||
language=language,
|
|
||||||
beam_size=5,
|
|
||||||
word_timestamps=True,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
|
|
||||||
segments = list(segments)
|
|
||||||
text = "".join(seg.text for seg in segments).strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": w.word,
|
|
||||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
|
||||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
|
||||||
}
|
|
||||||
for seg in segments
|
|
||||||
for w in seg.words
|
|
||||||
]
|
|
||||||
if text:
|
|
||||||
all_text.append(text)
|
|
||||||
all_words.extend(words)
|
|
||||||
|
|
||||||
return {"text": " ".join(all_text), "words": all_words}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: dict) -> str:
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
url_path = urlparse(url).path
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return ext
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return "mp3"
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return "wav"
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return "mp4"
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=(
|
|
||||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
return unique_filename, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60,
|
|
||||||
timeout=600,
|
|
||||||
secrets=[
|
|
||||||
modal.Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
image=image,
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=40)
|
|
||||||
@modal.asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import (
|
|
||||||
Body,
|
|
||||||
Depends,
|
|
||||||
FastAPI,
|
|
||||||
Form,
|
|
||||||
HTTPException,
|
|
||||||
UploadFile,
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
|
|
||||||
transcriber_live = TranscriberWhisperLive()
|
|
||||||
transcriber_file = TranscriberWhisperFile()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
return
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class TranscriptResponse(dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe(
|
|
||||||
file: UploadFile = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
model: str = Form(MODEL_NAME),
|
|
||||||
language: str = Form("en"),
|
|
||||||
batch: bool = Form(False),
|
|
||||||
):
|
|
||||||
if not file and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
|
||||||
)
|
|
||||||
if batch and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Batch transcription requires 'files'"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_files = [file] if file else files
|
|
||||||
|
|
||||||
uploaded_filenames: list[str] = []
|
|
||||||
for upload_file in upload_files:
|
|
||||||
audio_suffix = upload_file.filename.split(".")[-1]
|
|
||||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=(
|
|
||||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
content = upload_file.file.read()
|
|
||||||
f.write(content)
|
|
||||||
uploaded_filenames.append(unique_filename)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if batch and len(upload_files) > 1:
|
|
||||||
func = transcriber_live.transcribe_batch.spawn(
|
|
||||||
filenames=uploaded_filenames,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
results = func.get()
|
|
||||||
return {"results": results}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
func = transcriber_live.transcribe_segment.spawn(
|
|
||||||
filename=filename,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
result["filename"] = filename
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return {"results": results} if len(results) > 1 else results[0]
|
|
||||||
finally:
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
os.remove(file_path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe_from_url(
|
|
||||||
audio_file_url: str = Body(
|
|
||||||
..., description="URL of the audio file to transcribe"
|
|
||||||
),
|
|
||||||
model: str = Body(MODEL_NAME),
|
|
||||||
language: str = Body("en"),
|
|
||||||
timestamp_offset: float = Body(0.0),
|
|
||||||
):
|
|
||||||
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
|
|
||||||
try:
|
|
||||||
func = transcriber_file.transcribe_segment.spawn(
|
|
||||||
filename=unique_filename,
|
|
||||||
timestamp_offset=timestamp_offset,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
os.remove(file_path)
|
|
||||||
upload_volume.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
class NoStdStreams:
|
|
||||||
def __init__(self):
|
|
||||||
self.devnull = open(os.devnull, "w")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
|
||||||
self._stdout.flush()
|
|
||||||
self._stderr.flush()
|
|
||||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
|
||||||
self.devnull.close()
|
|
||||||
@@ -1,658 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import modal
|
|
||||||
|
|
||||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
|
||||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
||||||
SAMPLERATE = 16000
|
|
||||||
UPLOADS_PATH = "/uploads"
|
|
||||||
CACHE_PATH = "/cache"
|
|
||||||
VAD_CONFIG = {
|
|
||||||
"batch_max_duration": 30.0,
|
|
||||||
"silence_padding": 0.5,
|
|
||||||
"window_size": 512,
|
|
||||||
}
|
|
||||||
|
|
||||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
|
||||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeSegment(NamedTuple):
|
|
||||||
"""Represents a time segment with start and end times."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
|
|
||||||
|
|
||||||
class AudioSegment(NamedTuple):
|
|
||||||
"""Represents an audio segment with timing and audio data."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
audio: any
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptResult(NamedTuple):
|
|
||||||
"""Represents a transcription result with text and word timings."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
words: list["WordTiming"]
|
|
||||||
|
|
||||||
|
|
||||||
class WordTiming(TypedDict):
|
|
||||||
"""Represents a word with its timing information."""
|
|
||||||
|
|
||||||
word: str
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
|
|
||||||
|
|
||||||
app = modal.App("reflector-transcriber-parakeet")
|
|
||||||
|
|
||||||
# Volume for caching model weights
|
|
||||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
|
||||||
# Volume for temporary file uploads
|
|
||||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
|
||||||
|
|
||||||
image = (
|
|
||||||
modal.Image.from_registry(
|
|
||||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
|
||||||
)
|
|
||||||
.env(
|
|
||||||
{
|
|
||||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
||||||
"HF_HOME": "/cache",
|
|
||||||
"DEBIAN_FRONTEND": "noninteractive",
|
|
||||||
"CXX": "g++",
|
|
||||||
"CC": "g++",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.apt_install("ffmpeg")
|
|
||||||
.pip_install(
|
|
||||||
"hf_transfer==0.1.9",
|
|
||||||
"huggingface_hub[hf-xet]==0.31.2",
|
|
||||||
"nemo_toolkit[asr]==2.3.0",
|
|
||||||
"cuda-python==12.8.0",
|
|
||||||
"fastapi==0.115.12",
|
|
||||||
"numpy<2",
|
|
||||||
"librosa==0.10.1",
|
|
||||||
"requests",
|
|
||||||
"silero-vad==5.1.0",
|
|
||||||
"torch",
|
|
||||||
)
|
|
||||||
.entrypoint([]) # silence chatty logs by container on start
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
url_path = parsed_url.path
|
|
||||||
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return AudioFileExtension(ext)
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return AudioFileExtension("mp3")
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return AudioFileExtension("wav")
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return AudioFileExtension("mp4")
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported audio format for URL: {url}. "
|
|
||||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_volume(
|
|
||||||
audio_file_url: str,
|
|
||||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
return unique_filename, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
|
||||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
|
||||||
|
|
||||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
|
||||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
|
||||||
have to be of the same length
|
|
||||||
|
|
||||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
audio_duration = len(audio_array) / sample_rate
|
|
||||||
if audio_duration < 0.5:
|
|
||||||
silence_samples = int(sample_rate * 0.5)
|
|
||||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
return np.concatenate([audio_array, silence])
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A10G",
|
|
||||||
timeout=600,
|
|
||||||
scaledown_window=300,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
enable_memory_snapshot=True,
|
|
||||||
experimental_options={"enable_gpu_snapshot": True},
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=10)
|
|
||||||
class TranscriberParakeetLive:
|
|
||||||
@modal.enter(snap=True)
|
|
||||||
def enter(self):
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
|
|
||||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
print(f"Model is on device: {device}")
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
padded_audio = pad_audio(audio_array, sample_rate)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
|
||||||
|
|
||||||
text = output.text.strip()
|
|
||||||
words: list[WordTiming] = [
|
|
||||||
WordTiming(
|
|
||||||
# XXX the space added here is to match the output of whisper
|
|
||||||
# whisper add space to each words, while parakeet don't
|
|
||||||
word=word_info["word"] + " ",
|
|
||||||
start=round(word_info["start"], 2),
|
|
||||||
end=round(word_info["end"], 2),
|
|
||||||
)
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"text": text, "words": words}
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_batch(
|
|
||||||
self,
|
|
||||||
filenames: list[str],
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
audio_arrays = []
|
|
||||||
|
|
||||||
# Load all audio files with padding
|
|
||||||
for filename in filenames:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
padded_audio = pad_audio(audio_array, sample_rate)
|
|
||||||
audio_arrays.append(padded_audio)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
|
||||||
|
|
||||||
# Process results for each file
|
|
||||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
|
||||||
text = output.text.strip()
|
|
||||||
|
|
||||||
words: list[WordTiming] = [
|
|
||||||
WordTiming(
|
|
||||||
word=word_info["word"] + " ",
|
|
||||||
start=round(word_info["start"], 2),
|
|
||||||
end=round(word_info["end"], 2),
|
|
||||||
)
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"filename": filename,
|
|
||||||
"text": text,
|
|
||||||
"words": words,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# L40S class for file transcription (bigger files)
|
|
||||||
@app.cls(
|
|
||||||
gpu="L40S",
|
|
||||||
timeout=900,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
enable_memory_snapshot=True,
|
|
||||||
experimental_options={"enable_gpu_snapshot": True},
|
|
||||||
)
|
|
||||||
class TranscriberParakeetFile:
|
|
||||||
@modal.enter(snap=True)
|
|
||||||
def enter(self):
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
import torch
|
|
||||||
from silero_vad import load_silero_vad
|
|
||||||
|
|
||||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
print(f"Model is on device: {device}")
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
self.vad_model = load_silero_vad(onnx=False)
|
|
||||||
print("Silero VAD initialized")
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
timestamp_offset: float = 0.0,
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
from silero_vad import VADIterator
|
|
||||||
|
|
||||||
def load_and_convert_audio(file_path):
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
def vad_segment_generator(
|
|
||||||
audio_array,
|
|
||||||
) -> Generator[TimeSegment, None, None]:
|
|
||||||
"""Generate speech segments using VAD with start/end sample indices"""
|
|
||||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
|
||||||
window_size = VAD_CONFIG["window_size"]
|
|
||||||
start = None
|
|
||||||
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(
|
|
||||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
|
||||||
)
|
|
||||||
|
|
||||||
speech_dict = vad_iterator(chunk)
|
|
||||||
if not speech_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "start" in speech_dict:
|
|
||||||
start = speech_dict["start"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "end" in speech_dict and start is not None:
|
|
||||||
end = speech_dict["end"]
|
|
||||||
start_time = start / float(SAMPLERATE)
|
|
||||||
end_time = end / float(SAMPLERATE)
|
|
||||||
|
|
||||||
yield TimeSegment(start_time, end_time)
|
|
||||||
start = None
|
|
||||||
|
|
||||||
vad_iterator.reset_states()
|
|
||||||
|
|
||||||
def batch_speech_segments(
|
|
||||||
segments: Generator[TimeSegment, None, None], max_duration: int
|
|
||||||
) -> Generator[TimeSegment, None, None]:
|
|
||||||
"""
|
|
||||||
Input segments:
|
|
||||||
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
|
||||||
|
|
||||||
↓ (max_duration=10)
|
|
||||||
|
|
||||||
Output batches:
|
|
||||||
[0-8] [10-19] [20-22]
|
|
||||||
|
|
||||||
Note: silences are kept for better transcription, previous implementation was
|
|
||||||
passing segments separatly, but the output was less accurate.
|
|
||||||
"""
|
|
||||||
batch_start_time = None
|
|
||||||
batch_end_time = None
|
|
||||||
|
|
||||||
for segment in segments:
|
|
||||||
start_time, end_time = segment.start, segment.end
|
|
||||||
if batch_start_time is None or batch_end_time is None:
|
|
||||||
batch_start_time = start_time
|
|
||||||
batch_end_time = end_time
|
|
||||||
continue
|
|
||||||
|
|
||||||
total_duration = end_time - batch_start_time
|
|
||||||
|
|
||||||
if total_duration <= max_duration:
|
|
||||||
batch_end_time = end_time
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield TimeSegment(batch_start_time, batch_end_time)
|
|
||||||
batch_start_time = start_time
|
|
||||||
batch_end_time = end_time
|
|
||||||
|
|
||||||
if batch_start_time is None or batch_end_time is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
yield TimeSegment(batch_start_time, batch_end_time)
|
|
||||||
|
|
||||||
def batch_segment_to_audio_segment(
|
|
||||||
segments: Generator[TimeSegment, None, None],
|
|
||||||
audio_array,
|
|
||||||
) -> Generator[AudioSegment, None, None]:
|
|
||||||
"""Extract audio segments and apply padding for Parakeet compatibility.
|
|
||||||
|
|
||||||
Uses pad_audio to ensure segments are at least 0.5s long, preventing
|
|
||||||
Parakeet crashes. This padding may cause slight timing overlaps between
|
|
||||||
segments, which are corrected by enforce_word_timing_constraints.
|
|
||||||
"""
|
|
||||||
for segment in segments:
|
|
||||||
start_time, end_time = segment.start, segment.end
|
|
||||||
start_sample = int(start_time * SAMPLERATE)
|
|
||||||
end_sample = int(end_time * SAMPLERATE)
|
|
||||||
audio_segment = audio_array[start_sample:end_sample]
|
|
||||||
|
|
||||||
padded_segment = pad_audio(audio_segment, SAMPLERATE)
|
|
||||||
|
|
||||||
yield AudioSegment(start_time, end_time, padded_segment)
|
|
||||||
|
|
||||||
def transcribe_batch(model, audio_segments: list) -> list:
|
|
||||||
with NoStdStreams():
|
|
||||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def enforce_word_timing_constraints(
|
|
||||||
words: list[WordTiming],
|
|
||||||
) -> list[WordTiming]:
|
|
||||||
"""Enforce that word end times don't exceed the start time of the next word.
|
|
||||||
|
|
||||||
Due to silence padding added in batch_segment_to_audio_segment for better
|
|
||||||
transcription accuracy, word timings from different segments may overlap.
|
|
||||||
This function ensures there are no overlaps by adjusting end times.
|
|
||||||
"""
|
|
||||||
if len(words) <= 1:
|
|
||||||
return words
|
|
||||||
|
|
||||||
enforced_words = []
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
enforced_word = word.copy()
|
|
||||||
|
|
||||||
if i < len(words) - 1:
|
|
||||||
next_start = words[i + 1]["start"]
|
|
||||||
if enforced_word["end"] > next_start:
|
|
||||||
enforced_word["end"] = next_start
|
|
||||||
|
|
||||||
enforced_words.append(enforced_word)
|
|
||||||
|
|
||||||
return enforced_words
|
|
||||||
|
|
||||||
def emit_results(
|
|
||||||
results: list,
|
|
||||||
segments_info: list[AudioSegment],
|
|
||||||
) -> Generator[TranscriptResult, None, None]:
|
|
||||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
|
||||||
for i, (output, segment) in enumerate(zip(results, segments_info)):
|
|
||||||
start_time, end_time = segment.start, segment.end
|
|
||||||
text = output.text.strip()
|
|
||||||
words: list[WordTiming] = [
|
|
||||||
WordTiming(
|
|
||||||
word=word_info["word"] + " ",
|
|
||||||
start=round(
|
|
||||||
word_info["start"] + start_time + timestamp_offset, 2
|
|
||||||
),
|
|
||||||
end=round(word_info["end"] + start_time + timestamp_offset, 2),
|
|
||||||
)
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
yield TranscriptResult(text, words)
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array = load_and_convert_audio(file_path)
|
|
||||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
|
||||||
|
|
||||||
all_text_parts: list[str] = []
|
|
||||||
all_words: list[WordTiming] = []
|
|
||||||
|
|
||||||
raw_segments = vad_segment_generator(audio_array)
|
|
||||||
speech_segments = batch_speech_segments(
|
|
||||||
raw_segments,
|
|
||||||
VAD_CONFIG["batch_max_duration"],
|
|
||||||
)
|
|
||||||
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
|
||||||
|
|
||||||
for batch in audio_segments:
|
|
||||||
audio_segment = batch.audio
|
|
||||||
results = transcribe_batch(self.model, [audio_segment])
|
|
||||||
|
|
||||||
for result in emit_results(
|
|
||||||
results,
|
|
||||||
[batch],
|
|
||||||
):
|
|
||||||
if not result.text:
|
|
||||||
continue
|
|
||||||
all_text_parts.append(result.text)
|
|
||||||
all_words.extend(result.words)
|
|
||||||
|
|
||||||
all_words = enforce_word_timing_constraints(all_words)
|
|
||||||
|
|
||||||
combined_text = " ".join(all_text_parts)
|
|
||||||
return {"text": combined_text, "words": all_words}
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60,
|
|
||||||
timeout=600,
|
|
||||||
secrets=[
|
|
||||||
modal.Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
image=image,
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=40)
|
|
||||||
@modal.asgi_app()
|
|
||||||
def web():
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import (
|
|
||||||
Body,
|
|
||||||
Depends,
|
|
||||||
FastAPI,
|
|
||||||
Form,
|
|
||||||
HTTPException,
|
|
||||||
UploadFile,
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
transcriber_live = TranscriberParakeetLive()
|
|
||||||
transcriber_file = TranscriberParakeetFile()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
return
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class TranscriptResponse(BaseModel):
|
|
||||||
result: dict
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe(
|
|
||||||
file: UploadFile = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
model: str = Form(MODEL_NAME),
|
|
||||||
language: str = Form("en"),
|
|
||||||
batch: bool = Form(False),
|
|
||||||
):
|
|
||||||
# Parakeet only supports English
|
|
||||||
if language != "en":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
|
||||||
)
|
|
||||||
# Handle both single file and multiple files
|
|
||||||
if not file and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
|
||||||
)
|
|
||||||
if batch and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Batch transcription requires 'files'"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_files = [file] if file else files
|
|
||||||
|
|
||||||
# Upload files to volume
|
|
||||||
uploaded_filenames = []
|
|
||||||
for upload_file in upload_files:
|
|
||||||
audio_suffix = upload_file.filename.split(".")[-1]
|
|
||||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
|
||||||
|
|
||||||
# Generate unique filename
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
print(f"Writing file to: {file_path}")
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
content = upload_file.file.read()
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
uploaded_filenames.append(unique_filename)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use A10G live transcriber for per-file transcription
|
|
||||||
if batch and len(upload_files) > 1:
|
|
||||||
# Use batch transcription
|
|
||||||
func = transcriber_live.transcribe_batch.spawn(
|
|
||||||
filenames=uploaded_filenames,
|
|
||||||
)
|
|
||||||
results = func.get()
|
|
||||||
return {"results": results}
|
|
||||||
|
|
||||||
# Per-file transcription
|
|
||||||
results = []
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
func = transcriber_live.transcribe_segment.spawn(
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
result["filename"] = filename
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return {"results": results} if len(results) > 1 else results[0]
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
print(f"Deleting file: {file_path}")
|
|
||||||
os.remove(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deleting {filename}: {e}")
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe_from_url(
|
|
||||||
audio_file_url: str = Body(
|
|
||||||
..., description="URL of the audio file to transcribe"
|
|
||||||
),
|
|
||||||
model: str = Body(MODEL_NAME),
|
|
||||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
|
||||||
timestamp_offset: float = Body(0.0),
|
|
||||||
):
|
|
||||||
# Parakeet only supports English
|
|
||||||
if language != "en":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
|
||||||
)
|
|
||||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
|
||||||
|
|
||||||
try:
|
|
||||||
func = transcriber_file.transcribe_segment.spawn(
|
|
||||||
filename=unique_filename,
|
|
||||||
timestamp_offset=timestamp_offset,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
print(f"Deleting file: {file_path}")
|
|
||||||
os.remove(file_path)
|
|
||||||
upload_volume.commit()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error cleaning up {unique_filename}: {e}")
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
class NoStdStreams:
|
|
||||||
def __init__(self):
|
|
||||||
self.devnull = open(os.devnull, "w")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
|
||||||
self._stdout.flush()
|
|
||||||
self._stderr.flush()
|
|
||||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
|
||||||
self.devnull.close()
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
REFLECTOR_GPU_APIKEY=
|
|
||||||
HF_TOKEN=
|
|
||||||
38
gpu/self_hosted/.gitignore
vendored
38
gpu/self_hosted/.gitignore
vendored
@@ -1,38 +0,0 @@
|
|||||||
cache/
|
|
||||||
|
|
||||||
# OS / Editor
|
|
||||||
.DS_Store
|
|
||||||
.vscode/
|
|
||||||
.idea/
|
|
||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# Env and secrets
|
|
||||||
.env
|
|
||||||
*.env
|
|
||||||
*.secret
|
|
||||||
HF_TOKEN
|
|
||||||
REFLECTOR_GPU_APIKEY
|
|
||||||
|
|
||||||
# Virtual env / uv
|
|
||||||
.venv/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
uv/
|
|
||||||
|
|
||||||
# Build / dist
|
|
||||||
build/
|
|
||||||
dist/
|
|
||||||
.eggs/
|
|
||||||
*.egg-info/
|
|
||||||
|
|
||||||
# Coverage / test
|
|
||||||
.pytest_cache/
|
|
||||||
.coverage*
|
|
||||||
htmlcov/
|
|
||||||
|
|
||||||
# Logs
|
|
||||||
*.log
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
FROM python:3.12-slim
|
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
|
||||||
UV_LINK_MODE=copy \
|
|
||||||
UV_NO_CACHE=1
|
|
||||||
|
|
||||||
WORKDIR /tmp
|
|
||||||
RUN apt-get update \
|
|
||||||
&& apt-get install -y \
|
|
||||||
ffmpeg \
|
|
||||||
curl \
|
|
||||||
ca-certificates \
|
|
||||||
gnupg \
|
|
||||||
wget \
|
|
||||||
&& apt-get clean
|
|
||||||
# Add NVIDIA CUDA repo for Debian 12 (bookworm) and install cuDNN 9 for CUDA 12
|
|
||||||
ADD https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb /cuda-keyring.deb
|
|
||||||
RUN dpkg -i /cuda-keyring.deb \
|
|
||||||
&& rm /cuda-keyring.deb \
|
|
||||||
&& apt-get update \
|
|
||||||
&& apt-get install -y --no-install-recommends \
|
|
||||||
cuda-cudart-12-6 \
|
|
||||||
libcublas-12-6 \
|
|
||||||
libcudnn9-cuda-12 \
|
|
||||||
libcudnn9-dev-cuda-12 \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
|
||||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
|
||||||
ENV PATH="/root/.local/bin/:$PATH"
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
|
||||||
|
|
||||||
RUN mkdir -p /app
|
|
||||||
WORKDIR /app
|
|
||||||
COPY pyproject.toml uv.lock /app/
|
|
||||||
|
|
||||||
|
|
||||||
COPY ./app /app/app
|
|
||||||
COPY ./main.py /app/
|
|
||||||
COPY ./runserver.sh /app/
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
CMD ["sh", "/app/runserver.sh"]
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
# Self-hosted Model API
|
|
||||||
|
|
||||||
Run transcription, translation, and diarization services compatible with Reflector's GPU Model API. Works on CPU or GPU.
|
|
||||||
|
|
||||||
Environment variables
|
|
||||||
|
|
||||||
- REFLECTOR_GPU_APIKEY: Optional Bearer token. If unset, auth is disabled.
|
|
||||||
- HF_TOKEN: Optional. Required for diarization to download pyannote pipelines
|
|
||||||
|
|
||||||
Requirements
|
|
||||||
|
|
||||||
- FFmpeg must be installed and on PATH (used for URL-based and segmented transcription)
|
|
||||||
- Python 3.12+
|
|
||||||
- NVIDIA GPU optional. If available, it will be used automatically
|
|
||||||
|
|
||||||
Local run
|
|
||||||
Set env vars in self_hosted/.env file
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
|
||||||
|
|
||||||
Authentication
|
|
||||||
|
|
||||||
- If REFLECTOR_GPU_APIKEY is set, include header: Authorization: Bearer <key>
|
|
||||||
|
|
||||||
Endpoints
|
|
||||||
|
|
||||||
- POST /v1/audio/transcriptions
|
|
||||||
|
|
||||||
- multipart/form-data
|
|
||||||
- fields: file (single file) OR files[] (multiple files), language, batch (true/false)
|
|
||||||
- response: single { text, words, filename } or { results: [ ... ] }
|
|
||||||
|
|
||||||
- POST /v1/audio/transcriptions-from-url
|
|
||||||
|
|
||||||
- application/json
|
|
||||||
- body: { audio_file_url, language, timestamp_offset }
|
|
||||||
- response: { text, words }
|
|
||||||
|
|
||||||
- POST /translate
|
|
||||||
|
|
||||||
- text: query parameter
|
|
||||||
- body (application/json): { source_language, target_language }
|
|
||||||
- response: { text: { <src>: original, <tgt>: translated } }
|
|
||||||
|
|
||||||
- POST /diarize
|
|
||||||
- query parameters: audio_file_url, timestamp (optional)
|
|
||||||
- requires HF_TOKEN to be set (for pyannote)
|
|
||||||
- response: { diarization: [ { start, end, speaker } ] }
|
|
||||||
|
|
||||||
OpenAPI docs
|
|
||||||
|
|
||||||
- Visit /docs when the server is running
|
|
||||||
|
|
||||||
Docker
|
|
||||||
|
|
||||||
- Not yet provided in this directory. A Dockerfile will be added later. For now, use Local run above
|
|
||||||
|
|
||||||
Conformance tests
|
|
||||||
|
|
||||||
# From this directory
|
|
||||||
|
|
||||||
TRANSCRIPT_URL=http://localhost:8000 \
|
|
||||||
TRANSCRIPT_API_KEY=dev-key \
|
|
||||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_transcript.py
|
|
||||||
|
|
||||||
TRANSLATION_URL=http://localhost:8000 \
|
|
||||||
TRANSLATION_API_KEY=dev-key \
|
|
||||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_translation.py
|
|
||||||
|
|
||||||
DIARIZATION_URL=http://localhost:8000 \
|
|
||||||
DIARIZATION_API_KEY=dev-key \
|
|
||||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_diarization.py
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
required_key = os.environ.get("REFLECTOR_GPU_APIKEY")
|
|
||||||
if not required_key:
|
|
||||||
return
|
|
||||||
if apikey == required_key:
|
|
||||||
return
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
||||||
SAMPLE_RATE = 16000
|
|
||||||
VAD_CONFIG = {
|
|
||||||
"batch_max_duration": 30.0,
|
|
||||||
"silence_padding": 0.5,
|
|
||||||
"window_size": 512,
|
|
||||||
}
|
|
||||||
|
|
||||||
# App-level paths
|
|
||||||
UPLOADS_PATH = Path("/tmp/whisper-uploads")
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
from .routers.diarization import router as diarization_router
|
|
||||||
from .routers.transcription import router as transcription_router
|
|
||||||
from .routers.translation import router as translation_router
|
|
||||||
from .services.transcriber import WhisperService
|
|
||||||
from .services.diarizer import PyannoteDiarizationService
|
|
||||||
from .utils import ensure_dirs
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
ensure_dirs()
|
|
||||||
whisper_service = WhisperService()
|
|
||||||
whisper_service.load()
|
|
||||||
app.state.whisper = whisper_service
|
|
||||||
diarization_service = PyannoteDiarizationService()
|
|
||||||
diarization_service.load()
|
|
||||||
app.state.diarizer = diarization_service
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
app.include_router(transcription_router)
|
|
||||||
app.include_router(translation_router)
|
|
||||||
app.include_router(diarization_router)
|
|
||||||
return app
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..auth import apikey_auth
|
|
||||||
from ..services.diarizer import PyannoteDiarizationService
|
|
||||||
from ..utils import download_audio_file
|
|
||||||
|
|
||||||
router = APIRouter(tags=["diarization"])
|
|
||||||
|
|
||||||
|
|
||||||
class DiarizationSegment(BaseModel):
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
speaker: int
|
|
||||||
|
|
||||||
|
|
||||||
class DiarizationResponse(BaseModel):
|
|
||||||
diarization: List[DiarizationSegment]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/diarize", dependencies=[Depends(apikey_auth)], response_model=DiarizationResponse
|
|
||||||
)
|
|
||||||
def diarize(request: Request, audio_file_url: str, timestamp: float = 0.0):
|
|
||||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
|
||||||
file_path = str(file_path)
|
|
||||||
diarizer: PyannoteDiarizationService = request.app.state.diarizer
|
|
||||||
return diarizer.diarize_file(file_path, timestamp=timestamp)
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
import uuid
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, UploadFile
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pathlib import Path
|
|
||||||
from ..auth import apikey_auth
|
|
||||||
from ..config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
|
||||||
from ..services.transcriber import MODEL_NAME
|
|
||||||
from ..utils import cleanup_uploaded_files, download_audio_file
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/v1/audio", tags=["transcription"])
|
|
||||||
|
|
||||||
|
|
||||||
class WordTiming(BaseModel):
|
|
||||||
word: str
|
|
||||||
start: float
|
|
||||||
end: float
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptResult(BaseModel):
|
|
||||||
text: str
|
|
||||||
words: list[WordTiming]
|
|
||||||
filename: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptBatchResponse(BaseModel):
|
|
||||||
results: list[TranscriptResult]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/transcriptions",
|
|
||||||
dependencies=[Depends(apikey_auth)],
|
|
||||||
response_model=Union[TranscriptResult, TranscriptBatchResponse],
|
|
||||||
)
|
|
||||||
def transcribe(
|
|
||||||
request: Request,
|
|
||||||
file: UploadFile = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
model: str = Form(MODEL_NAME),
|
|
||||||
language: str = Form("en"),
|
|
||||||
batch: bool = Form(False),
|
|
||||||
):
|
|
||||||
service = request.app.state.whisper
|
|
||||||
if not file and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
|
||||||
)
|
|
||||||
if batch and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Batch transcription requires 'files'"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_files = [file] if file else files
|
|
||||||
|
|
||||||
uploaded_paths: list[Path] = []
|
|
||||||
with cleanup_uploaded_files(uploaded_paths):
|
|
||||||
for upload_file in upload_files:
|
|
||||||
audio_suffix = upload_file.filename.split(".")[-1].lower()
|
|
||||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=(
|
|
||||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path = UPLOADS_PATH / unique_filename
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
content = upload_file.file.read()
|
|
||||||
f.write(content)
|
|
||||||
uploaded_paths.append(file_path)
|
|
||||||
|
|
||||||
if batch and len(upload_files) > 1:
|
|
||||||
results = []
|
|
||||||
for path in uploaded_paths:
|
|
||||||
result = service.transcribe_file(str(path), language=language)
|
|
||||||
result["filename"] = path.name
|
|
||||||
results.append(result)
|
|
||||||
return {"results": results}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for path in uploaded_paths:
|
|
||||||
result = service.transcribe_file(str(path), language=language)
|
|
||||||
result["filename"] = path.name
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return {"results": results} if len(results) > 1 else results[0]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/transcriptions-from-url",
|
|
||||||
dependencies=[Depends(apikey_auth)],
|
|
||||||
response_model=TranscriptResult,
|
|
||||||
)
|
|
||||||
def transcribe_from_url(
|
|
||||||
request: Request,
|
|
||||||
audio_file_url: str = Body(..., description="URL of the audio file to transcribe"),
|
|
||||||
model: str = Body(MODEL_NAME),
|
|
||||||
language: str = Body("en"),
|
|
||||||
timestamp_offset: float = Body(0.0),
|
|
||||||
):
|
|
||||||
service = request.app.state.whisper
|
|
||||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
|
||||||
file_path = str(file_path)
|
|
||||||
result = service.transcribe_vad_url_segment(
|
|
||||||
file_path=file_path, timestamp_offset=timestamp_offset, language=language
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..auth import apikey_auth
|
|
||||||
from ..services.translator import TextTranslatorService
|
|
||||||
|
|
||||||
router = APIRouter(tags=["translation"])
|
|
||||||
|
|
||||||
translator = TextTranslatorService()
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationResponse(BaseModel):
|
|
||||||
text: Dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/translate",
|
|
||||||
dependencies=[Depends(apikey_auth)],
|
|
||||||
response_model=TranslationResponse,
|
|
||||||
)
|
|
||||||
def translate(
|
|
||||||
text: str,
|
|
||||||
source_language: str = Body("en"),
|
|
||||||
target_language: str = Body("fr"),
|
|
||||||
):
|
|
||||||
return translator.translate(text, source_language, target_language)
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import os
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
|
|
||||||
|
|
||||||
class PyannoteDiarizationService:
|
|
||||||
def __init__(self):
|
|
||||||
self._pipeline = None
|
|
||||||
self._device = "cpu"
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self._pipeline = Pipeline.from_pretrained(
|
|
||||||
"pyannote/speaker-diarization-3.1",
|
|
||||||
use_auth_token=os.environ.get("HF_TOKEN"),
|
|
||||||
)
|
|
||||||
self._pipeline.to(torch.device(self._device))
|
|
||||||
|
|
||||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
|
||||||
if self._pipeline is None:
|
|
||||||
self.load()
|
|
||||||
waveform, sample_rate = torchaudio.load(file_path)
|
|
||||||
with self._lock:
|
|
||||||
diarization = self._pipeline(
|
|
||||||
{"waveform": waveform, "sample_rate": sample_rate}
|
|
||||||
)
|
|
||||||
words = []
|
|
||||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
words.append(
|
|
||||||
{
|
|
||||||
"start": round(timestamp + diarization_segment.start, 3),
|
|
||||||
"end": round(timestamp + diarization_segment.end, 3),
|
|
||||||
"speaker": int(speaker[-2:])
|
|
||||||
if speaker and speaker[-2:].isdigit()
|
|
||||||
else 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"diarization": words}
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import faster_whisper
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from silero_vad import VADIterator, load_silero_vad
|
|
||||||
|
|
||||||
from ..config import SAMPLE_RATE, VAD_CONFIG
|
|
||||||
|
|
||||||
# Whisper configuration (service-local defaults)
|
|
||||||
MODEL_NAME = "large-v2"
|
|
||||||
# None delegates compute type to runtime: float16 on CUDA, int8 on CPU
|
|
||||||
MODEL_COMPUTE_TYPE = None
|
|
||||||
MODEL_NUM_WORKERS = 1
|
|
||||||
CACHE_PATH = os.path.join(os.path.expanduser("~"), ".cache", "reflector-whisper")
|
|
||||||
from ..utils import NoStdStreams
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperService:
|
|
||||||
def __init__(self):
|
|
||||||
self.model = None
|
|
||||||
self.device = "cpu"
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
compute_type = MODEL_COMPUTE_TYPE or (
|
|
||||||
"float16" if self.device == "cuda" else "int8"
|
|
||||||
)
|
|
||||||
self.model = faster_whisper.WhisperModel(
|
|
||||||
MODEL_NAME,
|
|
||||||
device=self.device,
|
|
||||||
compute_type=compute_type,
|
|
||||||
num_workers=MODEL_NUM_WORKERS,
|
|
||||||
download_root=CACHE_PATH,
|
|
||||||
)
|
|
||||||
|
|
||||||
def pad_audio(self, audio_array, sample_rate: int = SAMPLE_RATE):
|
|
||||||
audio_duration = len(audio_array) / sample_rate
|
|
||||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
|
||||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
|
||||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
return np.concatenate([audio_array, silence])
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
def enforce_word_timing_constraints(self, words: list[dict]) -> list[dict]:
|
|
||||||
if len(words) <= 1:
|
|
||||||
return words
|
|
||||||
enforced: list[dict] = []
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
current = dict(word)
|
|
||||||
if i < len(words) - 1:
|
|
||||||
next_start = words[i + 1]["start"]
|
|
||||||
if current["end"] > next_start:
|
|
||||||
current["end"] = next_start
|
|
||||||
enforced.append(current)
|
|
||||||
return enforced
|
|
||||||
|
|
||||||
def transcribe_file(self, file_path: str, language: str = "en") -> dict:
|
|
||||||
input_for_model: str | "object" = file_path
|
|
||||||
try:
|
|
||||||
audio_array, _sample_rate = librosa.load(
|
|
||||||
file_path, sr=SAMPLE_RATE, mono=True
|
|
||||||
)
|
|
||||||
if len(audio_array) / float(SAMPLE_RATE) < VAD_CONFIG["silence_padding"]:
|
|
||||||
input_for_model = self.pad_audio(audio_array, SAMPLE_RATE)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
segments, _ = self.model.transcribe(
|
|
||||||
input_for_model,
|
|
||||||
language=language,
|
|
||||||
beam_size=5,
|
|
||||||
word_timestamps=True,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
|
|
||||||
segments = list(segments)
|
|
||||||
text = "".join(segment.text for segment in segments).strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": word.word,
|
|
||||||
"start": round(float(word.start), 2),
|
|
||||||
"end": round(float(word.end), 2),
|
|
||||||
}
|
|
||||||
for segment in segments
|
|
||||||
for word in segment.words
|
|
||||||
]
|
|
||||||
words = self.enforce_word_timing_constraints(words)
|
|
||||||
return {"text": text, "words": words}
|
|
||||||
|
|
||||||
def transcribe_vad_url_segment(
|
|
||||||
self, file_path: str, timestamp_offset: float = 0.0, language: str = "en"
|
|
||||||
) -> dict:
|
|
||||||
def load_audio_via_ffmpeg(input_path: str, sample_rate: int) -> np.ndarray:
|
|
||||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
|
||||||
cmd = [
|
|
||||||
ffmpeg_bin,
|
|
||||||
"-nostdin",
|
|
||||||
"-threads",
|
|
||||||
"1",
|
|
||||||
"-i",
|
|
||||||
input_path,
|
|
||||||
"-f",
|
|
||||||
"f32le",
|
|
||||||
"-acodec",
|
|
||||||
"pcm_f32le",
|
|
||||||
"-ac",
|
|
||||||
"1",
|
|
||||||
"-ar",
|
|
||||||
str(sample_rate),
|
|
||||||
"pipe:1",
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
proc = subprocess.run(
|
|
||||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=f"ffmpeg failed: {e}")
|
|
||||||
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
def vad_segments(
|
|
||||||
audio_array,
|
|
||||||
sample_rate: int = SAMPLE_RATE,
|
|
||||||
window_size: int = VAD_CONFIG["window_size"],
|
|
||||||
) -> Generator[tuple[float, float], None, None]:
|
|
||||||
vad_model = load_silero_vad(onnx=False)
|
|
||||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
|
||||||
start = None
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(
|
|
||||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
|
||||||
)
|
|
||||||
speech = iterator(chunk)
|
|
||||||
if not speech:
|
|
||||||
continue
|
|
||||||
if "start" in speech:
|
|
||||||
start = speech["start"]
|
|
||||||
continue
|
|
||||||
if "end" in speech and start is not None:
|
|
||||||
end = speech["end"]
|
|
||||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
|
||||||
start = None
|
|
||||||
iterator.reset_states()
|
|
||||||
|
|
||||||
audio_array = load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
|
||||||
|
|
||||||
merged_batches: list[tuple[float, float]] = []
|
|
||||||
batch_start = None
|
|
||||||
batch_end = None
|
|
||||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
|
||||||
for seg_start, seg_end in vad_segments(audio_array):
|
|
||||||
if batch_start is None:
|
|
||||||
batch_start, batch_end = seg_start, seg_end
|
|
||||||
continue
|
|
||||||
if seg_end - batch_start <= max_duration:
|
|
||||||
batch_end = seg_end
|
|
||||||
else:
|
|
||||||
merged_batches.append((batch_start, batch_end))
|
|
||||||
batch_start, batch_end = seg_start, seg_end
|
|
||||||
if batch_start is not None and batch_end is not None:
|
|
||||||
merged_batches.append((batch_start, batch_end))
|
|
||||||
|
|
||||||
all_text = []
|
|
||||||
all_words = []
|
|
||||||
for start_time, end_time in merged_batches:
|
|
||||||
s_idx = int(start_time * SAMPLE_RATE)
|
|
||||||
e_idx = int(end_time * SAMPLE_RATE)
|
|
||||||
segment = audio_array[s_idx:e_idx]
|
|
||||||
segment = self.pad_audio(segment, SAMPLE_RATE)
|
|
||||||
with self.lock:
|
|
||||||
segments, _ = self.model.transcribe(
|
|
||||||
segment,
|
|
||||||
language=language,
|
|
||||||
beam_size=5,
|
|
||||||
word_timestamps=True,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
segments = list(segments)
|
|
||||||
text = "".join(seg.text for seg in segments).strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": w.word,
|
|
||||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
|
||||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
|
||||||
}
|
|
||||||
for seg in segments
|
|
||||||
for w in seg.words
|
|
||||||
]
|
|
||||||
if text:
|
|
||||||
all_text.append(text)
|
|
||||||
all_words.extend(words)
|
|
||||||
|
|
||||||
all_words = self.enforce_word_timing_constraints(all_words)
|
|
||||||
return {"text": " ".join(all_text), "words": all_words}
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import threading
|
|
||||||
|
|
||||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
|
||||||
|
|
||||||
|
|
||||||
class TextTranslatorService:
|
|
||||||
"""Simple text-to-text translator using HuggingFace MarianMT models.
|
|
||||||
|
|
||||||
This mirrors the modal translator API shape but uses text translation only.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._pipeline = None
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
|
||||||
# Pick a default MarianMT model pair if available; fall back to Helsinki-NLP en->fr
|
|
||||||
model_name = self._resolve_model_name(source_language, target_language)
|
|
||||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
|
||||||
model = MarianMTModel.from_pretrained(model_name)
|
|
||||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
|
||||||
# Minimal mapping; extend as needed
|
|
||||||
pair = (src.lower(), tgt.lower())
|
|
||||||
mapping = {
|
|
||||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
|
||||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
|
||||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
|
||||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
|
||||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
|
||||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
|
||||||
}
|
|
||||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
|
||||||
|
|
||||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
|
||||||
if self._pipeline is None:
|
|
||||||
self.load(source_language, target_language)
|
|
||||||
with self._lock:
|
|
||||||
results = self._pipeline(
|
|
||||||
text, src_lang=source_language, tgt_lang=target_language
|
|
||||||
)
|
|
||||||
translated = results[0]["translation_text"] if results else ""
|
|
||||||
return {"text": {source_language: text, target_language: translated}}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Mapping
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
from .config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NoStdStreams:
|
|
||||||
def __init__(self):
|
|
||||||
self.devnull = open(os.devnull, "w")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
|
||||||
self._stdout.flush()
|
|
||||||
self._stderr.flush()
|
|
||||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
|
||||||
self.devnull.close()
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_dirs():
|
|
||||||
UPLOADS_PATH.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> str:
|
|
||||||
url_path = urlparse(url).path
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return ext
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return "mp3"
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return "wav"
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return "mp4"
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=(
|
|
||||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_uploads(audio_file_url: str) -> tuple[Path, str]:
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path: Path = UPLOADS_PATH / unique_filename
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
return file_path, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def download_audio_file(audio_file_url: str):
|
|
||||||
"""Download an audio file to UPLOADS_PATH and remove it after use.
|
|
||||||
|
|
||||||
Yields (file_path: Path, audio_suffix: str).
|
|
||||||
"""
|
|
||||||
file_path, audio_suffix = download_audio_to_uploads(audio_file_url)
|
|
||||||
try:
|
|
||||||
yield file_path, audio_suffix
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
file_path.unlink(missing_ok=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error deleting temporary file %s: %s", file_path, e)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def cleanup_uploaded_files(file_paths: list[Path]):
|
|
||||||
"""Ensure provided file paths are removed after use.
|
|
||||||
|
|
||||||
The provided list can be populated inside the context; all present entries
|
|
||||||
at exit will be deleted.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
yield file_paths
|
|
||||||
finally:
|
|
||||||
for path in list(file_paths):
|
|
||||||
try:
|
|
||||||
path.unlink(missing_ok=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error deleting temporary file %s: %s", path, e)
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
services:
|
|
||||||
reflector_gpu:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
env_file:
|
|
||||||
- .env
|
|
||||||
volumes:
|
|
||||||
- ./cache:/root/.cache
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from app.factory import create_app
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "reflector-gpu"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Self-hosted GPU service for speech transcription, diarization, and translation via FastAPI."
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"fastapi[standard]>=0.116.1",
|
|
||||||
"uvicorn[standard]>=0.30.0",
|
|
||||||
"torch>=2.3.0",
|
|
||||||
"faster-whisper>=1.1.0",
|
|
||||||
"librosa==0.10.1",
|
|
||||||
"numpy<2",
|
|
||||||
"silero-vad==5.1.0",
|
|
||||||
"transformers>=4.35.0",
|
|
||||||
"sentencepiece",
|
|
||||||
"pyannote.audio==3.1.0",
|
|
||||||
"torchaudio>=2.3.0",
|
|
||||||
]
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
set -e
|
|
||||||
|
|
||||||
export PATH="/root/.local/bin:$PATH"
|
|
||||||
cd /app
|
|
||||||
|
|
||||||
# Install Python dependencies at runtime (first run or when FORCE_SYNC=1)
|
|
||||||
if [ ! -d "/app/.venv" ] || [ "$FORCE_SYNC" = "1" ]; then
|
|
||||||
echo "[startup] Installing Python dependencies with uv..."
|
|
||||||
uv sync --compile-bytecode --locked
|
|
||||||
else
|
|
||||||
echo "[startup] Using existing virtual environment at /app/.venv"
|
|
||||||
fi
|
|
||||||
|
|
||||||
exec uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
|
||||||
|
|
||||||
|
|
||||||
3013
gpu/self_hosted/uv.lock
generated
3013
gpu/self_hosted/uv.lock
generated
File diff suppressed because it is too large
Load Diff
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,8 +176,7 @@ artefacts/
|
|||||||
audio_*.wav
|
audio_*.wav
|
||||||
|
|
||||||
# ignore local database
|
# ignore local database
|
||||||
*.sqlite3
|
reflector.sqlite3
|
||||||
*.db
|
|
||||||
data/
|
data/
|
||||||
|
|
||||||
dump.rdb
|
dump.rdb
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
UV_LINK_MODE=copy \
|
UV_LINK_MODE=copy
|
||||||
UV_NO_CACHE=1
|
|
||||||
|
|
||||||
# builder install base dependencies
|
# builder install base dependencies
|
||||||
WORKDIR /tmp
|
WORKDIR /tmp
|
||||||
@@ -14,8 +13,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
|||||||
# install application dependencies
|
# install application dependencies
|
||||||
RUN mkdir -p /app
|
RUN mkdir -p /app
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY pyproject.toml uv.lock README.md /app/
|
COPY pyproject.toml uv.lock /app/
|
||||||
RUN uv sync --compile-bytecode --locked
|
RUN touch README.md && env uv sync --compile-bytecode --locked
|
||||||
|
|
||||||
# pre-download nltk packages
|
# pre-download nltk packages
|
||||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||||
@@ -27,15 +26,4 @@ COPY migrations /app/migrations
|
|||||||
COPY reflector /app/reflector
|
COPY reflector /app/reflector
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Create symlink for libgomp if it doesn't exist (for ARM64 compatibility)
|
|
||||||
RUN if [ "$(uname -m)" = "aarch64" ] && [ ! -f /usr/lib/libgomp.so.1 ]; then \
|
|
||||||
LIBGOMP_PATH=$(find /app/.venv/lib -path "*/torch.libs/libgomp*.so.*" 2>/dev/null | head -n1); \
|
|
||||||
if [ -n "$LIBGOMP_PATH" ]; then \
|
|
||||||
ln -sf "$LIBGOMP_PATH" /usr/lib/libgomp.so.1; \
|
|
||||||
fi \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Pre-check just to make sure the image will not fail
|
|
||||||
RUN uv run python -c "import silero_vad.model"
|
|
||||||
|
|
||||||
CMD ["./runserver.sh"]
|
CMD ["./runserver.sh"]
|
||||||
|
|||||||
@@ -40,5 +40,3 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
|||||||
```bash
|
```bash
|
||||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||||
```
|
```
|
||||||
|
|
||||||
.
|
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
|
|
||||||
|
|
||||||
## Problem Summary
|
|
||||||
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
|
|
||||||
- `RuntimeError: Task got Future attached to a different loop`
|
|
||||||
- `RuntimeError: Event loop is closed`
|
|
||||||
|
|
||||||
## Root Cause Analysis
|
|
||||||
|
|
||||||
### 1. Multiple Event Loop Creation Points
|
|
||||||
|
|
||||||
The test environment creates event loops at different scopes:
|
|
||||||
|
|
||||||
1. **Session-scoped loop** (conftest.py:27-34):
|
|
||||||
- Created once per test session
|
|
||||||
- Used by session-scoped fixtures
|
|
||||||
- Closed after all tests complete
|
|
||||||
|
|
||||||
2. **Function-scoped loop** (pytest-asyncio default):
|
|
||||||
- Created for each async test function
|
|
||||||
- This is the loop that runs the actual test
|
|
||||||
- Closed immediately after test completes
|
|
||||||
|
|
||||||
3. **AsyncPG internal loop**:
|
|
||||||
- AsyncPG connections store a reference to the loop they were created with
|
|
||||||
- Used for connection lifecycle management
|
|
||||||
|
|
||||||
### 2. Event Loop Lifecycle Mismatch
|
|
||||||
|
|
||||||
The issue occurs because:
|
|
||||||
|
|
||||||
1. **Session fixture creates database connection** on session-scoped loop
|
|
||||||
2. **Test runs** on function-scoped loop (different from session loop)
|
|
||||||
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
|
|
||||||
4. **AsyncPG connection** still references the function-scoped loop which is now closed
|
|
||||||
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
|
|
||||||
|
|
||||||
### 3. Configuration Issues
|
|
||||||
|
|
||||||
Current pytest configuration:
|
|
||||||
- `asyncio_mode = "auto"` in pyproject.toml
|
|
||||||
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
|
|
||||||
- `asyncio_default_test_loop_scope=function` (shown in test output)
|
|
||||||
|
|
||||||
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
|
|
||||||
|
|
||||||
## Solutions
|
|
||||||
|
|
||||||
### Option 1: Align Loop Scopes (Recommended)
|
|
||||||
Change pytest-asyncio configuration to use consistent loop scopes:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# pyproject.toml
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
asyncio_default_fixture_loop_scope = "function" # Change from session to function
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Use Function-Scoped Database Fixture
|
|
||||||
Change the `session` fixture scope from session to function:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest_asyncio.fixture # Remove scope="session"
|
|
||||||
async def session(setup_database):
|
|
||||||
# ... existing code ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 3: Explicit Loop Management
|
|
||||||
Ensure all async operations use the same loop:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def session(setup_database, event_loop):
|
|
||||||
# Force using the current event loop
|
|
||||||
engine = create_async_engine(
|
|
||||||
settings.DATABASE_URL,
|
|
||||||
echo=False,
|
|
||||||
poolclass=NullPool,
|
|
||||||
connect_args={"loop": event_loop} # Pass explicit loop
|
|
||||||
)
|
|
||||||
# ... rest of fixture ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 4: Upgrade pytest-asyncio
|
|
||||||
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
|
|
||||||
|
|
||||||
## Immediate Workaround
|
|
||||||
|
|
||||||
For the test to run cleanly without the teardown error, you can:
|
|
||||||
|
|
||||||
1. Add explicit cleanup in the test:
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_attendee_parsing_bug(session):
|
|
||||||
# ... existing test code ...
|
|
||||||
|
|
||||||
# Explicit cleanup before fixture teardown
|
|
||||||
await session.commit() # or await session.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Or suppress the teardown error (not recommended for production):
|
|
||||||
```python
|
|
||||||
@pytest.fixture
|
|
||||||
async def session(setup_database):
|
|
||||||
# ... existing setup ...
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
await session.rollback()
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "Event loop is closed" not in str(e):
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Recommendation
|
|
||||||
|
|
||||||
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
# Data Retention and Cleanup
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
For public instances of Reflector, a data retention policy is automatically enforced to delete anonymous user data after a configurable period (default: 7 days). This ensures compliance with privacy expectations and prevents unbounded storage growth.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
- `PUBLIC_MODE` (bool): Must be set to `true` to enable automatic cleanup
|
|
||||||
- `PUBLIC_DATA_RETENTION_DAYS` (int): Number of days to retain anonymous data (default: 7)
|
|
||||||
|
|
||||||
### What Gets Deleted
|
|
||||||
|
|
||||||
When data reaches the retention period, the following items are automatically removed:
|
|
||||||
|
|
||||||
1. **Transcripts** from anonymous users (where `user_id` is NULL):
|
|
||||||
- Database records
|
|
||||||
- Local files (audio.wav, audio.mp3, audio.json waveform)
|
|
||||||
- Storage files (cloud storage if configured)
|
|
||||||
|
|
||||||
## Automatic Cleanup
|
|
||||||
|
|
||||||
### Celery Beat Schedule
|
|
||||||
|
|
||||||
When `PUBLIC_MODE=true`, a Celery beat task runs daily at 3 AM to clean up old data:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Automatically scheduled when PUBLIC_MODE=true
|
|
||||||
"cleanup_old_public_data": {
|
|
||||||
"task": "reflector.worker.cleanup.cleanup_old_public_data",
|
|
||||||
"schedule": crontab(hour=3, minute=0), # Daily at 3 AM
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running the Worker
|
|
||||||
|
|
||||||
Ensure both Celery worker and beat scheduler are running:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start Celery worker
|
|
||||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
|
||||||
|
|
||||||
# Start Celery beat scheduler (in another terminal)
|
|
||||||
uv run celery -A reflector.worker.app beat
|
|
||||||
```
|
|
||||||
|
|
||||||
## Manual Cleanup
|
|
||||||
|
|
||||||
For testing or manual intervention, use the cleanup tool:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Delete data older than 7 days (default)
|
|
||||||
uv run python -m reflector.tools.cleanup_old_data
|
|
||||||
|
|
||||||
# Delete data older than 30 days
|
|
||||||
uv run python -m reflector.tools.cleanup_old_data --days 30
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: The manual tool uses the same implementation as the Celery worker task to ensure consistency.
|
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
1. **User Data Deletion**: Only anonymous data (where `user_id` is NULL) is deleted. Authenticated user data is preserved.
|
|
||||||
|
|
||||||
2. **Storage Cleanup**: The system properly cleans up both local files and cloud storage when configured.
|
|
||||||
|
|
||||||
3. **Error Handling**: If individual deletions fail, the cleanup continues and logs errors. Failed deletions are reported in the task output.
|
|
||||||
|
|
||||||
4. **Public Instance Only**: The automatic cleanup task only runs when `PUBLIC_MODE=true` to prevent accidental data loss in private deployments.
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Run the cleanup tests:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run pytest tests/test_cleanup.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## Monitoring
|
|
||||||
|
|
||||||
Check Celery logs for cleanup task execution:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Look for cleanup task logs
|
|
||||||
grep "cleanup_old_public_data" celery.log
|
|
||||||
grep "Starting cleanup of old public data" celery.log
|
|
||||||
```
|
|
||||||
|
|
||||||
Task statistics are logged after each run:
|
|
||||||
- Number of transcripts deleted
|
|
||||||
- Number of meetings deleted
|
|
||||||
- Number of orphaned recordings deleted
|
|
||||||
- Any errors encountered
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
## Reflector GPU Transcription API (Specification)
|
|
||||||
|
|
||||||
This document defines the Reflector GPU transcription API that all implementations must adhere to. Current implementations include NVIDIA Parakeet (NeMo) and Whisper (faster-whisper), both deployed on Modal.com. The API surface and response shapes are OpenAI/Whisper-compatible, so clients can switch implementations by changing only the base URL.
|
|
||||||
|
|
||||||
### Base URL and Authentication
|
|
||||||
|
|
||||||
- Example base URLs (Modal web endpoints):
|
|
||||||
|
|
||||||
- Parakeet: `https://<account>--reflector-transcriber-parakeet-web.modal.run`
|
|
||||||
- Whisper: `https://<account>--reflector-transcriber-web.modal.run`
|
|
||||||
|
|
||||||
- All endpoints are served under `/v1` and require a Bearer token:
|
|
||||||
|
|
||||||
```
|
|
||||||
Authorization: Bearer <REFLECTOR_GPU_APIKEY>
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: To switch implementations, deploy the desired variant and point `TRANSCRIPT_URL` to its base URL. The API is identical.
|
|
||||||
|
|
||||||
### Supported file types
|
|
||||||
|
|
||||||
`mp3, mp4, mpeg, mpga, m4a, wav, webm`
|
|
||||||
|
|
||||||
### Models and languages
|
|
||||||
|
|
||||||
- Parakeet (NVIDIA NeMo): default `nvidia/parakeet-tdt-0.6b-v2`
|
|
||||||
- Language support: only `en`. Other languages return HTTP 400.
|
|
||||||
- Whisper (faster-whisper): default `large-v2` (or deployment-specific)
|
|
||||||
- Language support: multilingual (per Whisper model capabilities).
|
|
||||||
|
|
||||||
Note: The `model` parameter is accepted by all implementations for interface parity. Some backends may treat it as informational.
|
|
||||||
|
|
||||||
### Endpoints
|
|
||||||
|
|
||||||
#### POST /v1/audio/transcriptions
|
|
||||||
|
|
||||||
Transcribe one or more uploaded audio files.
|
|
||||||
|
|
||||||
Request: multipart/form-data
|
|
||||||
|
|
||||||
- `file` (File) — optional. Single file to transcribe.
|
|
||||||
- `files` (File[]) — optional. One or more files to transcribe.
|
|
||||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
|
||||||
- `language` (string) — optional, defaults to `en`.
|
|
||||||
- Parakeet: only `en` is accepted; other values return HTTP 400
|
|
||||||
- Whisper: model-dependent; typically multilingual
|
|
||||||
- `batch` (boolean) — optional, defaults to `false`.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
|
|
||||||
- Provide either `file` or `files`, not both. If neither is provided, HTTP 400.
|
|
||||||
- `batch` requires `files`; using `batch=true` without `files` returns HTTP 400.
|
|
||||||
- Response shape for multiple files is the same regardless of `batch`.
|
|
||||||
- Files sent to this endpoint are processed in a single pass (no VAD/chunking). This is intended for short clips (roughly ≤ 30s; depends on GPU memory/model). For longer audio, prefer `/v1/audio/transcriptions-from-url` which supports VAD-based chunking.
|
|
||||||
|
|
||||||
Responses
|
|
||||||
|
|
||||||
Single file response:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"text": "transcribed text",
|
|
||||||
"words": [
|
|
||||||
{ "word": "hello", "start": 0.0, "end": 0.5 },
|
|
||||||
{ "word": "world", "start": 0.5, "end": 1.0 }
|
|
||||||
],
|
|
||||||
"filename": "audio.mp3"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Multiple files response:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"results": [
|
|
||||||
{"filename": "a1.mp3", "text": "...", "words": [...]},
|
|
||||||
{"filename": "a2.mp3", "text": "...", "words": [...]}]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
|
|
||||||
- Word objects always include keys: `word`, `start`, `end`.
|
|
||||||
- Some implementations may include a trailing space in `word` to match Whisper tokenization behavior; clients should trim if needed.
|
|
||||||
|
|
||||||
Example curl (single file):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST \
|
|
||||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
|
||||||
-F "file=@/path/to/audio.mp3" \
|
|
||||||
-F "language=en" \
|
|
||||||
"$BASE_URL/v1/audio/transcriptions"
|
|
||||||
```
|
|
||||||
|
|
||||||
Example curl (multiple files, batch):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST \
|
|
||||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
|
||||||
-F "files=@/path/a1.mp3" -F "files=@/path/a2.mp3" \
|
|
||||||
-F "batch=true" -F "language=en" \
|
|
||||||
"$BASE_URL/v1/audio/transcriptions"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### POST /v1/audio/transcriptions-from-url
|
|
||||||
|
|
||||||
Transcribe a single remote audio file by URL.
|
|
||||||
|
|
||||||
Request: application/json
|
|
||||||
|
|
||||||
Body parameters:
|
|
||||||
|
|
||||||
- `audio_file_url` (string) — required. URL of the audio file to transcribe.
|
|
||||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
|
||||||
- `language` (string) — optional, defaults to `en`. Parakeet only accepts `en`.
|
|
||||||
- `timestamp_offset` (number) — optional, defaults to `0.0`. Added to each word's `start`/`end` in the response.
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"audio_file_url": "https://example.com/audio.mp3",
|
|
||||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
|
||||||
"language": "en",
|
|
||||||
"timestamp_offset": 0.0
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Response:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"text": "transcribed text",
|
|
||||||
"words": [
|
|
||||||
{ "word": "hello", "start": 10.0, "end": 10.5 },
|
|
||||||
{ "word": "world", "start": 10.5, "end": 11.0 }
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
|
|
||||||
- `timestamp_offset` is added to each word’s `start`/`end` in the response.
|
|
||||||
- Implementations may perform VAD-based chunking and batching for long-form audio; word timings are adjusted accordingly.
|
|
||||||
|
|
||||||
Example curl:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST \
|
|
||||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"audio_file_url": "https://example.com/audio.mp3",
|
|
||||||
"language": "en",
|
|
||||||
"timestamp_offset": 0
|
|
||||||
}' \
|
|
||||||
"$BASE_URL/v1/audio/transcriptions-from-url"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Error handling
|
|
||||||
|
|
||||||
- 400 Bad Request
|
|
||||||
- Parakeet: `language` other than `en`
|
|
||||||
- Missing required parameters (`file`/`files` for upload; `audio_file_url` for URL endpoint)
|
|
||||||
- Unsupported file extension
|
|
||||||
- 401 Unauthorized
|
|
||||||
- Missing or invalid Bearer token
|
|
||||||
- 404 Not Found
|
|
||||||
- `audio_file_url` does not exist
|
|
||||||
|
|
||||||
### Implementation details
|
|
||||||
|
|
||||||
- GPUs: A10G for small-file/live, L40S for large-file URL transcription (subject to deployment)
|
|
||||||
- VAD chunking and segment batching; word timings adjusted and overlapping ends constrained
|
|
||||||
- Pads very short segments (< 0.5s) to avoid model crashes on some backends
|
|
||||||
|
|
||||||
### Server configuration (Reflector API)
|
|
||||||
|
|
||||||
Set the Reflector server to use the Modal backend and point `TRANSCRIPT_URL` to your chosen deployment:
|
|
||||||
|
|
||||||
```
|
|
||||||
TRANSCRIPT_BACKEND=modal
|
|
||||||
TRANSCRIPT_URL=https://<account>--reflector-transcriber-parakeet-web.modal.run
|
|
||||||
TRANSCRIPT_MODAL_API_KEY=<REFLECTOR_GPU_APIKEY>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Conformance tests
|
|
||||||
|
|
||||||
Use the pytest-based conformance tests to validate any new implementation (including self-hosted) against this spec:
|
|
||||||
|
|
||||||
```
|
|
||||||
TRANSCRIPT_URL=https://<your-deployment-base> \
|
|
||||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
|
||||||
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
|
|
||||||
```
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
# Reflector Webhook Documentation
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Reflector supports webhook notifications to notify external systems when transcript processing is completed. Webhooks can be configured per room and are triggered automatically after a transcript is successfully processed.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Webhooks are configured at the room level with two fields:
|
|
||||||
- `webhook_url`: The HTTPS endpoint to receive webhook notifications
|
|
||||||
- `webhook_secret`: Optional secret key for HMAC signature verification (auto-generated if not provided)
|
|
||||||
|
|
||||||
## Events
|
|
||||||
|
|
||||||
### `transcript.completed`
|
|
||||||
|
|
||||||
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
|
|
||||||
|
|
||||||
### `test`
|
|
||||||
|
|
||||||
A test event that can be triggered manually to verify webhook configuration.
|
|
||||||
|
|
||||||
## Webhook Request Format
|
|
||||||
|
|
||||||
### Headers
|
|
||||||
|
|
||||||
All webhook requests include the following headers:
|
|
||||||
|
|
||||||
| Header | Description | Example |
|
|
||||||
|--------|-------------|---------|
|
|
||||||
| `Content-Type` | Always `application/json` | `application/json` |
|
|
||||||
| `User-Agent` | Identifies Reflector as the source | `Reflector-Webhook/1.0` |
|
|
||||||
| `X-Webhook-Event` | The event type | `transcript.completed` or `test` |
|
|
||||||
| `X-Webhook-Retry` | Current retry attempt number | `0`, `1`, `2`... |
|
|
||||||
| `X-Webhook-Signature` | HMAC signature (if secret configured) | `t=1735306800,v1=abc123...` |
|
|
||||||
|
|
||||||
### Signature Verification
|
|
||||||
|
|
||||||
If a webhook secret is configured, Reflector includes an HMAC-SHA256 signature in the `X-Webhook-Signature` header to verify the webhook authenticity.
|
|
||||||
|
|
||||||
The signature format is: `t={timestamp},v1={signature}`
|
|
||||||
|
|
||||||
To verify the signature:
|
|
||||||
1. Extract the timestamp and signature from the header
|
|
||||||
2. Create the signed payload: `{timestamp}.{request_body}`
|
|
||||||
3. Compute HMAC-SHA256 of the signed payload using your webhook secret
|
|
||||||
4. Compare the computed signature with the received signature
|
|
||||||
|
|
||||||
Example verification (Python):
|
|
||||||
```python
|
|
||||||
import hmac
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
def verify_webhook_signature(payload: bytes, signature_header: str, secret: str) -> bool:
|
|
||||||
# Parse header: "t=1735306800,v1=abc123..."
|
|
||||||
parts = dict(part.split("=") for part in signature_header.split(","))
|
|
||||||
timestamp = parts["t"]
|
|
||||||
received_signature = parts["v1"]
|
|
||||||
|
|
||||||
# Create signed payload
|
|
||||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
|
||||||
|
|
||||||
# Compute expected signature
|
|
||||||
expected_signature = hmac.new(
|
|
||||||
secret.encode("utf-8"),
|
|
||||||
signed_payload.encode("utf-8"),
|
|
||||||
hashlib.sha256
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
# Compare signatures
|
|
||||||
return hmac.compare_digest(expected_signature, received_signature)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Event Payloads
|
|
||||||
|
|
||||||
### `transcript.completed` Event
|
|
||||||
|
|
||||||
This event includes a convenient URL for accessing the transcript:
|
|
||||||
- `frontend_url`: Direct link to view the transcript in the web interface
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"event": "transcript.completed",
|
|
||||||
"event_id": "transcript.completed-abc-123-def-456",
|
|
||||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
|
||||||
"transcript": {
|
|
||||||
"id": "abc-123-def-456",
|
|
||||||
"room_id": "room-789",
|
|
||||||
"created_at": "2025-08-27T12:00:00Z",
|
|
||||||
"duration": 1800.5,
|
|
||||||
"title": "Q3 Product Planning Meeting",
|
|
||||||
"short_summary": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
|
|
||||||
"long_summary": "The product team met to finalize the Q3 roadmap. Key decisions included...",
|
|
||||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone to today's meeting...",
|
|
||||||
"topics": [
|
|
||||||
{
|
|
||||||
"title": "Introduction and Agenda",
|
|
||||||
"summary": "Meeting kickoff with agenda review",
|
|
||||||
"timestamp": 0.0,
|
|
||||||
"duration": 120.0,
|
|
||||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone..."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"title": "Mobile App Features Discussion",
|
|
||||||
"summary": "Team reviewed proposed mobile app features for Q3",
|
|
||||||
"timestamp": 120.0,
|
|
||||||
"duration": 600.0,
|
|
||||||
"webvtt": "WEBVTT\n\n00:02:00.000 --> 00:02:10.000\n<v Speaker 2>Let's talk about the mobile app..."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"participants": [
|
|
||||||
{
|
|
||||||
"id": "participant-1",
|
|
||||||
"name": "John Doe",
|
|
||||||
"speaker": "Speaker 1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "participant-2",
|
|
||||||
"name": "Jane Smith",
|
|
||||||
"speaker": "Speaker 2"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source_language": "en",
|
|
||||||
"target_language": "en",
|
|
||||||
"status": "completed",
|
|
||||||
"frontend_url": "https://app.reflector.com/transcripts/abc-123-def-456"
|
|
||||||
},
|
|
||||||
"room": {
|
|
||||||
"id": "room-789",
|
|
||||||
"name": "Product Team Room"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### `test` Event
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"event": "test",
|
|
||||||
"event_id": "test.2025-08-27T12:34:56.789012Z",
|
|
||||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
|
||||||
"message": "This is a test webhook from Reflector",
|
|
||||||
"room": {
|
|
||||||
"id": "room-789",
|
|
||||||
"name": "Product Team Room"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Retry Policy
|
|
||||||
|
|
||||||
Webhooks are delivered with automatic retry logic to handle transient failures. When a webhook delivery fails due to server errors or network issues, Reflector will automatically retry the delivery multiple times over an extended period.
|
|
||||||
|
|
||||||
### Retry Mechanism
|
|
||||||
|
|
||||||
Reflector implements an exponential backoff strategy for webhook retries:
|
|
||||||
|
|
||||||
- **Initial retry delay**: 60 seconds after the first failure
|
|
||||||
- **Exponential backoff**: Each subsequent retry waits approximately twice as long as the previous one
|
|
||||||
- **Maximum retry interval**: 1 hour (backoff is capped at this duration)
|
|
||||||
- **Maximum retry attempts**: 30 attempts total
|
|
||||||
- **Total retry duration**: Retries continue for approximately 24 hours
|
|
||||||
|
|
||||||
### How Retries Work
|
|
||||||
|
|
||||||
When a webhook fails, Reflector will:
|
|
||||||
1. Wait 60 seconds, then retry (attempt #1)
|
|
||||||
2. If it fails again, wait ~2 minutes, then retry (attempt #2)
|
|
||||||
3. Continue doubling the wait time up to a maximum of 1 hour between attempts
|
|
||||||
4. Keep retrying at 1-hour intervals until successful or 30 attempts are exhausted
|
|
||||||
|
|
||||||
The `X-Webhook-Retry` header indicates the current retry attempt number (0 for the initial attempt, 1 for first retry, etc.), allowing your endpoint to track retry attempts.
|
|
||||||
|
|
||||||
### Retry Behavior by HTTP Status Code
|
|
||||||
|
|
||||||
| Status Code | Behavior |
|
|
||||||
|-------------|----------|
|
|
||||||
| 2xx (Success) | No retry, webhook marked as delivered |
|
|
||||||
| 4xx (Client Error) | No retry, request is considered permanently failed |
|
|
||||||
| 5xx (Server Error) | Automatic retry with exponential backoff |
|
|
||||||
| Network/Timeout Error | Automatic retry with exponential backoff |
|
|
||||||
|
|
||||||
**Important Notes:**
|
|
||||||
- Webhooks timeout after 30 seconds. If your endpoint takes longer to respond, it will be considered a timeout error and retried.
|
|
||||||
- During the retry period (~24 hours), you may receive the same webhook multiple times if your endpoint experiences intermittent failures.
|
|
||||||
- There is no mechanism to manually retry failed webhooks after the retry period expires.
|
|
||||||
|
|
||||||
## Testing Webhooks
|
|
||||||
|
|
||||||
You can test your webhook configuration before processing transcripts:
|
|
||||||
|
|
||||||
```http
|
|
||||||
POST /v1/rooms/{room_id}/webhook/test
|
|
||||||
```
|
|
||||||
|
|
||||||
Response:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"success": true,
|
|
||||||
"status_code": 200,
|
|
||||||
"message": "Webhook test successful",
|
|
||||||
"response_preview": "OK"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Or in case of failure:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"success": false,
|
|
||||||
"error": "Webhook request timed out (10 seconds)"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
86
server/gpu/modal_deployments/README.md
Normal file
86
server/gpu/modal_deployments/README.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Reflector GPU implementation - Transcription and LLM
|
||||||
|
|
||||||
|
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||||
|
and use [Modal.com](https://modal.com)
|
||||||
|
|
||||||
|
- `reflector_diarizer.py` - Diarization API
|
||||||
|
- `reflector_transcriber.py` - Transcription API
|
||||||
|
- `reflector_translator.py` - Translation API
|
||||||
|
|
||||||
|
## Modal.com deployment
|
||||||
|
|
||||||
|
Create a modal secret, and name it `reflector-gpu`.
|
||||||
|
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
||||||
|
|
||||||
|
The deployment is done using [Modal.com](https://modal.com) service.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ modal deploy reflector_transcriber.py
|
||||||
|
...
|
||||||
|
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||||
|
|
||||||
|
$ modal deploy reflector_llm.py
|
||||||
|
...
|
||||||
|
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in your reflector api configuration `.env`, you can set these keys:
|
||||||
|
|
||||||
|
```
|
||||||
|
TRANSCRIPT_BACKEND=modal
|
||||||
|
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||||
|
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
|
|
||||||
|
DIARIZATION_BACKEND=modal
|
||||||
|
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||||
|
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
|
|
||||||
|
TRANSLATION_BACKEND=modal
|
||||||
|
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||||
|
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
||||||
|
|
||||||
|
```
|
||||||
|
Authorization: bearer <REFLECTOR_APIKEY>
|
||||||
|
```
|
||||||
|
|
||||||
|
### LLM
|
||||||
|
|
||||||
|
`POST /llm`
|
||||||
|
|
||||||
|
**request**
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"prompt": "xxx"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"text": "xxx completed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transcription
|
||||||
|
|
||||||
|
`POST /transcribe`
|
||||||
|
|
||||||
|
**request** (multipart/form-data)
|
||||||
|
|
||||||
|
- `file` - audio file
|
||||||
|
- `language` - language code (e.g. `en`)
|
||||||
|
|
||||||
|
**response**
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"text": "xxx",
|
||||||
|
"words": [
|
||||||
|
{"text": "xxx", "start": 0.0, "end": 1.0}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
187
server/gpu/modal_deployments/reflector_diarizer.py
Normal file
187
server/gpu/modal_deployments/reflector_diarizer.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
Reflector GPU backend - diarizer
|
||||||
|
===================================
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import modal.gpu
|
||||||
|
from modal import App, Image, Secret, asgi_app, enter, method
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||||
|
MODEL_DIR = "/root/diarization_models"
|
||||||
|
app = App(name="reflector-diarizer")
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_cache_llm():
|
||||||
|
"""
|
||||||
|
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||||
|
Migrating your old cache. This is a one-time only operation. You can
|
||||||
|
interrupt this and resume the migration later on by calling
|
||||||
|
`transformers.utils.move_cache()`.
|
||||||
|
"""
|
||||||
|
from transformers.utils.hub import move_cache
|
||||||
|
|
||||||
|
print("Moving LLM cache")
|
||||||
|
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||||
|
print("LLM cache moved")
|
||||||
|
|
||||||
|
|
||||||
|
def download_pyannote_audio():
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
Pipeline.from_pretrained(
|
||||||
|
PYANNOTE_MODEL_NAME,
|
||||||
|
cache_dir=MODEL_DIR,
|
||||||
|
use_auth_token=os.environ["HF_TOKEN"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
diarizer_image = (
|
||||||
|
Image.debian_slim(python_version="3.10.8")
|
||||||
|
.pip_install(
|
||||||
|
"pyannote.audio==3.1.0",
|
||||||
|
"requests",
|
||||||
|
"onnx",
|
||||||
|
"torchaudio",
|
||||||
|
"onnxruntime-gpu",
|
||||||
|
"torch==2.0.0",
|
||||||
|
"transformers==4.34.0",
|
||||||
|
"sentencepiece",
|
||||||
|
"protobuf",
|
||||||
|
"numpy",
|
||||||
|
"huggingface_hub",
|
||||||
|
"hf-transfer",
|
||||||
|
)
|
||||||
|
.run_function(
|
||||||
|
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
||||||
|
)
|
||||||
|
.run_function(migrate_cache_llm)
|
||||||
|
.env(
|
||||||
|
{
|
||||||
|
"LD_LIBRARY_PATH": (
|
||||||
|
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||||
|
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cls(
|
||||||
|
gpu=modal.gpu.A100(size="40GB"),
|
||||||
|
timeout=60 * 30,
|
||||||
|
scaledown_window=60,
|
||||||
|
allow_concurrent_inputs=1,
|
||||||
|
image=diarizer_image,
|
||||||
|
)
|
||||||
|
class Diarizer:
|
||||||
|
@enter()
|
||||||
|
def enter(self):
|
||||||
|
import torch
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
self.device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||||
|
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
||||||
|
)
|
||||||
|
self.diarization_pipeline.to(torch.device(self.device))
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||||
|
fp.write(audio_data)
|
||||||
|
|
||||||
|
print("Diarizing audio")
|
||||||
|
waveform, sample_rate = torchaudio.load(fp.name)
|
||||||
|
diarization = self.diarization_pipeline(
|
||||||
|
{"waveform": waveform, "sample_rate": sample_rate}
|
||||||
|
)
|
||||||
|
|
||||||
|
words = []
|
||||||
|
for diarization_segment, _, speaker in diarization.itertracks(
|
||||||
|
yield_label=True
|
||||||
|
):
|
||||||
|
words.append(
|
||||||
|
{
|
||||||
|
"start": round(timestamp + diarization_segment.start, 3),
|
||||||
|
"end": round(timestamp + diarization_segment.end, 3),
|
||||||
|
"speaker": int(speaker[-2:]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print("Diarization complete")
|
||||||
|
return {"diarization": words}
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# Web API
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
timeout=60 * 10,
|
||||||
|
scaledown_window=60 * 3,
|
||||||
|
allow_concurrent_inputs=40,
|
||||||
|
secrets=[
|
||||||
|
Secret.from_name("reflector-gpu"),
|
||||||
|
],
|
||||||
|
image=diarizer_image,
|
||||||
|
)
|
||||||
|
@asgi_app()
|
||||||
|
def web():
|
||||||
|
import requests
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
diarizerstub = Diarizer()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_audio_file(audio_file_url: str):
|
||||||
|
# Check if the audio file exists
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail="The audio file does not exist.",
|
||||||
|
)
|
||||||
|
|
||||||
|
class DiarizationResponse(BaseModel):
|
||||||
|
result: dict
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
||||||
|
)
|
||||||
|
def diarize(
|
||||||
|
audio_file_url: str, timestamp: float = 0.0
|
||||||
|
) -> HTTPException | DiarizationResponse:
|
||||||
|
# Currently the uploaded files are in mp3 format
|
||||||
|
audio_suffix = "mp3"
|
||||||
|
|
||||||
|
print("Downloading audio file")
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
print("Audio file downloaded successfully")
|
||||||
|
|
||||||
|
func = diarizerstub.diarize.spawn(
|
||||||
|
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
return result
|
||||||
|
|
||||||
|
return app
|
||||||
161
server/gpu/modal_deployments/reflector_transcriber.py
Normal file
161
server/gpu/modal_deployments/reflector_transcriber.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import modal
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
MODELS_DIR = "/models"
|
||||||
|
|
||||||
|
MODEL_NAME = "large-v2"
|
||||||
|
MODEL_COMPUTE_TYPE: str = "float16"
|
||||||
|
MODEL_NUM_WORKERS: int = 1
|
||||||
|
|
||||||
|
MINUTES = 60 # seconds
|
||||||
|
|
||||||
|
volume = modal.Volume.from_name("models", create_if_missing=True)
|
||||||
|
|
||||||
|
app = modal.App("reflector-transcriber")
|
||||||
|
|
||||||
|
|
||||||
|
def download_model():
|
||||||
|
from faster_whisper import download_model
|
||||||
|
|
||||||
|
volume.reload()
|
||||||
|
|
||||||
|
download_model(MODEL_NAME, cache_dir=MODELS_DIR)
|
||||||
|
|
||||||
|
volume.commit()
|
||||||
|
|
||||||
|
|
||||||
|
image = (
|
||||||
|
modal.Image.debian_slim(python_version="3.12")
|
||||||
|
.pip_install(
|
||||||
|
"huggingface_hub==0.27.1",
|
||||||
|
"hf-transfer==0.1.9",
|
||||||
|
"torch==2.5.1",
|
||||||
|
"faster-whisper==1.1.1",
|
||||||
|
)
|
||||||
|
.env(
|
||||||
|
{
|
||||||
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||||
|
"LD_LIBRARY_PATH": (
|
||||||
|
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||||
|
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.run_function(download_model, volumes={MODELS_DIR: volume})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cls(
|
||||||
|
gpu="A10G",
|
||||||
|
timeout=5 * MINUTES,
|
||||||
|
scaledown_window=5 * MINUTES,
|
||||||
|
allow_concurrent_inputs=6,
|
||||||
|
image=image,
|
||||||
|
volumes={MODELS_DIR: volume},
|
||||||
|
)
|
||||||
|
class Transcriber:
|
||||||
|
@modal.enter()
|
||||||
|
def enter(self):
|
||||||
|
import faster_whisper
|
||||||
|
import torch
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
self.device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
self.model = faster_whisper.WhisperModel(
|
||||||
|
MODEL_NAME,
|
||||||
|
device=self.device,
|
||||||
|
compute_type=MODEL_COMPUTE_TYPE,
|
||||||
|
num_workers=MODEL_NUM_WORKERS,
|
||||||
|
download_root=MODELS_DIR,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@modal.method()
|
||||||
|
def transcribe_segment(
|
||||||
|
self,
|
||||||
|
audio_data: str,
|
||||||
|
audio_suffix: str,
|
||||||
|
language: str,
|
||||||
|
):
|
||||||
|
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||||
|
fp.write(audio_data)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
segments, _ = self.model.transcribe(
|
||||||
|
fp.name,
|
||||||
|
language=language,
|
||||||
|
beam_size=5,
|
||||||
|
word_timestamps=True,
|
||||||
|
vad_filter=True,
|
||||||
|
vad_parameters={"min_silence_duration_ms": 500},
|
||||||
|
)
|
||||||
|
|
||||||
|
segments = list(segments)
|
||||||
|
text = "".join(segment.text for segment in segments)
|
||||||
|
words = [
|
||||||
|
{"word": word.word, "start": word.start, "end": word.end}
|
||||||
|
for segment in segments
|
||||||
|
for word in segment.words
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"text": text, "words": words}
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
scaledown_window=60,
|
||||||
|
timeout=60,
|
||||||
|
allow_concurrent_inputs=40,
|
||||||
|
secrets=[
|
||||||
|
modal.Secret.from_name("reflector-gpu"),
|
||||||
|
],
|
||||||
|
volumes={MODELS_DIR: volume},
|
||||||
|
)
|
||||||
|
@modal.asgi_app()
|
||||||
|
def web():
|
||||||
|
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
transcriber = Transcriber()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
supported_file_types = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class TranscriptResponse(BaseModel):
|
||||||
|
result: dict
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||||
|
def transcribe(
|
||||||
|
file: UploadFile,
|
||||||
|
model: str = "whisper-1",
|
||||||
|
language: Annotated[str, Body(...)] = "en",
|
||||||
|
) -> TranscriptResponse:
|
||||||
|
audio_data = file.file.read()
|
||||||
|
audio_suffix = file.filename.split(".")[-1]
|
||||||
|
assert audio_suffix in supported_file_types
|
||||||
|
|
||||||
|
func = transcriber.transcribe_segment.spawn(
|
||||||
|
audio_data=audio_data,
|
||||||
|
audio_suffix=audio_suffix,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
result = func.get()
|
||||||
|
return result
|
||||||
|
|
||||||
|
return app
|
||||||
@@ -1,583 +0,0 @@
|
|||||||
# Celery to TaskIQ Migration Guide
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
|
|
||||||
This document outlines the migration path from Celery to TaskIQ for the Reflector project. TaskIQ is a modern, async-first distributed task queue that provides similar functionality to Celery while being designed specifically for async Python applications.
|
|
||||||
|
|
||||||
## Current Celery Usage Analysis
|
|
||||||
|
|
||||||
### Key Patterns in Use
|
|
||||||
1. **Task Decorators**: `@shared_task`, `@asynctask`, `@with_session` decorators
|
|
||||||
2. **Task Invocation**: `.delay()`, `.si()` for signatures
|
|
||||||
3. **Workflow Patterns**: `chain()`, `group()`, `chord()` for complex pipelines
|
|
||||||
4. **Scheduled Tasks**: Celery Beat with crontab and periodic schedules
|
|
||||||
5. **Session Management**: Custom `@with_session` and `@with_session_and_transcript` decorators
|
|
||||||
6. **Retry Logic**: Auto-retry with exponential backoff
|
|
||||||
7. **Redis Backend**: Using Redis for broker and result backend
|
|
||||||
|
|
||||||
### Critical Files to Migrate
|
|
||||||
- `reflector/worker/app.py` - Celery app configuration and beat schedule
|
|
||||||
- `reflector/worker/session_decorator.py` - Session management decorators
|
|
||||||
- `reflector/pipelines/main_file_pipeline.py` - File processing pipeline
|
|
||||||
- `reflector/pipelines/main_live_pipeline.py` - Live streaming pipeline (10 tasks)
|
|
||||||
- `reflector/worker/process.py` - Background processing tasks
|
|
||||||
- `reflector/worker/ics_sync.py` - Calendar sync tasks
|
|
||||||
- `reflector/worker/cleanup.py` - Cleanup tasks
|
|
||||||
- `reflector/worker/webhook.py` - Webhook notifications
|
|
||||||
|
|
||||||
## TaskIQ Architecture Mapping
|
|
||||||
|
|
||||||
### 1. Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Remove Celery dependencies
|
|
||||||
uv remove celery flower
|
|
||||||
|
|
||||||
# Install TaskIQ with Redis support
|
|
||||||
uv add taskiq taskiq-redis taskiq-pipelines
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Broker Configuration
|
|
||||||
|
|
||||||
#### Current (Celery)
|
|
||||||
```python
|
|
||||||
# reflector/worker/app.py
|
|
||||||
from celery import Celery
|
|
||||||
|
|
||||||
app = Celery(
|
|
||||||
"reflector",
|
|
||||||
broker=settings.CELERY_BROKER_URL,
|
|
||||||
backend=settings.CELERY_RESULT_BACKEND,
|
|
||||||
include=[...],
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ)
|
|
||||||
```python
|
|
||||||
# reflector/worker/broker.py
|
|
||||||
from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker
|
|
||||||
from taskiq import PipelineMiddleware, SimpleRetryMiddleware
|
|
||||||
|
|
||||||
result_backend = RedisAsyncResultBackend(
|
|
||||||
redis_url=settings.REDIS_URL,
|
|
||||||
result_ex_time=86400, # 24 hours
|
|
||||||
)
|
|
||||||
|
|
||||||
broker = RedisStreamBroker(
|
|
||||||
url=settings.REDIS_URL,
|
|
||||||
max_connection_pool_size=10,
|
|
||||||
).with_result_backend(result_backend).with_middlewares(
|
|
||||||
PipelineMiddleware(), # For chain/group/chord support
|
|
||||||
SimpleRetryMiddleware(default_retry_count=3),
|
|
||||||
)
|
|
||||||
|
|
||||||
# For testing environment
|
|
||||||
if os.environ.get("ENVIRONMENT") == "pytest":
|
|
||||||
from taskiq import InMemoryBroker
|
|
||||||
broker = InMemoryBroker(await_inplace=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Task Definition Migration
|
|
||||||
|
|
||||||
#### Current (Celery)
|
|
||||||
```python
|
|
||||||
@shared_task
|
|
||||||
@asynctask
|
|
||||||
@with_session
|
|
||||||
async def task_pipeline_file_process(session: AsyncSession, transcript_id: str):
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
await pipeline.process()
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ)
|
|
||||||
```python
|
|
||||||
from taskiq import TaskiqDepends
|
|
||||||
from reflector.worker.broker import broker
|
|
||||||
from reflector.worker.dependencies import get_db_session
|
|
||||||
|
|
||||||
@broker.task
|
|
||||||
async def task_pipeline_file_process(transcript_id: str):
|
|
||||||
# Use get_session for proper test mocking
|
|
||||||
async for session in get_session():
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
await pipeline.process()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Session Management
|
|
||||||
|
|
||||||
#### Current Session Decorators (Keep Using These!)
|
|
||||||
```python
|
|
||||||
# reflector/worker/session_decorator.py
|
|
||||||
def with_session(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
async with get_session_context() as session:
|
|
||||||
return await func(session, *args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Session Management Strategy
|
|
||||||
|
|
||||||
**⚠️ CRITICAL**: The key insight is to maintain consistent session management patterns:
|
|
||||||
|
|
||||||
1. **For Worker Tasks**: Continue using `@with_session` decorator pattern
|
|
||||||
2. **For FastAPI endpoints**: Use `get_session` dependency injection
|
|
||||||
3. **Never use `get_session_factory()` directly** in application code
|
|
||||||
|
|
||||||
```python
|
|
||||||
# APPROACH 1: Simple migration keeping decorator pattern
|
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
@with_session
|
|
||||||
async def task_pipeline_file_process(session, *, transcript_id: str):
|
|
||||||
# Session is provided by decorator, just like Celery version
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
await pipeline.process()
|
|
||||||
|
|
||||||
# APPROACH 2: For test compatibility without decorator
|
|
||||||
from reflector.db import get_session
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
async def task_pipeline_file_process(transcript_id: str):
|
|
||||||
# Use get_session which is mocked in tests
|
|
||||||
async for session in get_session():
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
await pipeline.process()
|
|
||||||
|
|
||||||
# APPROACH 3: Future - TaskIQ dependency injection (after full migration)
|
|
||||||
from taskiq import TaskiqDepends
|
|
||||||
|
|
||||||
async def get_session_context():
|
|
||||||
"""Context manager version of get_session for consistency"""
|
|
||||||
async for session in get_session():
|
|
||||||
yield session
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
async def task_pipeline_file_process(
|
|
||||||
transcript_id: str,
|
|
||||||
session: AsyncSession = TaskiqDepends(get_session_context)
|
|
||||||
):
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
await pipeline.process()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Points:**
|
|
||||||
- `@with_session` decorator works with TaskIQ tasks (remove `@asynctask`, keep `@with_session`)
|
|
||||||
- For testing: `get_session()` from `reflector.db` is properly mocked
|
|
||||||
- Never call `get_session_factory()` directly - always use the abstractions
|
|
||||||
|
|
||||||
### 5. Task Invocation
|
|
||||||
|
|
||||||
#### Current (Celery)
|
|
||||||
```python
|
|
||||||
# Simple async execution
|
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
|
||||||
|
|
||||||
# With signature for chaining
|
|
||||||
task_cleanup_consent.si(transcript_id=transcript_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ)
|
|
||||||
```python
|
|
||||||
# Simple async execution
|
|
||||||
await task_pipeline_file_process.kiq(transcript_id=transcript.id)
|
|
||||||
|
|
||||||
# With kicker for advanced configuration
|
|
||||||
await task_cleanup_consent.kicker().with_labels(
|
|
||||||
priority="high"
|
|
||||||
).kiq(transcript_id=transcript_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. Workflow Patterns (Chain, Group, Chord)
|
|
||||||
|
|
||||||
#### Current (Celery)
|
|
||||||
```python
|
|
||||||
from celery import chain, group, chord
|
|
||||||
|
|
||||||
# Chain example
|
|
||||||
post_chain = chain(
|
|
||||||
task_cleanup_consent.si(transcript_id=transcript_id),
|
|
||||||
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
|
|
||||||
task_send_webhook_if_needed.si(transcript_id=transcript_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chord example (parallel + callback)
|
|
||||||
chain = chord(
|
|
||||||
group(chain_mp3_and_diarize, chain_title_preview),
|
|
||||||
chain_final_summaries,
|
|
||||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ with Pipelines)
|
|
||||||
```python
|
|
||||||
from taskiq_pipelines import Pipeline
|
|
||||||
from taskiq import gather
|
|
||||||
|
|
||||||
# Chain example using Pipeline
|
|
||||||
post_pipeline = (
|
|
||||||
Pipeline(broker, task_cleanup_consent)
|
|
||||||
.call_next(task_pipeline_post_to_zulip, transcript_id=transcript_id)
|
|
||||||
.call_next(task_send_webhook_if_needed, transcript_id=transcript_id)
|
|
||||||
)
|
|
||||||
await post_pipeline.kiq(transcript_id=transcript_id)
|
|
||||||
|
|
||||||
# Parallel execution with gather
|
|
||||||
results = await gather([
|
|
||||||
chain_mp3_and_diarize.kiq(transcript_id),
|
|
||||||
chain_title_preview.kiq(transcript_id),
|
|
||||||
])
|
|
||||||
|
|
||||||
# Then execute callback
|
|
||||||
await chain_final_summaries.kiq(transcript_id, results)
|
|
||||||
await task_pipeline_post_to_zulip.kiq(transcript_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7. Scheduled Tasks (Celery Beat → TaskIQ Scheduler)
|
|
||||||
|
|
||||||
#### Current (Celery Beat)
|
|
||||||
```python
|
|
||||||
# reflector/worker/app.py
|
|
||||||
app.conf.beat_schedule = {
|
|
||||||
"process_messages": {
|
|
||||||
"task": "reflector.worker.process.process_messages",
|
|
||||||
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
|
|
||||||
},
|
|
||||||
"reprocess_failed_recordings": {
|
|
||||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
|
||||||
"schedule": crontab(hour=5, minute=0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ Scheduler)
|
|
||||||
```python
|
|
||||||
# reflector/worker/scheduler.py
|
|
||||||
from taskiq import TaskiqScheduler
|
|
||||||
from taskiq_redis import ListRedisScheduleSource
|
|
||||||
|
|
||||||
schedule_source = ListRedisScheduleSource(settings.REDIS_URL)
|
|
||||||
|
|
||||||
# Define scheduled tasks with decorators
|
|
||||||
@broker.task(
|
|
||||||
schedule=[
|
|
||||||
{
|
|
||||||
"cron": f"*/{int(settings.SQS_POLLING_TIMEOUT_SECONDS)} * * * * *"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
async def process_messages():
|
|
||||||
# Task implementation
|
|
||||||
pass
|
|
||||||
|
|
||||||
@broker.task(
|
|
||||||
schedule=[{"cron": "0 5 * * *"}] # Daily at 5 AM
|
|
||||||
)
|
|
||||||
async def reprocess_failed_recordings():
|
|
||||||
# Task implementation
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Initialize scheduler
|
|
||||||
scheduler = TaskiqScheduler(broker, sources=[schedule_source])
|
|
||||||
|
|
||||||
# Run scheduler (separate process)
|
|
||||||
# taskiq scheduler reflector.worker.scheduler:scheduler
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8. Retry Configuration
|
|
||||||
|
|
||||||
#### Current (Celery)
|
|
||||||
```python
|
|
||||||
@shared_task(
|
|
||||||
bind=True,
|
|
||||||
max_retries=30,
|
|
||||||
default_retry_delay=60,
|
|
||||||
retry_backoff=True,
|
|
||||||
retry_backoff_max=3600,
|
|
||||||
)
|
|
||||||
async def task_send_webhook_if_needed(self, ...):
|
|
||||||
try:
|
|
||||||
# Task logic
|
|
||||||
except Exception as exc:
|
|
||||||
raise self.retry(exc=exc)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New (TaskIQ)
|
|
||||||
```python
|
|
||||||
from taskiq.middlewares import SimpleRetryMiddleware
|
|
||||||
|
|
||||||
# Global middleware configuration (1:1 with Celery defaults)
|
|
||||||
broker = broker.with_middlewares(
|
|
||||||
SimpleRetryMiddleware(default_retry_count=3),
|
|
||||||
)
|
|
||||||
|
|
||||||
# For specific tasks with custom retry logic:
|
|
||||||
@broker.task(retry_on_error=True, max_retries=30)
|
|
||||||
async def task_send_webhook_if_needed(...):
|
|
||||||
# Task logic - exceptions auto-retry
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing Migration
|
|
||||||
|
|
||||||
### Current Pytest Setup (Celery)
|
|
||||||
```python
|
|
||||||
# tests/conftest.py
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def celery_config():
|
|
||||||
return {
|
|
||||||
"broker_url": "memory://",
|
|
||||||
"result_backend": "cache+memory://",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
|
||||||
async def test_task():
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
### New Pytest Setup (TaskIQ)
|
|
||||||
```python
|
|
||||||
# tests/conftest.py
|
|
||||||
import pytest
|
|
||||||
from taskiq import InMemoryBroker
|
|
||||||
from reflector.worker.broker import broker
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
async def setup_taskiq_broker():
|
|
||||||
"""Replace broker with InMemoryBroker for testing"""
|
|
||||||
original_broker = broker
|
|
||||||
test_broker = InMemoryBroker(await_inplace=True)
|
|
||||||
|
|
||||||
# Copy task registrations
|
|
||||||
for task_name, task in original_broker._tasks.items():
|
|
||||||
test_broker.register_task(task.original_function, task_name=task_name)
|
|
||||||
|
|
||||||
yield test_broker
|
|
||||||
await test_broker.shutdown()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def taskiq_with_db_session(db_session):
|
|
||||||
"""Setup TaskIQ with database session"""
|
|
||||||
from reflector.worker.broker import broker
|
|
||||||
broker.add_dependency_context({
|
|
||||||
AsyncSession: db_session
|
|
||||||
})
|
|
||||||
yield
|
|
||||||
broker.custom_dependency_context = {}
|
|
||||||
|
|
||||||
# Test example
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_task(taskiq_with_db_session):
|
|
||||||
result = await task_pipeline_file_process("transcript-id")
|
|
||||||
assert result is not None
|
|
||||||
```
|
|
||||||
|
|
||||||
## Migration Steps
|
|
||||||
|
|
||||||
### Phase 1: Setup (Week 1)
|
|
||||||
1. **Install TaskIQ packages**
|
|
||||||
```bash
|
|
||||||
uv add taskiq taskiq-redis taskiq-pipelines
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Create new broker configuration**
|
|
||||||
- Create `reflector/worker/broker.py` with TaskIQ broker setup
|
|
||||||
- Create `reflector/worker/dependencies.py` for dependency injection
|
|
||||||
|
|
||||||
3. **Update settings**
|
|
||||||
- Keep existing Redis configuration
|
|
||||||
- Add TaskIQ-specific settings if needed
|
|
||||||
|
|
||||||
### Phase 2: Parallel Running (Week 2-3)
|
|
||||||
1. **Migrate simple tasks first**
|
|
||||||
- Start with `cleanup.py` (1 task)
|
|
||||||
- Move to `webhook.py` (1 task)
|
|
||||||
- Test thoroughly in isolation
|
|
||||||
|
|
||||||
2. **Setup dual-mode operation**
|
|
||||||
- Keep Celery tasks running
|
|
||||||
- Add TaskIQ versions alongside
|
|
||||||
- Use feature flags to switch between them
|
|
||||||
|
|
||||||
### Phase 3: Complex Tasks (Week 3-4)
|
|
||||||
1. **Migrate pipeline tasks**
|
|
||||||
- Convert `main_file_pipeline.py`
|
|
||||||
- Convert `main_live_pipeline.py` (most complex with 10 tasks)
|
|
||||||
- Ensure chain/group/chord patterns work
|
|
||||||
|
|
||||||
2. **Migrate scheduled tasks**
|
|
||||||
- Setup TaskIQ scheduler
|
|
||||||
- Convert beat schedule to TaskIQ schedules
|
|
||||||
- Test cron patterns
|
|
||||||
|
|
||||||
### Phase 4: Testing & Validation (Week 4-5)
|
|
||||||
1. **Update test suite**
|
|
||||||
- Replace Celery fixtures with TaskIQ fixtures
|
|
||||||
- Update all test files
|
|
||||||
- Ensure coverage remains the same
|
|
||||||
|
|
||||||
2. **Performance testing**
|
|
||||||
- Compare task execution times
|
|
||||||
- Monitor Redis memory usage
|
|
||||||
- Test under load
|
|
||||||
|
|
||||||
### Phase 5: Cutover (Week 5-6)
|
|
||||||
1. **Final migration**
|
|
||||||
- Remove Celery dependencies
|
|
||||||
- Update deployment scripts
|
|
||||||
- Update documentation
|
|
||||||
|
|
||||||
2. **Monitoring**
|
|
||||||
- Setup TaskIQ monitoring (if available)
|
|
||||||
- Create health checks
|
|
||||||
- Document operational procedures
|
|
||||||
|
|
||||||
## Key Differences to Note
|
|
||||||
|
|
||||||
### Advantages of TaskIQ
|
|
||||||
1. **Native async support** - No need for `@asynctask` wrapper
|
|
||||||
2. **Dependency injection** - Cleaner than decorators for session management
|
|
||||||
3. **Type hints** - Better IDE support and autocompletion
|
|
||||||
4. **Modern Python** - Designed for Python 3.7+
|
|
||||||
5. **Simpler testing** - InMemoryBroker makes testing easier
|
|
||||||
|
|
||||||
### Potential Challenges
|
|
||||||
1. **Less mature ecosystem** - Fewer third-party integrations
|
|
||||||
2. **Documentation** - Less comprehensive than Celery
|
|
||||||
3. **Monitoring tools** - No Flower equivalent (may need custom solution)
|
|
||||||
4. **Community support** - Smaller community than Celery
|
|
||||||
|
|
||||||
## Command Line Changes
|
|
||||||
|
|
||||||
### Current (Celery)
|
|
||||||
```bash
|
|
||||||
# Start worker
|
|
||||||
celery -A reflector.worker.app worker --loglevel=info
|
|
||||||
|
|
||||||
# Start beat scheduler
|
|
||||||
celery -A reflector.worker.app beat
|
|
||||||
```
|
|
||||||
|
|
||||||
### New (TaskIQ)
|
|
||||||
```bash
|
|
||||||
# Start worker
|
|
||||||
taskiq worker reflector.worker.broker:broker
|
|
||||||
|
|
||||||
# Start scheduler
|
|
||||||
taskiq scheduler reflector.worker.scheduler:scheduler
|
|
||||||
|
|
||||||
# With custom settings
|
|
||||||
taskiq worker reflector.worker.broker:broker --workers 4 --log-level INFO
|
|
||||||
```
|
|
||||||
|
|
||||||
## Rollback Plan
|
|
||||||
|
|
||||||
If issues arise during migration:
|
|
||||||
|
|
||||||
1. **Keep Celery code in version control** - Tag the last Celery version
|
|
||||||
2. **Maintain dual broker setup** - Can switch back via environment variable
|
|
||||||
3. **Database compatibility** - No schema changes required
|
|
||||||
4. **Redis compatibility** - Both use Redis, easy to switch back
|
|
||||||
|
|
||||||
## Success Criteria
|
|
||||||
|
|
||||||
1. ✅ All tasks migrated and functioning
|
|
||||||
2. ✅ Test coverage maintained at current levels
|
|
||||||
3. ✅ Performance equal or better than Celery
|
|
||||||
4. ✅ Scheduled tasks running reliably
|
|
||||||
5. ✅ Error handling and retries working correctly
|
|
||||||
6. ✅ WebSocket notifications still functioning
|
|
||||||
7. ✅ Pipeline processing maintaining same behavior
|
|
||||||
|
|
||||||
## Monitoring & Operations
|
|
||||||
|
|
||||||
### Health Checks
|
|
||||||
```python
|
|
||||||
# reflector/worker/healthcheck.py
|
|
||||||
@broker.task
|
|
||||||
async def healthcheck_ping():
|
|
||||||
"""TaskIQ health check task"""
|
|
||||||
return {"status": "healthy", "timestamp": datetime.now()}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Metrics Collection
|
|
||||||
- Task execution times
|
|
||||||
- Success/failure rates
|
|
||||||
- Queue depths
|
|
||||||
- Worker utilization
|
|
||||||
|
|
||||||
## Key Implementation Points - MUST READ
|
|
||||||
|
|
||||||
### Critical Changes Required
|
|
||||||
|
|
||||||
1. **Session Management in Tasks**
|
|
||||||
- ✅ **VERIFIED**: Tasks MUST use `get_session()` from `reflector.db` for test compatibility
|
|
||||||
- ❌ Do NOT use `get_session_factory()` directly in tasks - it bypasses test mocks
|
|
||||||
- ✅ The test database session IS properly shared when using `get_session()`
|
|
||||||
|
|
||||||
2. **Task Invocation Changes**
|
|
||||||
- Replace `.delay()` with `await .kiq()`
|
|
||||||
- All task invocations become async/await
|
|
||||||
- No need to commit sessions before task invocation (controllers handle this)
|
|
||||||
|
|
||||||
3. **Broker Configuration**
|
|
||||||
- TaskIQ broker must be initialized in `worker/app.py`
|
|
||||||
- Use `InMemoryBroker(await_inplace=True)` for testing
|
|
||||||
- Use `RedisStreamBroker` for production
|
|
||||||
|
|
||||||
4. **Test Setup Requirements**
|
|
||||||
- Set `os.environ["ENVIRONMENT"] = "pytest"` at top of test files
|
|
||||||
- Add TaskIQ broker fixture to test functions
|
|
||||||
- Keep Celery fixtures for now (dual-mode operation)
|
|
||||||
|
|
||||||
5. **Import Pattern Changes**
|
|
||||||
```python
|
|
||||||
# Each file needs both imports during migration
|
|
||||||
from reflector.pipelines.main_file_pipeline import (
|
|
||||||
task_pipeline_file_process, # Celery version
|
|
||||||
task_pipeline_file_process_taskiq, # TaskIQ version
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
6. **Decorator Changes**
|
|
||||||
- Remove `@asynctask` - TaskIQ is async-native
|
|
||||||
- **Keep `@with_session`** - it works with TaskIQ tasks!
|
|
||||||
- Remove `@shared_task` from TaskIQ version
|
|
||||||
- Keep `@shared_task` on Celery version for backward compatibility
|
|
||||||
|
|
||||||
## Verified POC Results
|
|
||||||
|
|
||||||
✅ **Database transactions work correctly** across test and TaskIQ tasks
|
|
||||||
✅ **Tasks execute immediately** in tests with `InMemoryBroker(await_inplace=True)`
|
|
||||||
✅ **Session mocking works** when using `get_session()` properly
|
|
||||||
✅ **"OK" output confirmed** - TaskIQ task executes and accesses test data
|
|
||||||
|
|
||||||
## Conclusion
|
|
||||||
|
|
||||||
The migration from Celery to TaskIQ is feasible and offers several advantages for an async-first codebase like Reflector. The key challenges will be:
|
|
||||||
|
|
||||||
1. Migrating complex pipeline patterns (chain/chord)
|
|
||||||
2. Ensuring scheduled task reliability
|
|
||||||
3. **SOLVED**: Maintaining session management patterns - use `get_session()`
|
|
||||||
4. Updating the test suite
|
|
||||||
|
|
||||||
The phased approach allows for gradual migration with minimal risk. The ability to run both systems in parallel provides a safety net during the transition period.
|
|
||||||
|
|
||||||
## Appendix: Quick Reference
|
|
||||||
|
|
||||||
| Celery | TaskIQ |
|
|
||||||
|--------|--------|
|
|
||||||
| `@shared_task` | `@broker.task` |
|
|
||||||
| `.delay()` | `.kiq()` |
|
|
||||||
| `.apply_async()` | `.kicker().kiq()` |
|
|
||||||
| `chain()` | `Pipeline()` |
|
|
||||||
| `group()` | `gather()` |
|
|
||||||
| `chord()` | `gather() + callback` |
|
|
||||||
| `@task.retry()` | `retry_on_error=True` |
|
|
||||||
| Celery Beat | TaskIQ Scheduler |
|
|
||||||
| `celery worker` | `taskiq worker` |
|
|
||||||
| Flower | Custom monitoring needed |
|
|
||||||
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
|||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
from reflector.db.base import metadata
|
from reflector.db import metadata
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
"""Add webhook fields to rooms
|
|
||||||
|
|
||||||
Revision ID: 0194f65cd6d3
|
|
||||||
Revises: 5a8907fd1d78
|
|
||||||
Create Date: 2025-08-27 09:03:19.610995
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "0194f65cd6d3"
|
|
||||||
down_revision: Union[str, None] = "5a8907fd1d78"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(sa.Column("webhook_url", sa.String(), nullable=True))
|
|
||||||
batch_op.add_column(sa.Column("webhook_secret", sa.String(), nullable=True))
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
|
||||||
batch_op.drop_column("webhook_secret")
|
|
||||||
batch_op.drop_column("webhook_url")
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
"""add_long_summary_to_search_vector
|
|
||||||
|
|
||||||
Revision ID: 0ab2d7ffaa16
|
|
||||||
Revises: b1c33bd09963
|
|
||||||
Create Date: 2025-08-15 13:27:52.680211
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "0ab2d7ffaa16"
|
|
||||||
down_revision: Union[str, None] = "b1c33bd09963"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Drop the existing search vector column and index
|
|
||||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
|
||||||
op.drop_column("transcript", "search_vector_en")
|
|
||||||
|
|
||||||
# Recreate the search vector column with long_summary included
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
|
||||||
GENERATED ALWAYS AS (
|
|
||||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
|
||||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
|
||||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
|
||||||
) STORED
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Recreate the GIN index for the search vector
|
|
||||||
op.create_index(
|
|
||||||
"idx_transcript_search_vector_en",
|
|
||||||
"transcript",
|
|
||||||
["search_vector_en"],
|
|
||||||
postgresql_using="gin",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Drop the updated search vector column and index
|
|
||||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
|
||||||
op.drop_column("transcript", "search_vector_en")
|
|
||||||
|
|
||||||
# Recreate the original search vector column without long_summary
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
|
||||||
GENERATED ALWAYS AS (
|
|
||||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
|
||||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
|
||||||
) STORED
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Recreate the GIN index for the search vector
|
|
||||||
op.create_index(
|
|
||||||
"idx_transcript_search_vector_en",
|
|
||||||
"transcript",
|
|
||||||
["search_vector_en"],
|
|
||||||
postgresql_using="gin",
|
|
||||||
)
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
"""remove user_id from meeting table
|
|
||||||
|
|
||||||
Revision ID: 0ce521cda2ee
|
|
||||||
Revises: 6dec9fb5b46c
|
|
||||||
Create Date: 2025-09-10 12:40:55.688899
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "0ce521cda2ee"
|
|
||||||
down_revision: Union[str, None] = "6dec9fb5b46c"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.drop_column("user_id")
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -21,15 +21,13 @@ def upgrade() -> None:
|
|||||||
if conn.dialect.name != "postgresql":
|
if conn.dialect.name != "postgresql":
|
||||||
return
|
return
|
||||||
|
|
||||||
op.execute(
|
op.execute("""
|
||||||
"""
|
|
||||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||||
GENERATED ALWAYS AS (
|
GENERATED ALWAYS AS (
|
||||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||||
) STORED
|
) STORED
|
||||||
"""
|
""")
|
||||||
)
|
|
||||||
|
|
||||||
op.create_index(
|
op.create_index(
|
||||||
"idx_transcript_search_vector_en",
|
"idx_transcript_search_vector_en",
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
"""clean up orphaned room_id references in meeting table
|
|
||||||
|
|
||||||
Revision ID: 2ae3db106d4e
|
|
||||||
Revises: def1b5867d4c
|
|
||||||
Create Date: 2025-09-11 10:35:15.759967
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "2ae3db106d4e"
|
|
||||||
down_revision: Union[str, None] = "def1b5867d4c"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Set room_id to NULL for meetings that reference non-existent rooms
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
UPDATE meeting
|
|
||||||
SET room_id = NULL
|
|
||||||
WHERE room_id IS NOT NULL
|
|
||||||
AND room_id NOT IN (SELECT id FROM room WHERE id IS NOT NULL)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Cannot restore orphaned references - no operation needed
|
|
||||||
pass
|
|
||||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# select only the one with duration = 0
|
# select only the one with duration = 0
|
||||||
results = bind.execute(
|
results = bind.execute(
|
||||||
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0)
|
select([transcript.c.id, transcript.c.duration]).where(
|
||||||
|
transcript.c.duration == 0
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
data_dir = Path(settings.DATA_DIR)
|
data_dir = Path(settings.DATA_DIR)
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
"""add cascade delete to meeting consent foreign key
|
|
||||||
|
|
||||||
Revision ID: 5a8907fd1d78
|
|
||||||
Revises: 0ab2d7ffaa16
|
|
||||||
Create Date: 2025-08-26 17:26:50.945491
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "5a8907fd1d78"
|
|
||||||
down_revision: Union[str, None] = "0ab2d7ffaa16"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
|
||||||
)
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
|
||||||
"meeting",
|
|
||||||
["meeting_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="CASCADE",
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
|
||||||
)
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
|
||||||
"meeting",
|
|
||||||
["meeting_id"],
|
|
||||||
["id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""remove_one_active_meeting_per_room_constraint
|
"""remove_one_active_meeting_per_room_constraint
|
||||||
|
|
||||||
Revision ID: 6025e9b2bef2
|
Revision ID: 6025e9b2bef2
|
||||||
Revises: 2ae3db106d4e
|
Revises: 9f5c78d352d6
|
||||||
Create Date: 2025-08-18 18:45:44.418392
|
Create Date: 2025-08-18 18:45:44.418392
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -13,7 +13,7 @@ from alembic import op
|
|||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "6025e9b2bef2"
|
revision: str = "6025e9b2bef2"
|
||||||
down_revision: Union[str, None] = "2ae3db106d4e"
|
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""webhook url and secret null by default
|
|
||||||
|
|
||||||
|
|
||||||
Revision ID: 61882a919591
|
|
||||||
Revises: 0194f65cd6d3
|
|
||||||
Create Date: 2025-08-29 11:46:36.738091
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "61882a919591"
|
|
||||||
down_revision: Union[str, None] = "0194f65cd6d3"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""make meeting room_id required and add foreign key
|
|
||||||
|
|
||||||
Revision ID: 6dec9fb5b46c
|
|
||||||
Revises: 61882a919591
|
|
||||||
Create Date: 2025-09-10 10:47:06.006819
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "6dec9fb5b46c"
|
|
||||||
down_revision: Union[str, None] = "61882a919591"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
None, "room", ["room_id"], ["id"], ondelete="CASCADE"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.drop_constraint("meeting_room_id_fkey", type_="foreignkey")
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
"""add_search_optimization_indexes
|
|
||||||
|
|
||||||
Revision ID: b1c33bd09963
|
|
||||||
Revises: 9f5c78d352d6
|
|
||||||
Create Date: 2025-08-14 17:26:02.117408
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "b1c33bd09963"
|
|
||||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add indexes for actual search filtering patterns used in frontend
|
|
||||||
# Based on /browse page filters: room_id and source_kind
|
|
||||||
|
|
||||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
|
||||||
op.create_index(
|
|
||||||
"idx_transcript_room_id_created_at",
|
|
||||||
"transcript",
|
|
||||||
["room_id", "created_at"],
|
|
||||||
if_not_exists=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index for source_kind alone (actively used filter in frontend)
|
|
||||||
op.create_index(
|
|
||||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Remove the indexes in reverse order
|
|
||||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
|
||||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
|
||||||
@@ -27,8 +27,7 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# Populate room_id for existing ROOM-type transcripts
|
# Populate room_id for existing ROOM-type transcripts
|
||||||
# This joins through recording -> meeting -> room to get the room_id
|
# This joins through recording -> meeting -> room to get the room_id
|
||||||
op.execute(
|
op.execute("""
|
||||||
"""
|
|
||||||
UPDATE transcript AS t
|
UPDATE transcript AS t
|
||||||
SET room_id = r.id
|
SET room_id = r.id
|
||||||
FROM recording rec
|
FROM recording rec
|
||||||
@@ -37,13 +36,11 @@ def upgrade() -> None:
|
|||||||
WHERE t.recording_id = rec.id
|
WHERE t.recording_id = rec.id
|
||||||
AND t.source_kind = 'room'
|
AND t.source_kind = 'room'
|
||||||
AND t.room_id IS NULL
|
AND t.room_id IS NULL
|
||||||
"""
|
""")
|
||||||
)
|
|
||||||
|
|
||||||
# Fix missing meeting_id for ROOM-type transcripts
|
# Fix missing meeting_id for ROOM-type transcripts
|
||||||
# The meeting_id field exists but was never populated
|
# The meeting_id field exists but was never populated
|
||||||
op.execute(
|
op.execute("""
|
||||||
"""
|
|
||||||
UPDATE transcript AS t
|
UPDATE transcript AS t
|
||||||
SET meeting_id = rec.meeting_id
|
SET meeting_id = rec.meeting_id
|
||||||
FROM recording rec
|
FROM recording rec
|
||||||
@@ -51,8 +48,7 @@ def upgrade() -> None:
|
|||||||
AND t.source_kind = 'room'
|
AND t.source_kind = 'room'
|
||||||
AND t.meeting_id IS NULL
|
AND t.meeting_id IS NULL
|
||||||
AND rec.meeting_id IS NOT NULL
|
AND rec.meeting_id IS NOT NULL
|
||||||
"""
|
""")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
|||||||
@@ -1,129 +0,0 @@
|
|||||||
"""add calendar
|
|
||||||
|
|
||||||
Revision ID: d8e204bbf615
|
|
||||||
Revises: d4a1c446458c
|
|
||||||
Create Date: 2025-09-10 19:56:22.295756
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "d8e204bbf615"
|
|
||||||
down_revision: Union[str, None] = "d4a1c446458c"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table(
|
|
||||||
"calendar_event",
|
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
|
||||||
sa.Column("room_id", sa.String(), nullable=False),
|
|
||||||
sa.Column("ics_uid", sa.Text(), nullable=False),
|
|
||||||
sa.Column("title", sa.Text(), nullable=True),
|
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
|
||||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("attendees", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
|
||||||
sa.Column("location", sa.Text(), nullable=True),
|
|
||||||
sa.Column("ics_raw_data", sa.Text(), nullable=True),
|
|
||||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"is_deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
|
||||||
),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["room_id"],
|
|
||||||
["room.id"],
|
|
||||||
name="fk_calendar_event_room_id",
|
|
||||||
ondelete="CASCADE",
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
|
||||||
)
|
|
||||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
|
||||||
batch_op.create_index(
|
|
||||||
"idx_calendar_event_deleted",
|
|
||||||
["is_deleted"],
|
|
||||||
unique=False,
|
|
||||||
postgresql_where=sa.text("NOT is_deleted"),
|
|
||||||
)
|
|
||||||
batch_op.create_index(
|
|
||||||
"idx_calendar_event_room_start", ["room_id", "start_time"], unique=False
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(sa.Column("calendar_event_id", sa.String(), nullable=True))
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column(
|
|
||||||
"calendar_metadata",
|
|
||||||
postgresql.JSONB(astext_type=sa.Text()),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_op.create_index(
|
|
||||||
"idx_meeting_calendar_event", ["calendar_event_id"], unique=False
|
|
||||||
)
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
"fk_meeting_calendar_event_id",
|
|
||||||
"calendar_event",
|
|
||||||
["calendar_event_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="SET NULL",
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(sa.Column("ics_url", sa.Text(), nullable=True))
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column(
|
|
||||||
"ics_fetch_interval", sa.Integer(), server_default="300", nullable=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column(
|
|
||||||
"ics_enabled",
|
|
||||||
sa.Boolean(),
|
|
||||||
server_default=sa.text("false"),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column("ics_last_sync", sa.DateTime(timezone=True), nullable=True)
|
|
||||||
)
|
|
||||||
batch_op.add_column(sa.Column("ics_last_etag", sa.Text(), nullable=True))
|
|
||||||
batch_op.create_index("idx_room_ics_enabled", ["ics_enabled"], unique=False)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
|
||||||
batch_op.drop_index("idx_room_ics_enabled")
|
|
||||||
batch_op.drop_column("ics_last_etag")
|
|
||||||
batch_op.drop_column("ics_last_sync")
|
|
||||||
batch_op.drop_column("ics_enabled")
|
|
||||||
batch_op.drop_column("ics_fetch_interval")
|
|
||||||
batch_op.drop_column("ics_url")
|
|
||||||
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.drop_constraint("fk_meeting_calendar_event_id", type_="foreignkey")
|
|
||||||
batch_op.drop_index("idx_meeting_calendar_event")
|
|
||||||
batch_op.drop_column("calendar_metadata")
|
|
||||||
batch_op.drop_column("calendar_event_id")
|
|
||||||
|
|
||||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
|
||||||
batch_op.drop_index("idx_calendar_event_room_start")
|
|
||||||
batch_op.drop_index(
|
|
||||||
"idx_calendar_event_deleted", postgresql_where=sa.text("NOT is_deleted")
|
|
||||||
)
|
|
||||||
|
|
||||||
op.drop_table("calendar_event")
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""remove_grace_period_fields
|
|
||||||
|
|
||||||
Revision ID: dc035ff72fd5
|
|
||||||
Revises: d8e204bbf615
|
|
||||||
Create Date: 2025-09-11 10:36:45.197588
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "dc035ff72fd5"
|
|
||||||
down_revision: Union[str, None] = "d8e204bbf615"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Remove grace period columns from meeting table
|
|
||||||
op.drop_column("meeting", "last_participant_left_at")
|
|
||||||
op.drop_column("meeting", "grace_period_minutes")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Add back grace period columns to meeting table
|
|
||||||
op.add_column(
|
|
||||||
"meeting",
|
|
||||||
sa.Column(
|
|
||||||
"last_participant_left_at", sa.DateTime(timezone=True), nullable=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
op.add_column(
|
|
||||||
"meeting",
|
|
||||||
sa.Column(
|
|
||||||
"grace_period_minutes",
|
|
||||||
sa.Integer(),
|
|
||||||
server_default=sa.text("15"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""make meeting room_id nullable but keep foreign key
|
|
||||||
|
|
||||||
Revision ID: def1b5867d4c
|
|
||||||
Revises: 0ce521cda2ee
|
|
||||||
Create Date: 2025-09-11 09:42:18.697264
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "def1b5867d4c"
|
|
||||||
down_revision: Union[str, None] = "0ce521cda2ee"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=True)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
|
||||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=False)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -12,6 +12,7 @@ dependencies = [
|
|||||||
"requests>=2.31.0",
|
"requests>=2.31.0",
|
||||||
"aiortc>=1.5.0",
|
"aiortc>=1.5.0",
|
||||||
"sortedcontainers>=2.4.0",
|
"sortedcontainers>=2.4.0",
|
||||||
|
"loguru>=0.7.0",
|
||||||
"pydantic-settings>=2.0.2",
|
"pydantic-settings>=2.0.2",
|
||||||
"structlog>=23.1.0",
|
"structlog>=23.1.0",
|
||||||
"uvicorn[standard]>=0.23.1",
|
"uvicorn[standard]>=0.23.1",
|
||||||
@@ -19,16 +20,19 @@ dependencies = [
|
|||||||
"sentry-sdk[fastapi]>=1.29.2",
|
"sentry-sdk[fastapi]>=1.29.2",
|
||||||
"httpx>=0.24.1",
|
"httpx>=0.24.1",
|
||||||
"fastapi-pagination>=0.12.6",
|
"fastapi-pagination>=0.12.6",
|
||||||
"sqlalchemy>=2.0.0",
|
"databases[aiosqlite, asyncpg]>=0.7.0",
|
||||||
"asyncpg>=0.29.0",
|
"sqlalchemy<1.5",
|
||||||
"alembic>=1.11.3",
|
"alembic>=1.11.3",
|
||||||
"nltk>=3.8.1",
|
"nltk>=3.8.1",
|
||||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||||
"sentencepiece>=0.1.99",
|
"sentencepiece>=0.1.99",
|
||||||
"protobuf>=4.24.3",
|
"protobuf>=4.24.3",
|
||||||
|
"profanityfilter>=2.0.6",
|
||||||
|
"celery>=5.3.4",
|
||||||
"redis>=5.0.1",
|
"redis>=5.0.1",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
|
"faster-whisper>=0.10.0",
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"jsonschema>=4.23.0",
|
"jsonschema>=4.23.0",
|
||||||
"openai>=1.59.7",
|
"openai>=1.59.7",
|
||||||
@@ -38,8 +42,6 @@ dependencies = [
|
|||||||
"pytest-env>=1.1.5",
|
"pytest-env>=1.1.5",
|
||||||
"webvtt-py>=0.5.0",
|
"webvtt-py>=0.5.0",
|
||||||
"icalendar>=6.0.0",
|
"icalendar>=6.0.0",
|
||||||
"taskiq>=0.11.18",
|
|
||||||
"taskiq-redis>=1.1.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -47,7 +49,6 @@ dev = [
|
|||||||
"black>=24.1.1",
|
"black>=24.1.1",
|
||||||
"stamina>=23.1.0",
|
"stamina>=23.1.0",
|
||||||
"pyinstrument>=4.6.1",
|
"pyinstrument>=4.6.1",
|
||||||
"pytest-async-sqlalchemy>=0.2.0",
|
|
||||||
]
|
]
|
||||||
tests = [
|
tests = [
|
||||||
"pytest-cov>=4.1.0",
|
"pytest-cov>=4.1.0",
|
||||||
@@ -56,7 +57,7 @@ tests = [
|
|||||||
"pytest>=7.4.0",
|
"pytest>=7.4.0",
|
||||||
"httpx-ws>=0.4.1",
|
"httpx-ws>=0.4.1",
|
||||||
"pytest-httpx>=0.23.1",
|
"pytest-httpx>=0.23.1",
|
||||||
"pytest-recording>=0.13.4",
|
"pytest-celery>=0.0.0",
|
||||||
"pytest-docker>=3.2.3",
|
"pytest-docker>=3.2.3",
|
||||||
"asgi-lifespan>=2.1.0",
|
"asgi-lifespan>=2.1.0",
|
||||||
]
|
]
|
||||||
@@ -67,15 +68,6 @@ evaluation = [
|
|||||||
"tqdm>=4.66.0",
|
"tqdm>=4.66.0",
|
||||||
"pydantic>=2.1.1",
|
"pydantic>=2.1.1",
|
||||||
]
|
]
|
||||||
local = [
|
|
||||||
"pyannote-audio>=3.3.2",
|
|
||||||
"faster-whisper>=0.10.0",
|
|
||||||
]
|
|
||||||
silero-vad = [
|
|
||||||
"silero-vad>=5.1.2",
|
|
||||||
"torch>=2.8.0",
|
|
||||||
"torchaudio>=2.8.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
default-groups = [
|
default-groups = [
|
||||||
@@ -83,21 +75,6 @@ default-groups = [
|
|||||||
"tests",
|
"tests",
|
||||||
"aws",
|
"aws",
|
||||||
"evaluation",
|
"evaluation",
|
||||||
"local",
|
|
||||||
"silero-vad"
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = [
|
|
||||||
{ index = "pytorch-cpu" },
|
|
||||||
]
|
|
||||||
torchaudio = [
|
|
||||||
{ index = "pytorch-cpu" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@@ -112,18 +89,12 @@ source = ["reflector"]
|
|||||||
|
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
ENVIRONMENT = "pytest"
|
ENVIRONMENT = "pytest"
|
||||||
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test"
|
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
asyncio_debug = true
|
|
||||||
asyncio_default_fixture_loop_scope = "session"
|
|
||||||
asyncio_default_test_loop_scope = "session"
|
|
||||||
markers = [
|
|
||||||
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -134,7 +105,7 @@ select = [
|
|||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||||
"gpu/modal_deployments/**.py" = ["PLC0415"]
|
"gpu/**.py" = ["PLC0415"]
|
||||||
"reflector/tools/**.py" = ["PLC0415"]
|
"reflector/tools/**.py" = ["PLC0415"]
|
||||||
"migrations/versions/**.py" = ["PLC0415"]
|
"migrations/versions/**.py" = ["PLC0415"]
|
||||||
"tests/**.py" = ["PLC0415"]
|
"tests/**.py" = ["PLC0415"]
|
||||||
|
|||||||
@@ -88,8 +88,8 @@ app.include_router(zulip_router, prefix="/v1")
|
|||||||
app.include_router(whereby_router, prefix="/v1")
|
app.include_router(whereby_router, prefix="/v1")
|
||||||
add_pagination(app)
|
add_pagination(app)
|
||||||
|
|
||||||
# prepare taskiq
|
# prepare celery
|
||||||
from reflector.worker import app as taskiq_app # noqa
|
from reflector.worker import app as celery_app # noqa
|
||||||
|
|
||||||
|
|
||||||
# simpler openapi id
|
# simpler openapi id
|
||||||
|
|||||||
@@ -67,8 +67,7 @@ def current_user(
|
|||||||
try:
|
try:
|
||||||
payload = jwtauth.verify_token(token)
|
payload = jwtauth.verify_token(token)
|
||||||
sub = payload["sub"]
|
sub = payload["sub"]
|
||||||
email = payload["email"]
|
return UserInfo(sub=sub)
|
||||||
return UserInfo(sub=sub, email=email)
|
|
||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
logger.error(f"JWT error: {e}")
|
logger.error(f"JWT error: {e}")
|
||||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||||
|
|||||||
@@ -1,82 +1,48 @@
|
|||||||
from contextlib import asynccontextmanager
|
import contextvars
|
||||||
from typing import AsyncGenerator
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import (
|
import databases
|
||||||
AsyncEngine,
|
import sqlalchemy
|
||||||
AsyncSession,
|
|
||||||
async_sessionmaker,
|
|
||||||
create_async_engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
from reflector.db.base import Base as Base
|
|
||||||
from reflector.db.base import metadata as metadata
|
|
||||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
metadata = sqlalchemy.MetaData()
|
||||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
|
||||||
|
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||||
|
contextvars.ContextVar("database", default=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_engine() -> AsyncEngine:
|
def get_database() -> databases.Database:
|
||||||
global _engine
|
"""Get database instance for current asyncio context"""
|
||||||
if _engine is None:
|
db = _database_context.get()
|
||||||
_engine = create_async_engine(
|
if db is None:
|
||||||
settings.DATABASE_URL,
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
echo=False,
|
_database_context.set(db)
|
||||||
pool_pre_ping=True,
|
return db
|
||||||
)
|
|
||||||
return _engine
|
|
||||||
|
|
||||||
|
|
||||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
|
||||||
global _session_factory
|
|
||||||
if _session_factory is None:
|
|
||||||
_session_factory = async_sessionmaker(
|
|
||||||
get_engine(),
|
|
||||||
class_=AsyncSession,
|
|
||||||
expire_on_commit=False,
|
|
||||||
)
|
|
||||||
return _session_factory
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
# necessary implementation to ease mocking on pytest
|
|
||||||
async with get_session_factory()() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""
|
|
||||||
Get a database session, fastapi dependency injection style
|
|
||||||
"""
|
|
||||||
async for session in _get_session():
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_session_context():
|
|
||||||
"""
|
|
||||||
Get a database session as an async context manager
|
|
||||||
"""
|
|
||||||
async for session in _get_session():
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
|
# import models
|
||||||
import reflector.db.calendar_events # noqa
|
import reflector.db.calendar_events # noqa
|
||||||
import reflector.db.meetings # noqa
|
import reflector.db.meetings # noqa
|
||||||
import reflector.db.recordings # noqa
|
import reflector.db.recordings # noqa
|
||||||
import reflector.db.rooms # noqa
|
import reflector.db.rooms # noqa
|
||||||
import reflector.db.transcripts # noqa
|
import reflector.db.transcripts # noqa
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if "postgres" not in settings.DATABASE_URL:
|
||||||
|
raise Exception("Only postgres database is supported in reflector")
|
||||||
|
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@subscribers_startup.append
|
@subscribers_startup.append
|
||||||
async def database_connect(_):
|
async def database_connect(_):
|
||||||
get_engine()
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
|
||||||
|
|
||||||
@subscribers_shutdown.append
|
@subscribers_shutdown.append
|
||||||
async def database_disconnect(_):
|
async def database_disconnect(_):
|
||||||
global _engine
|
database = get_database()
|
||||||
if _engine:
|
await database.disconnect()
|
||||||
await _engine.dispose()
|
|
||||||
_engine = None
|
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
||||||
|
|
||||||
|
|
||||||
class Base(AsyncAttrs, DeclarativeBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptModel(Base):
|
|
||||||
__tablename__ = "transcript"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
name: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
status: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
|
||||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
|
||||||
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
title: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
reviewed: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
audio_location: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="local"
|
|
||||||
)
|
|
||||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
share_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="private"
|
|
||||||
)
|
|
||||||
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
|
||||||
source_kind: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False
|
|
||||||
) # Enum will be handled separately
|
|
||||||
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
|
||||||
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_transcript_recording_id", "recording_id"),
|
|
||||||
sa.Index("idx_transcript_user_id", "user_id"),
|
|
||||||
sa.Index("idx_transcript_created_at", "created_at"),
|
|
||||||
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
|
||||||
sa.Index("idx_transcript_room_id", "room_id"),
|
|
||||||
sa.Index("idx_transcript_source_kind", "source_kind"),
|
|
||||||
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TranscriptModel.search_vector_en = sa.Column(
|
|
||||||
"search_vector_en",
|
|
||||||
TSVECTOR,
|
|
||||||
sa.Computed(
|
|
||||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
|
||||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
|
||||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
|
||||||
persisted=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoomModel(Base):
|
|
||||||
__tablename__ = "room"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
|
|
||||||
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
zulip_auto_post: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
is_locked: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
room_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="normal"
|
|
||||||
)
|
|
||||||
recording_type: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="cloud"
|
|
||||||
)
|
|
||||||
recording_trigger: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
|
||||||
)
|
|
||||||
is_shared: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
|
|
||||||
sa.Integer, server_default=sa.text("300")
|
|
||||||
)
|
|
||||||
ics_enabled: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True)
|
|
||||||
)
|
|
||||||
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_room_is_shared", "is_shared"),
|
|
||||||
sa.Index("idx_room_ics_enabled", "ics_enabled"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingModel(Base):
|
|
||||||
__tablename__ = "meeting"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
room_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
|
|
||||||
)
|
|
||||||
is_locked: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
room_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="normal"
|
|
||||||
)
|
|
||||||
recording_type: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="cloud"
|
|
||||||
)
|
|
||||||
recording_trigger: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
|
||||||
)
|
|
||||||
num_clients: Mapped[int] = mapped_column(
|
|
||||||
sa.Integer, nullable=False, server_default=sa.text("0")
|
|
||||||
)
|
|
||||||
is_active: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("true")
|
|
||||||
)
|
|
||||||
calendar_event_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
sa.String,
|
|
||||||
sa.ForeignKey(
|
|
||||||
"calendar_event.id",
|
|
||||||
ondelete="SET NULL",
|
|
||||||
name="fk_meeting_calendar_event_id",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_meeting_room_id", "room_id"),
|
|
||||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsentModel(Base):
|
|
||||||
__tablename__ = "meeting_consent"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
meeting_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
|
|
||||||
consent_timestamp: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RecordingModel(Base):
|
|
||||||
__tablename__ = "recording"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
meeting_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
url: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class CalendarEventModel(Base):
|
|
||||||
__tablename__ = "calendar_event"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
room_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
title: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
description: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
start_time: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
end_time: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
|
|
||||||
location: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
last_synced: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
is_deleted: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metadata = Base.metadata
|
|
||||||
@@ -1,18 +1,46 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import CalendarEventModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
calendar_events = sa.Table(
|
||||||
|
"calendar_event",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"room_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey("room.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("ics_uid", sa.Text, nullable=False),
|
||||||
|
sa.Column("title", sa.Text),
|
||||||
|
sa.Column("description", sa.Text),
|
||||||
|
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("attendees", JSONB),
|
||||||
|
sa.Column("location", sa.Text),
|
||||||
|
sa.Column("ics_raw_data", sa.Text),
|
||||||
|
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||||
|
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
||||||
|
sa.Index(
|
||||||
|
"idx_calendar_event_deleted",
|
||||||
|
"is_deleted",
|
||||||
|
postgresql_where=sa.text("NOT is_deleted"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CalendarEvent(BaseModel):
|
class CalendarEvent(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
room_id: str
|
room_id: str
|
||||||
ics_uid: str
|
ics_uid: str
|
||||||
@@ -30,160 +58,136 @@ class CalendarEvent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class CalendarEventController:
|
class CalendarEventController:
|
||||||
async def get_upcoming_events(
|
|
||||||
self,
|
|
||||||
session: AsyncSession,
|
|
||||||
room_id: str,
|
|
||||||
current_time: datetime,
|
|
||||||
buffer_minutes: int = 15,
|
|
||||||
) -> list[CalendarEvent]:
|
|
||||||
buffer_time = current_time + timedelta(minutes=buffer_minutes)
|
|
||||||
|
|
||||||
query = (
|
|
||||||
select(CalendarEventModel)
|
|
||||||
.where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.start_time <= buffer_time,
|
|
||||||
CalendarEventModel.end_time > current_time,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(CalendarEventModel.start_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
|
||||||
|
|
||||||
async def get_by_id(
|
|
||||||
self, session: AsyncSession, event_id: str
|
|
||||||
) -> CalendarEvent | None:
|
|
||||||
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return CalendarEvent.model_validate(row)
|
|
||||||
|
|
||||||
async def get_by_ics_uid(
|
|
||||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
|
||||||
) -> CalendarEvent | None:
|
|
||||||
query = select(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.ics_uid == ics_uid,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return CalendarEvent.model_validate(row)
|
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self, session: AsyncSession, event: CalendarEvent
|
|
||||||
) -> CalendarEvent:
|
|
||||||
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
event.updated_at = datetime.now(timezone.utc)
|
|
||||||
query = (
|
|
||||||
update(CalendarEventModel)
|
|
||||||
.where(CalendarEventModel.id == existing.id)
|
|
||||||
.values(**event.model_dump(exclude={"id"}))
|
|
||||||
)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return event
|
|
||||||
else:
|
|
||||||
new_event = CalendarEventModel(**event.model_dump())
|
|
||||||
session.add(new_event)
|
|
||||||
await session.commit()
|
|
||||||
return event
|
|
||||||
|
|
||||||
async def delete_old_events(
|
|
||||||
self, session: AsyncSession, room_id: str, cutoff_date: datetime
|
|
||||||
) -> int:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.end_time < cutoff_date,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
async def delete_events_not_in_list(
|
|
||||||
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
|
|
||||||
) -> int:
|
|
||||||
if not keep_ics_uids:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
CalendarEventModel.room_id == room_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
async def get_by_room(
|
async def get_by_room(
|
||||||
self, session: AsyncSession, room_id: str, include_deleted: bool = True
|
self,
|
||||||
|
room_id: str,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
start_after: datetime | None = None,
|
||||||
|
end_before: datetime | None = None,
|
||||||
) -> list[CalendarEvent]:
|
) -> list[CalendarEvent]:
|
||||||
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id)
|
"""Get calendar events for a room."""
|
||||||
|
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
|
||||||
|
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
query = query.where(CalendarEventModel.is_deleted == False)
|
query = query.where(calendar_events.c.is_deleted == False)
|
||||||
result = await session.execute(query)
|
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
if start_after:
|
||||||
|
query = query.where(calendar_events.c.start_time >= start_after)
|
||||||
|
|
||||||
|
if end_before:
|
||||||
|
query = query.where(calendar_events.c.end_time <= end_before)
|
||||||
|
|
||||||
|
query = query.order_by(calendar_events.c.start_time.asc())
|
||||||
|
|
||||||
|
results = await get_database().fetch_all(query)
|
||||||
|
return [CalendarEvent(**result) for result in results]
|
||||||
|
|
||||||
async def get_upcoming(
|
async def get_upcoming(
|
||||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
self, room_id: str, minutes_ahead: int = 30
|
||||||
) -> list[CalendarEvent]:
|
) -> list[CalendarEvent]:
|
||||||
|
"""Get upcoming events for a room within the specified minutes."""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
buffer_time = now + timedelta(minutes=minutes_ahead)
|
future_time = now + timedelta(minutes=minutes_ahead)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
select(CalendarEventModel)
|
calendar_events.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
CalendarEventModel.room_id == room_id,
|
calendar_events.c.room_id == room_id,
|
||||||
CalendarEventModel.start_time <= buffer_time,
|
calendar_events.c.is_deleted == False,
|
||||||
CalendarEventModel.end_time > now,
|
calendar_events.c.start_time >= now,
|
||||||
CalendarEventModel.is_deleted == False,
|
calendar_events.c.start_time <= future_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(CalendarEventModel.start_time)
|
.order_by(calendar_events.c.start_time.asc())
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
return [CalendarEvent(**result) for result in results]
|
||||||
|
|
||||||
|
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
|
||||||
|
"""Get a calendar event by its ICS UID."""
|
||||||
|
query = calendar_events.select().where(
|
||||||
|
sa.and_(
|
||||||
|
calendar_events.c.room_id == room_id,
|
||||||
|
calendar_events.c.ics_uid == ics_uid,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await get_database().fetch_one(query)
|
||||||
|
return CalendarEvent(**result) if result else None
|
||||||
|
|
||||||
|
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
|
||||||
|
"""Create or update a calendar event."""
|
||||||
|
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Update existing event
|
||||||
|
event.id = existing.id
|
||||||
|
event.created_at = existing.created_at
|
||||||
|
event.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
calendar_events.update()
|
||||||
|
.where(calendar_events.c.id == existing.id)
|
||||||
|
.values(**event.model_dump())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Insert new event
|
||||||
|
query = calendar_events.insert().values(**event.model_dump())
|
||||||
|
|
||||||
|
await get_database().execute(query)
|
||||||
|
return event
|
||||||
|
|
||||||
async def soft_delete_missing(
|
async def soft_delete_missing(
|
||||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
self, room_id: str, current_ics_uids: list[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
query = (
|
"""Soft delete future events that are no longer in the calendar."""
|
||||||
update(CalendarEventModel)
|
now = datetime.now(timezone.utc)
|
||||||
.where(
|
|
||||||
sa.and_(
|
# First, get the IDs of events to delete
|
||||||
CalendarEventModel.room_id == room_id,
|
select_query = calendar_events.select().where(
|
||||||
(
|
sa.and_(
|
||||||
CalendarEventModel.ics_uid.notin_(current_ics_uids)
|
calendar_events.c.room_id == room_id,
|
||||||
if current_ics_uids
|
calendar_events.c.start_time > now,
|
||||||
else True
|
calendar_events.c.is_deleted == False,
|
||||||
),
|
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||||
CalendarEventModel.end_time > datetime.now(timezone.utc),
|
if current_ics_uids
|
||||||
)
|
else True,
|
||||||
)
|
)
|
||||||
.values(is_deleted=True)
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
|
||||||
await session.commit()
|
to_delete = await get_database().fetch_all(select_query)
|
||||||
|
delete_count = len(to_delete)
|
||||||
|
|
||||||
|
if delete_count > 0:
|
||||||
|
# Now update them
|
||||||
|
update_query = (
|
||||||
|
calendar_events.update()
|
||||||
|
.where(
|
||||||
|
sa.and_(
|
||||||
|
calendar_events.c.room_id == room_id,
|
||||||
|
calendar_events.c.start_time > now,
|
||||||
|
calendar_events.c.is_deleted == False,
|
||||||
|
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||||
|
if current_ics_uids
|
||||||
|
else True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(is_deleted=True, updated_at=now)
|
||||||
|
)
|
||||||
|
|
||||||
|
await get_database().execute(update_query)
|
||||||
|
|
||||||
|
return delete_count
|
||||||
|
|
||||||
|
async def delete_by_room(self, room_id: str) -> int:
|
||||||
|
"""Hard delete all events for a room (used when room is deleted)."""
|
||||||
|
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
|
||||||
|
result = await get_database().execute(query)
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
# Add missing import
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
calendar_events_controller = CalendarEventController()
|
calendar_events_controller = CalendarEventController()
|
||||||
|
|||||||
@@ -2,18 +2,70 @@ from datetime import datetime
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from fastapi import HTTPException
|
||||||
from sqlalchemy import select, update
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
from reflector.db.base import MeetingConsentModel, MeetingModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.rooms import Room
|
from reflector.db.rooms import Room
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
meetings = sa.Table(
|
||||||
|
"meeting",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("room_name", sa.String),
|
||||||
|
sa.Column("room_url", sa.String),
|
||||||
|
sa.Column("host_room_url", sa.String),
|
||||||
|
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||||
|
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||||
|
sa.Column("user_id", sa.String),
|
||||||
|
sa.Column("room_id", sa.String),
|
||||||
|
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||||
|
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
|
||||||
|
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
|
||||||
|
sa.Column(
|
||||||
|
"recording_trigger",
|
||||||
|
sa.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="automatic-2nd-participant",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"num_clients",
|
||||||
|
sa.Integer,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("0"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"is_active",
|
||||||
|
sa.Boolean,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.true(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"calendar_event_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey("calendar_event.id", ondelete="SET NULL"),
|
||||||
|
),
|
||||||
|
sa.Column("calendar_metadata", JSONB),
|
||||||
|
sa.Column("last_participant_left_at", sa.DateTime(timezone=True)),
|
||||||
|
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||||
|
sa.Index("idx_meeting_room_id", "room_id"),
|
||||||
|
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
meeting_consent = sa.Table(
|
||||||
|
"meeting_consent",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||||
|
sa.Column("user_id", sa.String),
|
||||||
|
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||||
|
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsent(BaseModel):
|
class MeetingConsent(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
meeting_id: str
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
@@ -22,15 +74,14 @@ class MeetingConsent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Meeting(BaseModel):
|
class Meeting(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
room_name: str
|
room_name: str
|
||||||
room_url: str
|
room_url: str
|
||||||
host_room_url: str
|
host_room_url: str
|
||||||
start_date: datetime
|
start_date: datetime
|
||||||
end_date: datetime
|
end_date: datetime
|
||||||
room_id: str | None
|
user_id: str | None = None
|
||||||
|
room_id: str | None = None
|
||||||
is_locked: bool = False
|
is_locked: bool = False
|
||||||
room_mode: Literal["normal", "group"] = "normal"
|
room_mode: Literal["normal", "group"] = "normal"
|
||||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||||
@@ -41,22 +92,27 @@ class Meeting(BaseModel):
|
|||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
calendar_event_id: str | None = None
|
calendar_event_id: str | None = None
|
||||||
calendar_metadata: dict[str, Any] | None = None
|
calendar_metadata: dict[str, Any] | None = None
|
||||||
|
last_participant_left_at: datetime | None = None
|
||||||
|
grace_period_minutes: int = 15
|
||||||
|
|
||||||
|
|
||||||
class MeetingController:
|
class MeetingController:
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
id: str,
|
id: str,
|
||||||
room_name: str,
|
room_name: str,
|
||||||
room_url: str,
|
room_url: str,
|
||||||
host_room_url: str,
|
host_room_url: str,
|
||||||
start_date: datetime,
|
start_date: datetime,
|
||||||
end_date: datetime,
|
end_date: datetime,
|
||||||
|
user_id: str,
|
||||||
room: Room,
|
room: Room,
|
||||||
calendar_event_id: str | None = None,
|
calendar_event_id: str | None = None,
|
||||||
calendar_metadata: dict[str, Any] | None = None,
|
calendar_metadata: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Create a new meeting
|
||||||
|
"""
|
||||||
meeting = Meeting(
|
meeting = Meeting(
|
||||||
id=id,
|
id=id,
|
||||||
room_name=room_name,
|
room_name=room_name,
|
||||||
@@ -64,6 +120,7 @@ class MeetingController:
|
|||||||
host_room_url=host_room_url,
|
host_room_url=host_room_url,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
|
user_id=user_id,
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
is_locked=room.is_locked,
|
is_locked=room.is_locked,
|
||||||
room_mode=room.room_mode,
|
room_mode=room.room_mode,
|
||||||
@@ -72,198 +129,192 @@ class MeetingController:
|
|||||||
calendar_event_id=calendar_event_id,
|
calendar_event_id=calendar_event_id,
|
||||||
calendar_metadata=calendar_metadata,
|
calendar_metadata=calendar_metadata,
|
||||||
)
|
)
|
||||||
new_meeting = MeetingModel(**meeting.model_dump())
|
query = meetings.insert().values(**meeting.model_dump())
|
||||||
session.add(new_meeting)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return meeting
|
return meeting
|
||||||
|
|
||||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
async def get_all_active(self) -> list[Meeting]:
|
||||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
"""
|
||||||
result = await session.execute(query)
|
Get active meetings.
|
||||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
"""
|
||||||
|
query = meetings.select().where(meetings.c.is_active)
|
||||||
|
return await get_database().fetch_all(query)
|
||||||
|
|
||||||
async def get_by_room_name(
|
async def get_by_room_name(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
room_name: str,
|
room_name: str,
|
||||||
) -> Meeting | None:
|
) -> Meeting:
|
||||||
"""
|
"""
|
||||||
Get a meeting by room name.
|
Get a meeting by room name.
|
||||||
For backward compatibility, returns the most recent meeting.
|
|
||||||
"""
|
"""
|
||||||
query = (
|
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||||
select(MeetingModel)
|
result = await get_database().fetch_one(query)
|
||||||
.where(MeetingModel.room_name == room_name)
|
if not result:
|
||||||
.order_by(MeetingModel.end_date.desc())
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
|
||||||
|
|
||||||
async def get_active(
|
return Meeting(**result)
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
|
||||||
) -> Meeting | None:
|
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
|
||||||
"""
|
"""
|
||||||
Get latest active meeting for a room.
|
Get latest active meeting for a room.
|
||||||
For backward compatibility, returns the most recent active meeting.
|
For backward compatibility, returns the most recent active meeting.
|
||||||
"""
|
"""
|
||||||
|
end_date = getattr(meetings.c, "end_date")
|
||||||
query = (
|
query = (
|
||||||
select(MeetingModel)
|
meetings.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
|
||||||
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_all_active_for_room(
|
async def get_all_active_for_room(
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
self, room: Room, current_time: datetime
|
||||||
) -> list[Meeting]:
|
) -> list[Meeting]:
|
||||||
|
"""
|
||||||
|
Get all active meetings for a room.
|
||||||
|
This supports multiple concurrent meetings per room.
|
||||||
|
"""
|
||||||
|
end_date = getattr(meetings.c, "end_date")
|
||||||
query = (
|
query = (
|
||||||
select(MeetingModel)
|
meetings.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
return [Meeting(**result) for result in results]
|
||||||
|
|
||||||
async def get_active_by_calendar_event(
|
async def get_active_by_calendar_event(
|
||||||
self,
|
self, room: Room, calendar_event_id: str, current_time: datetime
|
||||||
session: AsyncSession,
|
|
||||||
room: Room,
|
|
||||||
calendar_event_id: str,
|
|
||||||
current_time: datetime,
|
|
||||||
) -> Meeting | None:
|
) -> Meeting | None:
|
||||||
"""
|
"""
|
||||||
Get active meeting for a specific calendar event.
|
Get active meeting for a specific calendar event.
|
||||||
"""
|
"""
|
||||||
query = select(MeetingModel).where(
|
query = meetings.select().where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.calendar_event_id == calendar_event_id,
|
meetings.c.calendar_event_id == calendar_event_id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
"""
|
||||||
) -> Meeting | None:
|
Get a meeting by id
|
||||||
query = select(MeetingModel).where(MeetingModel.id == meeting_id)
|
"""
|
||||||
result = await session.execute(query)
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
row = result.scalar_one_or_none()
|
result = await get_database().fetch_one(query)
|
||||||
if not row:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_by_calendar_event(
|
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Meeting:
|
||||||
self, session: AsyncSession, calendar_event_id: str
|
"""
|
||||||
) -> Meeting | None:
|
Get a meeting by ID for HTTP request.
|
||||||
query = select(MeetingModel).where(
|
|
||||||
MeetingModel.calendar_event_id == calendar_event_id
|
If not found, it will raise a 404 error.
|
||||||
|
"""
|
||||||
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
|
result = await get_database().fetch_one(query)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|
||||||
|
meeting = Meeting(**result)
|
||||||
|
if result["user_id"] != user_id:
|
||||||
|
meeting.host_room_url = ""
|
||||||
|
|
||||||
|
return meeting
|
||||||
|
|
||||||
|
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
|
||||||
|
query = meetings.select().where(
|
||||||
|
meetings.c.calendar_event_id == calendar_event_id
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||||
query = (
|
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||||
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs)
|
await get_database().execute(query)
|
||||||
)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsentController:
|
class MeetingConsentController:
|
||||||
async def get_by_meeting_id(
|
async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
|
||||||
self, session: AsyncSession, meeting_id: str
|
query = meeting_consent.select().where(
|
||||||
) -> list[MeetingConsent]:
|
meeting_consent.c.meeting_id == meeting_id
|
||||||
query = select(MeetingConsentModel).where(
|
|
||||||
MeetingConsentModel.meeting_id == meeting_id
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
return [MeetingConsent(**result) for result in results]
|
||||||
|
|
||||||
async def get_by_meeting_and_user(
|
async def get_by_meeting_and_user(
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
self, meeting_id: str, user_id: str
|
||||||
) -> MeetingConsent | None:
|
) -> MeetingConsent | None:
|
||||||
"""Get existing consent for a specific user and meeting"""
|
"""Get existing consent for a specific user and meeting"""
|
||||||
query = select(MeetingConsentModel).where(
|
query = meeting_consent.select().where(
|
||||||
sa.and_(
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
MeetingConsentModel.meeting_id == meeting_id,
|
meeting_consent.c.user_id == user_id,
|
||||||
MeetingConsentModel.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if result is None:
|
||||||
if row is None:
|
|
||||||
return None
|
return None
|
||||||
return MeetingConsent.model_validate(row)
|
return MeetingConsent(**result) if result else None
|
||||||
|
|
||||||
async def upsert(
|
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
|
||||||
self, session: AsyncSession, consent: MeetingConsent
|
"""Create new consent or update existing one for authenticated users"""
|
||||||
) -> MeetingConsent:
|
|
||||||
if consent.user_id:
|
if consent.user_id:
|
||||||
# For authenticated users, check if consent already exists
|
# For authenticated users, check if consent already exists
|
||||||
# not transactional but we're ok with that; the consents ain't deleted anyways
|
# not transactional but we're ok with that; the consents ain't deleted anyways
|
||||||
existing = await self.get_by_meeting_and_user(
|
existing = await self.get_by_meeting_and_user(
|
||||||
session, consent.meeting_id, consent.user_id
|
consent.meeting_id, consent.user_id
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
query = (
|
query = (
|
||||||
update(MeetingConsentModel)
|
meeting_consent.update()
|
||||||
.where(MeetingConsentModel.id == existing.id)
|
.where(meeting_consent.c.id == existing.id)
|
||||||
.values(
|
.values(
|
||||||
consent_given=consent.consent_given,
|
consent_given=consent.consent_given,
|
||||||
consent_timestamp=consent.consent_timestamp,
|
consent_timestamp=consent.consent_timestamp,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
existing.consent_given = consent.consent_given
|
existing.consent_given = consent.consent_given
|
||||||
existing.consent_timestamp = consent.consent_timestamp
|
existing.consent_timestamp = consent.consent_timestamp
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
new_consent = MeetingConsentModel(**consent.model_dump())
|
query = meeting_consent.insert().values(**consent.model_dump())
|
||||||
session.add(new_consent)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return consent
|
return consent
|
||||||
|
|
||||||
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool:
|
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||||
"""Check if any participant denied consent for this meeting"""
|
"""Check if any participant denied consent for this meeting"""
|
||||||
query = select(MeetingConsentModel).where(
|
query = meeting_consent.select().where(
|
||||||
sa.and_(
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
MeetingConsentModel.meeting_id == meeting_id,
|
meeting_consent.c.consent_given.is_(False),
|
||||||
MeetingConsentModel.consent_given.is_(False),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
return result is not None
|
||||||
return row is not None
|
|
||||||
|
|
||||||
|
|
||||||
meetings_controller = MeetingController()
|
meetings_controller = MeetingController()
|
||||||
|
|||||||
@@ -1,79 +1,61 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, select
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import RecordingModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
recordings = sa.Table(
|
||||||
|
"recording",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("bucket_name", sa.String, nullable=False),
|
||||||
|
sa.Column("object_key", sa.String, nullable=False),
|
||||||
|
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="pending",
|
||||||
|
),
|
||||||
|
sa.Column("meeting_id", sa.String),
|
||||||
|
sa.Index("idx_recording_meeting_id", "meeting_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Recording(BaseModel):
|
class Recording(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
bucket_name: str
|
||||||
url: str
|
|
||||||
object_key: str
|
object_key: str
|
||||||
duration: float | None = None
|
recorded_at: datetime
|
||||||
created_at: datetime
|
status: Literal["pending", "processing", "completed", "failed"] = "pending"
|
||||||
|
meeting_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class RecordingController:
|
class RecordingController:
|
||||||
async def create(
|
async def create(self, recording: Recording):
|
||||||
self,
|
query = recordings.insert().values(**recording.model_dump())
|
||||||
session: AsyncSession,
|
await get_database().execute(query)
|
||||||
meeting_id: str,
|
|
||||||
url: str,
|
|
||||||
object_key: str,
|
|
||||||
duration: float | None = None,
|
|
||||||
created_at: datetime | None = None,
|
|
||||||
):
|
|
||||||
if created_at is None:
|
|
||||||
created_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
recording = Recording(
|
|
||||||
meeting_id=meeting_id,
|
|
||||||
url=url,
|
|
||||||
object_key=object_key,
|
|
||||||
duration=duration,
|
|
||||||
created_at=created_at,
|
|
||||||
)
|
|
||||||
new_recording = RecordingModel(**recording.model_dump())
|
|
||||||
session.add(new_recording)
|
|
||||||
await session.commit()
|
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, id: str) -> Recording:
|
||||||
self, session: AsyncSession, recording_id: str
|
query = recordings.select().where(recordings.c.id == id)
|
||||||
) -> Recording | None:
|
result = await get_database().fetch_one(query)
|
||||||
"""
|
return Recording(**result) if result else None
|
||||||
Get a recording by id
|
|
||||||
"""
|
|
||||||
query = select(RecordingModel).where(RecordingModel.id == recording_id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return Recording.model_validate(row)
|
|
||||||
|
|
||||||
async def get_by_meeting_id(
|
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||||
self, session: AsyncSession, meeting_id: str
|
query = recordings.select().where(
|
||||||
) -> list[Recording]:
|
recordings.c.bucket_name == bucket_name,
|
||||||
"""
|
recordings.c.object_key == object_key,
|
||||||
Get all recordings for a meeting
|
)
|
||||||
"""
|
result = await get_database().fetch_one(query)
|
||||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
return Recording(**result) if result else None
|
||||||
result = await session.execute(query)
|
|
||||||
return [Recording.model_validate(row) for row in result.scalars().all()]
|
|
||||||
|
|
||||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
async def remove_by_id(self, id: str) -> None:
|
||||||
"""
|
query = recordings.delete().where(recordings.c.id == id)
|
||||||
Remove a recording by id
|
await get_database().execute(query)
|
||||||
"""
|
|
||||||
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
recordings_controller = RecordingController()
|
recordings_controller = RecordingController()
|
||||||
|
|||||||
@@ -1,21 +1,58 @@
|
|||||||
import secrets
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from sqlite3 import IntegrityError
|
from sqlite3 import IntegrityError
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy.sql import false, or_
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.sql import or_
|
|
||||||
|
|
||||||
from reflector.db.base import RoomModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
rooms = sqlalchemy.Table(
|
||||||
|
"room",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"recording_trigger",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="automatic-2nd-participant",
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("ics_url", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
|
||||||
|
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
|
||||||
|
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Room(BaseModel):
|
class Room(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
name: str
|
name: str
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -30,8 +67,6 @@ class Room(BaseModel):
|
|||||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||||
] = "automatic-2nd-participant"
|
] = "automatic-2nd-participant"
|
||||||
is_shared: bool = False
|
is_shared: bool = False
|
||||||
webhook_url: str | None = None
|
|
||||||
webhook_secret: str | None = None
|
|
||||||
ics_url: str | None = None
|
ics_url: str | None = None
|
||||||
ics_fetch_interval: int = 300
|
ics_fetch_interval: int = 300
|
||||||
ics_enabled: bool = False
|
ics_enabled: bool = False
|
||||||
@@ -42,7 +77,6 @@ class Room(BaseModel):
|
|||||||
class RoomController:
|
class RoomController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
return_query: bool = False,
|
return_query: bool = False,
|
||||||
@@ -56,14 +90,14 @@ class RoomController:
|
|||||||
Parameters:
|
Parameters:
|
||||||
- `order_by`: field to order by, e.g. "-created_at"
|
- `order_by`: field to order by, e.g. "-created_at"
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel)
|
query = rooms.select()
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared))
|
query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
|
||||||
else:
|
else:
|
||||||
query = query.where(RoomModel.is_shared)
|
query = query.where(rooms.c.is_shared)
|
||||||
|
|
||||||
if order_by is not None:
|
if order_by is not None:
|
||||||
field = getattr(RoomModel, order_by[1:])
|
field = getattr(rooms.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
@@ -71,12 +105,11 @@ class RoomController:
|
|||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [Room.model_validate(row) for row in result.scalars().all()]
|
return results
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
name: str,
|
name: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
zulip_auto_post: bool,
|
zulip_auto_post: bool,
|
||||||
@@ -87,8 +120,6 @@ class RoomController:
|
|||||||
recording_type: str,
|
recording_type: str,
|
||||||
recording_trigger: str,
|
recording_trigger: str,
|
||||||
is_shared: bool,
|
is_shared: bool,
|
||||||
webhook_url: str = "",
|
|
||||||
webhook_secret: str = "",
|
|
||||||
ics_url: str | None = None,
|
ics_url: str | None = None,
|
||||||
ics_fetch_interval: int = 300,
|
ics_fetch_interval: int = 300,
|
||||||
ics_enabled: bool = False,
|
ics_enabled: bool = False,
|
||||||
@@ -96,9 +127,6 @@ class RoomController:
|
|||||||
"""
|
"""
|
||||||
Add a new room
|
Add a new room
|
||||||
"""
|
"""
|
||||||
if webhook_url and not webhook_secret:
|
|
||||||
webhook_secret = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
room = Room(
|
room = Room(
|
||||||
name=name,
|
name=name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -110,33 +138,24 @@ class RoomController:
|
|||||||
recording_type=recording_type,
|
recording_type=recording_type,
|
||||||
recording_trigger=recording_trigger,
|
recording_trigger=recording_trigger,
|
||||||
is_shared=is_shared,
|
is_shared=is_shared,
|
||||||
webhook_url=webhook_url,
|
|
||||||
webhook_secret=webhook_secret,
|
|
||||||
ics_url=ics_url,
|
ics_url=ics_url,
|
||||||
ics_fetch_interval=ics_fetch_interval,
|
ics_fetch_interval=ics_fetch_interval,
|
||||||
ics_enabled=ics_enabled,
|
ics_enabled=ics_enabled,
|
||||||
)
|
)
|
||||||
new_room = RoomModel(**room.model_dump())
|
query = rooms.insert().values(**room.model_dump())
|
||||||
session.add(new_room)
|
|
||||||
try:
|
try:
|
||||||
await session.flush()
|
await get_database().execute(query)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
return room
|
return room
|
||||||
|
|
||||||
async def update(
|
async def update(self, room: Room, values: dict, mutate=True):
|
||||||
self, session: AsyncSession, room: Room, values: dict, mutate=True
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Update a room fields with key/values in values
|
Update a room fields with key/values in values
|
||||||
"""
|
"""
|
||||||
if values.get("webhook_url") and not values.get("webhook_secret"):
|
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||||
values["webhook_secret"] = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
query = update(RoomModel).where(RoomModel.id == room.id).values(**values)
|
|
||||||
try:
|
try:
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.flush()
|
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
|
|
||||||
@@ -144,79 +163,60 @@ class RoomController:
|
|||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(room, key, value)
|
setattr(room, key, value)
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
|
||||||
self, session: AsyncSession, room_id: str, **kwargs
|
|
||||||
) -> Room | None:
|
|
||||||
"""
|
"""
|
||||||
Get a room by id
|
Get a room by id
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.id == room_id)
|
query = rooms.select().where(rooms.c.id == room_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Room.model_validate(row)
|
return Room(**result)
|
||||||
|
|
||||||
async def get_by_name(
|
async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
|
||||||
self, session: AsyncSession, room_name: str, **kwargs
|
|
||||||
) -> Room | None:
|
|
||||||
"""
|
"""
|
||||||
Get a room by name
|
Get a room by name
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.name == room_name)
|
query = rooms.select().where(rooms.c.name == room_name)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Room.model_validate(row)
|
return Room(**result)
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
|
||||||
) -> Room:
|
|
||||||
"""
|
"""
|
||||||
Get a room by ID for HTTP request.
|
Get a room by ID for HTTP request.
|
||||||
|
|
||||||
If not found, it will raise a 404 error.
|
If not found, it will raise a 404 error.
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.id == meeting_id)
|
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
room = Room.model_validate(row)
|
room = Room(**result)
|
||||||
|
|
||||||
return room
|
return room
|
||||||
|
|
||||||
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]:
|
|
||||||
query = select(RoomModel).where(
|
|
||||||
RoomModel.ics_enabled == True, RoomModel.ics_url != None
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
results = result.scalars().all()
|
|
||||||
return [Room(**row.__dict__) for row in results]
|
|
||||||
|
|
||||||
async def remove_by_id(
|
async def remove_by_id(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a room by id
|
Remove a room by id
|
||||||
"""
|
"""
|
||||||
room = await self.get_by_id(session, room_id, user_id=user_id)
|
room = await self.get_by_id(room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
return
|
return
|
||||||
if user_id is not None and room.user_id != user_id:
|
if user_id is not None and room.user_id != user_id:
|
||||||
return
|
return
|
||||||
query = delete(RoomModel).where(RoomModel.id == room_id)
|
query = rooms.delete().where(rooms.c.id == room_id)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
|
|
||||||
rooms_controller = RoomController()
|
rooms_controller = RoomController()
|
||||||
|
|||||||
@@ -1,36 +1,22 @@
|
|||||||
"""Search functionality for transcripts and other entities."""
|
"""Search functionality for transcripts and other entities."""
|
||||||
|
|
||||||
import itertools
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Annotated, Any, Dict, Iterator
|
from typing import Annotated, Any, Dict
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import webvtt
|
import webvtt
|
||||||
from fastapi import HTTPException
|
from pydantic import BaseModel, Field, constr, field_serializer
|
||||||
from pydantic import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
NonNegativeFloat,
|
|
||||||
NonNegativeInt,
|
|
||||||
TypeAdapter,
|
|
||||||
ValidationError,
|
|
||||||
constr,
|
|
||||||
field_serializer,
|
|
||||||
)
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import RoomModel, TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptStatus
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 20
|
DEFAULT_SEARCH_LIMIT = 20
|
||||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
DEFAULT_MAX_SNIPPETS = 3
|
||||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
|
||||||
|
|
||||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||||
@@ -38,7 +24,6 @@ SearchOffsetBase = Annotated[int, Field(ge=0)]
|
|||||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||||
|
|
||||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||||
search_query_adapter = TypeAdapter(SearchQuery)
|
|
||||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||||
SearchOffset = Annotated[
|
SearchOffset = Annotated[
|
||||||
SearchOffsetBase, Field(description="Number of results to skip")
|
SearchOffsetBase, Field(description="Number of results to skip")
|
||||||
@@ -47,92 +32,15 @@ SearchTotal = Annotated[
|
|||||||
SearchTotalBase, Field(description="Total number of search results")
|
SearchTotalBase, Field(description="Total number of search results")
|
||||||
]
|
]
|
||||||
|
|
||||||
WEBVTT_SPEC_HEADER = "WEBVTT"
|
|
||||||
|
|
||||||
WebVTTContent = Annotated[
|
|
||||||
str,
|
|
||||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class WebVTTProcessor:
|
|
||||||
"""Stateless processor for WebVTT content operations."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse(raw_content: str) -> WebVTTContent:
|
|
||||||
"""Parse WebVTT content and return it as a string."""
|
|
||||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
|
||||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
|
||||||
return raw_content
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
|
||||||
"""Extract plain text from WebVTT content using webvtt library."""
|
|
||||||
try:
|
|
||||||
buffer = StringIO(webvtt_content)
|
|
||||||
vtt = webvtt.read_buffer(buffer)
|
|
||||||
return " ".join(caption.text for caption in vtt if caption.text)
|
|
||||||
except webvtt.errors.MalformedFileError as e:
|
|
||||||
logger.warning(f"Malformed WebVTT content: {e}")
|
|
||||||
return ""
|
|
||||||
except (UnicodeDecodeError, ValueError) as e:
|
|
||||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
|
||||||
return ""
|
|
||||||
except AttributeError as e:
|
|
||||||
logger.error(
|
|
||||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_snippets(
|
|
||||||
webvtt_content: WebVTTContent,
|
|
||||||
query: SearchQuery,
|
|
||||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Generate snippets from WebVTT content."""
|
|
||||||
return SnippetGenerator.generate(
|
|
||||||
WebVTTProcessor.extract_text(webvtt_content),
|
|
||||||
query,
|
|
||||||
max_snippets=max_snippets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SnippetCandidate:
|
|
||||||
"""Represents a candidate snippet with its position."""
|
|
||||||
|
|
||||||
_text: str
|
|
||||||
start: NonNegativeInt
|
|
||||||
_original_text_length: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def end(self) -> NonNegativeInt:
|
|
||||||
"""Calculate end position from start and raw text length."""
|
|
||||||
return self.start + len(self._text)
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
"""Get display text with ellipses added if needed."""
|
|
||||||
result = self._text.strip()
|
|
||||||
if self.start > 0:
|
|
||||||
result = "..." + result
|
|
||||||
if self.end < self._original_text_length:
|
|
||||||
result = result + "..."
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SearchParameters(BaseModel):
|
class SearchParameters(BaseModel):
|
||||||
"""Validated search parameters for full-text search."""
|
"""Validated search parameters for full-text search."""
|
||||||
|
|
||||||
query_text: SearchQuery | None = None
|
query_text: SearchQuery
|
||||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||||
offset: SearchOffset = 0
|
offset: SearchOffset = 0
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
source_kind: SourceKind | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResultDB(BaseModel):
|
class SearchResultDB(BaseModel):
|
||||||
@@ -156,18 +64,13 @@ class SearchResult(BaseModel):
|
|||||||
title: str | None = None
|
title: str | None = None
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
room_name: str | None = None
|
|
||||||
source_kind: SourceKind
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
status: TranscriptStatus = Field(..., min_length=1)
|
status: str = Field(..., min_length=1)
|
||||||
rank: float = Field(..., ge=0, le=1)
|
rank: float = Field(..., ge=0, le=1)
|
||||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||||
search_snippets: list[str] = Field(
|
search_snippets: list[str] = Field(
|
||||||
description="Text snippets around search matches"
|
description="Text snippets around search matches"
|
||||||
)
|
)
|
||||||
total_match_count: NonNegativeInt = Field(
|
|
||||||
default=0, description="Total number of matches found in the transcript"
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def serialize_datetime(self, dt: datetime) -> str:
|
def serialize_datetime(self, dt: datetime) -> str:
|
||||||
@@ -176,289 +79,153 @@ class SearchResult(BaseModel):
|
|||||||
return dt.isoformat()
|
return dt.isoformat()
|
||||||
|
|
||||||
|
|
||||||
class SnippetGenerator:
|
|
||||||
"""Stateless generator for text snippets and match operations."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
|
||||||
"""Generate all match positions for a query in text."""
|
|
||||||
if not text:
|
|
||||||
logger.warning("Empty text for search query in find_all_matches")
|
|
||||||
return
|
|
||||||
if not query:
|
|
||||||
logger.warning("Empty query for search text in find_all_matches")
|
|
||||||
return
|
|
||||||
|
|
||||||
text_lower = text.lower()
|
|
||||||
query_lower = query.lower()
|
|
||||||
start = 0
|
|
||||||
prev_start = start
|
|
||||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
|
||||||
yield pos
|
|
||||||
start = pos + len(query_lower)
|
|
||||||
if start <= prev_start:
|
|
||||||
raise ValueError("panic! find_all_matches is not incremental")
|
|
||||||
prev_start = start
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
|
|
||||||
"""Count total number of matches for a query in text."""
|
|
||||||
ZERO = NonNegativeInt(0)
|
|
||||||
if not text:
|
|
||||||
logger.warning("Empty text for search query in count_matches")
|
|
||||||
return ZERO
|
|
||||||
assert query is not None
|
|
||||||
return NonNegativeInt(
|
|
||||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_snippet(
|
|
||||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
|
||||||
) -> SnippetCandidate:
|
|
||||||
"""Create a snippet from a match position."""
|
|
||||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
|
||||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
|
||||||
|
|
||||||
snippet_text = text[snippet_start:snippet_end]
|
|
||||||
|
|
||||||
return SnippetCandidate(
|
|
||||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def filter_non_overlapping(
|
|
||||||
candidates: Iterator[SnippetCandidate],
|
|
||||||
) -> Iterator[str]:
|
|
||||||
"""Filter out overlapping snippets and return only display text."""
|
|
||||||
last_end = 0
|
|
||||||
for candidate in candidates:
|
|
||||||
display_text = candidate.text()
|
|
||||||
# it means that next overlapping snippets simply don't get included
|
|
||||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
|
||||||
if candidate.start >= last_end and display_text:
|
|
||||||
yield display_text
|
|
||||||
last_end = candidate.end
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate(
|
|
||||||
text: str,
|
|
||||||
query: SearchQuery,
|
|
||||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
|
||||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Generate snippets from text."""
|
|
||||||
assert query is not None
|
|
||||||
if not text:
|
|
||||||
logger.warning("Empty text for generate_snippets")
|
|
||||||
return []
|
|
||||||
|
|
||||||
candidates = (
|
|
||||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
|
||||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
|
||||||
)
|
|
||||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
|
||||||
snippets = list(itertools.islice(filtered, max_snippets))
|
|
||||||
|
|
||||||
# Fallback to first word search if no full matches
|
|
||||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
|
||||||
if not snippets and " " in query:
|
|
||||||
first_word = query.split()[0]
|
|
||||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
|
||||||
|
|
||||||
return snippets
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_summary(
|
|
||||||
summary: str,
|
|
||||||
query: SearchQuery,
|
|
||||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Generate snippets from summary text."""
|
|
||||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_sources(
|
|
||||||
summary: NonEmptyString | None,
|
|
||||||
webvtt: WebVTTContent | None,
|
|
||||||
query: SearchQuery,
|
|
||||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
|
||||||
) -> tuple[list[str], NonNegativeInt]:
|
|
||||||
"""Combine snippets from multiple sources and return total match count.
|
|
||||||
|
|
||||||
Returns (snippets, total_match_count) tuple.
|
|
||||||
|
|
||||||
snippets can be empty for real in case of e.g. title match
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert (
|
|
||||||
summary is not None or webvtt is not None
|
|
||||||
), "At least one source must be present"
|
|
||||||
|
|
||||||
webvtt_matches = 0
|
|
||||||
summary_matches = 0
|
|
||||||
|
|
||||||
if webvtt:
|
|
||||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
|
||||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
|
||||||
|
|
||||||
if summary:
|
|
||||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
|
||||||
|
|
||||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
|
||||||
|
|
||||||
summary_snippets = (
|
|
||||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(summary_snippets) >= max_total:
|
|
||||||
return summary_snippets[:max_total], total_matches
|
|
||||||
|
|
||||||
remaining = max_total - len(summary_snippets)
|
|
||||||
webvtt_snippets = (
|
|
||||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
|
||||||
if webvtt
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
return summary_snippets + webvtt_snippets, total_matches
|
|
||||||
|
|
||||||
|
|
||||||
class SearchController:
|
class SearchController:
|
||||||
"""Controller for search operations across different entities."""
|
"""Controller for search operations across different entities."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||||
|
"""Extract plain text from WebVTT content using webvtt library."""
|
||||||
|
if not webvtt_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
buffer = StringIO(webvtt_content)
|
||||||
|
vtt = webvtt.read_buffer(buffer)
|
||||||
|
return " ".join(caption.text for caption in vtt if caption.text)
|
||||||
|
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||||
|
return ""
|
||||||
|
except AttributeError as e:
|
||||||
|
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_snippets(
|
||||||
|
text: str,
|
||||||
|
q: SearchQuery,
|
||||||
|
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||||
|
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate multiple snippets around all occurrences of search term."""
|
||||||
|
if not text or not q:
|
||||||
|
return []
|
||||||
|
|
||||||
|
snippets = []
|
||||||
|
lower_text = text.lower()
|
||||||
|
search_lower = q.lower()
|
||||||
|
|
||||||
|
last_snippet_end = 0
|
||||||
|
start_pos = 0
|
||||||
|
|
||||||
|
while len(snippets) < max_snippets:
|
||||||
|
match_pos = lower_text.find(search_lower, start_pos)
|
||||||
|
|
||||||
|
if match_pos == -1:
|
||||||
|
if not snippets and search_lower.split():
|
||||||
|
first_word = search_lower.split()[0]
|
||||||
|
match_pos = lower_text.find(first_word, start_pos)
|
||||||
|
if match_pos == -1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||||
|
snippet_end = min(
|
||||||
|
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
if snippet_start < last_snippet_end:
|
||||||
|
start_pos = match_pos + len(search_lower)
|
||||||
|
continue
|
||||||
|
|
||||||
|
snippet = text[snippet_start:snippet_end]
|
||||||
|
|
||||||
|
if snippet_start > 0:
|
||||||
|
snippet = "..." + snippet
|
||||||
|
if snippet_end < len(text):
|
||||||
|
snippet = snippet + "..."
|
||||||
|
|
||||||
|
snippet = snippet.strip()
|
||||||
|
|
||||||
|
if snippet:
|
||||||
|
snippets.append(snippet)
|
||||||
|
last_snippet_end = snippet_end
|
||||||
|
|
||||||
|
start_pos = match_pos + len(search_lower)
|
||||||
|
if start_pos >= len(text):
|
||||||
|
break
|
||||||
|
|
||||||
|
return snippets
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def search_transcripts(
|
async def search_transcripts(
|
||||||
cls, session: AsyncSession, params: SearchParameters
|
cls, params: SearchParameters
|
||||||
) -> tuple[list[SearchResult], int]:
|
) -> tuple[list[SearchResult], int]:
|
||||||
"""
|
"""
|
||||||
Full-text search for transcripts using PostgreSQL tsvector.
|
Full-text search for transcripts using PostgreSQL tsvector.
|
||||||
Returns (results, total_count).
|
Returns (results, total_count).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_columns = [
|
if not is_postgresql():
|
||||||
TranscriptModel.id,
|
logger.warning(
|
||||||
TranscriptModel.title,
|
"Full-text search requires PostgreSQL. Returning empty results."
|
||||||
TranscriptModel.created_at,
|
|
||||||
TranscriptModel.duration,
|
|
||||||
TranscriptModel.status,
|
|
||||||
TranscriptModel.user_id,
|
|
||||||
TranscriptModel.room_id,
|
|
||||||
TranscriptModel.source_kind,
|
|
||||||
TranscriptModel.webvtt,
|
|
||||||
TranscriptModel.long_summary,
|
|
||||||
sqlalchemy.case(
|
|
||||||
(
|
|
||||||
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
|
|
||||||
"Deleted Room",
|
|
||||||
),
|
|
||||||
else_=RoomModel.name,
|
|
||||||
).label("room_name"),
|
|
||||||
]
|
|
||||||
search_query = None
|
|
||||||
if params.query_text is not None:
|
|
||||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
|
||||||
"english", params.query_text
|
|
||||||
)
|
)
|
||||||
rank_column = sqlalchemy.func.ts_rank(
|
return [], 0
|
||||||
TranscriptModel.search_vector_en,
|
|
||||||
search_query,
|
|
||||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
|
||||||
).label("rank")
|
|
||||||
else:
|
|
||||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
|
||||||
|
|
||||||
columns = base_columns + [rank_column]
|
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||||
base_query = (
|
"english", params.query_text
|
||||||
sqlalchemy.select(*columns)
|
|
||||||
.select_from(TranscriptModel)
|
|
||||||
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text is not None:
|
base_query = sqlalchemy.select(
|
||||||
# because already initialized based on params.query_text presence above
|
[
|
||||||
assert search_query is not None
|
transcripts.c.id,
|
||||||
base_query = base_query.where(
|
transcripts.c.title,
|
||||||
TranscriptModel.search_vector_en.op("@@")(search_query)
|
transcripts.c.created_at,
|
||||||
)
|
transcripts.c.duration,
|
||||||
|
transcripts.c.status,
|
||||||
|
transcripts.c.user_id,
|
||||||
|
transcripts.c.room_id,
|
||||||
|
transcripts.c.source_kind,
|
||||||
|
transcripts.c.webvtt,
|
||||||
|
sqlalchemy.func.ts_rank(
|
||||||
|
transcripts.c.search_vector_en,
|
||||||
|
search_query,
|
||||||
|
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||||
|
).label("rank"),
|
||||||
|
]
|
||||||
|
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||||
|
|
||||||
if params.user_id:
|
if params.user_id:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||||
sqlalchemy.or_(
|
|
||||||
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
base_query = base_query.where(RoomModel.is_shared)
|
|
||||||
if params.room_id:
|
if params.room_id:
|
||||||
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
|
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||||
if params.source_kind:
|
|
||||||
base_query = base_query.where(
|
|
||||||
TranscriptModel.source_kind == params.source_kind
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.query_text is not None:
|
query = (
|
||||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||||
else:
|
.limit(params.limit)
|
||||||
order_by = sqlalchemy.desc(TranscriptModel.created_at)
|
.offset(params.offset)
|
||||||
|
)
|
||||||
|
rs = await get_database().fetch_all(query)
|
||||||
|
|
||||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
rs = result.mappings().all()
|
|
||||||
|
|
||||||
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from(
|
|
||||||
base_query.alias("search_results")
|
base_query.alias("search_results")
|
||||||
)
|
)
|
||||||
count_result = await session.execute(count_query)
|
total = await get_database().fetch_val(count_query)
|
||||||
total = count_result.scalar()
|
|
||||||
|
|
||||||
def _process_result(r: dict) -> SearchResult:
|
def _process_result(r) -> SearchResult:
|
||||||
r_dict: Dict[str, Any] = dict(r)
|
r_dict: Dict[str, Any] = dict(r)
|
||||||
|
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
|
||||||
webvtt: WebVTTContent | None
|
|
||||||
if webvtt_raw:
|
|
||||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
|
||||||
else:
|
|
||||||
webvtt = None
|
|
||||||
|
|
||||||
long_summary_r: str | None = r_dict.pop("long_summary", None)
|
|
||||||
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
|
|
||||||
room_name: str | None = r_dict.pop("room_name", None)
|
|
||||||
db_result = SearchResultDB.model_validate(r_dict)
|
db_result = SearchResultDB.model_validate(r_dict)
|
||||||
|
|
||||||
at_least_one_source = webvtt is not None or long_summary is not None
|
snippets = []
|
||||||
has_query = params.query_text is not None
|
if webvtt:
|
||||||
snippets, total_match_count = (
|
plain_text = cls._extract_webvtt_text(webvtt)
|
||||||
SnippetGenerator.combine_sources(
|
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
|
||||||
)
|
|
||||||
if has_query and at_least_one_source
|
|
||||||
else ([], 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||||
**db_result.model_dump(),
|
|
||||||
room_name=room_name,
|
|
||||||
search_snippets=snippets,
|
|
||||||
total_match_count=total_match_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
results = [_process_result(r) for r in rs]
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Internal search result data consistency error"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
results = [_process_result(r) for r in rs]
|
||||||
return results, total
|
return results, total
|
||||||
|
|
||||||
|
|
||||||
search_controller = SearchController()
|
search_controller = SearchController()
|
||||||
webvtt_processor = WebVTTProcessor()
|
|
||||||
snippet_generator = SnippetGenerator()
|
|
||||||
|
|||||||
@@ -2,18 +2,22 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
from sqlalchemy import delete, insert, select, update
|
from sqlalchemy import Enum
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||||
from sqlalchemy.sql import or_
|
from sqlalchemy.sql import false, or_
|
||||||
|
|
||||||
from reflector.db.base import RoomModel, TranscriptModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
|
from reflector.db.rooms import rooms
|
||||||
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -28,20 +32,93 @@ class SourceKind(enum.StrEnum):
|
|||||||
FILE = enum.auto()
|
FILE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
transcripts = sqlalchemy.Table(
|
||||||
|
"transcript",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("status", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||||
|
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||||
|
sqlalchemy.Column("title", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("participants", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("source_language", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("target_language", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"audio_location",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="local",
|
||||||
|
),
|
||||||
|
# with user attached, optional
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"share_mode",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="private",
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"meeting_id",
|
||||||
|
sqlalchemy.String,
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("recording_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"source_kind",
|
||||||
|
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
# indicative field: whether associated audio is deleted
|
||||||
|
# the main "audio deleted" is the presence of the audio itself / consents not-given
|
||||||
|
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||||
|
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||||
|
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||||
|
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add PostgreSQL-specific full-text search column
|
||||||
|
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
|
||||||
|
if is_postgresql():
|
||||||
|
transcripts.append_column(
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"search_vector_en",
|
||||||
|
TSVECTOR,
|
||||||
|
sqlalchemy.Computed(
|
||||||
|
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||||
|
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||||
|
persisted=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Add GIN index for the search vector
|
||||||
|
transcripts.append_constraint(
|
||||||
|
sqlalchemy.Index(
|
||||||
|
"idx_transcript_search_vector_en",
|
||||||
|
"search_vector_en",
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_transcript_name() -> str:
|
def generate_transcript_name() -> str:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
|
||||||
TranscriptStatus = Literal[
|
|
||||||
"idle", "uploaded", "recording", "processing", "error", "ended"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class StrValue(BaseModel):
|
|
||||||
value: str
|
|
||||||
|
|
||||||
|
|
||||||
class AudioWaveform(BaseModel):
|
class AudioWaveform(BaseModel):
|
||||||
data: list[float]
|
data: list[float]
|
||||||
|
|
||||||
@@ -102,12 +179,10 @@ class TranscriptParticipant(BaseModel):
|
|||||||
class Transcript(BaseModel):
|
class Transcript(BaseModel):
|
||||||
"""Full transcript model with all fields."""
|
"""Full transcript model with all fields."""
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
name: str = Field(default_factory=generate_transcript_name)
|
name: str = Field(default_factory=generate_transcript_name)
|
||||||
status: TranscriptStatus = "idle"
|
status: str = "idle"
|
||||||
duration: float = 0
|
duration: float = 0
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
@@ -272,7 +347,6 @@ class Transcript(BaseModel):
|
|||||||
class TranscriptController:
|
class TranscriptController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
filter_empty: bool | None = False,
|
filter_empty: bool | None = False,
|
||||||
@@ -297,114 +371,102 @@ class TranscriptController:
|
|||||||
- `search_term`: filter transcripts by search term
|
- `search_term`: filter transcripts by search term
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query = select(TranscriptModel).join(
|
query = transcripts.select().join(
|
||||||
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
|
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
|
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = query.where(RoomModel.is_shared)
|
query = query.where(rooms.c.is_shared)
|
||||||
|
|
||||||
if source_kind:
|
if source_kind:
|
||||||
query = query.where(TranscriptModel.source_kind == source_kind)
|
query = query.where(transcripts.c.source_kind == source_kind)
|
||||||
|
|
||||||
if room_id:
|
if room_id:
|
||||||
query = query.where(TranscriptModel.room_id == room_id)
|
query = query.where(transcripts.c.room_id == room_id)
|
||||||
|
|
||||||
if search_term:
|
if search_term:
|
||||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
|
||||||
|
|
||||||
# Exclude heavy JSON columns from list queries
|
# Exclude heavy JSON columns from list queries
|
||||||
# Get all ORM column attributes except excluded ones
|
|
||||||
transcript_columns = [
|
transcript_columns = [
|
||||||
getattr(TranscriptModel, col.name)
|
col for col in transcripts.c if col.name not in exclude_columns
|
||||||
for col in TranscriptModel.__table__.c
|
|
||||||
if col.name not in exclude_columns
|
|
||||||
]
|
]
|
||||||
|
|
||||||
query = query.with_only_columns(
|
query = query.with_only_columns(
|
||||||
*transcript_columns,
|
transcript_columns
|
||||||
RoomModel.name.label("room_name"),
|
+ [
|
||||||
|
rooms.c.name.label("room_name"),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if order_by is not None:
|
if order_by is not None:
|
||||||
field = getattr(TranscriptModel, order_by[1:])
|
field = getattr(transcripts.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
|
|
||||||
if filter_empty:
|
if filter_empty:
|
||||||
query = query.filter(TranscriptModel.status != "idle")
|
query = query.filter(transcripts.c.status != "idle")
|
||||||
|
|
||||||
if filter_recording:
|
if filter_recording:
|
||||||
query = query.filter(TranscriptModel.status != "recording")
|
query = query.filter(transcripts.c.status != "recording")
|
||||||
|
|
||||||
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
|
||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [dict(row) for row in result.mappings().all()]
|
return results
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||||
self, session: AsyncSession, transcript_id: str, **kwargs
|
|
||||||
) -> Transcript | None:
|
|
||||||
"""
|
"""
|
||||||
Get a transcript by id
|
Get a transcript by id
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Transcript.model_validate(row)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def get_by_recording_id(
|
async def get_by_recording_id(
|
||||||
self, session: AsyncSession, recording_id: str, **kwargs
|
self, recording_id: str, **kwargs
|
||||||
) -> Transcript | None:
|
) -> Transcript | None:
|
||||||
"""
|
"""
|
||||||
Get a transcript by recording_id
|
Get a transcript by recording_id
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(
|
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||||
TranscriptModel.recording_id == recording_id
|
|
||||||
)
|
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Transcript.model_validate(row)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def get_by_room_id(
|
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
|
||||||
self, session: AsyncSession, room_id: str, **kwargs
|
|
||||||
) -> list[Transcript]:
|
|
||||||
"""
|
"""
|
||||||
Get transcripts by room_id (direct access without joins)
|
Get transcripts by room_id (direct access without joins)
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
|
query = transcripts.select().where(transcripts.c.room_id == room_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
if "order_by" in kwargs:
|
if "order_by" in kwargs:
|
||||||
order_by = kwargs["order_by"]
|
order_by = kwargs["order_by"]
|
||||||
field = getattr(TranscriptModel, order_by[1:])
|
field = getattr(transcripts.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
results = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [
|
return [Transcript(**result) for result in results]
|
||||||
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
) -> Transcript:
|
) -> Transcript:
|
||||||
@@ -417,14 +479,13 @@ class TranscriptController:
|
|||||||
This method checks the share mode of the transcript and the user_id
|
This method checks the share mode of the transcript and the user_id
|
||||||
to determine if the user can access the transcript.
|
to determine if the user can access the transcript.
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
# if the transcript is anonymous, share mode is not checked
|
# if the transcript is anonymous, share mode is not checked
|
||||||
transcript = Transcript.model_validate(row)
|
transcript = Transcript(**result)
|
||||||
if transcript.user_id is None:
|
if transcript.user_id is None:
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
@@ -447,7 +508,6 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
name: str,
|
name: str,
|
||||||
source_kind: SourceKind,
|
source_kind: SourceKind,
|
||||||
source_language: str = "en",
|
source_language: str = "en",
|
||||||
@@ -472,20 +532,14 @@ class TranscriptController:
|
|||||||
meeting_id=meeting_id,
|
meeting_id=meeting_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
query = insert(TranscriptModel).values(**transcript.model_dump())
|
query = transcripts.insert().values(**transcript.model_dump())
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||||
# using mutate=True is discouraged
|
# using mutate=True is discouraged
|
||||||
async def update(
|
async def update(
|
||||||
self,
|
self, transcript: Transcript, values: dict, mutate=False
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
|
||||||
values: dict,
|
|
||||||
commit=True,
|
|
||||||
mutate=False,
|
|
||||||
) -> Transcript:
|
) -> Transcript:
|
||||||
"""
|
"""
|
||||||
Update a transcript fields with key/values in values.
|
Update a transcript fields with key/values in values.
|
||||||
@@ -494,13 +548,11 @@ class TranscriptController:
|
|||||||
values = TranscriptController._handle_topics_update(values)
|
values = TranscriptController._handle_topics_update(values)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
update(TranscriptModel)
|
transcripts.update()
|
||||||
.where(TranscriptModel.id == transcript.id)
|
.where(transcripts.c.id == transcript.id)
|
||||||
.values(**values)
|
.values(**values)
|
||||||
)
|
)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
if commit:
|
|
||||||
await session.commit()
|
|
||||||
if mutate:
|
if mutate:
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(transcript, key, value)
|
setattr(transcript, key, value)
|
||||||
@@ -529,14 +581,13 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def remove_by_id(
|
async def remove_by_id(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a transcript by id
|
Remove a transcript by id
|
||||||
"""
|
"""
|
||||||
transcript = await self.get_by_id(session, transcript_id)
|
transcript = await self.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
return
|
return
|
||||||
if user_id is not None and transcript.user_id != user_id:
|
if user_id is not None and transcript.user_id != user_id:
|
||||||
@@ -556,7 +607,7 @@ class TranscriptController:
|
|||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
try:
|
try:
|
||||||
recording = await recordings_controller.get_by_id(
|
recording = await recordings_controller.get_by_id(
|
||||||
session, transcript.recording_id
|
transcript.recording_id
|
||||||
)
|
)
|
||||||
if recording:
|
if recording:
|
||||||
try:
|
try:
|
||||||
@@ -567,49 +618,46 @@ class TranscriptController:
|
|||||||
exc_info=e,
|
exc_info=e,
|
||||||
recording_id=transcript.recording_id,
|
recording_id=transcript.recording_id,
|
||||||
)
|
)
|
||||||
await recordings_controller.remove_by_id(
|
await recordings_controller.remove_by_id(transcript.recording_id)
|
||||||
session, transcript.recording_id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to delete recording row",
|
"Failed to delete recording row",
|
||||||
exc_info=e,
|
exc_info=e,
|
||||||
recording_id=transcript.recording_id,
|
recording_id=transcript.recording_id,
|
||||||
)
|
)
|
||||||
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
|
async def remove_by_recording_id(self, recording_id: str):
|
||||||
"""
|
"""
|
||||||
Remove a transcript by recording_id
|
Remove a transcript by recording_id
|
||||||
"""
|
"""
|
||||||
query = delete(TranscriptModel).where(
|
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||||
TranscriptModel.recording_id == recording_id
|
await get_database().execute(query)
|
||||||
)
|
|
||||||
await session.execute(query)
|
@asynccontextmanager
|
||||||
await session.commit()
|
async def transaction(self):
|
||||||
|
"""
|
||||||
|
A context manager for database transaction
|
||||||
|
"""
|
||||||
|
async with get_database().transaction(isolation="serializable"):
|
||||||
|
yield
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
event: str,
|
event: str,
|
||||||
data: Any,
|
data: Any,
|
||||||
commit=True,
|
|
||||||
) -> TranscriptEvent:
|
) -> TranscriptEvent:
|
||||||
"""
|
"""
|
||||||
Append an event to a transcript
|
Append an event to a transcript
|
||||||
"""
|
"""
|
||||||
resp = transcript.add_event(event=event, data=data)
|
resp = transcript.add_event(event=event, data=data)
|
||||||
await self.update(
|
await self.update(transcript, {"events": transcript.events_dump()})
|
||||||
session, transcript, {"events": transcript.events_dump()}, commit=commit
|
|
||||||
)
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def upsert_topic(
|
async def upsert_topic(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
topic: TranscriptTopic,
|
topic: TranscriptTopic,
|
||||||
) -> TranscriptEvent:
|
) -> TranscriptEvent:
|
||||||
@@ -617,9 +665,9 @@ class TranscriptController:
|
|||||||
Upsert topics to a transcript
|
Upsert topics to a transcript
|
||||||
"""
|
"""
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await self.update(session, transcript, {"topics": transcript.topics_dump()})
|
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||||
|
|
||||||
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
|
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||||
"""
|
"""
|
||||||
Move mp3 file to storage
|
Move mp3 file to storage
|
||||||
"""
|
"""
|
||||||
@@ -643,16 +691,12 @@ class TranscriptController:
|
|||||||
|
|
||||||
# indicate on the transcript that the audio is now on storage
|
# indicate on the transcript that the audio is now on storage
|
||||||
# mutates transcript argument
|
# mutates transcript argument
|
||||||
await self.update(
|
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
|
||||||
session, transcript, {"audio_location": "storage"}, mutate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# unlink the local file
|
# unlink the local file
|
||||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||||
|
|
||||||
async def download_mp3_from_storage(
|
async def download_mp3_from_storage(self, transcript: Transcript):
|
||||||
self, session: AsyncSession, transcript: Transcript
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Download audio from storage
|
Download audio from storage
|
||||||
"""
|
"""
|
||||||
@@ -664,7 +708,6 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def upsert_participant(
|
async def upsert_participant(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
participant: TranscriptParticipant,
|
participant: TranscriptParticipant,
|
||||||
) -> TranscriptParticipant:
|
) -> TranscriptParticipant:
|
||||||
@@ -672,14 +715,11 @@ class TranscriptController:
|
|||||||
Add/update a participant to a transcript
|
Add/update a participant to a transcript
|
||||||
"""
|
"""
|
||||||
result = transcript.upsert_participant(participant)
|
result = transcript.upsert_participant(participant)
|
||||||
await self.update(
|
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||||
session, transcript, {"participants": transcript.participants_dump()}
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_participant(
|
async def delete_participant(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
):
|
):
|
||||||
@@ -687,38 +727,7 @@ class TranscriptController:
|
|||||||
Delete a participant from a transcript
|
Delete a participant from a transcript
|
||||||
"""
|
"""
|
||||||
transcript.delete_participant(participant_id)
|
transcript.delete_participant(participant_id)
|
||||||
await self.update(
|
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||||
session, transcript, {"participants": transcript.participants_dump()}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_status(
|
|
||||||
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
|
|
||||||
) -> TranscriptEvent | None:
|
|
||||||
"""
|
|
||||||
Update the status of a transcript
|
|
||||||
|
|
||||||
Will add an event STATUS + update the status field of transcript
|
|
||||||
"""
|
|
||||||
transcript = await self.get_by_id(session, transcript_id)
|
|
||||||
if not transcript:
|
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
|
||||||
if transcript.status == status:
|
|
||||||
return
|
|
||||||
resp = await self.append_event(
|
|
||||||
session,
|
|
||||||
transcript=transcript,
|
|
||||||
event="STATUS",
|
|
||||||
data=StrValue(value=status),
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
await self.update(
|
|
||||||
session,
|
|
||||||
transcript,
|
|
||||||
{"status": status},
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
transcripts_controller = TranscriptController()
|
transcripts_controller = TranscriptController()
|
||||||
|
|||||||
9
server/reflector/db/utils.py
Normal file
9
server/reflector/db/utils.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Database utility functions."""
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
|
||||||
|
def is_postgresql() -> bool:
|
||||||
|
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||||
|
"postgresql"
|
||||||
|
)
|
||||||
@@ -1,445 +0,0 @@
|
|||||||
"""
|
|
||||||
File-based processing pipeline
|
|
||||||
==============================
|
|
||||||
|
|
||||||
Optimized pipeline for processing complete audio/video files.
|
|
||||||
Uses parallel processing for transcription, diarization, and waveform generation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import av
|
|
||||||
import structlog
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.rooms import rooms_controller
|
|
||||||
from reflector.db.transcripts import (
|
|
||||||
SourceKind,
|
|
||||||
Transcript,
|
|
||||||
TranscriptStatus,
|
|
||||||
transcripts_controller,
|
|
||||||
)
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.pipelines.main_live_pipeline import (
|
|
||||||
PipelineMainBase,
|
|
||||||
broadcast_to_sockets,
|
|
||||||
task_cleanup_consent_taskiq,
|
|
||||||
task_pipeline_post_to_zulip_taskiq,
|
|
||||||
)
|
|
||||||
from reflector.processors import (
|
|
||||||
AudioFileWriterProcessor,
|
|
||||||
TranscriptFinalSummaryProcessor,
|
|
||||||
TranscriptFinalTitleProcessor,
|
|
||||||
TranscriptTopicDetectorProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
|
||||||
from reflector.processors.file_diarization import FileDiarizationInput
|
|
||||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
|
||||||
from reflector.processors.file_transcript import FileTranscriptInput
|
|
||||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
|
||||||
from reflector.processors.transcript_diarization_assembler import (
|
|
||||||
TranscriptDiarizationAssemblerInput,
|
|
||||||
TranscriptDiarizationAssemblerProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.types import (
|
|
||||||
DiarizationSegment,
|
|
||||||
TitleSummary,
|
|
||||||
)
|
|
||||||
from reflector.processors.types import (
|
|
||||||
Transcript as TranscriptType,
|
|
||||||
)
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.storage import get_transcripts_storage
|
|
||||||
from reflector.worker.app import taskiq_broker
|
|
||||||
from reflector.worker.session_decorator import catch_exception, with_session
|
|
||||||
from reflector.worker.webhook import send_transcript_webhook_taskiq
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyPipeline:
|
|
||||||
"""Empty pipeline for processors that need a pipeline reference"""
|
|
||||||
|
|
||||||
def __init__(self, logger: structlog.BoundLogger):
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
def get_pref(self, k, d=None):
|
|
||||||
return d
|
|
||||||
|
|
||||||
async def emit(self, event):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainFile(PipelineMainBase):
|
|
||||||
"""
|
|
||||||
Optimized file processing pipeline.
|
|
||||||
Processes complete audio/video files with parallel execution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger: structlog.BoundLogger = None
|
|
||||||
empty_pipeline = None
|
|
||||||
|
|
||||||
def __init__(self, transcript_id: str):
|
|
||||||
super().__init__(transcript_id=transcript_id)
|
|
||||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
|
||||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
|
||||||
|
|
||||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
|
||||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
if not isinstance(result, Exception):
|
|
||||||
continue
|
|
||||||
self.logger.error(
|
|
||||||
f"Error in {operation} (task {i}): {result}",
|
|
||||||
transcript_id=self.transcript_id,
|
|
||||||
exc_info=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
@broadcast_to_sockets
|
|
||||||
async def set_status(
|
|
||||||
self,
|
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: str,
|
|
||||||
status: TranscriptStatus,
|
|
||||||
):
|
|
||||||
return await transcripts_controller.set_status(session, transcript_id, status)
|
|
||||||
|
|
||||||
async def process(self, session: AsyncSession, file_path: Path):
|
|
||||||
"""Main entry point for file processing"""
|
|
||||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
|
||||||
|
|
||||||
# Clear transcript as we're going to regenerate everything
|
|
||||||
await transcripts_controller.update(
|
|
||||||
session,
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"events": [],
|
|
||||||
"topics": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract audio and write to transcript location
|
|
||||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
|
||||||
|
|
||||||
# Upload for processing
|
|
||||||
audio_url = await self.upload_audio(audio_path, transcript)
|
|
||||||
|
|
||||||
# Run parallel processing
|
|
||||||
await self.run_parallel_processing(
|
|
||||||
session,
|
|
||||||
audio_path,
|
|
||||||
audio_url,
|
|
||||||
transcript.source_language,
|
|
||||||
transcript.target_language,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("File pipeline complete")
|
|
||||||
|
|
||||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
|
||||||
|
|
||||||
async def extract_and_write_audio(
|
|
||||||
self, file_path: Path, transcript: Transcript
|
|
||||||
) -> Path:
|
|
||||||
"""Extract audio from video if needed and write to transcript location as MP3"""
|
|
||||||
self.logger.info(f"Processing audio file: {file_path}")
|
|
||||||
|
|
||||||
# Check if it's already audio-only
|
|
||||||
container = av.open(str(file_path))
|
|
||||||
has_video = len(container.streams.video) > 0
|
|
||||||
container.close()
|
|
||||||
|
|
||||||
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
|
||||||
mp3_writer = AudioFileWriterProcessor(
|
|
||||||
path=transcript.audio_mp3_filename,
|
|
||||||
on_duration=self.on_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process audio frames and write to transcript location
|
|
||||||
input_container = av.open(str(file_path))
|
|
||||||
for frame in input_container.decode(audio=0):
|
|
||||||
await mp3_writer.push(frame)
|
|
||||||
|
|
||||||
await mp3_writer.flush()
|
|
||||||
input_container.close()
|
|
||||||
|
|
||||||
if has_video:
|
|
||||||
self.logger.info(
|
|
||||||
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info(
|
|
||||||
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return transcript.audio_mp3_filename
|
|
||||||
|
|
||||||
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
|
||||||
"""Upload audio to storage for processing"""
|
|
||||||
storage = get_transcripts_storage()
|
|
||||||
|
|
||||||
if not storage:
|
|
||||||
raise Exception(
|
|
||||||
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("Uploading audio to storage")
|
|
||||||
|
|
||||||
with open(audio_path, "rb") as f:
|
|
||||||
audio_data = f.read()
|
|
||||||
|
|
||||||
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
|
||||||
await storage.put_file(storage_path, audio_data)
|
|
||||||
|
|
||||||
audio_url = await storage.get_file_url(storage_path)
|
|
||||||
|
|
||||||
self.logger.info(f"Audio uploaded to {audio_url}")
|
|
||||||
return audio_url
|
|
||||||
|
|
||||||
async def run_parallel_processing(
|
|
||||||
self,
|
|
||||||
session,
|
|
||||||
audio_path: Path,
|
|
||||||
audio_url: str,
|
|
||||||
source_language: str,
|
|
||||||
target_language: str,
|
|
||||||
):
|
|
||||||
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
|
||||||
self.logger.info(
|
|
||||||
"Starting parallel processing", transcript_id=self.transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Phase 1: Parallel processing of independent tasks
|
|
||||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
|
||||||
diarization_task = self.diarize_file(audio_url)
|
|
||||||
waveform_task = self.generate_waveform(session, audio_path)
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
|
||||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript_result = results[0]
|
|
||||||
diarization_result = results[1]
|
|
||||||
|
|
||||||
# Handle errors - raise any exception that occurred
|
|
||||||
self._handle_gather_exceptions(results, "parallel processing")
|
|
||||||
for result in results:
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
raise result
|
|
||||||
|
|
||||||
# Phase 2: Assemble transcript with diarization
|
|
||||||
self.logger.info(
|
|
||||||
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
|
||||||
)
|
|
||||||
processor = TranscriptDiarizationAssemblerProcessor()
|
|
||||||
input_data = TranscriptDiarizationAssemblerInput(
|
|
||||||
transcript=transcript_result, diarization=diarization_result or []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store result for retrieval
|
|
||||||
diarized_transcript: Transcript | None = None
|
|
||||||
|
|
||||||
async def capture_result(transcript):
|
|
||||||
nonlocal diarized_transcript
|
|
||||||
diarized_transcript = transcript
|
|
||||||
|
|
||||||
processor.on(capture_result)
|
|
||||||
await processor.push(input_data)
|
|
||||||
await processor.flush()
|
|
||||||
|
|
||||||
if not diarized_transcript:
|
|
||||||
raise ValueError("No diarized transcript captured")
|
|
||||||
|
|
||||||
# Phase 3: Generate topics from diarized transcript
|
|
||||||
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
|
||||||
topics = await self.detect_topics(diarized_transcript, target_language)
|
|
||||||
|
|
||||||
# Phase 4: Generate title and summaries in parallel
|
|
||||||
self.logger.info(
|
|
||||||
"Generating title and summaries", transcript_id=self.transcript_id
|
|
||||||
)
|
|
||||||
results = await asyncio.gather(
|
|
||||||
self.generate_title(topics),
|
|
||||||
self.generate_summaries(session, topics),
|
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._handle_gather_exceptions(results, "title and summary generation")
|
|
||||||
|
|
||||||
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
|
||||||
"""Transcribe complete file"""
|
|
||||||
processor = FileTranscriptAutoProcessor()
|
|
||||||
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
|
||||||
|
|
||||||
# Store result for retrieval
|
|
||||||
result: TranscriptType | None = None
|
|
||||||
|
|
||||||
async def capture_result(transcript):
|
|
||||||
nonlocal result
|
|
||||||
result = transcript
|
|
||||||
|
|
||||||
processor.on(capture_result)
|
|
||||||
await processor.push(input_data)
|
|
||||||
await processor.flush()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise ValueError("No transcript captured")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
|
||||||
"""Get diarization for file"""
|
|
||||||
if not settings.DIARIZATION_BACKEND:
|
|
||||||
self.logger.info("Diarization disabled")
|
|
||||||
return None
|
|
||||||
|
|
||||||
processor = FileDiarizationAutoProcessor()
|
|
||||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
|
||||||
|
|
||||||
# Store result for retrieval
|
|
||||||
result = None
|
|
||||||
|
|
||||||
async def capture_result(diarization_output):
|
|
||||||
nonlocal result
|
|
||||||
result = diarization_output.diarization
|
|
||||||
|
|
||||||
try:
|
|
||||||
processor.on(capture_result)
|
|
||||||
await processor.push(input_data)
|
|
||||||
await processor.flush()
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Diarization failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
|
|
||||||
"""Generate and save waveform"""
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
|
||||||
|
|
||||||
processor = AudioWaveformProcessor(
|
|
||||||
audio_path=audio_path,
|
|
||||||
waveform_path=transcript.audio_waveform_filename,
|
|
||||||
on_waveform=self.on_waveform,
|
|
||||||
)
|
|
||||||
processor.set_pipeline(self.empty_pipeline)
|
|
||||||
|
|
||||||
await processor.flush()
|
|
||||||
|
|
||||||
async def detect_topics(
|
|
||||||
self, transcript: TranscriptType, target_language: str
|
|
||||||
) -> list[TitleSummary]:
|
|
||||||
"""Detect topics from complete transcript"""
|
|
||||||
chunk_size = 300
|
|
||||||
topics: list[TitleSummary] = []
|
|
||||||
|
|
||||||
async def on_topic(topic: TitleSummary):
|
|
||||||
topics.append(topic)
|
|
||||||
return await self.on_topic(topic)
|
|
||||||
|
|
||||||
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
|
||||||
topic_detector.set_pipeline(self.empty_pipeline)
|
|
||||||
|
|
||||||
for i in range(0, len(transcript.words), chunk_size):
|
|
||||||
chunk_words = transcript.words[i : i + chunk_size]
|
|
||||||
if not chunk_words:
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_transcript = TranscriptType(
|
|
||||||
words=chunk_words, translation=transcript.translation
|
|
||||||
)
|
|
||||||
|
|
||||||
await topic_detector.push(chunk_transcript)
|
|
||||||
|
|
||||||
await topic_detector.flush()
|
|
||||||
return topics
|
|
||||||
|
|
||||||
async def generate_title(self, topics: list[TitleSummary]):
|
|
||||||
"""Generate title from topics"""
|
|
||||||
if not topics:
|
|
||||||
self.logger.warning("No topics for title generation")
|
|
||||||
return
|
|
||||||
|
|
||||||
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
|
||||||
processor.set_pipeline(self.empty_pipeline)
|
|
||||||
|
|
||||||
for topic in topics:
|
|
||||||
await processor.push(topic)
|
|
||||||
|
|
||||||
await processor.flush()
|
|
||||||
|
|
||||||
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
|
||||||
"""Generate long and short summaries from topics"""
|
|
||||||
if not topics:
|
|
||||||
self.logger.warning("No topics for summary generation")
|
|
||||||
return
|
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
|
||||||
processor = TranscriptFinalSummaryProcessor(
|
|
||||||
transcript=transcript,
|
|
||||||
callback=self.on_long_summary,
|
|
||||||
on_short_summary=self.on_short_summary,
|
|
||||||
)
|
|
||||||
processor.set_pipeline(self.empty_pipeline)
|
|
||||||
|
|
||||||
for topic in topics:
|
|
||||||
await processor.push(topic)
|
|
||||||
|
|
||||||
await processor.flush()
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
@with_session
|
|
||||||
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
if not transcript:
|
|
||||||
return
|
|
||||||
|
|
||||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
|
||||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
|
||||||
if room and room.webhook_url:
|
|
||||||
logger.info(
|
|
||||||
"Dispatching webhook",
|
|
||||||
transcript_id=transcript_id,
|
|
||||||
room_id=room.id,
|
|
||||||
webhook_url=room.webhook_url,
|
|
||||||
)
|
|
||||||
await send_transcript_webhook_taskiq.kiq(
|
|
||||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
@catch_exception
|
|
||||||
@with_session
|
|
||||||
async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: str):
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
if not transcript:
|
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
|
||||||
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
|
||||||
try:
|
|
||||||
await pipeline.set_status(session, transcript_id, "processing")
|
|
||||||
|
|
||||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
|
||||||
if not audio_file:
|
|
||||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
|
||||||
|
|
||||||
if not audio_file:
|
|
||||||
raise Exception("No audio file found to process")
|
|
||||||
|
|
||||||
await pipeline.process(session, audio_file)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.error("Error while processing the file", exc_info=True)
|
|
||||||
try:
|
|
||||||
await pipeline.set_status(session, transcript_id, "error")
|
|
||||||
except:
|
|
||||||
logger.error(
|
|
||||||
"Error setting status in task_pipeline_file_process during exception, ignoring it"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
await task_cleanup_consent_taskiq.kiq(transcript_id=transcript_id)
|
|
||||||
await task_pipeline_post_to_zulip_taskiq.kiq(transcript_id=transcript_id)
|
|
||||||
await task_send_webhook_if_needed.kiq(transcript_id=transcript_id)
|
|
||||||
@@ -12,16 +12,17 @@ It is directly linked to our data model.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import boto3
|
import boto3
|
||||||
|
from celery import chord, current_task, group, shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from structlog import BoundLogger as Logger
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
from reflector.db import get_session_context
|
from reflector.db import get_database
|
||||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -31,7 +32,6 @@ from reflector.db.transcripts import (
|
|||||||
TranscriptFinalLongSummary,
|
TranscriptFinalLongSummary,
|
||||||
TranscriptFinalShortSummary,
|
TranscriptFinalShortSummary,
|
||||||
TranscriptFinalTitle,
|
TranscriptFinalTitle,
|
||||||
TranscriptStatus,
|
|
||||||
TranscriptText,
|
TranscriptText,
|
||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
TranscriptWaveform,
|
TranscriptWaveform,
|
||||||
@@ -40,9 +40,8 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerAutoProcessor,
|
AudioChunkerProcessor,
|
||||||
AudioDiarizationAutoProcessor,
|
AudioDiarizationAutoProcessor,
|
||||||
AudioDownscaleProcessor,
|
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -61,8 +60,6 @@ from reflector.processors.types import (
|
|||||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
from reflector.worker.app import taskiq_broker
|
|
||||||
from reflector.worker.session_decorator import with_session_and_transcript
|
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
from reflector.zulip import (
|
from reflector.zulip import (
|
||||||
get_zulip_message,
|
get_zulip_message,
|
||||||
@@ -71,6 +68,29 @@ from reflector.zulip import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def asynctask(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
async def run_with_db():
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
return await f(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
coro = run_with_db()
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
if loop and loop.is_running():
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def broadcast_to_sockets(func):
|
def broadcast_to_sockets(func):
|
||||||
"""
|
"""
|
||||||
Decorator to broadcast transcript event to websockets
|
Decorator to broadcast transcript event to websockets
|
||||||
@@ -89,27 +109,59 @@ def broadcast_to_sockets(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_transcript(func):
|
||||||
|
"""
|
||||||
|
Decorator to fetch the transcript from the database from the first argument
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(**kwargs):
|
||||||
|
transcript_id = kwargs.pop("transcript_id")
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise Exception("Transcript {transcript_id} not found")
|
||||||
|
|
||||||
|
# Enhanced logger with Celery task context
|
||||||
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
|
if current_task:
|
||||||
|
tlogger = tlogger.bind(
|
||||||
|
task_id=current_task.request.id,
|
||||||
|
task_name=current_task.name,
|
||||||
|
worker_hostname=current_task.request.hostname,
|
||||||
|
task_retries=current_task.request.retries,
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class StrValue(BaseModel):
|
class StrValue(BaseModel):
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||||
def __init__(self, transcript_id: str):
|
transcript_id: str
|
||||||
super().__init__()
|
ws_room_id: str | None = None
|
||||||
|
ws_manager: WebsocketManager | None = None
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
# prepare websocket
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self.transcript_id = transcript_id
|
|
||||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||||
self._ws_manager = None
|
self.ws_manager = get_ws_manager()
|
||||||
|
|
||||||
@property
|
async def get_transcript(self) -> Transcript:
|
||||||
def ws_manager(self) -> WebsocketManager:
|
|
||||||
if self._ws_manager is None:
|
|
||||||
self._ws_manager = get_ws_manager()
|
|
||||||
return self._ws_manager
|
|
||||||
|
|
||||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
result = await transcripts_controller.get_by_id(
|
||||||
|
transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
@@ -132,31 +184,24 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
]
|
]
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lock_transaction(self):
|
async def transaction(self):
|
||||||
# This lock is to prevent multiple processor starting adding
|
|
||||||
# into event array at the same time
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
yield
|
async with transcripts_controller.transaction():
|
||||||
|
yield
|
||||||
@asynccontextmanager
|
|
||||||
async def locked_session(self):
|
|
||||||
async with self.lock_transaction():
|
|
||||||
async with get_session_context() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_status(self, status):
|
async def on_status(self, status):
|
||||||
# if it's the first part, update the status of the transcript
|
# if it's the first part, update the status of the transcript
|
||||||
# but do not set the ended status yet.
|
# but do not set the ended status yet.
|
||||||
if isinstance(self, PipelineMainLive):
|
if isinstance(self, PipelineMainLive):
|
||||||
status_mapping: dict[str, TranscriptStatus] = {
|
status_mapping = {
|
||||||
"started": "recording",
|
"started": "recording",
|
||||||
"push": "recording",
|
"push": "recording",
|
||||||
"flush": "processing",
|
"flush": "processing",
|
||||||
"error": "error",
|
"error": "error",
|
||||||
}
|
}
|
||||||
elif isinstance(self, PipelineMainFinalSummaries):
|
elif isinstance(self, PipelineMainFinalSummaries):
|
||||||
status_mapping: dict[str, TranscriptStatus] = {
|
status_mapping = {
|
||||||
"push": "processing",
|
"push": "processing",
|
||||||
"flush": "processing",
|
"flush": "processing",
|
||||||
"error": "error",
|
"error": "error",
|
||||||
@@ -172,18 +217,28 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
return
|
return
|
||||||
|
|
||||||
# when the status of the pipeline changes, update the transcript
|
# when the status of the pipeline changes, update the transcript
|
||||||
async with self._lock:
|
async with self.transaction():
|
||||||
async with get_session_context() as session:
|
transcript = await self.get_transcript()
|
||||||
return await transcripts_controller.set_status(
|
if status == transcript.status:
|
||||||
session, self.transcript_id, status
|
return
|
||||||
)
|
resp = await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="STATUS",
|
||||||
|
data=StrValue(value=status),
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"status": status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_transcript(self, data):
|
async def on_transcript(self, data):
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TRANSCRIPT",
|
event="TRANSCRIPT",
|
||||||
data=TranscriptText(text=data.text, translation=data.translation),
|
data=TranscriptText(text=data.text, translation=data.translation),
|
||||||
@@ -200,11 +255,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
)
|
)
|
||||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
topic.id = data.id
|
topic.id = data.id
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TOPIC",
|
event="TOPIC",
|
||||||
data=topic,
|
data=topic,
|
||||||
@@ -213,18 +267,16 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_title(self, data):
|
async def on_title(self, data):
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
if not transcript.title:
|
if not transcript.title:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"title": final_title.title,
|
"title": final_title.title,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_TITLE",
|
event="FINAL_TITLE",
|
||||||
data=final_title,
|
data=final_title,
|
||||||
@@ -233,17 +285,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_long_summary(self, data):
|
async def on_long_summary(self, data):
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"long_summary": final_long_summary.long_summary,
|
"long_summary": final_long_summary.long_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_LONG_SUMMARY",
|
event="FINAL_LONG_SUMMARY",
|
||||||
data=final_long_summary,
|
data=final_long_summary,
|
||||||
@@ -254,17 +304,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
final_short_summary = TranscriptFinalShortSummary(
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
short_summary=data.short_summary
|
short_summary=data.short_summary
|
||||||
)
|
)
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"short_summary": final_short_summary.short_summary,
|
"short_summary": final_short_summary.short_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_SHORT_SUMMARY",
|
event="FINAL_SHORT_SUMMARY",
|
||||||
data=final_short_summary,
|
data=final_short_summary,
|
||||||
@@ -272,30 +320,29 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_duration(self, data):
|
async def on_duration(self, data):
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
duration = TranscriptDuration(duration=data)
|
duration = TranscriptDuration(duration=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"duration": duration.duration,
|
"duration": duration.duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session, transcript=transcript, event="DURATION", data=duration
|
transcript=transcript, event="DURATION", data=duration
|
||||||
)
|
)
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_waveform(self, data):
|
async def on_waveform(self, data):
|
||||||
async with self.locked_session() as session:
|
async with self.transaction():
|
||||||
waveform = TranscriptWaveform(waveform=data)
|
waveform = TranscriptWaveform(waveform=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session, transcript=transcript, event="WAVEFORM", data=waveform
|
transcript=transcript, event="WAVEFORM", data=waveform
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -308,16 +355,15 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
async with get_session_context() as session:
|
self.prepare()
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
AudioFileWriterProcessor(
|
AudioFileWriterProcessor(
|
||||||
path=transcript.audio_wav_filename,
|
path=transcript.audio_wav_filename,
|
||||||
on_duration=self.on_duration,
|
on_duration=self.on_duration,
|
||||||
),
|
),
|
||||||
AudioDownscaleProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioChunkerAutoProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
@@ -330,7 +376,6 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
pipeline.logger.bind(transcript_id=transcript.id)
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
pipeline.logger.info("Pipeline main live created")
|
pipeline.logger.info("Pipeline main live created")
|
||||||
pipeline.describe()
|
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@@ -349,6 +394,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
|
self.prepare()
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||||
)
|
)
|
||||||
@@ -357,8 +403,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
# now let's start the pipeline by pushing information to the
|
# now let's start the pipeline by pushing information to the
|
||||||
# first processor diarization processor
|
# first processor diarization processor
|
||||||
# XXX translation is lost when converting our data model to the processor model
|
# XXX translation is lost when converting our data model to the processor model
|
||||||
async with get_session_context() as session:
|
transcript = await self.get_transcript()
|
||||||
transcript = await self.get_transcript(session)
|
|
||||||
|
|
||||||
# diarization works only if the file is uploaded to an external storage
|
# diarization works only if the file is uploaded to an external storage
|
||||||
if transcript.audio_location == "local":
|
if transcript.audio_location == "local":
|
||||||
@@ -390,9 +435,10 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
|
self.prepare()
|
||||||
|
|
||||||
# get transcript
|
# get transcript
|
||||||
async with get_session_context() as session:
|
self._transcript = transcript = await self.get_transcript()
|
||||||
self._transcript = transcript = await self.get_transcript(session)
|
|
||||||
|
|
||||||
# create pipeline
|
# create pipeline
|
||||||
processors = self.get_processors()
|
processors = self.get_processors()
|
||||||
@@ -452,7 +498,8 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||||
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
||||||
logger.info("Starting remove upload")
|
logger.info("Starting remove upload")
|
||||||
uploads = transcript.data_path.glob("upload.*")
|
uploads = transcript.data_path.glob("upload.*")
|
||||||
@@ -461,14 +508,16 @@ async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger
|
|||||||
logger.info("Remove upload done")
|
logger.info("Remove upload done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting waveform")
|
logger.info("Starting waveform")
|
||||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Waveform done")
|
logger.info("Waveform done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting convert to mp3")
|
logger.info("Starting convert to mp3")
|
||||||
|
|
||||||
# If the audio wav is not available, just skip
|
# If the audio wav is not available, just skip
|
||||||
@@ -494,7 +543,8 @@ async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logge
|
|||||||
logger.info("Convert to mp3 done")
|
logger.info("Convert to mp3 done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||||
logger.info("No storage backend configured, skipping mp3 upload")
|
logger.info("No storage backend configured, skipping mp3 upload")
|
||||||
return
|
return
|
||||||
@@ -512,49 +562,49 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Upload to external storage and delete the file
|
# Upload to external storage and delete the file
|
||||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||||
|
|
||||||
logger.info("Upload mp3 done")
|
logger.info("Upload mp3 done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting diarization")
|
logger.info("Starting diarization")
|
||||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Diarization done")
|
logger.info("Diarization done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_title(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_title(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting title")
|
logger.info("Starting title")
|
||||||
runner = PipelineMainTitle(transcript_id=transcript.id)
|
runner = PipelineMainTitle(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Title done")
|
logger.info("Title done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting summaries")
|
logger.info("Starting summaries")
|
||||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Summaries done")
|
logger.info("Summaries done")
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting consent cleanup")
|
logger.info("Starting consent cleanup")
|
||||||
|
|
||||||
consent_denied = False
|
consent_denied = False
|
||||||
recording = None
|
recording = None
|
||||||
try:
|
try:
|
||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
recording = await recordings_controller.get_by_id(
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
session, transcript.recording_id
|
|
||||||
)
|
|
||||||
if recording and recording.meeting_id:
|
if recording and recording.meeting_id:
|
||||||
meeting = await meetings_controller.get_by_id(
|
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||||
session, recording.meeting_id
|
|
||||||
)
|
|
||||||
if meeting:
|
if meeting:
|
||||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||||
session, meeting.id
|
meeting.id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||||
@@ -583,7 +633,7 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||||
|
|
||||||
# non-transactional, files marked for deletion not actually deleted is possible
|
# non-transactional, files marked for deletion not actually deleted is possible
|
||||||
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||||
# 2. Delete processed audio from transcript storage S3 bucket
|
# 2. Delete processed audio from transcript storage S3 bucket
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
storage = get_transcripts_storage()
|
storage = get_transcripts_storage()
|
||||||
@@ -607,14 +657,15 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Consent cleanup done")
|
logger.info("Consent cleanup done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting post to zulip")
|
logger.info("Starting post to zulip")
|
||||||
|
|
||||||
if not transcript.recording_id:
|
if not transcript.recording_id:
|
||||||
logger.info("Transcript has no recording")
|
logger.info("Transcript has no recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
if not recording:
|
if not recording:
|
||||||
logger.info("Recording not found")
|
logger.info("Recording not found")
|
||||||
return
|
return
|
||||||
@@ -623,12 +674,12 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
logger.info("Recording has no meeting")
|
logger.info("Recording has no meeting")
|
||||||
return
|
return
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
logger.info("No meeting found for this recording")
|
logger.info("No meeting found for this recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||||
if not room:
|
if not room:
|
||||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||||
return
|
return
|
||||||
@@ -654,7 +705,7 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
room.zulip_stream, room.zulip_topic, message
|
room.zulip_stream, room.zulip_topic, message
|
||||||
)
|
)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session, transcript, {"zulip_message_id": response["id"]}
|
transcript, {"zulip_message_id": response["id"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Posted to zulip")
|
logger.info("Posted to zulip")
|
||||||
@@ -665,120 +716,92 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
# ===================================================================
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_remove_upload(
|
async def task_pipeline_remove_upload(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_remove_upload(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_waveform(
|
async def task_pipeline_waveform(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_waveform(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_waveform(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_convert_to_mp3(
|
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_upload_mp3(
|
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_diarization(
|
async def task_pipeline_diarization(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_diarization(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_diarization(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_title(
|
async def task_pipeline_title(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_title(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_title(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_final_summaries(
|
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
await pipeline_summaries(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_summaries(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
|
async def task_cleanup_consent(*, transcript_id: str):
|
||||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
await cleanup_consent(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_post_to_zulip(
|
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger
|
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
def pipeline_post(*, transcript_id: str):
|
||||||
@with_session_and_transcript
|
"""
|
||||||
async def task_cleanup_consent_taskiq(
|
Run the post pipeline
|
||||||
session, *, transcript: Transcript, logger: Logger
|
"""
|
||||||
):
|
chain_mp3_and_diarize = (
|
||||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||||
|
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||||
|
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||||
@taskiq_broker.task
|
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
||||||
@with_session_and_transcript
|
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||||
async def task_pipeline_post_to_zulip_taskiq(
|
| task_cleanup_consent.si(transcript_id=transcript_id)
|
||||||
session, *, transcript: Transcript, logger: Logger
|
)
|
||||||
):
|
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
|
||||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||||
|
transcript_id=transcript_id
|
||||||
|
|
||||||
async def pipeline_post(*, transcript_id: str):
|
|
||||||
await task_pipeline_post_sequential.kiq(transcript_id=transcript_id)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
async def task_pipeline_post_sequential(*, transcript_id: str):
|
|
||||||
await task_pipeline_waveform.kiq(transcript_id=transcript_id)
|
|
||||||
await task_pipeline_convert_to_mp3.kiq(transcript_id=transcript_id)
|
|
||||||
await task_pipeline_upload_mp3.kiq(transcript_id=transcript_id)
|
|
||||||
await task_pipeline_remove_upload.kiq(transcript_id=transcript_id)
|
|
||||||
await task_pipeline_diarization.kiq(transcript_id=transcript_id)
|
|
||||||
await task_cleanup_consent.kiq(transcript_id=transcript_id)
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
task_pipeline_title.kiq(transcript_id=transcript_id),
|
|
||||||
task_pipeline_final_summaries.kiq(transcript_id=transcript_id),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
chain = chord(
|
||||||
|
group(chain_mp3_and_diarize, chain_title_preview),
|
||||||
|
chain_final_summaries,
|
||||||
|
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
chain.delay()
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||||
try:
|
try:
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": [],
|
"topics": [],
|
||||||
@@ -816,7 +839,6 @@ async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Pipeline error", exc_info=exc)
|
logger.error("Pipeline error", exc_info=exc)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"status": "error",
|
"status": "error",
|
||||||
@@ -827,9 +849,7 @@ async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Pipeline ended")
|
logger.info("Pipeline ended")
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@shared_task
|
||||||
@with_session_and_transcript
|
@asynctask
|
||||||
async def task_pipeline_process(
|
async def task_pipeline_process(*, transcript_id: str):
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
return await pipeline_process(transcript_id=transcript_id)
|
||||||
):
|
|
||||||
return await pipeline_process(session, transcript=transcript, logger=logger)
|
|
||||||
|
|||||||
@@ -18,14 +18,22 @@ During its lifecycle, it will emit the following status:
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import Pipeline
|
from reflector.processors import Pipeline
|
||||||
|
|
||||||
PipelineMessage = TypeVar("PipelineMessage")
|
PipelineMessage = TypeVar("PipelineMessage")
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunner(Generic[PipelineMessage]):
|
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||||
def __init__(self):
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
status: str = "idle"
|
||||||
|
pipeline: Pipeline | None = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self._task = None
|
self._task = None
|
||||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||||
self._ev_done = asyncio.Event()
|
self._ev_done = asyncio.Event()
|
||||||
@@ -34,8 +42,6 @@ class PipelineRunner(Generic[PipelineMessage]):
|
|||||||
runner=id(self),
|
runner=id(self),
|
||||||
runner_cls=self.__class__.__name__,
|
runner_cls=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
self.status = "idle"
|
|
||||||
self.pipeline: Pipeline | None = None
|
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||||
from .audio_chunker_auto import AudioChunkerAutoProcessor # noqa: F401
|
|
||||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
|
||||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
@@ -13,13 +11,6 @@ from .base import ( # noqa: F401
|
|||||||
Processor,
|
Processor,
|
||||||
ThreadedProcessor,
|
ThreadedProcessor,
|
||||||
)
|
)
|
||||||
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
|
||||||
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
|
||||||
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
|
||||||
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
|
||||||
from .transcript_diarization_assembler import (
|
|
||||||
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
|
||||||
)
|
|
||||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||||
|
|||||||
@@ -1,78 +1,28 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from prometheus_client import Counter, Histogram
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerProcessor(Processor):
|
class AudioChunkerProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Base class for assembling audio frames into chunks
|
Assemble audio frames into chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = av.AudioFrame
|
INPUT_TYPE = av.AudioFrame
|
||||||
OUTPUT_TYPE = list[av.AudioFrame]
|
OUTPUT_TYPE = list[av.AudioFrame]
|
||||||
|
|
||||||
m_chunk = Histogram(
|
def __init__(self, max_frames=256):
|
||||||
"audio_chunker",
|
super().__init__()
|
||||||
"Time spent in AudioChunker.chunk",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_chunk_call = Counter(
|
|
||||||
"audio_chunker_call",
|
|
||||||
"Number of calls to AudioChunker.chunk",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_chunk_success = Counter(
|
|
||||||
"audio_chunker_success",
|
|
||||||
"Number of successful calls to AudioChunker.chunk",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_chunk_failure = Counter(
|
|
||||||
"audio_chunker_failure",
|
|
||||||
"Number of failed calls to AudioChunker.chunk",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
name = self.__class__.__name__
|
|
||||||
self.m_chunk = self.m_chunk.labels(name)
|
|
||||||
self.m_chunk_call = self.m_chunk_call.labels(name)
|
|
||||||
self.m_chunk_success = self.m_chunk_success.labels(name)
|
|
||||||
self.m_chunk_failure = self.m_chunk_failure.labels(name)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.frames: list[av.AudioFrame] = []
|
self.frames: list[av.AudioFrame] = []
|
||||||
|
self.max_frames = max_frames
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
"""Process incoming audio frame"""
|
self.frames.append(data)
|
||||||
# Validate audio format on first frame
|
if len(self.frames) >= self.max_frames:
|
||||||
if len(self.frames) == 0:
|
await self.flush()
|
||||||
if data.sample_rate != 16000 or len(data.layout.channels) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"AudioChunkerProcessor expects 16kHz mono audio, got {data.sample_rate}Hz "
|
|
||||||
f"with {len(data.layout.channels)} channel(s). "
|
|
||||||
f"Use AudioDownscaleProcessor before this processor."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.m_chunk_call.inc()
|
|
||||||
with self.m_chunk.time():
|
|
||||||
result = await self._chunk(data)
|
|
||||||
self.m_chunk_success.inc()
|
|
||||||
if result:
|
|
||||||
await self.emit(result)
|
|
||||||
except Exception:
|
|
||||||
self.m_chunk_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
|
||||||
"""
|
|
||||||
Process audio frame and return chunk when ready.
|
|
||||||
Subclasses should implement their chunking logic here.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
"""Flush any remaining frames when processing ends"""
|
frames = self.frames[:]
|
||||||
raise NotImplementedError
|
self.frames = []
|
||||||
|
if frames:
|
||||||
|
await self.emit(frames)
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
import importlib
|
|
||||||
|
|
||||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerAutoProcessor(AudioChunkerProcessor):
|
|
||||||
_registry = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, kclass):
|
|
||||||
cls._registry[name] = kclass
|
|
||||||
|
|
||||||
def __new__(cls, name: str | None = None, **kwargs):
|
|
||||||
if name is None:
|
|
||||||
name = settings.AUDIO_CHUNKER_BACKEND
|
|
||||||
if name not in cls._registry:
|
|
||||||
module_name = f"reflector.processors.audio_chunker_{name}"
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# gather specific configuration for the processor
|
|
||||||
# search `AUDIO_CHUNKER_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
|
||||||
config = {}
|
|
||||||
name_upper = name.upper()
|
|
||||||
settings_prefix = "AUDIO_CHUNKER_"
|
|
||||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
|
||||||
for key, value in settings:
|
|
||||||
if key.startswith(config_prefix):
|
|
||||||
config_name = key[len(settings_prefix) :].lower()
|
|
||||||
config[config_name] = value
|
|
||||||
|
|
||||||
return cls._registry[name](**config | kwargs)
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
|
||||||
|
|
||||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
|
||||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerFramesProcessor(AudioChunkerProcessor):
|
|
||||||
"""
|
|
||||||
Simple frame-based audio chunker that emits chunks after a fixed number of frames
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_frames=256, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.max_frames = max_frames
|
|
||||||
|
|
||||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
|
||||||
self.frames.append(data)
|
|
||||||
if len(self.frames) >= self.max_frames:
|
|
||||||
frames_to_emit = self.frames[:]
|
|
||||||
self.frames = []
|
|
||||||
return frames_to_emit
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
frames = self.frames[:]
|
|
||||||
self.frames = []
|
|
||||||
if frames:
|
|
||||||
await self.emit(frames)
|
|
||||||
|
|
||||||
|
|
||||||
AudioChunkerAutoProcessor.register("frames", AudioChunkerFramesProcessor)
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from silero_vad import VADIterator, load_silero_vad
|
|
||||||
|
|
||||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
|
||||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerSileroProcessor(AudioChunkerProcessor):
|
|
||||||
"""
|
|
||||||
Assemble audio frames into chunks with VAD-based speech detection using Silero VAD
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
block_frames=256,
|
|
||||||
max_frames=1024,
|
|
||||||
use_onnx=True,
|
|
||||||
min_frames=2,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.block_frames = block_frames
|
|
||||||
self.max_frames = max_frames
|
|
||||||
self.min_frames = min_frames
|
|
||||||
|
|
||||||
# Initialize Silero VAD
|
|
||||||
self._init_vad(use_onnx)
|
|
||||||
|
|
||||||
def _init_vad(self, use_onnx=False):
|
|
||||||
"""Initialize Silero VAD model"""
|
|
||||||
try:
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
|
||||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
|
||||||
self.logger.info("Silero VAD initialized successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
|
||||||
self.vad_model = None
|
|
||||||
self.vad_iterator = None
|
|
||||||
|
|
||||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
|
||||||
"""Process audio frame and return chunk when ready"""
|
|
||||||
self.frames.append(data)
|
|
||||||
|
|
||||||
# Check for speech segments every 32 frames (~1 second)
|
|
||||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
|
||||||
return await self._process_block()
|
|
||||||
|
|
||||||
# Safety fallback - emit if we hit max frames
|
|
||||||
elif len(self.frames) >= self.max_frames:
|
|
||||||
self.logger.warning(
|
|
||||||
f"AudioChunkerSileroProcessor: Reached max frames ({self.max_frames}), "
|
|
||||||
f"emitting first {self.max_frames // 2} frames"
|
|
||||||
)
|
|
||||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
|
||||||
self.frames = self.frames[self.max_frames // 2 :]
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
return frames_to_emit
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _process_block(self) -> Optional[list[av.AudioFrame]]:
|
|
||||||
# Need at least 32 frames for VAD detection (~1 second)
|
|
||||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Processing block with current buffer size
|
|
||||||
print(f"Processing block: {len(self.frames)} frames in buffer")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Convert frames to numpy array for VAD
|
|
||||||
audio_array = self._frames_to_numpy(self.frames)
|
|
||||||
|
|
||||||
if audio_array is None:
|
|
||||||
# Fallback: emit all frames if conversion failed
|
|
||||||
frames_to_emit = self.frames[:]
|
|
||||||
self.frames = []
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
return frames_to_emit
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Find complete speech segments in the buffer
|
|
||||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
|
||||||
|
|
||||||
if speech_end_frame is None or speech_end_frame <= 0:
|
|
||||||
# No speech found but buffer is getting large
|
|
||||||
if len(self.frames) > 512:
|
|
||||||
# Check if it's all silence and can be discarded
|
|
||||||
# No speech segment found, buffer at {len(self.frames)} frames
|
|
||||||
|
|
||||||
# Could emit silence or discard old frames here
|
|
||||||
# For now, keep first 256 frames and discard older silence
|
|
||||||
if len(self.frames) > 768:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
|
||||||
)
|
|
||||||
self.frames = self.frames[-256:]
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Calculate segment timing information
|
|
||||||
frames_to_emit = self.frames[:speech_end_frame]
|
|
||||||
|
|
||||||
# Get timing from av.AudioFrame
|
|
||||||
if frames_to_emit:
|
|
||||||
first_frame = frames_to_emit[0]
|
|
||||||
last_frame = frames_to_emit[-1]
|
|
||||||
sample_rate = first_frame.sample_rate
|
|
||||||
|
|
||||||
# Calculate duration
|
|
||||||
total_samples = sum(f.samples for f in frames_to_emit)
|
|
||||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
|
||||||
|
|
||||||
# Get timestamps if available
|
|
||||||
start_time = (
|
|
||||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
|
||||||
)
|
|
||||||
end_time = (
|
|
||||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to HH:MM:SS format for logging
|
|
||||||
def format_time(seconds):
|
|
||||||
if not seconds:
|
|
||||||
return "00:00:00"
|
|
||||||
total_seconds = int(float(seconds))
|
|
||||||
hours = total_seconds // 3600
|
|
||||||
minutes = (total_seconds % 3600) // 60
|
|
||||||
secs = total_seconds % 60
|
|
||||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
|
||||||
|
|
||||||
start_formatted = format_time(start_time)
|
|
||||||
end_formatted = format_time(end_time)
|
|
||||||
|
|
||||||
# Keep remaining frames for next processing
|
|
||||||
remaining_after = len(self.frames) - speech_end_frame
|
|
||||||
|
|
||||||
# Single structured log line
|
|
||||||
self.logger.info(
|
|
||||||
"Speech segment found",
|
|
||||||
start=start_formatted,
|
|
||||||
end=end_formatted,
|
|
||||||
frames=speech_end_frame,
|
|
||||||
duration=round(duration_seconds, 2),
|
|
||||||
buffer_before=len(self.frames),
|
|
||||||
remaining=remaining_after,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep remaining frames for next processing
|
|
||||||
self.frames = self.frames[speech_end_frame:]
|
|
||||||
|
|
||||||
# Filter out segments with too few frames
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
return frames_to_emit
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in VAD processing: {e}")
|
|
||||||
# Fallback to simple chunking
|
|
||||||
if len(self.frames) >= self.block_frames:
|
|
||||||
frames_to_emit = self.frames[: self.block_frames]
|
|
||||||
self.frames = self.frames[self.block_frames :]
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
return frames_to_emit
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
|
||||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
|
||||||
if not frames:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
audio_data = []
|
|
||||||
for frame in frames:
|
|
||||||
frame_array = frame.to_ndarray()
|
|
||||||
|
|
||||||
if len(frame_array.shape) == 2:
|
|
||||||
frame_array = frame_array.flatten()
|
|
||||||
|
|
||||||
audio_data.append(frame_array)
|
|
||||||
|
|
||||||
if not audio_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
combined_audio = np.concatenate(audio_data)
|
|
||||||
|
|
||||||
# Ensure float32 format
|
|
||||||
if combined_audio.dtype == np.int16:
|
|
||||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
|
||||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
|
||||||
elif combined_audio.dtype != np.float32:
|
|
||||||
combined_audio = combined_audio.astype(np.float32)
|
|
||||||
|
|
||||||
return combined_audio
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
|
||||||
"""Find complete speech segments and return frame index at segment end"""
|
|
||||||
if self.vad_iterator is None or len(audio_array) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process audio in 512-sample windows for VAD
|
|
||||||
window_size = 512
|
|
||||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
|
||||||
|
|
||||||
# Track speech state
|
|
||||||
in_speech = False
|
|
||||||
speech_start = None
|
|
||||||
speech_end = None
|
|
||||||
silence_count = 0
|
|
||||||
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
|
||||||
|
|
||||||
# Detect if this window has speech
|
|
||||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
|
||||||
|
|
||||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
|
||||||
if speech_dict:
|
|
||||||
if not in_speech:
|
|
||||||
# Speech started
|
|
||||||
speech_start = i
|
|
||||||
in_speech = True
|
|
||||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
|
||||||
silence_count = 0 # Reset silence counter
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not in_speech:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# We're in speech but found silence
|
|
||||||
silence_count += 1
|
|
||||||
if silence_count < min_silence_windows:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Found end of speech segment
|
|
||||||
speech_end = i - (min_silence_windows - 1) * window_size
|
|
||||||
# Debug: print(f"Speech END at sample {speech_end}")
|
|
||||||
|
|
||||||
# Convert sample position to frame index
|
|
||||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
|
||||||
frame_index = speech_end // samples_per_frame
|
|
||||||
|
|
||||||
# Ensure we don't exceed buffer
|
|
||||||
frame_index = min(frame_index, len(self.frames))
|
|
||||||
return frame_index
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error finding speech segment: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
frames = self.frames[:]
|
|
||||||
self.frames = []
|
|
||||||
if frames:
|
|
||||||
if len(frames) >= self.min_frames:
|
|
||||||
await self.emit(frames)
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring flush segment with {len(frames)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
AudioChunkerAutoProcessor.register("silero", AudioChunkerSileroProcessor)
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import (
|
from reflector.processors.types import (
|
||||||
AudioDiarizationInput,
|
AudioDiarizationInput,
|
||||||
DiarizationSegment,
|
|
||||||
TitleSummary,
|
TitleSummary,
|
||||||
Word,
|
Word,
|
||||||
)
|
)
|
||||||
@@ -38,21 +37,18 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
async def _diarize(self, data: AudioDiarizationInput):
|
async def _diarize(self, data: AudioDiarizationInput):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||||
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
self._diarization_remove_overlap(diarization)
|
||||||
cls._diarization_remove_overlap(diarization)
|
self._diarization_remove_segment_without_words(words, diarization)
|
||||||
cls._diarization_remove_segment_without_words(words, diarization)
|
self._diarization_merge_same_speaker(words, diarization)
|
||||||
cls._diarization_merge_same_speaker(diarization)
|
self._diarization_assign_speaker(words, diarization)
|
||||||
cls._diarization_assign_speaker(words, diarization)
|
|
||||||
|
|
||||||
@staticmethod
|
def iter_words_from_topics(self, topics: TitleSummary):
|
||||||
def iter_words_from_topics(topics: list[TitleSummary]):
|
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
for word in topic.transcript.words:
|
for word in topic.transcript.words:
|
||||||
yield word
|
yield word
|
||||||
|
|
||||||
@staticmethod
|
def is_word_continuation(self, word_prev, word):
|
||||||
def is_word_continuation(word_prev, word):
|
|
||||||
"""
|
"""
|
||||||
Return True if the word is a continuation of the previous word
|
Return True if the word is a continuation of the previous word
|
||||||
by checking if the previous word is ending with a punctuation
|
by checking if the previous word is ending with a punctuation
|
||||||
@@ -65,8 +61,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||||
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
|
||||||
"""
|
"""
|
||||||
Remove overlap in diarization results
|
Remove overlap in diarization results
|
||||||
|
|
||||||
@@ -91,9 +86,8 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _diarization_remove_segment_without_words(
|
def _diarization_remove_segment_without_words(
|
||||||
words: list[Word], diarization: list[DiarizationSegment]
|
self, words: list[Word], diarization: list[dict]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Remove diarization segments without words
|
Remove diarization segments without words
|
||||||
@@ -122,8 +116,9 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
@staticmethod
|
def _diarization_merge_same_speaker(
|
||||||
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
self, words: list[Word], diarization: list[dict]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Merge diarization contigous segments with the same speaker
|
Merge diarization contigous segments with the same speaker
|
||||||
|
|
||||||
@@ -140,10 +135,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
else:
|
else:
|
||||||
diarization_idx += 1
|
diarization_idx += 1
|
||||||
|
|
||||||
@classmethod
|
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||||
def _diarization_assign_speaker(
|
|
||||||
cls, words: list[Word], diarization: list[DiarizationSegment]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Assign speaker to words based on diarization
|
Assign speaker to words based on diarization
|
||||||
|
|
||||||
@@ -151,7 +143,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
word_idx = 0
|
word_idx = 0
|
||||||
last_speaker = 0
|
last_speaker = None
|
||||||
for d in diarization:
|
for d in diarization:
|
||||||
start = d["start"]
|
start = d["start"]
|
||||||
end = d["end"]
|
end = d["end"]
|
||||||
@@ -166,7 +158,7 @@ class AudioDiarizationProcessor(Processor):
|
|||||||
# If it's a continuation, assign with the last speaker
|
# If it's a continuation, assign with the last speaker
|
||||||
is_continuation = False
|
is_continuation = False
|
||||||
if word_idx > 0 and word_idx < len(words) - 1:
|
if word_idx > 0 and word_idx < len(words) - 1:
|
||||||
is_continuation = cls.is_word_continuation(
|
is_continuation = self.is_word_continuation(
|
||||||
*words[word_idx - 1 : word_idx + 1]
|
*words[word_idx - 1 : word_idx + 1]
|
||||||
)
|
)
|
||||||
if is_continuation:
|
if is_continuation:
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
|
|
||||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
|
||||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
|
||||||
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
|
||||||
|
|
||||||
|
|
||||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
|
||||||
"""Local diarization processor using pyannote.audio library"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "pyannote/speaker-diarization-3.1",
|
|
||||||
pyannote_auth_token: str | None = None,
|
|
||||||
device: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.model_name = model_name
|
|
||||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
if device is None:
|
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
|
||||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
|
||||||
self.model_name, use_auth_token=self.auth_token
|
|
||||||
)
|
|
||||||
self.diarization_pipeline.to(torch.device(self.device))
|
|
||||||
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
|
||||||
|
|
||||||
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
|
||||||
try:
|
|
||||||
# Load audio file (audio_url is assumed to be a local file path)
|
|
||||||
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
|
||||||
waveform, sample_rate = torchaudio.load(data.audio_url)
|
|
||||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
|
||||||
self.logger.info("Running speaker diarization")
|
|
||||||
diarization = self.diarization_pipeline(audio_input)
|
|
||||||
|
|
||||||
# Convert pyannote diarization output to our format
|
|
||||||
segments = []
|
|
||||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
|
||||||
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
|
||||||
speaker_id = 0
|
|
||||||
if speaker.startswith("SPEAKER_"):
|
|
||||||
try:
|
|
||||||
speaker_id = int(speaker.split("_")[-1])
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
# Fallback to hash-based ID if parsing fails
|
|
||||||
speaker_id = hash(speaker) % 1000
|
|
||||||
|
|
||||||
segments.append(
|
|
||||||
{
|
|
||||||
"start": round(segment.start, 3),
|
|
||||||
"end": round(segment.end, 3),
|
|
||||||
"speaker": speaker_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
|
||||||
return segments
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.exception(f"Diarization failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
|
||||||
from av.audio.resampler import AudioResampler
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
|
||||||
|
|
||||||
|
|
||||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
|
||||||
frame_copy = frame.from_ndarray(
|
|
||||||
frame.to_ndarray(),
|
|
||||||
format=frame.format.name,
|
|
||||||
layout=frame.layout.name,
|
|
||||||
)
|
|
||||||
frame_copy.sample_rate = frame.sample_rate
|
|
||||||
frame_copy.pts = frame.pts
|
|
||||||
frame_copy.time_base = frame.time_base
|
|
||||||
return frame_copy
|
|
||||||
|
|
||||||
|
|
||||||
class AudioDownscaleProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Downscale audio frames to 16kHz mono format
|
|
||||||
"""
|
|
||||||
|
|
||||||
INPUT_TYPE = av.AudioFrame
|
|
||||||
OUTPUT_TYPE = av.AudioFrame
|
|
||||||
|
|
||||||
def __init__(self, target_rate: int = 16000, target_layout: str = "mono", **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.target_rate = target_rate
|
|
||||||
self.target_layout = target_layout
|
|
||||||
self.resampler: Optional[AudioResampler] = None
|
|
||||||
self.needs_resampling: Optional[bool] = None
|
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
|
||||||
if self.needs_resampling is None:
|
|
||||||
self.needs_resampling = (
|
|
||||||
data.sample_rate != self.target_rate
|
|
||||||
or data.layout.name != self.target_layout
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.needs_resampling:
|
|
||||||
self.resampler = AudioResampler(
|
|
||||||
format="s16", layout=self.target_layout, rate=self.target_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.needs_resampling or not self.resampler:
|
|
||||||
await self.emit(data)
|
|
||||||
return
|
|
||||||
|
|
||||||
resampled_frames = self.resampler.resample(copy_frame(data))
|
|
||||||
for resampled_frame in resampled_frames:
|
|
||||||
await self.emit(resampled_frame)
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
if self.needs_resampling and self.resampler:
|
|
||||||
final_frames = self.resampler.resample(None)
|
|
||||||
for frame in final_frames:
|
|
||||||
await self.emit(frame)
|
|
||||||
@@ -16,46 +16,37 @@ class AudioMergeProcessor(Processor):
|
|||||||
INPUT_TYPE = list[av.AudioFrame]
|
INPUT_TYPE = list[av.AudioFrame]
|
||||||
OUTPUT_TYPE = AudioFile
|
OUTPUT_TYPE = AudioFile
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
async def _push(self, data: list[av.AudioFrame]):
|
async def _push(self, data: list[av.AudioFrame]):
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get audio information from first frame
|
# get audio information from first frame
|
||||||
frame = data[0]
|
frame = data[0]
|
||||||
output_channels = len(frame.layout.channels)
|
channels = len(frame.layout.channels)
|
||||||
output_sample_rate = frame.sample_rate
|
sample_rate = frame.sample_rate
|
||||||
output_sample_width = frame.format.bytes
|
sample_width = frame.format.bytes
|
||||||
|
|
||||||
# create audio file
|
# create audio file
|
||||||
uu = uuid4().hex
|
uu = uuid4().hex
|
||||||
fd = io.BytesIO()
|
fd = io.BytesIO()
|
||||||
|
|
||||||
# Use PyAV to write frames
|
|
||||||
out_container = av.open(fd, "w", format="wav")
|
out_container = av.open(fd, "w", format="wav")
|
||||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||||
out_stream.layout = frame.layout.name
|
|
||||||
|
|
||||||
for frame in data:
|
for frame in data:
|
||||||
for packet in out_stream.encode(frame):
|
for packet in out_stream.encode(frame):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
|
|
||||||
# Flush the encoder
|
|
||||||
for packet in out_stream.encode(None):
|
for packet in out_stream.encode(None):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
out_container.close()
|
out_container.close()
|
||||||
|
|
||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
|
|
||||||
# emit audio file
|
# emit audio file
|
||||||
audiofile = AudioFile(
|
audiofile = AudioFile(
|
||||||
name=f"{monotonic_ns()}-{uu}.wav",
|
name=f"{monotonic_ns()}-{uu}.wav",
|
||||||
fd=fd,
|
fd=fd,
|
||||||
sample_rate=output_sample_rate,
|
sample_rate=sample_rate,
|
||||||
channels=output_channels,
|
channels=channels,
|
||||||
sample_width=output_sample_width,
|
sample_width=sample_width,
|
||||||
timestamp=data[0].pts * data[0].time_base,
|
timestamp=data[0].pts * data[0].time_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,7 @@ from reflector.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||||
def __init__(
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
self,
|
|
||||||
modal_api_key: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not settings.TRANSCRIPT_URL:
|
if not settings.TRANSCRIPT_URL:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ class Processor(Emitter):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.m_processor_failure.inc()
|
self.m_processor_failure.inc()
|
||||||
self.logger.exception("Error in push")
|
self.logger.exception("Error in push")
|
||||||
raise
|
|
||||||
|
|
||||||
async def flush(self):
|
async def flush(self):
|
||||||
"""
|
"""
|
||||||
@@ -241,45 +240,33 @@ class ThreadedProcessor(Processor):
|
|||||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self.queue = asyncio.Queue(maxsize=50)
|
self.queue = asyncio.Queue()
|
||||||
self.task: asyncio.Task | None = None
|
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: "Pipeline"):
|
def set_pipeline(self, pipeline: "Pipeline"):
|
||||||
super().set_pipeline(pipeline)
|
super().set_pipeline(pipeline)
|
||||||
self.processor.set_pipeline(pipeline)
|
self.processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
async def loop(self):
|
async def loop(self):
|
||||||
try:
|
while True:
|
||||||
while True:
|
data = await self.queue.get()
|
||||||
data = await self.queue.get()
|
self.m_processor_queue.set(self.queue.qsize())
|
||||||
self.m_processor_queue.set(self.queue.qsize())
|
with self.m_processor_queue_in_progress.track_inprogress():
|
||||||
with self.m_processor_queue_in_progress.track_inprogress():
|
try:
|
||||||
|
if data is None:
|
||||||
|
await self.processor.flush()
|
||||||
|
break
|
||||||
try:
|
try:
|
||||||
if data is None:
|
await self.processor.push(data)
|
||||||
await self.processor.flush()
|
except Exception:
|
||||||
break
|
self.logger.error(
|
||||||
try:
|
f"Error in push {self.processor.__class__.__name__}"
|
||||||
await self.processor.push(data)
|
", continue"
|
||||||
except Exception:
|
)
|
||||||
self.logger.error(
|
finally:
|
||||||
f"Error in push {self.processor.__class__.__name__}"
|
self.queue.task_done()
|
||||||
", continue"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.queue.task_done()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
|
||||||
|
|
||||||
async def _ensure_task(self):
|
|
||||||
if self.task is None:
|
|
||||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
|
||||||
|
|
||||||
# XXX not doing a sleep here make the whole pipeline prior the thread
|
|
||||||
# to be running without having a chance to work on the task here.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
async def _push(self, data):
|
async def _push(self, data):
|
||||||
await self._ensure_task()
|
|
||||||
await self.queue.put(data)
|
await self.queue.put(data)
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import DiarizationSegment
|
|
||||||
|
|
||||||
|
|
||||||
class FileDiarizationInput(BaseModel):
|
|
||||||
"""Input for file diarization containing audio URL"""
|
|
||||||
|
|
||||||
audio_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class FileDiarizationOutput(BaseModel):
|
|
||||||
"""Output for file diarization containing speaker segments"""
|
|
||||||
|
|
||||||
diarization: list[DiarizationSegment]
|
|
||||||
|
|
||||||
|
|
||||||
class FileDiarizationProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Diarize complete audio files from URL
|
|
||||||
"""
|
|
||||||
|
|
||||||
INPUT_TYPE = FileDiarizationInput
|
|
||||||
OUTPUT_TYPE = FileDiarizationOutput
|
|
||||||
|
|
||||||
async def _push(self, data: FileDiarizationInput):
|
|
||||||
result = await self._diarize(data)
|
|
||||||
if result:
|
|
||||||
await self.emit(result)
|
|
||||||
|
|
||||||
async def _diarize(self, data: FileDiarizationInput):
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import importlib
|
|
||||||
|
|
||||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
|
||||||
_registry = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, kclass):
|
|
||||||
cls._registry[name] = kclass
|
|
||||||
|
|
||||||
def __new__(cls, name: str | None = None, **kwargs):
|
|
||||||
if name is None:
|
|
||||||
name = settings.DIARIZATION_BACKEND
|
|
||||||
|
|
||||||
if name not in cls._registry:
|
|
||||||
module_name = f"reflector.processors.file_diarization_{name}"
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# gather specific configuration for the processor
|
|
||||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
|
||||||
config = {}
|
|
||||||
name_upper = name.upper()
|
|
||||||
settings_prefix = "DIARIZATION_"
|
|
||||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
|
||||||
for key, value in settings:
|
|
||||||
if key.startswith(config_prefix):
|
|
||||||
config_name = key[len(settings_prefix) :].lower()
|
|
||||||
config[config_name] = value
|
|
||||||
|
|
||||||
return cls._registry[name](**config | kwargs)
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
"""
|
|
||||||
File diarization implementation using the GPU service from modal.com
|
|
||||||
|
|
||||||
API will be a POST request to DIARIZATION_URL:
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /diarize?audio_file_url=...×tamp=0
|
|
||||||
Authorization: Bearer <modal_api_key>
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from reflector.processors.file_diarization import (
|
|
||||||
FileDiarizationInput,
|
|
||||||
FileDiarizationOutput,
|
|
||||||
FileDiarizationProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
|
||||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
if not settings.DIARIZATION_URL:
|
|
||||||
raise Exception(
|
|
||||||
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
|
||||||
)
|
|
||||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
|
||||||
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
|
||||||
self.modal_api_key = modal_api_key
|
|
||||||
|
|
||||||
async def _diarize(self, data: FileDiarizationInput):
|
|
||||||
"""Get speaker diarization for file"""
|
|
||||||
self.logger.info(f"Starting diarization from {data.audio_url}")
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if self.modal_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.diarization_url,
|
|
||||||
headers=headers,
|
|
||||||
params={
|
|
||||||
"audio_file_url": data.audio_url,
|
|
||||||
"timestamp": 0,
|
|
||||||
},
|
|
||||||
follow_redirects=True,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
diarization_data = response.json()["diarization"]
|
|
||||||
|
|
||||||
return FileDiarizationOutput(diarization=diarization_data)
|
|
||||||
|
|
||||||
|
|
||||||
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
from prometheus_client import Counter, Histogram
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import Transcript
|
|
||||||
|
|
||||||
|
|
||||||
class FileTranscriptInput:
|
|
||||||
"""Input for file transcription containing audio URL and language settings"""
|
|
||||||
|
|
||||||
def __init__(self, audio_url: str, language: str = "en"):
|
|
||||||
self.audio_url = audio_url
|
|
||||||
self.language = language
|
|
||||||
|
|
||||||
|
|
||||||
class FileTranscriptProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Transcript complete audio files from URL
|
|
||||||
"""
|
|
||||||
|
|
||||||
INPUT_TYPE = FileTranscriptInput
|
|
||||||
OUTPUT_TYPE = Transcript
|
|
||||||
|
|
||||||
m_transcript = Histogram(
|
|
||||||
"file_transcript",
|
|
||||||
"Time spent in FileTranscript.transcript",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_transcript_call = Counter(
|
|
||||||
"file_transcript_call",
|
|
||||||
"Number of calls to FileTranscript.transcript",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_transcript_success = Counter(
|
|
||||||
"file_transcript_success",
|
|
||||||
"Number of successful calls to FileTranscript.transcript",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_transcript_failure = Counter(
|
|
||||||
"file_transcript_failure",
|
|
||||||
"Number of failed calls to FileTranscript.transcript",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
name = self.__class__.__name__
|
|
||||||
self.m_transcript = self.m_transcript.labels(name)
|
|
||||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
|
||||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
|
||||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
async def _push(self, data: FileTranscriptInput):
|
|
||||||
try:
|
|
||||||
self.m_transcript_call.inc()
|
|
||||||
with self.m_transcript.time():
|
|
||||||
result = await self._transcript(data)
|
|
||||||
self.m_transcript_success.inc()
|
|
||||||
if result:
|
|
||||||
await self.emit(result)
|
|
||||||
except Exception:
|
|
||||||
self.m_transcript_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _transcript(self, data: FileTranscriptInput):
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import importlib
|
|
||||||
|
|
||||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
|
||||||
_registry = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, kclass):
|
|
||||||
cls._registry[name] = kclass
|
|
||||||
|
|
||||||
def __new__(cls, name: str | None = None, **kwargs):
|
|
||||||
if name is None:
|
|
||||||
name = settings.TRANSCRIPT_BACKEND
|
|
||||||
if name not in cls._registry:
|
|
||||||
module_name = f"reflector.processors.file_transcript_{name}"
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# gather specific configuration for the processor
|
|
||||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
|
||||||
config = {}
|
|
||||||
name_upper = name.upper()
|
|
||||||
settings_prefix = "TRANSCRIPT_"
|
|
||||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
|
||||||
for key, value in settings:
|
|
||||||
if key.startswith(config_prefix):
|
|
||||||
config_name = key[len(settings_prefix) :].lower()
|
|
||||||
config[config_name] = value
|
|
||||||
|
|
||||||
return cls._registry[name](**config | kwargs)
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
"""
|
|
||||||
File transcription implementation using the GPU service from modal.com
|
|
||||||
|
|
||||||
API will be a POST request to TRANSCRIPT_URL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"audio_file_url": "https://...",
|
|
||||||
"language": "en",
|
|
||||||
"model": "parakeet-tdt-0.6b-v2",
|
|
||||||
"batch": true
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from reflector.processors.file_transcript import (
|
|
||||||
FileTranscriptInput,
|
|
||||||
FileTranscriptProcessor,
|
|
||||||
)
|
|
||||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
|
||||||
from reflector.processors.types import Transcript, Word
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
|
||||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
if not settings.TRANSCRIPT_URL:
|
|
||||||
raise Exception(
|
|
||||||
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
|
||||||
)
|
|
||||||
self.transcript_url = settings.TRANSCRIPT_URL
|
|
||||||
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
|
||||||
self.modal_api_key = modal_api_key
|
|
||||||
|
|
||||||
async def _transcript(self, data: FileTranscriptInput):
|
|
||||||
"""Send full file to Modal for transcription"""
|
|
||||||
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
|
||||||
|
|
||||||
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if self.modal_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
url,
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"audio_file_url": data.audio_url,
|
|
||||||
"language": data.language,
|
|
||||||
"batch": True,
|
|
||||||
},
|
|
||||||
follow_redirects=True,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
words = [
|
|
||||||
Word(
|
|
||||||
text=word_info["word"],
|
|
||||||
start=word_info["start"],
|
|
||||||
end=word_info["end"],
|
|
||||||
)
|
|
||||||
for word_info in result.get("words", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
# words come not in order
|
|
||||||
words.sort(key=lambda w: w.start)
|
|
||||||
|
|
||||||
return Transcript(words=words)
|
|
||||||
|
|
||||||
|
|
||||||
# Register with the auto processor
|
|
||||||
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""
|
|
||||||
Processor to assemble transcript with diarization results
|
|
||||||
"""
|
|
||||||
|
|
||||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
|
||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import DiarizationSegment, Transcript
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptDiarizationAssemblerInput:
|
|
||||||
"""Input containing transcript and diarization data"""
|
|
||||||
|
|
||||||
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
|
||||||
self.transcript = transcript
|
|
||||||
self.diarization = diarization
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptDiarizationAssemblerProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Assemble transcript with diarization results by applying speaker assignments
|
|
||||||
"""
|
|
||||||
|
|
||||||
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
|
||||||
OUTPUT_TYPE = Transcript
|
|
||||||
|
|
||||||
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
|
||||||
result = await self._assemble(data)
|
|
||||||
if result:
|
|
||||||
await self.emit(result)
|
|
||||||
|
|
||||||
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
|
||||||
"""Apply diarization to transcript words"""
|
|
||||||
if not data.diarization:
|
|
||||||
self.logger.info(
|
|
||||||
"No diarization data provided, returning original transcript"
|
|
||||||
)
|
|
||||||
return data.transcript
|
|
||||||
|
|
||||||
# Reuse logic from AudioDiarizationProcessor
|
|
||||||
processor = AudioDiarizationProcessor()
|
|
||||||
words = data.transcript.words
|
|
||||||
processor.assign_speaker(words, data.diarization)
|
|
||||||
|
|
||||||
self.logger.info(f"Applied diarization to {len(words)} words")
|
|
||||||
return data.transcript
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user