mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
5 Commits
v0.12.0
...
mathieu/ca
| Author | SHA1 | Date | |
|---|---|---|---|
| 311d453e41 | |||
| f286f0882c | |||
| ffcafb3bf2 | |||
| 27075d840c | |||
| 30b5cd45e3 |
5
.github/workflows/db_migrations.yml
vendored
5
.github/workflows/db_migrations.yml
vendored
@@ -2,8 +2,6 @@ name: Test Database Migrations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/migrations/**"
|
||||
- "server/reflector/db/**"
|
||||
@@ -19,9 +17,6 @@ on:
|
||||
jobs:
|
||||
test-migrations:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: db-ubuntu-latest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
|
||||
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,30 +8,18 @@ env:
|
||||
ECR_REPOSITORY: reflector
|
||||
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
runner: linux-amd64
|
||||
arch: amd64
|
||||
- platform: linux/arm64
|
||||
runner: linux-arm64
|
||||
arch: arm64
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
outputs:
|
||||
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
@@ -39,52 +27,21 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- name: Build and push ${{ matrix.arch }}
|
||||
uses: docker/build-push-action@v5
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: server
|
||||
platforms: ${{ matrix.platform }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
provenance: false
|
||||
|
||||
create-manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build]
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Create and push multi-arch manifest
|
||||
run: |
|
||||
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||
|
||||
docker manifest create \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||
|
||||
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||
|
||||
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
45
.github/workflows/test_next_server.yml
vendored
45
.github/workflows/test_next_server.yml
vendored
@@ -1,45 +0,0 @@
|
||||
name: Test Next Server
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "www/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "www/**"
|
||||
|
||||
jobs:
|
||||
test-next-server:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./www
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8
|
||||
|
||||
- name: Setup Node.js cache
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: './www/pnpm-lock.yaml'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test
|
||||
49
.github/workflows/test_server.yml
vendored
49
.github/workflows/test_server.yml
vendored
@@ -5,17 +5,12 @@ on:
|
||||
paths:
|
||||
- "server/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/**"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: pytest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
redis:
|
||||
image: redis:6
|
||||
@@ -24,47 +19,29 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: server
|
||||
|
||||
- name: Tests
|
||||
run: |
|
||||
cd server
|
||||
uv run -m pytest -v tests
|
||||
|
||||
docker-amd64:
|
||||
runs-on: linux-amd64
|
||||
concurrency:
|
||||
group: docker-amd64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build AMD64
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64
|
||||
cache-from: type=gha,scope=amd64
|
||||
cache-to: type=gha,mode=max,scope=amd64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
docker-arm64:
|
||||
runs-on: linux-arm64
|
||||
concurrency:
|
||||
group: docker-arm64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build ARM64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/arm64
|
||||
cache-from: type=gha,scope=arm64
|
||||
cache-to: type=gha,mode=max,scope=arm64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -14,7 +14,4 @@ data/
|
||||
www/REFACTOR.md
|
||||
www/reload-frontend
|
||||
server/test.sqlite
|
||||
CLAUDE.local.md
|
||||
www/.env.development
|
||||
www/.env.production
|
||||
.playwright-mcp
|
||||
CLAUDE.local.md
|
||||
@@ -1 +0,0 @@
|
||||
b9d891d3424f371642cb032ecfd0e2564470a72c:server/tests/test_transcripts_recording_deletion.py:generic-api-key:15
|
||||
@@ -27,8 +27,3 @@ repos:
|
||||
files: ^server/
|
||||
- id: ruff-format
|
||||
files: ^server/
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
127
CHANGELOG.md
127
CHANGELOG.md
@@ -1,132 +1,5 @@
|
||||
# Changelog
|
||||
|
||||
## [0.12.0](https://github.com/Monadical-SAS/reflector/compare/v0.11.0...v0.12.0) (2025-09-17)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* calendar integration ([#608](https://github.com/Monadical-SAS/reflector/issues/608)) ([6f680b5](https://github.com/Monadical-SAS/reflector/commit/6f680b57954c688882c4ed49f40f161c52a00a24))
|
||||
* self-hosted gpu api ([#636](https://github.com/Monadical-SAS/reflector/issues/636)) ([ab859d6](https://github.com/Monadical-SAS/reflector/commit/ab859d65a6bded904133a163a081a651b3938d42))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* ignore player hotkeys for text inputs ([#646](https://github.com/Monadical-SAS/reflector/issues/646)) ([fa049e8](https://github.com/Monadical-SAS/reflector/commit/fa049e8d068190ce7ea015fd9fcccb8543f54a3f))
|
||||
|
||||
## [0.11.0](https://github.com/Monadical-SAS/reflector/compare/v0.10.0...v0.11.0) (2025-09-16)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* remove profanity filter that was there for conference ([#652](https://github.com/Monadical-SAS/reflector/issues/652)) ([b42f7cf](https://github.com/Monadical-SAS/reflector/commit/b42f7cfc606783afcee792590efcc78b507468ab))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* zulip and consent handler on the file pipeline ([#645](https://github.com/Monadical-SAS/reflector/issues/645)) ([5f143fe](https://github.com/Monadical-SAS/reflector/commit/5f143fe3640875dcb56c26694254a93189281d17))
|
||||
* zulip stream and topic selection in share dialog ([#644](https://github.com/Monadical-SAS/reflector/issues/644)) ([c546e69](https://github.com/Monadical-SAS/reflector/commit/c546e69739e68bb74fbc877eb62609928e5b8de6))
|
||||
|
||||
## [0.10.0](https://github.com/Monadical-SAS/reflector/compare/v0.9.0...v0.10.0) (2025-09-11)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* replace nextjs-config with environment variables ([#632](https://github.com/Monadical-SAS/reflector/issues/632)) ([369ecdf](https://github.com/Monadical-SAS/reflector/commit/369ecdff13f3862d926a9c0b87df52c9d94c4dde))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* anonymous users transcript permissions ([#621](https://github.com/Monadical-SAS/reflector/issues/621)) ([f81fe99](https://github.com/Monadical-SAS/reflector/commit/f81fe9948a9237b3e0001b2d8ca84f54d76878f9))
|
||||
* auth post ([#624](https://github.com/Monadical-SAS/reflector/issues/624)) ([cde99ca](https://github.com/Monadical-SAS/reflector/commit/cde99ca2716f84ba26798f289047732f0448742e))
|
||||
* auth post ([#626](https://github.com/Monadical-SAS/reflector/issues/626)) ([3b85ff3](https://github.com/Monadical-SAS/reflector/commit/3b85ff3bdf4fb053b103070646811bc990c0e70a))
|
||||
* auth post ([#627](https://github.com/Monadical-SAS/reflector/issues/627)) ([962038e](https://github.com/Monadical-SAS/reflector/commit/962038ee3f2a555dc3c03856be0e4409456e0996))
|
||||
* missing follow_redirects=True on modal endpoint ([#630](https://github.com/Monadical-SAS/reflector/issues/630)) ([fc363bd](https://github.com/Monadical-SAS/reflector/commit/fc363bd49b17b075e64f9186e5e0185abc325ea7))
|
||||
* sync backend and frontend token refresh logic ([#614](https://github.com/Monadical-SAS/reflector/issues/614)) ([5a5b323](https://github.com/Monadical-SAS/reflector/commit/5a5b3233820df9536da75e87ce6184a983d4713a))
|
||||
|
||||
## [0.9.0](https://github.com/Monadical-SAS/reflector/compare/v0.8.2...v0.9.0) (2025-09-06)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* frontend openapi react query ([#606](https://github.com/Monadical-SAS/reflector/issues/606)) ([c4d2825](https://github.com/Monadical-SAS/reflector/commit/c4d2825c81f81ad8835629fbf6ea8c7383f8c31b))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* align whisper transcriber api with parakeet ([#602](https://github.com/Monadical-SAS/reflector/issues/602)) ([0663700](https://github.com/Monadical-SAS/reflector/commit/0663700a615a4af69a03c96c410f049e23ec9443))
|
||||
* kv use tls explicit ([#610](https://github.com/Monadical-SAS/reflector/issues/610)) ([08d88ec](https://github.com/Monadical-SAS/reflector/commit/08d88ec349f38b0d13e0fa4cb73486c8dfd31836))
|
||||
* source kind for file processing ([#601](https://github.com/Monadical-SAS/reflector/issues/601)) ([dc82f8b](https://github.com/Monadical-SAS/reflector/commit/dc82f8bb3bdf3ab3d4088e592a30fd63907319e1))
|
||||
* token refresh locking ([#613](https://github.com/Monadical-SAS/reflector/issues/613)) ([7f5a4c9](https://github.com/Monadical-SAS/reflector/commit/7f5a4c9ddc7fd098860c8bdda2ca3b57f63ded2f))
|
||||
|
||||
## [0.8.2](https://github.com/Monadical-SAS/reflector/compare/v0.8.1...v0.8.2) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* search-logspam ([#593](https://github.com/Monadical-SAS/reflector/issues/593)) ([695d1a9](https://github.com/Monadical-SAS/reflector/commit/695d1a957d4cd862753049f9beed88836cabd5ab))
|
||||
|
||||
## [0.8.1](https://github.com/Monadical-SAS/reflector/compare/v0.8.0...v0.8.1) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* make webhook secret/url allowing null ([#590](https://github.com/Monadical-SAS/reflector/issues/590)) ([84a3812](https://github.com/Monadical-SAS/reflector/commit/84a381220bc606231d08d6f71d4babc818fa3c75))
|
||||
|
||||
## [0.8.0](https://github.com/Monadical-SAS/reflector/compare/v0.7.3...v0.8.0) (2025-08-29)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* **cleanup:** add automatic data retention for public instances ([#574](https://github.com/Monadical-SAS/reflector/issues/574)) ([6f0c7c1](https://github.com/Monadical-SAS/reflector/commit/6f0c7c1a5e751713366886c8e764c2009e12ba72))
|
||||
* **rooms:** add webhook for transcript completion ([#578](https://github.com/Monadical-SAS/reflector/issues/578)) ([88ed7cf](https://github.com/Monadical-SAS/reflector/commit/88ed7cfa7804794b9b54cad4c3facc8a98cf85fd))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* file pipeline status reporting and websocket updates ([#589](https://github.com/Monadical-SAS/reflector/issues/589)) ([9dfd769](https://github.com/Monadical-SAS/reflector/commit/9dfd76996f851cc52be54feea078adbc0816dc57))
|
||||
* Igor/evaluation ([#575](https://github.com/Monadical-SAS/reflector/issues/575)) ([124ce03](https://github.com/Monadical-SAS/reflector/commit/124ce03bf86044c18313d27228a25da4bc20c9c5))
|
||||
* optimize parakeet transcription batching algorithm ([#577](https://github.com/Monadical-SAS/reflector/issues/577)) ([7030e0f](https://github.com/Monadical-SAS/reflector/commit/7030e0f23649a8cf6c1eb6d5889684a41ce849ec))
|
||||
|
||||
## [0.7.3](https://github.com/Monadical-SAS/reflector/compare/v0.7.2...v0.7.3) (2025-08-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* cleaned repo, and get git-leaks clean ([359280d](https://github.com/Monadical-SAS/reflector/commit/359280dd340433ba4402ed69034094884c825e67))
|
||||
* restore previous behavior on live pipeline + audio downscaler ([#561](https://github.com/Monadical-SAS/reflector/issues/561)) ([9265d20](https://github.com/Monadical-SAS/reflector/commit/9265d201b590d23c628c5f19251b70f473859043))
|
||||
|
||||
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
|
||||
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
|
||||
|
||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
||||
|
||||
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
||||
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
||||
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
||||
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
||||
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
||||
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
||||
|
||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ pnpm install
|
||||
|
||||
# Copy configuration templates
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
```
|
||||
|
||||
**Development:**
|
||||
|
||||
497
ICS_IMPLEMENTATION.md
Normal file
497
ICS_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,497 @@
|
||||
# ICS Calendar Integration - Implementation Guide
|
||||
|
||||
## Overview
|
||||
This document provides detailed implementation guidance for integrating ICS calendar feeds with Reflector rooms. Unlike CalDAV which requires complex authentication and protocol handling, ICS integration uses simple HTTP(S) fetching of calendar files.
|
||||
|
||||
## Key Differences from CalDAV Approach
|
||||
|
||||
| Aspect | CalDAV | ICS |
|
||||
|--------|--------|-----|
|
||||
| Protocol | WebDAV extension | HTTP/HTTPS GET |
|
||||
| Authentication | Username/password, OAuth | Tokens embedded in URL |
|
||||
| Data Access | Selective event queries | Full calendar download |
|
||||
| Implementation | Complex (caldav library) | Simple (requests + icalendar) |
|
||||
| Real-time Updates | Supported | Polling only |
|
||||
| Write Access | Yes | No (read-only) |
|
||||
|
||||
## Technical Architecture
|
||||
|
||||
### 1. ICS Fetching Service
|
||||
|
||||
```python
|
||||
# reflector/services/ics_sync.py
|
||||
|
||||
import requests
|
||||
from icalendar import Calendar
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class ICSFetchService:
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'User-Agent': 'Reflector/1.0'})
|
||||
|
||||
def fetch_ics(self, url: str) -> str:
|
||||
"""Fetch ICS file from URL (authentication via URL token if needed)."""
|
||||
response = self.session.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
def parse_ics(self, ics_content: str) -> Calendar:
|
||||
"""Parse ICS content into calendar object."""
|
||||
return Calendar.from_ical(ics_content)
|
||||
|
||||
def extract_room_events(self, calendar: Calendar, room_url: str) -> List[dict]:
|
||||
"""Extract events that match the room URL."""
|
||||
events = []
|
||||
|
||||
for component in calendar.walk():
|
||||
if component.name == "VEVENT":
|
||||
# Check if event matches this room
|
||||
if self._event_matches_room(component, room_url):
|
||||
events.append(self._parse_event(component))
|
||||
|
||||
return events
|
||||
|
||||
def _event_matches_room(self, event, room_url: str) -> bool:
|
||||
"""Check if event location or description contains room URL."""
|
||||
location = str(event.get('LOCATION', ''))
|
||||
description = str(event.get('DESCRIPTION', ''))
|
||||
|
||||
# Support various URL formats
|
||||
patterns = [
|
||||
room_url,
|
||||
room_url.replace('https://', ''),
|
||||
room_url.split('/')[-1], # Just room name
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in location or pattern in description:
|
||||
return True
|
||||
|
||||
return False
|
||||
```
|
||||
|
||||
### 2. Database Schema
|
||||
|
||||
```sql
|
||||
-- Modify room table
|
||||
ALTER TABLE room ADD COLUMN ics_url TEXT; -- encrypted to protect embedded tokens
|
||||
ALTER TABLE room ADD COLUMN ics_fetch_interval INTEGER DEFAULT 300; -- seconds
|
||||
ALTER TABLE room ADD COLUMN ics_enabled BOOLEAN DEFAULT FALSE;
|
||||
ALTER TABLE room ADD COLUMN ics_last_sync TIMESTAMP;
|
||||
ALTER TABLE room ADD COLUMN ics_last_etag TEXT; -- for caching
|
||||
|
||||
-- Calendar events table
|
||||
CREATE TABLE calendar_event (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
room_id UUID REFERENCES room(id) ON DELETE CASCADE,
|
||||
external_id TEXT NOT NULL, -- ICS UID
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
start_time TIMESTAMP NOT NULL,
|
||||
end_time TIMESTAMP NOT NULL,
|
||||
attendees JSONB,
|
||||
location TEXT,
|
||||
ics_raw_data TEXT, -- Store raw VEVENT for reference
|
||||
last_synced TIMESTAMP DEFAULT NOW(),
|
||||
is_deleted BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
UNIQUE(room_id, external_id)
|
||||
);
|
||||
|
||||
-- Index for efficient queries
|
||||
CREATE INDEX idx_calendar_event_room_start ON calendar_event(room_id, start_time);
|
||||
CREATE INDEX idx_calendar_event_deleted ON calendar_event(is_deleted) WHERE NOT is_deleted;
|
||||
```
|
||||
|
||||
### 3. Background Tasks
|
||||
|
||||
```python
|
||||
# reflector/worker/tasks/ics_sync.py
|
||||
|
||||
from celery import shared_task
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
|
||||
@shared_task
|
||||
def sync_ics_calendars():
|
||||
"""Sync all enabled ICS calendars based on their fetch intervals."""
|
||||
rooms = Room.query.filter_by(ics_enabled=True).all()
|
||||
|
||||
for room in rooms:
|
||||
# Check if it's time to sync based on fetch interval
|
||||
if should_sync(room):
|
||||
sync_room_calendar.delay(room.id)
|
||||
|
||||
@shared_task
|
||||
def sync_room_calendar(room_id: str):
|
||||
"""Sync calendar for a specific room."""
|
||||
room = Room.query.get(room_id)
|
||||
if not room or not room.ics_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Fetch ICS file (decrypt URL first)
|
||||
service = ICSFetchService()
|
||||
decrypted_url = decrypt_ics_url(room.ics_url)
|
||||
ics_content = service.fetch_ics(decrypted_url)
|
||||
|
||||
# Check if content changed (using ETag or hash)
|
||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||
if room.ics_last_etag == content_hash:
|
||||
logger.info(f"No changes in ICS for room {room_id}")
|
||||
return
|
||||
|
||||
# Parse and extract events
|
||||
calendar = service.parse_ics(ics_content)
|
||||
events = service.extract_room_events(calendar, room.url)
|
||||
|
||||
# Update database
|
||||
sync_events_to_database(room_id, events)
|
||||
|
||||
# Update sync metadata
|
||||
room.ics_last_sync = datetime.utcnow()
|
||||
room.ics_last_etag = content_hash
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync ICS for room {room_id}: {e}")
|
||||
|
||||
def should_sync(room) -> bool:
|
||||
"""Check if room calendar should be synced."""
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.utcnow() - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
```
|
||||
|
||||
### 4. Celery Beat Schedule
|
||||
|
||||
```python
|
||||
# reflector/worker/celeryconfig.py
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
beat_schedule = {
|
||||
'sync-ics-calendars': {
|
||||
'task': 'reflector.worker.tasks.ics_sync.sync_ics_calendars',
|
||||
'schedule': 60.0, # Check every minute which calendars need syncing
|
||||
},
|
||||
'pre-create-meetings': {
|
||||
'task': 'reflector.worker.tasks.ics_sync.pre_create_calendar_meetings',
|
||||
'schedule': 60.0, # Check every minute for upcoming meetings
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Room ICS Configuration
|
||||
|
||||
```python
|
||||
# PATCH /v1/rooms/{room_id}
|
||||
{
|
||||
"ics_url": "https://calendar.google.com/calendar/ical/.../private-token/basic.ics",
|
||||
"ics_fetch_interval": 300, # seconds
|
||||
"ics_enabled": true
|
||||
# URL will be encrypted in database to protect embedded tokens
|
||||
}
|
||||
```
|
||||
|
||||
### Manual Sync Trigger
|
||||
|
||||
```python
|
||||
# POST /v1/rooms/{room_name}/ics/sync
|
||||
# Response:
|
||||
{
|
||||
"status": "syncing",
|
||||
"last_sync": "2024-01-15T10:30:00Z",
|
||||
"events_found": 5
|
||||
}
|
||||
```
|
||||
|
||||
### ICS Status
|
||||
|
||||
```python
|
||||
# GET /v1/rooms/{room_name}/ics/status
|
||||
# Response:
|
||||
{
|
||||
"enabled": true,
|
||||
"last_sync": "2024-01-15T10:30:00Z",
|
||||
"next_sync": "2024-01-15T10:35:00Z",
|
||||
"fetch_interval": 300,
|
||||
"events_count": 12,
|
||||
"upcoming_events": 3
|
||||
}
|
||||
```
|
||||
|
||||
## ICS Parsing Details
|
||||
|
||||
### Event Field Mapping
|
||||
|
||||
| ICS Field | Database Field | Notes |
|
||||
|-----------|---------------|-------|
|
||||
| UID | external_id | Unique identifier |
|
||||
| SUMMARY | title | Event title |
|
||||
| DESCRIPTION | description | Full description |
|
||||
| DTSTART | start_time | Convert to UTC |
|
||||
| DTEND | end_time | Convert to UTC |
|
||||
| LOCATION | location | Check for room URL |
|
||||
| ATTENDEE | attendees | Parse into JSON |
|
||||
| ORGANIZER | attendees | Add as organizer |
|
||||
| STATUS | (internal) | Filter cancelled events |
|
||||
|
||||
### Handling Recurring Events
|
||||
|
||||
```python
|
||||
def expand_recurring_events(event, start_date, end_date):
|
||||
"""Expand recurring events into individual occurrences."""
|
||||
from dateutil.rrule import rrulestr
|
||||
|
||||
if 'RRULE' not in event:
|
||||
return [event]
|
||||
|
||||
# Parse recurrence rule
|
||||
rrule_str = event['RRULE'].to_ical().decode()
|
||||
dtstart = event['DTSTART'].dt
|
||||
|
||||
# Generate occurrences
|
||||
rrule = rrulestr(rrule_str, dtstart=dtstart)
|
||||
occurrences = []
|
||||
|
||||
for dt in rrule.between(start_date, end_date):
|
||||
# Clone event with new date
|
||||
occurrence = event.copy()
|
||||
occurrence['DTSTART'].dt = dt
|
||||
if 'DTEND' in event:
|
||||
duration = event['DTEND'].dt - event['DTSTART'].dt
|
||||
occurrence['DTEND'].dt = dt + duration
|
||||
|
||||
# Unique ID for each occurrence
|
||||
occurrence['UID'] = f"{event['UID']}_{dt.isoformat()}"
|
||||
occurrences.append(occurrence)
|
||||
|
||||
return occurrences
|
||||
```
|
||||
|
||||
### Timezone Handling
|
||||
|
||||
```python
|
||||
def normalize_datetime(dt):
|
||||
"""Convert various datetime formats to UTC."""
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, 'dt'): # icalendar property
|
||||
dt = dt.dt
|
||||
|
||||
if isinstance(dt, datetime):
|
||||
if dt.tzinfo is None:
|
||||
# Assume local timezone if naive
|
||||
dt = pytz.timezone('UTC').localize(dt)
|
||||
else:
|
||||
# Convert to UTC
|
||||
dt = dt.astimezone(pytz.UTC)
|
||||
|
||||
return dt
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. URL Validation
|
||||
|
||||
```python
|
||||
def validate_ics_url(url: str) -> bool:
|
||||
"""Validate ICS URL for security."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Must be HTTPS in production
|
||||
if not settings.DEBUG and parsed.scheme != 'https':
|
||||
return False
|
||||
|
||||
# Prevent local file access
|
||||
if parsed.scheme in ('file', 'ftp'):
|
||||
return False
|
||||
|
||||
# Prevent internal network access
|
||||
if is_internal_ip(parsed.hostname):
|
||||
return False
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### 2. Rate Limiting
|
||||
|
||||
```python
|
||||
# Implement per-room rate limiting
|
||||
RATE_LIMITS = {
|
||||
'min_fetch_interval': 60, # Minimum 1 minute between fetches
|
||||
'max_requests_per_hour': 60, # Max 60 requests per hour per room
|
||||
'max_file_size': 10 * 1024 * 1024, # Max 10MB ICS file
|
||||
}
|
||||
```
|
||||
|
||||
### 3. ICS URL Encryption
|
||||
|
||||
```python
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
class URLEncryption:
|
||||
def __init__(self):
|
||||
self.cipher = Fernet(settings.ENCRYPTION_KEY)
|
||||
|
||||
def encrypt_url(self, url: str) -> str:
|
||||
"""Encrypt ICS URL to protect embedded tokens."""
|
||||
return self.cipher.encrypt(url.encode()).decode()
|
||||
|
||||
def decrypt_url(self, encrypted: str) -> str:
|
||||
"""Decrypt ICS URL for fetching."""
|
||||
return self.cipher.decrypt(encrypted.encode()).decode()
|
||||
|
||||
def mask_url(self, url: str) -> str:
|
||||
"""Mask sensitive parts of URL for display."""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
# Keep scheme, host, and path structure but mask tokens
|
||||
if '/private-' in parsed.path:
|
||||
# Google Calendar format
|
||||
parts = parsed.path.split('/private-')
|
||||
masked_path = parts[0] + '/private-***' + parts[1].split('/')[-1]
|
||||
elif 'token=' in url:
|
||||
# Query parameter token
|
||||
masked_path = parsed.path
|
||||
parsed = parsed._replace(query='token=***')
|
||||
else:
|
||||
# Generic masking of path segments that look like tokens
|
||||
import re
|
||||
masked_path = re.sub(r'/[a-zA-Z0-9]{20,}/', '/***/', parsed.path)
|
||||
|
||||
return urlunparse(parsed._replace(path=masked_path))
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### 1. Unit Tests
|
||||
|
||||
```python
|
||||
# tests/test_ics_sync.py
|
||||
|
||||
def test_ics_parsing():
|
||||
"""Test ICS file parsing."""
|
||||
ics_content = """BEGIN:VCALENDAR
|
||||
VERSION:2.0
|
||||
BEGIN:VEVENT
|
||||
UID:test-123
|
||||
SUMMARY:Team Meeting
|
||||
LOCATION:https://reflector.monadical.com/engineering
|
||||
DTSTART:20240115T100000Z
|
||||
DTEND:20240115T110000Z
|
||||
END:VEVENT
|
||||
END:VCALENDAR"""
|
||||
|
||||
service = ICSFetchService()
|
||||
calendar = service.parse_ics(ics_content)
|
||||
events = service.extract_room_events(
|
||||
calendar,
|
||||
"https://reflector.monadical.com/engineering"
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0]['title'] == 'Team Meeting'
|
||||
```
|
||||
|
||||
### 2. Integration Tests
|
||||
|
||||
```python
|
||||
def test_full_sync_flow():
|
||||
"""Test complete sync workflow."""
|
||||
# Create room with ICS URL (encrypt URL to protect tokens)
|
||||
encryption = URLEncryption()
|
||||
room = Room(
|
||||
name="test-room",
|
||||
ics_url=encryption.encrypt_url("https://example.com/calendar.ics?token=secret"),
|
||||
ics_enabled=True
|
||||
)
|
||||
|
||||
# Mock ICS fetch
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_get.return_value.text = sample_ics_content
|
||||
|
||||
# Run sync
|
||||
sync_room_calendar(room.id)
|
||||
|
||||
# Verify events created
|
||||
events = CalendarEvent.query.filter_by(room_id=room.id).all()
|
||||
assert len(events) > 0
|
||||
```
|
||||
|
||||
## Common ICS Provider Configurations
|
||||
|
||||
### Google Calendar
|
||||
- URL Format: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||
- Authentication via token embedded in URL
|
||||
- Updates every 3-8 hours by default
|
||||
|
||||
### Outlook/Office 365
|
||||
- URL Format: `https://outlook.office365.com/owa/calendar/{id}/calendar.ics`
|
||||
- May include token in URL path or query parameters
|
||||
- Real-time updates
|
||||
|
||||
### Apple iCloud
|
||||
- URL Format: `webcal://p{XX}-caldav.icloud.com/published/2/{token}`
|
||||
- Convert webcal:// to https://
|
||||
- Token embedded in URL path
|
||||
- Public calendars only
|
||||
|
||||
### Nextcloud/ownCloud
|
||||
- URL Format: `https://cloud.example.com/remote.php/dav/public-calendars/{token}`
|
||||
- Token embedded in URL path
|
||||
- Configurable update frequency
|
||||
|
||||
## Migration from CalDAV
|
||||
|
||||
If migrating from an existing CalDAV implementation:
|
||||
|
||||
1. **Database Migration**: Rename fields from `caldav_*` to `ics_*`
|
||||
2. **URL Conversion**: Most CalDAV servers provide ICS export endpoints
|
||||
3. **Authentication**: Convert from username/password to URL-embedded tokens
|
||||
4. **Remove Dependencies**: Uninstall caldav library, add icalendar
|
||||
5. **Update Background Tasks**: Replace CalDAV sync with ICS fetch
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **Caching**: Use ETag/Last-Modified headers to avoid refetching unchanged calendars
|
||||
2. **Incremental Sync**: Store last sync timestamp, only process new/modified events
|
||||
3. **Batch Processing**: Process multiple room calendars in parallel
|
||||
4. **Connection Pooling**: Reuse HTTP connections for multiple requests
|
||||
5. **Compression**: Support gzip encoding for large ICS files
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Metrics to Track
|
||||
- Sync success/failure rate per room
|
||||
- Average sync duration
|
||||
- ICS file sizes
|
||||
- Number of events processed
|
||||
- Failed event matches
|
||||
|
||||
### Debug Logging
|
||||
```python
|
||||
logger.debug(f"Fetching ICS from {room.ics_url}")
|
||||
logger.debug(f"ICS content size: {len(ics_content)} bytes")
|
||||
logger.debug(f"Found {len(events)} matching events")
|
||||
logger.debug(f"Event UIDs: {[e['external_id'] for e in events]}")
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
1. **SSL Certificate Errors**: Add certificate validation options
|
||||
2. **Timeout Issues**: Increase timeout for large calendars
|
||||
3. **Encoding Problems**: Handle various character encodings
|
||||
4. **Timezone Mismatches**: Always convert to UTC
|
||||
5. **Memory Issues**: Stream large ICS files instead of loading entirely
|
||||
337
PLAN.md
Normal file
337
PLAN.md
Normal file
@@ -0,0 +1,337 @@
|
||||
# ICS Calendar Integration Plan
|
||||
|
||||
## Core Concept
|
||||
ICS calendar URLs are attached to rooms (not users) to enable automatic meeting tracking and management through periodic fetching of calendar data.
|
||||
|
||||
## Database Schema Updates
|
||||
|
||||
### 1. Add ICS configuration to rooms
|
||||
- Add `ics_url` field to room table (URL to .ics file, may include auth token)
|
||||
- Add `ics_fetch_interval` field to room table (default: 5 minutes, configurable)
|
||||
- Add `ics_enabled` boolean field to room table
|
||||
- Add `ics_last_sync` timestamp field to room table
|
||||
|
||||
### 2. Create calendar_events table
|
||||
- `id` - UUID primary key
|
||||
- `room_id` - Foreign key to room
|
||||
- `external_id` - ICS event UID
|
||||
- `title` - Event title
|
||||
- `description` - Event description
|
||||
- `start_time` - Event start timestamp
|
||||
- `end_time` - Event end timestamp
|
||||
- `attendees` - JSON field with attendee list and status
|
||||
- `location` - Meeting location (should contain room name)
|
||||
- `last_synced` - Last sync timestamp
|
||||
- `is_deleted` - Boolean flag for soft delete (preserve past events)
|
||||
- `ics_raw_data` - TEXT field to store raw VEVENT data for reference
|
||||
|
||||
### 3. Update meeting table
|
||||
- Add `calendar_event_id` - Foreign key to calendar_events
|
||||
- Add `calendar_metadata` - JSON field for additional calendar data
|
||||
- Remove unique constraint on room_id + active status (allow multiple active meetings per room)
|
||||
|
||||
## Backend Implementation
|
||||
|
||||
### 1. ICS Sync Service
|
||||
- Create background task that runs based on room's `ics_fetch_interval` (default: 5 minutes)
|
||||
- For each room with ICS enabled, fetch the .ics file via HTTP/HTTPS
|
||||
- Parse ICS file using icalendar library
|
||||
- Extract VEVENT components and filter events looking for room URL (e.g., "https://reflector.monadical.com/max")
|
||||
- Store matching events in calendar_events table
|
||||
- Mark events as "upcoming" if start_time is within next 30 minutes
|
||||
- Pre-create Whereby meetings 1 minute before start (ensures no delay when users join)
|
||||
- Soft-delete future events that were removed from calendar (set is_deleted=true)
|
||||
- Never delete past events (preserve for historical record)
|
||||
- Support authenticated ICS feeds via tokens embedded in URL
|
||||
|
||||
### 2. Meeting Management Updates
|
||||
- Allow multiple active meetings per room
|
||||
- Pre-create meeting record 1 minute before calendar event starts (ensures meeting is ready)
|
||||
- Link meeting to calendar_event for metadata
|
||||
- Keep meeting active for 15 minutes after last participant leaves (grace period)
|
||||
- Don't auto-close if new participant joins within grace period
|
||||
|
||||
### 3. API Endpoints
|
||||
- `GET /v1/rooms/{room_name}/meetings` - List all active and upcoming meetings for a room
|
||||
- Returns filtered data based on user role (owner vs participant)
|
||||
- `GET /v1/rooms/{room_name}/meetings/upcoming` - List upcoming meetings (next 30 min)
|
||||
- Returns filtered data based on user role
|
||||
- `POST /v1/rooms/{room_name}/meetings/{meeting_id}/join` - Join specific meeting
|
||||
- `PATCH /v1/rooms/{room_id}` - Update room settings (including ICS configuration)
|
||||
- ICS fields only visible/editable by room owner
|
||||
- `POST /v1/rooms/{room_name}/ics/sync` - Trigger manual ICS sync
|
||||
- Only accessible by room owner
|
||||
- `GET /v1/rooms/{room_name}/ics/status` - Get ICS sync status and last fetch time
|
||||
- Only accessible by room owner
|
||||
|
||||
## Frontend Implementation
|
||||
|
||||
### 1. Room Settings Page
|
||||
- Add ICS configuration section
|
||||
- Field for ICS URL (e.g., Google Calendar private URL, Outlook ICS export)
|
||||
- Field for fetch interval (dropdown: 1 min, 5 min, 10 min, 30 min, 1 hour)
|
||||
- Test connection button (validates ICS file can be fetched and parsed)
|
||||
- Manual sync button
|
||||
- Show last sync time and next scheduled sync
|
||||
|
||||
### 2. Meeting Selection Page (New)
|
||||
- Show when accessing `/room/{room_name}`
|
||||
- **Host view** (room owner):
|
||||
- Full calendar event details
|
||||
- Meeting title and description
|
||||
- Complete attendee list with RSVP status
|
||||
- Number of current participants
|
||||
- Duration (how long it's been running)
|
||||
- **Participant view** (non-owners):
|
||||
- Meeting title only
|
||||
- Date and time
|
||||
- Number of current participants
|
||||
- Duration (how long it's been running)
|
||||
- No attendee list or description (privacy)
|
||||
- Display upcoming meetings (visible 30min before):
|
||||
- Show countdown to start
|
||||
- Can click to join early → redirected to waiting page
|
||||
- Waiting page shows countdown until meeting starts
|
||||
- Meeting pre-created by background task (ready when users arrive)
|
||||
- Option to create unscheduled meeting (uses existing flow)
|
||||
|
||||
### 3. Meeting Room Updates
|
||||
- Show calendar metadata in meeting info
|
||||
- Display invited attendees vs actual participants
|
||||
- Show meeting title from calendar event
|
||||
|
||||
## Meeting Lifecycle
|
||||
|
||||
### 1. Meeting Creation
|
||||
- Automatic: Pre-created 1 minute before calendar event starts (ensures Whereby room is ready)
|
||||
- Manual: User creates unscheduled meeting (existing `/rooms/{room_name}/meeting` endpoint)
|
||||
- Background task handles pre-creation to avoid delays when users join
|
||||
|
||||
### 2. Meeting Join Rules
|
||||
- Can join active meetings immediately
|
||||
- Can see upcoming meetings 30 minutes before start
|
||||
- Can click to join upcoming meetings early → sent to waiting page
|
||||
- Waiting page automatically transitions to meeting at scheduled time
|
||||
- Unscheduled meetings always joinable (current behavior)
|
||||
|
||||
### 3. Meeting Closure Rules
|
||||
- All meetings: 15-minute grace period after last participant leaves
|
||||
- If participant rejoins within grace period, keep meeting active
|
||||
- Calendar meetings: Force close 30 minutes after scheduled end time
|
||||
- Unscheduled meetings: Keep active for 8 hours (current behavior)
|
||||
|
||||
## ICS Parsing Logic
|
||||
|
||||
### 1. Event Matching
|
||||
- Parse ICS file using Python icalendar library
|
||||
- Iterate through VEVENT components
|
||||
- Check LOCATION field for full FQDN URL (e.g., "https://reflector.monadical.com/max")
|
||||
- Check DESCRIPTION for room URL or mention
|
||||
- Support multiple formats:
|
||||
- Full URL: "https://reflector.monadical.com/max"
|
||||
- With /room path: "https://reflector.monadical.com/room/max"
|
||||
- Partial paths: "room/max", "/max room"
|
||||
|
||||
### 2. Attendee Extraction
|
||||
- Parse ATTENDEE properties from VEVENT
|
||||
- Extract email (MAILTO), name (CN parameter), and RSVP status (PARTSTAT)
|
||||
- Store as JSON in calendar_events.attendees
|
||||
|
||||
### 3. Sync Strategy
|
||||
- Fetch complete ICS file (contains all events)
|
||||
- Filter events from (now - 1 hour) to (now + 24 hours) for processing
|
||||
- Update existing events if LAST-MODIFIED or SEQUENCE changed
|
||||
- Delete future events that no longer exist in ICS (start_time > now)
|
||||
- Keep past events for historical record (never delete if start_time < now)
|
||||
- Handle recurring events (RRULE) - expand to individual instances
|
||||
- Track deleted calendar events to clean up future meetings
|
||||
- Cache ICS file hash to detect changes and skip unnecessary processing
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. ICS URL Security
|
||||
- ICS URLs may contain authentication tokens (e.g., Google Calendar private URLs)
|
||||
- Store full ICS URLs encrypted using Fernet to protect embedded tokens
|
||||
- Validate ICS URLs (must be HTTPS for production)
|
||||
- Never expose full ICS URLs in API responses (return masked version)
|
||||
- Rate limit ICS fetching to prevent abuse
|
||||
|
||||
### 2. Room Access
|
||||
- Only room owner can configure ICS URL
|
||||
- ICS URL shown as masked version to room owner (hides embedded tokens)
|
||||
- ICS settings not visible to other users
|
||||
- Meeting list visible to all room participants
|
||||
- ICS fetch logs only visible to room owner
|
||||
|
||||
### 3. Meeting Privacy
|
||||
- Full calendar details visible only to room owner
|
||||
- Participants see limited info: title, date/time only
|
||||
- Attendee list and description hidden from non-owners
|
||||
- Meeting titles visible in room listing to all
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Database and ICS Setup (Week 1) ✅ COMPLETED (2025-08-18)
|
||||
1. ✅ Created database migrations for ICS fields and calendar_events table
|
||||
- Added ics_url, ics_fetch_interval, ics_enabled, ics_last_sync, ics_last_etag to room table
|
||||
- Created calendar_event table with ics_uid (instead of external_id) and proper typing
|
||||
- Added calendar_event_id and calendar_metadata (JSONB) to meeting table
|
||||
- Removed server_default from datetime fields for consistency
|
||||
2. ✅ Installed icalendar Python library for ICS parsing
|
||||
- Added icalendar>=6.0.0 to dependencies
|
||||
- No encryption needed - ICS URLs are read-only
|
||||
3. ✅ Built ICS fetch and sync service
|
||||
- Simple HTTP fetching without unnecessary validation
|
||||
- Proper TypedDict typing for event data structures
|
||||
- Supports any standard ICS format
|
||||
- Event matching on full room URL only
|
||||
4. ✅ API endpoints for ICS configuration
|
||||
- Room model updated to support ICS fields via existing PATCH endpoint
|
||||
- POST /v1/rooms/{room_name}/ics/sync - Trigger manual sync (owner only)
|
||||
- GET /v1/rooms/{room_name}/ics/status - Get sync status (owner only)
|
||||
- GET /v1/rooms/{room_name}/meetings - List meetings with privacy controls
|
||||
- GET /v1/rooms/{room_name}/meetings/upcoming - List upcoming meetings
|
||||
5. ✅ Celery background tasks for periodic sync
|
||||
- sync_room_ics - Sync individual room calendar
|
||||
- sync_all_ics_calendars - Check all rooms and queue sync based on fetch intervals
|
||||
- pre_create_upcoming_meetings - Pre-create Whereby meetings 1 minute before start
|
||||
- Tasks scheduled in beat schedule (every minute for checking, respects individual intervals)
|
||||
6. ✅ Tests written and passing
|
||||
- 6 tests for Room ICS fields
|
||||
- 7 tests for CalendarEvent model
|
||||
- 7 tests for ICS sync service
|
||||
- 11 tests for API endpoints
|
||||
- 6 tests for background tasks
|
||||
- All 31 ICS-related tests passing
|
||||
|
||||
### Phase 2: Meeting Management (Week 2) ✅ COMPLETED (2025-08-19)
|
||||
1. ✅ Updated meeting lifecycle logic with grace period support
|
||||
- 15-minute grace period after last participant leaves
|
||||
- Automatic reactivation when participants rejoin
|
||||
- Force close calendar meetings 30 minutes after scheduled end
|
||||
2. ✅ Support multiple active meetings per room
|
||||
- Removed unique constraint on active meetings
|
||||
- Added get_all_active_for_room() method
|
||||
- Added get_active_by_calendar_event() method
|
||||
3. ✅ Implemented grace period logic
|
||||
- Added last_participant_left_at and grace_period_minutes fields
|
||||
- Process meetings task handles grace period checking
|
||||
- Whereby webhooks clear grace period on participant join
|
||||
4. ✅ Link meetings to calendar events
|
||||
- Pre-created meetings properly linked via calendar_event_id
|
||||
- Calendar metadata stored with meeting
|
||||
- API endpoints for listing and joining specific meetings
|
||||
|
||||
### Phase 3: Frontend Meeting Selection (Week 3)
|
||||
1. Build meeting selection page
|
||||
2. Show active and upcoming meetings
|
||||
3. Implement waiting page for early joiners
|
||||
4. Add automatic transition from waiting to meeting
|
||||
5. Support unscheduled meeting creation
|
||||
|
||||
### Phase 4: Calendar Integration UI (Week 4)
|
||||
1. Add ICS settings to room configuration
|
||||
2. Display calendar metadata in meetings
|
||||
3. Show attendee information
|
||||
4. Add sync status indicators
|
||||
5. Show fetch interval and next sync time
|
||||
|
||||
## Success Metrics
|
||||
- Zero merged meetings from consecutive calendar events
|
||||
- Successful ICS sync from major providers (Google Calendar, Outlook, Apple Calendar, Nextcloud)
|
||||
- Meeting join accuracy: correct meeting 100% of the time
|
||||
- Grace period prevents 90% of accidental meeting closures
|
||||
- Configurable fetch intervals reduce unnecessary API calls
|
||||
|
||||
## Design Decisions
|
||||
1. **ICS attached to room, not user** - Prevents duplicate meetings from multiple calendars
|
||||
2. **Multiple active meetings per room** - Supported with meeting selection page
|
||||
3. **Grace period for rejoining** - 15 minutes after last participant leaves
|
||||
4. **Upcoming meeting visibility** - Show 30 minutes before, join only on time
|
||||
5. **Calendar data storage** - Attached to meeting record for full context
|
||||
6. **No "ad-hoc" meetings** - Use existing meeting creation flow (unscheduled meetings)
|
||||
7. **ICS configuration via room PATCH** - Reuse existing room configuration endpoint
|
||||
8. **Event deletion handling** - Soft-delete future events, preserve past meetings
|
||||
9. **Configurable fetch interval** - Balance between freshness and server load
|
||||
10. **ICS over CalDAV** - Simpler implementation, wider compatibility, no complex auth
|
||||
|
||||
## Phase 2 Implementation Files
|
||||
|
||||
### Database Migrations
|
||||
- `/server/migrations/versions/6025e9b2bef2_remove_one_active_meeting_per_room_.py` - Remove unique constraint
|
||||
- `/server/migrations/versions/d4a1c446458c_add_grace_period_fields_to_meeting.py` - Add grace period fields
|
||||
|
||||
### Updated Models
|
||||
- `/server/reflector/db/meetings.py` - Added grace period fields and new query methods
|
||||
|
||||
### Updated Services
|
||||
- `/server/reflector/worker/process.py` - Enhanced with grace period logic and multiple meeting support
|
||||
|
||||
### Updated API
|
||||
- `/server/reflector/views/rooms.py` - Added endpoints for listing active meetings and joining specific meetings
|
||||
- `/server/reflector/views/whereby.py` - Clear grace period on participant join
|
||||
|
||||
### Tests
|
||||
- `/server/tests/test_multiple_active_meetings.py` - Comprehensive tests for Phase 2 features (5 tests)
|
||||
|
||||
## Phase 1 Implementation Files Created
|
||||
|
||||
### Database Models
|
||||
- `/server/reflector/db/rooms.py` - Updated with ICS fields (url, fetch_interval, enabled, last_sync, etag)
|
||||
- `/server/reflector/db/calendar_events.py` - New CalendarEvent model with ics_uid and proper typing
|
||||
- `/server/reflector/db/meetings.py` - Updated with calendar_event_id and calendar_metadata (JSONB)
|
||||
|
||||
### Services
|
||||
- `/server/reflector/services/ics_sync.py` - ICS fetching and parsing with TypedDict for proper typing
|
||||
|
||||
### API Endpoints
|
||||
- `/server/reflector/views/rooms.py` - Added ICS management endpoints with privacy controls
|
||||
|
||||
### Background Tasks
|
||||
- `/server/reflector/worker/ics_sync.py` - Celery tasks for automatic periodic sync
|
||||
- `/server/reflector/worker/app.py` - Updated beat schedule for ICS tasks
|
||||
|
||||
### Tests
|
||||
- `/server/tests/test_room_ics.py` - Room model ICS fields tests (6 tests)
|
||||
- `/server/tests/test_calendar_event.py` - CalendarEvent model tests (7 tests)
|
||||
- `/server/tests/test_ics_sync.py` - ICS sync service tests (7 tests)
|
||||
- `/server/tests/test_room_ics_api.py` - API endpoint tests (11 tests)
|
||||
- `/server/tests/test_ics_background_tasks.py` - Background task tests (6 tests)
|
||||
|
||||
### Key Design Decisions
|
||||
- No encryption needed - ICS URLs are read-only access
|
||||
- Using ics_uid instead of external_id for clarity
|
||||
- Proper TypedDict typing for event data structures
|
||||
- Removed unnecessary URL validation and webcal handling
|
||||
- calendar_metadata in meetings stores flexible calendar data (organizer, recurrence, etc)
|
||||
- Background tasks query all rooms directly to avoid filtering issues
|
||||
- Sync intervals respected per-room configuration
|
||||
|
||||
## Implementation Approach
|
||||
|
||||
### ICS Fetching vs CalDAV
|
||||
- **ICS Benefits**:
|
||||
- Simpler implementation (HTTP GET vs CalDAV protocol)
|
||||
- Wider compatibility (all calendar apps can export ICS)
|
||||
- No authentication complexity (simple URL with optional token)
|
||||
- Easier debugging (ICS is plain text)
|
||||
- Lower server requirements (no CalDAV library dependencies)
|
||||
|
||||
### Supported Calendar Providers
|
||||
1. **Google Calendar**: Private ICS URL from calendar settings
|
||||
2. **Outlook/Office 365**: ICS export URL from calendar sharing
|
||||
3. **Apple Calendar**: Published calendar ICS URL
|
||||
4. **Nextcloud**: Public/private calendar ICS export
|
||||
5. **Any CalDAV server**: Via ICS export endpoint
|
||||
|
||||
### ICS URL Examples
|
||||
- Google: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||
- Outlook: `https://outlook.live.com/owa/calendar/{id}/calendar.ics`
|
||||
- Custom: `https://example.com/calendars/room-schedule.ics`
|
||||
|
||||
### Fetch Interval Configuration
|
||||
- 1 minute: For critical/high-activity rooms
|
||||
- 5 minutes (default): Balance of freshness and efficiency
|
||||
- 10 minutes: Standard meeting rooms
|
||||
- 30 minutes: Low-activity rooms
|
||||
- 1 hour: Rarely-used rooms or stable schedules
|
||||
81
README.md
81
README.md
@@ -1,60 +1,43 @@
|
||||
<div align="center">
|
||||
<img width="100" alt="image" src="https://github.com/user-attachments/assets/66fb367b-2c89-4516-9912-f47ac59c6a7f"/>
|
||||
|
||||
# Reflector
|
||||
|
||||
Reflector is an AI-powered audio transcription and meeting analysis platform that provides real-time transcription, speaker diarization, translation and summarization for audio content and live meetings. It works 100% with local models (whisper/parakeet, pyannote, seamless-m4t, and your local llm like phi-4).
|
||||
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/test_server.yml)
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/pytests.yml)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Screenshots
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97" />
|
||||
<a href="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c" />
|
||||
<a href="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4" />
|
||||
<a href="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc" />
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## What is Reflector?
|
||||
|
||||
Reflector is a web application that utilizes local models to process audio content, providing:
|
||||
|
||||
- **Real-time Transcription**: Convert speech to text using [Whisper](https://github.com/openai/whisper) (multi-language) or [Parakeet](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) models
|
||||
- **Speaker Diarization**: Identify and label different speakers using [Pyannote](https://github.com/pyannote/pyannote-audio) 3.1
|
||||
- **Live Translation**: Translate audio content in real-time to many languages with [Facebook Seamless-M4T](https://github.com/facebookresearch/seamless_communication)
|
||||
- **Topic Detection & Summarization**: Extract key topics and generate concise summaries using LLMs
|
||||
- **Meeting Recording**: Create permanent records of meetings with searchable transcripts
|
||||
|
||||
Currently we provide [modal.com](https://modal.com/) gpu template to deploy.
|
||||
|
||||
## Background
|
||||
|
||||
The project architecture consists of three primary components:
|
||||
|
||||
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. Most reliable option is Modal deployment
|
||||
|
||||
It also uses authentik for authentication if activated.
|
||||
It also uses authentik for authentication if activated, and Vercel for deployment and configuration of the front-end.
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
@@ -89,8 +72,6 @@ Note: We currently do not have instructions for Windows users.
|
||||
|
||||
## Installation
|
||||
|
||||
*Note: we're working toward better installation, theses instructions are not accurate for now*
|
||||
|
||||
### Frontend
|
||||
|
||||
Start with `cd www`.
|
||||
@@ -99,10 +80,11 @@ Start with `cd www`.
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
cp .env.example .env
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
```
|
||||
|
||||
Then, fill in the environment variables in `.env` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
|
||||
**Run in development mode**
|
||||
|
||||
@@ -167,34 +149,3 @@ You can manually process an audio file by calling the process tool:
|
||||
```bash
|
||||
uv run python -m reflector.tools.process path/to/audio.wav
|
||||
```
|
||||
|
||||
|
||||
## Feature Flags
|
||||
|
||||
Reflector uses environment variable-based feature flags to control application functionality. These flags allow you to enable or disable features without code changes.
|
||||
|
||||
### Available Feature Flags
|
||||
|
||||
| Feature Flag | Environment Variable |
|
||||
|-------------|---------------------|
|
||||
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
|
||||
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
|
||||
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
|
||||
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
|
||||
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
|
||||
|
||||
### Setting Feature Flags
|
||||
|
||||
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Enable user authentication requirement
|
||||
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
||||
|
||||
# Disable browse functionality
|
||||
NEXT_PUBLIC_FEATURE_BROWSE=false
|
||||
|
||||
# Enable Zulip integration
|
||||
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
||||
```
|
||||
|
||||
@@ -6,7 +6,6 @@ services:
|
||||
- 1250:1250
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -17,7 +16,6 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -28,7 +26,6 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
|
||||
33
gpu/modal_deployments/.gitignore
vendored
33
gpu/modal_deployments/.gitignore
vendored
@@ -1,33 +0,0 @@
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
.env.*
|
||||
*.env
|
||||
*.secret
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Modal local state (if any)
|
||||
modal_mounts/
|
||||
.modal_cache/
|
||||
@@ -1,171 +0,0 @@
|
||||
# Reflector GPU implementation - Transcription and LLM
|
||||
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
Create a modal secret, and name it `reflector-gpu`.
|
||||
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
||||
|
||||
The deployment is done using [Modal.com](https://modal.com) service.
|
||||
|
||||
```
|
||||
$ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_transcriber_parakeet.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
```
|
||||
|
||||
Then in your reflector api configuration `.env`, you can set these keys:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
||||
|
||||
```
|
||||
Authorization: bearer <REFLECTOR_APIKEY>
|
||||
```
|
||||
|
||||
### LLM
|
||||
|
||||
`POST /llm`
|
||||
|
||||
**request**
|
||||
```
|
||||
{
|
||||
"prompt": "xxx"
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx completed"
|
||||
}
|
||||
```
|
||||
|
||||
### Transcription
|
||||
|
||||
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||
|
||||
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||
|
||||
**GPU Configuration:**
|
||||
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||
- Higher concurrency (max_inputs=10)
|
||||
- Optimized for multiple small audio files
|
||||
- Supports batch processing for efficiency
|
||||
|
||||
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||
- Lower concurrency but more powerful processing
|
||||
- Optimized for single large audio files
|
||||
- VAD-based chunking for long-form audio
|
||||
|
||||
##### `/v1/audio/transcriptions` - Small file transcription
|
||||
|
||||
**request** (multipart/form-data)
|
||||
- `file` or `files[]` - audio file(s) to transcribe
|
||||
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||
- `language` - language code (default: `en`)
|
||||
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
For multiple files with batch=true:
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"filename": "audio1.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
},
|
||||
{
|
||||
"filename": "audio2.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||
|
||||
**request** (application/json)
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text from large file",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||
|
||||
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
- `file` - audio file
|
||||
- `language` - language code (e.g. `en`)
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx",
|
||||
"words": [
|
||||
{"text": "xxx", "start": 0.0, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -1,253 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - diarizer
|
||||
===================================
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
UPLOADS_PATH = "/uploads"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App(name="reflector-diarizer")
|
||||
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
print(f"Checking audio file at: {audio_file_url}")
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
print(f"Downloading audio file from: {audio_file_url}")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to download audio file: {response.status_code}",
|
||||
)
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
print(f"File saved as: {unique_filename}")
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
def download_pyannote_audio():
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
modal.Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
"onnx",
|
||||
"torchaudio",
|
||||
"onnxruntime-gpu",
|
||||
"torch==2.0.0",
|
||||
"transformers==4.34.0",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"numpy",
|
||||
"huggingface_hub",
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio,
|
||||
secrets=[modal.Secret.from_name("hf_token")],
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A100",
|
||||
timeout=60 * 30,
|
||||
image=diarizer_image,
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
secrets=[
|
||||
modal.Secret.from_name("hf_token"),
|
||||
],
|
||||
)
|
||||
@modal.concurrent(max_inputs=1)
|
||||
class Diarizer:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@modal.method()
|
||||
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||
import torchaudio
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
print(f"Diarizing audio from: {file_path}")
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
image=diarizer_image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = diarizerstub.diarize.spawn(
|
||||
filename=unique_filename, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
@@ -1,608 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
MINUTES = 60 # seconds
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/models"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
model_cache = modal.Volume.from_name("models", create_if_missing=True)
|
||||
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
model_cache.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
|
||||
|
||||
model_cache.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
"fastapi==0.115.12",
|
||||
"requests",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
)
|
||||
.run_function(download_model, volumes={CACHE_PATH: model_cache})
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
|
||||
|
||||
Whisper does not require this strictly, but aligning behavior with Parakeet
|
||||
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberWhisperLive:
|
||||
"""Live transcriber class for small audio segments (A10G).
|
||||
|
||||
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
|
||||
"""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
print(f"Model is on device: {self.device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe a single uploaded audio file by filename."""
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe multiple uploaded audio files and return per-file results."""
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start), 2),
|
||||
"end": round(float(w.end), 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=15 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
class TranscriberWhisperFile:
|
||||
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLERATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments as TimeSegment using Silero VAD."""
|
||||
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield TimeSegment(
|
||||
start / float(SAMPLERATE), end / float(SAMPLERATE)
|
||||
)
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
upload_volume.reload()
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
|
||||
# Batch segments up to ~30s windows by merging contiguous VAD segments
|
||||
merged_batches: list[TimeSegment] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for segment in vad_segments(audio_array):
|
||||
seg_start, seg_end = segment.start, segment.end
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
|
||||
for segment in merged_batches:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
s_idx = int(start_time * SAMPLERATE)
|
||||
e_idx = int(end_time * SAMPLERATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = pad_audio(segment, SAMPLERATE)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: dict) -> str:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
transcriber_live = TranscriberWhisperLive()
|
||||
transcriber_file = TranscriberWhisperFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(dict):
|
||||
pass
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_filenames: list[str] = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
if batch and len(upload_files) > 1:
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
language=language,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
@@ -1,658 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/cache"
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
app = modal.App("reflector-transcriber-parakeet")
|
||||
|
||||
# Volume for caching model weights
|
||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry(
|
||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"HF_HOME": "/cache",
|
||||
"DEBIAN_FRONTEND": "noninteractive",
|
||||
"CXX": "g++",
|
||||
"CC": "g++",
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"hf_transfer==0.1.9",
|
||||
"huggingface_hub[hf-xet]==0.31.2",
|
||||
"nemo_toolkit[asr]==2.3.0",
|
||||
"cuda-python==12.8.0",
|
||||
"fastapi==0.115.12",
|
||||
"numpy<2",
|
||||
"librosa==0.10.1",
|
||||
"requests",
|
||||
"silero-vad==5.1.0",
|
||||
"torch",
|
||||
)
|
||||
.entrypoint([]) # silence chatty logs by container on start
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||
|
||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||
have to be of the same length
|
||||
|
||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < 0.5:
|
||||
silence_samples = int(sample_rate * 0.5)
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=600,
|
||||
scaledown_window=300,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberParakeetLive:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||
|
||||
text = output.text.strip()
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
# XXX the space added here is to match the output of whisper
|
||||
# whisper add space to each words, while parakeet don't
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
audio_arrays = []
|
||||
|
||||
# Load all audio files with padding
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
audio_arrays.append(padded_audio)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||
|
||||
# Process results for each file
|
||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||
text = output.text.strip()
|
||||
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# L40S class for file transcription (bigger files)
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=900,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
class TranscriberParakeetFile:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
print("Silero VAD initialized")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
timestamp_offset: float = 0.0,
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def load_and_convert_audio(file_path):
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
return audio_array
|
||||
|
||||
def vad_segment_generator(
|
||||
audio_array,
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments using VAD with start/end sample indices"""
|
||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||
window_size = VAD_CONFIG["window_size"]
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
|
||||
speech_dict = vad_iterator(chunk)
|
||||
if not speech_dict:
|
||||
continue
|
||||
|
||||
if "start" in speech_dict:
|
||||
start = speech_dict["start"]
|
||||
continue
|
||||
|
||||
if "end" in speech_dict and start is not None:
|
||||
end = speech_dict["end"]
|
||||
start_time = start / float(SAMPLERATE)
|
||||
end_time = end / float(SAMPLERATE)
|
||||
|
||||
yield TimeSegment(start_time, end_time)
|
||||
start = None
|
||||
|
||||
vad_iterator.reset_states()
|
||||
|
||||
def batch_speech_segments(
|
||||
segments: Generator[TimeSegment, None, None], max_duration: int
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""
|
||||
Input segments:
|
||||
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
||||
|
||||
↓ (max_duration=10)
|
||||
|
||||
Output batches:
|
||||
[0-8] [10-19] [20-22]
|
||||
|
||||
Note: silences are kept for better transcription, previous implementation was
|
||||
passing segments separatly, but the output was less accurate.
|
||||
"""
|
||||
batch_start_time = None
|
||||
batch_end_time = None
|
||||
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
total_duration = end_time - batch_start_time
|
||||
|
||||
if total_duration <= max_duration:
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
return
|
||||
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
|
||||
def batch_segment_to_audio_segment(
|
||||
segments: Generator[TimeSegment, None, None],
|
||||
audio_array,
|
||||
) -> Generator[AudioSegment, None, None]:
|
||||
"""Extract audio segments and apply padding for Parakeet compatibility.
|
||||
|
||||
Uses pad_audio to ensure segments are at least 0.5s long, preventing
|
||||
Parakeet crashes. This padding may cause slight timing overlaps between
|
||||
segments, which are corrected by enforce_word_timing_constraints.
|
||||
"""
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
start_sample = int(start_time * SAMPLERATE)
|
||||
end_sample = int(end_time * SAMPLERATE)
|
||||
audio_segment = audio_array[start_sample:end_sample]
|
||||
|
||||
padded_segment = pad_audio(audio_segment, SAMPLERATE)
|
||||
|
||||
yield AudioSegment(start_time, end_time, padded_segment)
|
||||
|
||||
def transcribe_batch(model, audio_segments: list) -> list:
|
||||
with NoStdStreams():
|
||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||
return outputs
|
||||
|
||||
def enforce_word_timing_constraints(
|
||||
words: list[WordTiming],
|
||||
) -> list[WordTiming]:
|
||||
"""Enforce that word end times don't exceed the start time of the next word.
|
||||
|
||||
Due to silence padding added in batch_segment_to_audio_segment for better
|
||||
transcription accuracy, word timings from different segments may overlap.
|
||||
This function ensures there are no overlaps by adjusting end times.
|
||||
"""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
|
||||
enforced_words = []
|
||||
for i, word in enumerate(words):
|
||||
enforced_word = word.copy()
|
||||
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if enforced_word["end"] > next_start:
|
||||
enforced_word["end"] = next_start
|
||||
|
||||
enforced_words.append(enforced_word)
|
||||
|
||||
return enforced_words
|
||||
|
||||
def emit_results(
|
||||
results: list,
|
||||
segments_info: list[AudioSegment],
|
||||
) -> Generator[TranscriptResult, None, None]:
|
||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||
for i, (output, segment) in enumerate(zip(results, segments_info)):
|
||||
start_time, end_time = segment.start, segment.end
|
||||
text = output.text.strip()
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(
|
||||
word_info["start"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
end=round(word_info["end"] + start_time + timestamp_offset, 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
yield TranscriptResult(text, words)
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array = load_and_convert_audio(file_path)
|
||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||
|
||||
all_text_parts: list[str] = []
|
||||
all_words: list[WordTiming] = []
|
||||
|
||||
raw_segments = vad_segment_generator(audio_array)
|
||||
speech_segments = batch_speech_segments(
|
||||
raw_segments,
|
||||
VAD_CONFIG["batch_max_duration"],
|
||||
)
|
||||
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
||||
|
||||
for batch in audio_segments:
|
||||
audio_segment = batch.audio
|
||||
results = transcribe_batch(self.model, [audio_segment])
|
||||
|
||||
for result in emit_results(
|
||||
results,
|
||||
[batch],
|
||||
):
|
||||
if not result.text:
|
||||
continue
|
||||
all_text_parts.append(result.text)
|
||||
all_words.extend(result.words)
|
||||
|
||||
all_words = enforce_word_timing_constraints(all_words)
|
||||
|
||||
combined_text = " ".join(all_text_parts)
|
||||
return {"text": combined_text, "words": all_words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
transcriber_live = TranscriberParakeetLive()
|
||||
transcriber_file = TranscriberParakeetFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
# Handle both single file and multiple files
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
# Upload files to volume
|
||||
uploaded_filenames = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||
|
||||
# Generate unique filename
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
# Use A10G live transcriber for per-file transcription
|
||||
if batch and len(upload_files) > 1:
|
||||
# Use batch transcription
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
# Per-file transcription
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {filename}: {e}")
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
@@ -1,2 +0,0 @@
|
||||
REFLECTOR_GPU_APIKEY=
|
||||
HF_TOKEN=
|
||||
38
gpu/self_hosted/.gitignore
vendored
38
gpu/self_hosted/.gitignore
vendored
@@ -1,38 +0,0 @@
|
||||
cache/
|
||||
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
*.env
|
||||
*.secret
|
||||
HF_TOKEN
|
||||
REFLECTOR_GPU_APIKEY
|
||||
|
||||
# Virtual env / uv
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
uv/
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
@@ -1,46 +0,0 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
WORKDIR /tmp
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
ffmpeg \
|
||||
curl \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
wget \
|
||||
&& apt-get clean
|
||||
# Add NVIDIA CUDA repo for Debian 12 (bookworm) and install cuDNN 9 for CUDA 12
|
||||
ADD https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb /cuda-keyring.deb
|
||||
RUN dpkg -i /cuda-keyring.deb \
|
||||
&& rm /cuda-keyring.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-12-6 \
|
||||
libcublas-12-6 \
|
||||
libcudnn9-cuda-12 \
|
||||
libcudnn9-dev-cuda-12 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
|
||||
|
||||
COPY ./app /app/app
|
||||
COPY ./main.py /app/
|
||||
COPY ./runserver.sh /app/
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["sh", "/app/runserver.sh"]
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# Self-hosted Model API
|
||||
|
||||
Run transcription, translation, and diarization services compatible with Reflector's GPU Model API. Works on CPU or GPU.
|
||||
|
||||
Environment variables
|
||||
|
||||
- REFLECTOR_GPU_APIKEY: Optional Bearer token. If unset, auth is disabled.
|
||||
- HF_TOKEN: Optional. Required for diarization to download pyannote pipelines
|
||||
|
||||
Requirements
|
||||
|
||||
- FFmpeg must be installed and on PATH (used for URL-based and segmented transcription)
|
||||
- Python 3.12+
|
||||
- NVIDIA GPU optional. If available, it will be used automatically
|
||||
|
||||
Local run
|
||||
Set env vars in self_hosted/.env file
|
||||
uv sync
|
||||
|
||||
uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Authentication
|
||||
|
||||
- If REFLECTOR_GPU_APIKEY is set, include header: Authorization: Bearer <key>
|
||||
|
||||
Endpoints
|
||||
|
||||
- POST /v1/audio/transcriptions
|
||||
|
||||
- multipart/form-data
|
||||
- fields: file (single file) OR files[] (multiple files), language, batch (true/false)
|
||||
- response: single { text, words, filename } or { results: [ ... ] }
|
||||
|
||||
- POST /v1/audio/transcriptions-from-url
|
||||
|
||||
- application/json
|
||||
- body: { audio_file_url, language, timestamp_offset }
|
||||
- response: { text, words }
|
||||
|
||||
- POST /translate
|
||||
|
||||
- text: query parameter
|
||||
- body (application/json): { source_language, target_language }
|
||||
- response: { text: { <src>: original, <tgt>: translated } }
|
||||
|
||||
- POST /diarize
|
||||
- query parameters: audio_file_url, timestamp (optional)
|
||||
- requires HF_TOKEN to be set (for pyannote)
|
||||
- response: { diarization: [ { start, end, speaker } ] }
|
||||
|
||||
OpenAPI docs
|
||||
|
||||
- Visit /docs when the server is running
|
||||
|
||||
Docker
|
||||
|
||||
- Not yet provided in this directory. A Dockerfile will be added later. For now, use Local run above
|
||||
|
||||
Conformance tests
|
||||
|
||||
# From this directory
|
||||
|
||||
TRANSCRIPT_URL=http://localhost:8000 \
|
||||
TRANSCRIPT_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_transcript.py
|
||||
|
||||
TRANSLATION_URL=http://localhost:8000 \
|
||||
TRANSLATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_translation.py
|
||||
|
||||
DIARIZATION_URL=http://localhost:8000 \
|
||||
DIARIZATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_diarization.py
|
||||
@@ -1,19 +0,0 @@
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
required_key = os.environ.get("REFLECTOR_GPU_APIKEY")
|
||||
if not required_key:
|
||||
return
|
||||
if apikey == required_key:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLE_RATE = 16000
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
# App-level paths
|
||||
UPLOADS_PATH = Path("/tmp/whisper-uploads")
|
||||
@@ -1,30 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
from .services.transcriber import WhisperService
|
||||
from .services.diarizer import PyannoteDiarizationService
|
||||
from .utils import ensure_dirs
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
ensure_dirs()
|
||||
whisper_service = WhisperService()
|
||||
whisper_service.load()
|
||||
app.state.whisper = whisper_service
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
diarization_service.load()
|
||||
app.state.diarizer = diarization_service
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
return app
|
||||
@@ -1,30 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.diarizer import PyannoteDiarizationService
|
||||
from ..utils import download_audio_file
|
||||
|
||||
router = APIRouter(tags=["diarization"])
|
||||
|
||||
|
||||
class DiarizationSegment(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
diarization: List[DiarizationSegment]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth)], response_model=DiarizationResponse
|
||||
)
|
||||
def diarize(request: Request, audio_file_url: str, timestamp: float = 0.0):
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
diarizer: PyannoteDiarizationService = request.app.state.diarizer
|
||||
return diarizer.diarize_file(file_path, timestamp=timestamp)
|
||||
@@ -1,109 +0,0 @@
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
from ..auth import apikey_auth
|
||||
from ..config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
from ..services.transcriber import MODEL_NAME
|
||||
from ..utils import cleanup_uploaded_files, download_audio_file
|
||||
|
||||
router = APIRouter(prefix="/v1/audio", tags=["transcription"])
|
||||
|
||||
|
||||
class WordTiming(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
text: str
|
||||
words: list[WordTiming]
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class TranscriptBatchResponse(BaseModel):
|
||||
results: list[TranscriptResult]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=Union[TranscriptResult, TranscriptBatchResponse],
|
||||
)
|
||||
def transcribe(
|
||||
request: Request,
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_paths: list[Path] = []
|
||||
with cleanup_uploaded_files(uploaded_paths):
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1].lower()
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = UPLOADS_PATH / unique_filename
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_paths.append(file_path)
|
||||
|
||||
if batch and len(upload_files) > 1:
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions-from-url",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranscriptResult,
|
||||
)
|
||||
def transcribe_from_url(
|
||||
request: Request,
|
||||
audio_file_url: str = Body(..., description="URL of the audio file to transcribe"),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
result = service.transcribe_vad_url_segment(
|
||||
file_path=file_path, timestamp_offset=timestamp_offset, language=language
|
||||
)
|
||||
return result
|
||||
@@ -1,28 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.translator import TextTranslatorService
|
||||
|
||||
router = APIRouter(tags=["translation"])
|
||||
|
||||
translator = TextTranslatorService()
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
text: Dict[str, str]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/translate",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranslationResponse,
|
||||
)
|
||||
def translate(
|
||||
text: str,
|
||||
source_language: str = Body("en"),
|
||||
target_language: str = Body("fr"),
|
||||
):
|
||||
return translator.translate(text, source_language, target_language)
|
||||
@@ -1,42 +0,0 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=os.environ.get("HF_TOKEN"),
|
||||
)
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": words}
|
||||
@@ -1,208 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import faster_whisper
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import HTTPException
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from ..config import SAMPLE_RATE, VAD_CONFIG
|
||||
|
||||
# Whisper configuration (service-local defaults)
|
||||
MODEL_NAME = "large-v2"
|
||||
# None delegates compute type to runtime: float16 on CUDA, int8 on CPU
|
||||
MODEL_COMPUTE_TYPE = None
|
||||
MODEL_NUM_WORKERS = 1
|
||||
CACHE_PATH = os.path.join(os.path.expanduser("~"), ".cache", "reflector-whisper")
|
||||
from ..utils import NoStdStreams
|
||||
|
||||
|
||||
class WhisperService:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.device = "cpu"
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = MODEL_COMPUTE_TYPE or (
|
||||
"float16" if self.device == "cuda" else "int8"
|
||||
)
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=compute_type,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
)
|
||||
|
||||
def pad_audio(self, audio_array, sample_rate: int = SAMPLE_RATE):
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
def enforce_word_timing_constraints(self, words: list[dict]) -> list[dict]:
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
def transcribe_file(self, file_path: str, language: str = "en") -> dict:
|
||||
input_for_model: str | "object" = file_path
|
||||
try:
|
||||
audio_array, _sample_rate = librosa.load(
|
||||
file_path, sr=SAMPLE_RATE, mono=True
|
||||
)
|
||||
if len(audio_array) / float(SAMPLE_RATE) < VAD_CONFIG["silence_padding"]:
|
||||
input_for_model = self.pad_audio(audio_array, SAMPLE_RATE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
input_for_model,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
words = self.enforce_word_timing_constraints(words)
|
||||
return {"text": text, "words": words}
|
||||
|
||||
def transcribe_vad_url_segment(
|
||||
self, file_path: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
) -> dict:
|
||||
def load_audio_via_ffmpeg(input_path: str, sample_rate: int) -> np.ndarray:
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"ffmpeg failed: {e}")
|
||||
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
return audio
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
audio_array = load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for seg_start, seg_end in vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = self.pad_audio(segment, SAMPLE_RATE)
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
all_words = self.enforce_word_timing_constraints(all_words)
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
@@ -1,44 +0,0 @@
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
|
||||
class TextTranslatorService:
|
||||
"""Simple text-to-text translator using HuggingFace MarianMT models.
|
||||
|
||||
This mirrors the modal translator API shape but uses text translation only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
# Pick a default MarianMT model pair if available; fall back to Helsinki-NLP en->fr
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
# Minimal mapping; extend as needed
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
@@ -1,107 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Mapping
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
UPLOADS_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> str:
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_uploads(audio_file_url: str) -> tuple[Path, str]:
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path: Path = UPLOADS_PATH / unique_filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return file_path, audio_suffix
|
||||
|
||||
|
||||
@contextmanager
|
||||
def download_audio_file(audio_file_url: str):
|
||||
"""Download an audio file to UPLOADS_PATH and remove it after use.
|
||||
|
||||
Yields (file_path: Path, audio_suffix: str).
|
||||
"""
|
||||
file_path, audio_suffix = download_audio_to_uploads(audio_file_url)
|
||||
try:
|
||||
yield file_path, audio_suffix
|
||||
finally:
|
||||
try:
|
||||
file_path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", file_path, e)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cleanup_uploaded_files(file_paths: list[Path]):
|
||||
"""Ensure provided file paths are removed after use.
|
||||
|
||||
The provided list can be populated inside the context; all present entries
|
||||
at exit will be deleted.
|
||||
"""
|
||||
try:
|
||||
yield file_paths
|
||||
finally:
|
||||
for path in list(file_paths):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", path, e)
|
||||
@@ -1,10 +0,0 @@
|
||||
services:
|
||||
reflector_gpu:
|
||||
build:
|
||||
context: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./cache:/root/.cache
|
||||
@@ -1,3 +0,0 @@
|
||||
from app.factory import create_app
|
||||
|
||||
app = create_app()
|
||||
@@ -1,19 +0,0 @@
|
||||
[project]
|
||||
name = "reflector-gpu"
|
||||
version = "0.1.0"
|
||||
description = "Self-hosted GPU service for speech transcription, diarization, and translation via FastAPI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"torch>=2.3.0",
|
||||
"faster-whisper>=1.1.0",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
"transformers>=4.35.0",
|
||||
"sentencepiece",
|
||||
"pyannote.audio==3.1.0",
|
||||
"torchaudio>=2.3.0",
|
||||
]
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
export PATH="/root/.local/bin:$PATH"
|
||||
cd /app
|
||||
|
||||
# Install Python dependencies at runtime (first run or when FORCE_SYNC=1)
|
||||
if [ ! -d "/app/.venv" ] || [ "$FORCE_SYNC" = "1" ]; then
|
||||
echo "[startup] Installing Python dependencies with uv..."
|
||||
uv sync --compile-bytecode --locked
|
||||
else
|
||||
echo "[startup] Using existing virtual environment at /app/.venv"
|
||||
fi
|
||||
|
||||
exec uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
|
||||
3013
gpu/self_hosted/uv.lock
generated
3013
gpu/self_hosted/uv.lock
generated
File diff suppressed because it is too large
Load Diff
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,8 +176,7 @@ artefacts/
|
||||
audio_*.wav
|
||||
|
||||
# ignore local database
|
||||
*.sqlite3
|
||||
*.db
|
||||
reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
@@ -14,8 +13,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
||||
# install application dependencies
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock README.md /app/
|
||||
RUN uv sync --compile-bytecode --locked
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN touch README.md && env uv sync --compile-bytecode --locked
|
||||
|
||||
# pre-download nltk packages
|
||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
@@ -27,15 +26,4 @@ COPY migrations /app/migrations
|
||||
COPY reflector /app/reflector
|
||||
WORKDIR /app
|
||||
|
||||
# Create symlink for libgomp if it doesn't exist (for ARM64 compatibility)
|
||||
RUN if [ "$(uname -m)" = "aarch64" ] && [ ! -f /usr/lib/libgomp.so.1 ]; then \
|
||||
LIBGOMP_PATH=$(find /app/.venv/lib -path "*/torch.libs/libgomp*.so.*" 2>/dev/null | head -n1); \
|
||||
if [ -n "$LIBGOMP_PATH" ]; then \
|
||||
ln -sf "$LIBGOMP_PATH" /usr/lib/libgomp.so.1; \
|
||||
fi \
|
||||
fi
|
||||
|
||||
# Pre-check just to make sure the image will not fail
|
||||
RUN uv run python -c "import silero_vad.model"
|
||||
|
||||
CMD ["./runserver.sh"]
|
||||
|
||||
@@ -40,5 +40,3 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||
```
|
||||
|
||||
.
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# Data Retention and Cleanup
|
||||
|
||||
## Overview
|
||||
|
||||
For public instances of Reflector, a data retention policy is automatically enforced to delete anonymous user data after a configurable period (default: 7 days). This ensures compliance with privacy expectations and prevents unbounded storage growth.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `PUBLIC_MODE` (bool): Must be set to `true` to enable automatic cleanup
|
||||
- `PUBLIC_DATA_RETENTION_DAYS` (int): Number of days to retain anonymous data (default: 7)
|
||||
|
||||
### What Gets Deleted
|
||||
|
||||
When data reaches the retention period, the following items are automatically removed:
|
||||
|
||||
1. **Transcripts** from anonymous users (where `user_id` is NULL):
|
||||
- Database records
|
||||
- Local files (audio.wav, audio.mp3, audio.json waveform)
|
||||
- Storage files (cloud storage if configured)
|
||||
|
||||
## Automatic Cleanup
|
||||
|
||||
### Celery Beat Schedule
|
||||
|
||||
When `PUBLIC_MODE=true`, a Celery beat task runs daily at 3 AM to clean up old data:
|
||||
|
||||
```python
|
||||
# Automatically scheduled when PUBLIC_MODE=true
|
||||
"cleanup_old_public_data": {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data",
|
||||
"schedule": crontab(hour=3, minute=0), # Daily at 3 AM
|
||||
}
|
||||
```
|
||||
|
||||
### Running the Worker
|
||||
|
||||
Ensure both Celery worker and beat scheduler are running:
|
||||
|
||||
```bash
|
||||
# Start Celery worker
|
||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
# Start Celery beat scheduler (in another terminal)
|
||||
uv run celery -A reflector.worker.app beat
|
||||
```
|
||||
|
||||
## Manual Cleanup
|
||||
|
||||
For testing or manual intervention, use the cleanup tool:
|
||||
|
||||
```bash
|
||||
# Delete data older than 7 days (default)
|
||||
uv run python -m reflector.tools.cleanup_old_data
|
||||
|
||||
# Delete data older than 30 days
|
||||
uv run python -m reflector.tools.cleanup_old_data --days 30
|
||||
```
|
||||
|
||||
Note: The manual tool uses the same implementation as the Celery worker task to ensure consistency.
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **User Data Deletion**: Only anonymous data (where `user_id` is NULL) is deleted. Authenticated user data is preserved.
|
||||
|
||||
2. **Storage Cleanup**: The system properly cleans up both local files and cloud storage when configured.
|
||||
|
||||
3. **Error Handling**: If individual deletions fail, the cleanup continues and logs errors. Failed deletions are reported in the task output.
|
||||
|
||||
4. **Public Instance Only**: The automatic cleanup task only runs when `PUBLIC_MODE=true` to prevent accidental data loss in private deployments.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the cleanup tests:
|
||||
|
||||
```bash
|
||||
uv run pytest tests/test_cleanup.py -v
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check Celery logs for cleanup task execution:
|
||||
|
||||
```bash
|
||||
# Look for cleanup task logs
|
||||
grep "cleanup_old_public_data" celery.log
|
||||
grep "Starting cleanup of old public data" celery.log
|
||||
```
|
||||
|
||||
Task statistics are logged after each run:
|
||||
- Number of transcripts deleted
|
||||
- Number of meetings deleted
|
||||
- Number of orphaned recordings deleted
|
||||
- Any errors encountered
|
||||
@@ -1,194 +0,0 @@
|
||||
## Reflector GPU Transcription API (Specification)
|
||||
|
||||
This document defines the Reflector GPU transcription API that all implementations must adhere to. Current implementations include NVIDIA Parakeet (NeMo) and Whisper (faster-whisper), both deployed on Modal.com. The API surface and response shapes are OpenAI/Whisper-compatible, so clients can switch implementations by changing only the base URL.
|
||||
|
||||
### Base URL and Authentication
|
||||
|
||||
- Example base URLs (Modal web endpoints):
|
||||
|
||||
- Parakeet: `https://<account>--reflector-transcriber-parakeet-web.modal.run`
|
||||
- Whisper: `https://<account>--reflector-transcriber-web.modal.run`
|
||||
|
||||
- All endpoints are served under `/v1` and require a Bearer token:
|
||||
|
||||
```
|
||||
Authorization: Bearer <REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
Note: To switch implementations, deploy the desired variant and point `TRANSCRIPT_URL` to its base URL. The API is identical.
|
||||
|
||||
### Supported file types
|
||||
|
||||
`mp3, mp4, mpeg, mpga, m4a, wav, webm`
|
||||
|
||||
### Models and languages
|
||||
|
||||
- Parakeet (NVIDIA NeMo): default `nvidia/parakeet-tdt-0.6b-v2`
|
||||
- Language support: only `en`. Other languages return HTTP 400.
|
||||
- Whisper (faster-whisper): default `large-v2` (or deployment-specific)
|
||||
- Language support: multilingual (per Whisper model capabilities).
|
||||
|
||||
Note: The `model` parameter is accepted by all implementations for interface parity. Some backends may treat it as informational.
|
||||
|
||||
### Endpoints
|
||||
|
||||
#### POST /v1/audio/transcriptions
|
||||
|
||||
Transcribe one or more uploaded audio files.
|
||||
|
||||
Request: multipart/form-data
|
||||
|
||||
- `file` (File) — optional. Single file to transcribe.
|
||||
- `files` (File[]) — optional. One or more files to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`.
|
||||
- Parakeet: only `en` is accepted; other values return HTTP 400
|
||||
- Whisper: model-dependent; typically multilingual
|
||||
- `batch` (boolean) — optional, defaults to `false`.
|
||||
|
||||
Notes:
|
||||
|
||||
- Provide either `file` or `files`, not both. If neither is provided, HTTP 400.
|
||||
- `batch` requires `files`; using `batch=true` without `files` returns HTTP 400.
|
||||
- Response shape for multiple files is the same regardless of `batch`.
|
||||
- Files sent to this endpoint are processed in a single pass (no VAD/chunking). This is intended for short clips (roughly ≤ 30s; depends on GPU memory/model). For longer audio, prefer `/v1/audio/transcriptions-from-url` which supports VAD-based chunking.
|
||||
|
||||
Responses
|
||||
|
||||
Single file response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 0.0, "end": 0.5 },
|
||||
{ "word": "world", "start": 0.5, "end": 1.0 }
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
Multiple files response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{"filename": "a1.mp3", "text": "...", "words": [...]},
|
||||
{"filename": "a2.mp3", "text": "...", "words": [...]}]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Word objects always include keys: `word`, `start`, `end`.
|
||||
- Some implementations may include a trailing space in `word` to match Whisper tokenization behavior; clients should trim if needed.
|
||||
|
||||
Example curl (single file):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "file=@/path/to/audio.mp3" \
|
||||
-F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
Example curl (multiple files, batch):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "files=@/path/a1.mp3" -F "files=@/path/a2.mp3" \
|
||||
-F "batch=true" -F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
#### POST /v1/audio/transcriptions-from-url
|
||||
|
||||
Transcribe a single remote audio file by URL.
|
||||
|
||||
Request: application/json
|
||||
|
||||
Body parameters:
|
||||
|
||||
- `audio_file_url` (string) — required. URL of the audio file to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`. Parakeet only accepts `en`.
|
||||
- `timestamp_offset` (number) — optional, defaults to `0.0`. Added to each word's `start`/`end` in the response.
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 10.0, "end": 10.5 },
|
||||
{ "word": "world", "start": 10.5, "end": 11.0 }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `timestamp_offset` is added to each word’s `start`/`end` in the response.
|
||||
- Implementations may perform VAD-based chunking and batching for long-form audio; word timings are adjusted accordingly.
|
||||
|
||||
Example curl:
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0
|
||||
}' \
|
||||
"$BASE_URL/v1/audio/transcriptions-from-url"
|
||||
```
|
||||
|
||||
### Error handling
|
||||
|
||||
- 400 Bad Request
|
||||
- Parakeet: `language` other than `en`
|
||||
- Missing required parameters (`file`/`files` for upload; `audio_file_url` for URL endpoint)
|
||||
- Unsupported file extension
|
||||
- 401 Unauthorized
|
||||
- Missing or invalid Bearer token
|
||||
- 404 Not Found
|
||||
- `audio_file_url` does not exist
|
||||
|
||||
### Implementation details
|
||||
|
||||
- GPUs: A10G for small-file/live, L40S for large-file URL transcription (subject to deployment)
|
||||
- VAD chunking and segment batching; word timings adjusted and overlapping ends constrained
|
||||
- Pads very short segments (< 0.5s) to avoid model crashes on some backends
|
||||
|
||||
### Server configuration (Reflector API)
|
||||
|
||||
Set the Reflector server to use the Modal backend and point `TRANSCRIPT_URL` to your chosen deployment:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://<account>--reflector-transcriber-parakeet-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=<REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
### Conformance tests
|
||||
|
||||
Use the pytest-based conformance tests to validate any new implementation (including self-hosted) against this spec:
|
||||
|
||||
```
|
||||
TRANSCRIPT_URL=https://<your-deployment-base> \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
|
||||
```
|
||||
@@ -1,212 +0,0 @@
|
||||
# Reflector Webhook Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
Reflector supports webhook notifications to notify external systems when transcript processing is completed. Webhooks can be configured per room and are triggered automatically after a transcript is successfully processed.
|
||||
|
||||
## Configuration
|
||||
|
||||
Webhooks are configured at the room level with two fields:
|
||||
- `webhook_url`: The HTTPS endpoint to receive webhook notifications
|
||||
- `webhook_secret`: Optional secret key for HMAC signature verification (auto-generated if not provided)
|
||||
|
||||
## Events
|
||||
|
||||
### `transcript.completed`
|
||||
|
||||
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
|
||||
|
||||
### `test`
|
||||
|
||||
A test event that can be triggered manually to verify webhook configuration.
|
||||
|
||||
## Webhook Request Format
|
||||
|
||||
### Headers
|
||||
|
||||
All webhook requests include the following headers:
|
||||
|
||||
| Header | Description | Example |
|
||||
|--------|-------------|---------|
|
||||
| `Content-Type` | Always `application/json` | `application/json` |
|
||||
| `User-Agent` | Identifies Reflector as the source | `Reflector-Webhook/1.0` |
|
||||
| `X-Webhook-Event` | The event type | `transcript.completed` or `test` |
|
||||
| `X-Webhook-Retry` | Current retry attempt number | `0`, `1`, `2`... |
|
||||
| `X-Webhook-Signature` | HMAC signature (if secret configured) | `t=1735306800,v1=abc123...` |
|
||||
|
||||
### Signature Verification
|
||||
|
||||
If a webhook secret is configured, Reflector includes an HMAC-SHA256 signature in the `X-Webhook-Signature` header to verify the webhook authenticity.
|
||||
|
||||
The signature format is: `t={timestamp},v1={signature}`
|
||||
|
||||
To verify the signature:
|
||||
1. Extract the timestamp and signature from the header
|
||||
2. Create the signed payload: `{timestamp}.{request_body}`
|
||||
3. Compute HMAC-SHA256 of the signed payload using your webhook secret
|
||||
4. Compare the computed signature with the received signature
|
||||
|
||||
Example verification (Python):
|
||||
```python
|
||||
import hmac
|
||||
import hashlib
|
||||
|
||||
def verify_webhook_signature(payload: bytes, signature_header: str, secret: str) -> bool:
|
||||
# Parse header: "t=1735306800,v1=abc123..."
|
||||
parts = dict(part.split("=") for part in signature_header.split(","))
|
||||
timestamp = parts["t"]
|
||||
received_signature = parts["v1"]
|
||||
|
||||
# Create signed payload
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
|
||||
# Compute expected signature
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
return hmac.compare_digest(expected_signature, received_signature)
|
||||
```
|
||||
|
||||
## Event Payloads
|
||||
|
||||
### `transcript.completed` Event
|
||||
|
||||
This event includes a convenient URL for accessing the transcript:
|
||||
- `frontend_url`: Direct link to view the transcript in the web interface
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "transcript.completed",
|
||||
"event_id": "transcript.completed-abc-123-def-456",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"transcript": {
|
||||
"id": "abc-123-def-456",
|
||||
"room_id": "room-789",
|
||||
"created_at": "2025-08-27T12:00:00Z",
|
||||
"duration": 1800.5,
|
||||
"title": "Q3 Product Planning Meeting",
|
||||
"short_summary": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
|
||||
"long_summary": "The product team met to finalize the Q3 roadmap. Key decisions included...",
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone to today's meeting...",
|
||||
"topics": [
|
||||
{
|
||||
"title": "Introduction and Agenda",
|
||||
"summary": "Meeting kickoff with agenda review",
|
||||
"timestamp": 0.0,
|
||||
"duration": 120.0,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone..."
|
||||
},
|
||||
{
|
||||
"title": "Mobile App Features Discussion",
|
||||
"summary": "Team reviewed proposed mobile app features for Q3",
|
||||
"timestamp": 120.0,
|
||||
"duration": 600.0,
|
||||
"webvtt": "WEBVTT\n\n00:02:00.000 --> 00:02:10.000\n<v Speaker 2>Let's talk about the mobile app..."
|
||||
}
|
||||
],
|
||||
"participants": [
|
||||
{
|
||||
"id": "participant-1",
|
||||
"name": "John Doe",
|
||||
"speaker": "Speaker 1"
|
||||
},
|
||||
{
|
||||
"id": "participant-2",
|
||||
"name": "Jane Smith",
|
||||
"speaker": "Speaker 2"
|
||||
}
|
||||
],
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"status": "completed",
|
||||
"frontend_url": "https://app.reflector.com/transcripts/abc-123-def-456"
|
||||
},
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `test` Event
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "test",
|
||||
"event_id": "test.2025-08-27T12:34:56.789012Z",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"message": "This is a test webhook from Reflector",
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Retry Policy
|
||||
|
||||
Webhooks are delivered with automatic retry logic to handle transient failures. When a webhook delivery fails due to server errors or network issues, Reflector will automatically retry the delivery multiple times over an extended period.
|
||||
|
||||
### Retry Mechanism
|
||||
|
||||
Reflector implements an exponential backoff strategy for webhook retries:
|
||||
|
||||
- **Initial retry delay**: 60 seconds after the first failure
|
||||
- **Exponential backoff**: Each subsequent retry waits approximately twice as long as the previous one
|
||||
- **Maximum retry interval**: 1 hour (backoff is capped at this duration)
|
||||
- **Maximum retry attempts**: 30 attempts total
|
||||
- **Total retry duration**: Retries continue for approximately 24 hours
|
||||
|
||||
### How Retries Work
|
||||
|
||||
When a webhook fails, Reflector will:
|
||||
1. Wait 60 seconds, then retry (attempt #1)
|
||||
2. If it fails again, wait ~2 minutes, then retry (attempt #2)
|
||||
3. Continue doubling the wait time up to a maximum of 1 hour between attempts
|
||||
4. Keep retrying at 1-hour intervals until successful or 30 attempts are exhausted
|
||||
|
||||
The `X-Webhook-Retry` header indicates the current retry attempt number (0 for the initial attempt, 1 for first retry, etc.), allowing your endpoint to track retry attempts.
|
||||
|
||||
### Retry Behavior by HTTP Status Code
|
||||
|
||||
| Status Code | Behavior |
|
||||
|-------------|----------|
|
||||
| 2xx (Success) | No retry, webhook marked as delivered |
|
||||
| 4xx (Client Error) | No retry, request is considered permanently failed |
|
||||
| 5xx (Server Error) | Automatic retry with exponential backoff |
|
||||
| Network/Timeout Error | Automatic retry with exponential backoff |
|
||||
|
||||
**Important Notes:**
|
||||
- Webhooks timeout after 30 seconds. If your endpoint takes longer to respond, it will be considered a timeout error and retried.
|
||||
- During the retry period (~24 hours), you may receive the same webhook multiple times if your endpoint experiences intermittent failures.
|
||||
- There is no mechanism to manually retry failed webhooks after the retry period expires.
|
||||
|
||||
## Testing Webhooks
|
||||
|
||||
You can test your webhook configuration before processing transcripts:
|
||||
|
||||
```http
|
||||
POST /v1/rooms/{room_id}/webhook/test
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"status_code": 200,
|
||||
"message": "Webhook test successful",
|
||||
"response_preview": "OK"
|
||||
}
|
||||
```
|
||||
|
||||
Or in case of failure:
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": "Webhook request timed out (10 seconds)"
|
||||
}
|
||||
```
|
||||
86
server/gpu/modal_deployments/README.md
Normal file
86
server/gpu/modal_deployments/README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Reflector GPU implementation - Transcription and LLM
|
||||
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
Create a modal secret, and name it `reflector-gpu`.
|
||||
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
|
||||
|
||||
The deployment is done using [Modal.com](https://modal.com) service.
|
||||
|
||||
```
|
||||
$ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
```
|
||||
|
||||
Then in your reflector api configuration `.env`, you can set these keys:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
|
||||
|
||||
```
|
||||
Authorization: bearer <REFLECTOR_APIKEY>
|
||||
```
|
||||
|
||||
### LLM
|
||||
|
||||
`POST /llm`
|
||||
|
||||
**request**
|
||||
```
|
||||
{
|
||||
"prompt": "xxx"
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx completed"
|
||||
}
|
||||
```
|
||||
|
||||
### Transcription
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
- `file` - audio file
|
||||
- `language` - language code (e.g. `en`)
|
||||
|
||||
**response**
|
||||
```
|
||||
{
|
||||
"text": "xxx",
|
||||
"words": [
|
||||
{"text": "xxx", "start": 0.0, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
187
server/gpu/modal_deployments/reflector_diarizer.py
Normal file
187
server/gpu/modal_deployments/reflector_diarizer.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Reflector GPU backend - diarizer
|
||||
===================================
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import modal.gpu
|
||||
from modal import App, Image, Secret, asgi_app, enter, method
|
||||
from pydantic import BaseModel
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
app = App(name="reflector-diarizer")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
def download_pyannote_audio():
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
)
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
"onnx",
|
||||
"torchaudio",
|
||||
"onnxruntime-gpu",
|
||||
"torch==2.0.0",
|
||||
"transformers==4.34.0",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"numpy",
|
||||
"huggingface_hub",
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu=modal.gpu.A100(size="40GB"),
|
||||
timeout=60 * 30,
|
||||
scaledown_window=60,
|
||||
allow_concurrent_inputs=1,
|
||||
image=diarizer_image,
|
||||
)
|
||||
class Diarizer:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@method()
|
||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
||||
import tempfile
|
||||
|
||||
import torchaudio
|
||||
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
print("Diarizing audio")
|
||||
waveform, sample_rate = torchaudio.load(fp.name)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(
|
||||
yield_label=True
|
||||
):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
image=diarizer_image,
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def validate_audio_file(audio_file_url: str):
|
||||
# Check if the audio file exists
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="The audio file does not exist.",
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
||||
)
|
||||
def diarize(
|
||||
audio_file_url: str, timestamp: float = 0.0
|
||||
) -> HTTPException | DiarizationResponse:
|
||||
# Currently the uploaded files are in mp3 format
|
||||
audio_suffix = "mp3"
|
||||
|
||||
print("Downloading audio file")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
print("Audio file downloaded successfully")
|
||||
|
||||
func = diarizerstub.diarize.spawn(
|
||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
161
server/gpu/modal_deployments/reflector_transcriber.py
Normal file
161
server/gpu/modal_deployments/reflector_transcriber.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import modal
|
||||
from pydantic import BaseModel
|
||||
|
||||
MODELS_DIR = "/models"
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
|
||||
MINUTES = 60 # seconds
|
||||
|
||||
volume = modal.Volume.from_name("models", create_if_missing=True)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
volume.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=MODELS_DIR)
|
||||
|
||||
volume.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.run_function(download_model, volumes={MODELS_DIR: volume})
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
allow_concurrent_inputs=6,
|
||||
image=image,
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
class Transcriber:
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=MODELS_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
audio_data: str,
|
||||
audio_suffix: str,
|
||||
language: str,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
fp.name,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments)
|
||||
words = [
|
||||
{"word": word.word, "start": word.start, "end": word.end}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=60,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
transcriber = Transcriber()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
supported_file_types = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile,
|
||||
model: str = "whisper-1",
|
||||
language: Annotated[str, Body(...)] = "en",
|
||||
) -> TranscriptResponse:
|
||||
audio_data = file.file.read()
|
||||
audio_suffix = file.filename.split(".")[-1]
|
||||
assert audio_suffix in supported_file_types
|
||||
|
||||
func = transcriber.transcribe_segment.spawn(
|
||||
audio_data=audio_data,
|
||||
audio_suffix=audio_suffix,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Add webhook fields to rooms
|
||||
|
||||
Revision ID: 0194f65cd6d3
|
||||
Revises: 5a8907fd1d78
|
||||
Create Date: 2025-08-27 09:03:19.610995
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0194f65cd6d3"
|
||||
down_revision: Union[str, None] = "5a8907fd1d78"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("webhook_url", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("webhook_secret", sa.String(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_column("webhook_secret")
|
||||
batch_op.drop_column("webhook_url")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,64 +0,0 @@
|
||||
"""add_long_summary_to_search_vector
|
||||
|
||||
Revision ID: 0ab2d7ffaa16
|
||||
Revises: b1c33bd09963
|
||||
Create Date: 2025-08-15 13:27:52.680211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ab2d7ffaa16"
|
||||
down_revision: Union[str, None] = "b1c33bd09963"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the search vector column with long_summary included
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the updated search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the original search vector column without long_summary
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""remove user_id from meeting table
|
||||
|
||||
Revision ID: 0ce521cda2ee
|
||||
Revises: 6dec9fb5b46c
|
||||
Create Date: 2025-09-10 12:40:55.688899
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ce521cda2ee"
|
||||
down_revision: Union[str, None] = "6dec9fb5b46c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_column("user_id")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,32 +0,0 @@
|
||||
"""clean up orphaned room_id references in meeting table
|
||||
|
||||
Revision ID: 2ae3db106d4e
|
||||
Revises: def1b5867d4c
|
||||
Create Date: 2025-09-11 10:35:15.759967
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2ae3db106d4e"
|
||||
down_revision: Union[str, None] = "def1b5867d4c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set room_id to NULL for meetings that reference non-existent rooms
|
||||
op.execute("""
|
||||
UPDATE meeting
|
||||
SET room_id = NULL
|
||||
WHERE room_id IS NOT NULL
|
||||
AND room_id NOT IN (SELECT id FROM room WHERE id IS NOT NULL)
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore orphaned references - no operation needed
|
||||
pass
|
||||
@@ -1,50 +0,0 @@
|
||||
"""add cascade delete to meeting consent foreign key
|
||||
|
||||
Revision ID: 5a8907fd1d78
|
||||
Revises: 0ab2d7ffaa16
|
||||
Create Date: 2025-08-26 17:26:50.945491
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5a8907fd1d78"
|
||||
down_revision: Union[str, None] = "0ab2d7ffaa16"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,7 +1,7 @@
|
||||
"""remove_one_active_meeting_per_room_constraint
|
||||
|
||||
Revision ID: 6025e9b2bef2
|
||||
Revises: 2ae3db106d4e
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-18 18:45:44.418392
|
||||
|
||||
"""
|
||||
@@ -13,7 +13,7 @@ from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6025e9b2bef2"
|
||||
down_revision: Union[str, None] = "2ae3db106d4e"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""webhook url and secret null by default
|
||||
|
||||
|
||||
Revision ID: 61882a919591
|
||||
Revises: 0194f65cd6d3
|
||||
Create Date: 2025-08-29 11:46:36.738091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "61882a919591"
|
||||
down_revision: Union[str, None] = "0194f65cd6d3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,38 +0,0 @@
|
||||
"""make meeting room_id required and add foreign key
|
||||
|
||||
Revision ID: 6dec9fb5b46c
|
||||
Revises: 61882a919591
|
||||
Create Date: 2025-09-10 10:47:06.006819
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6dec9fb5b46c"
|
||||
down_revision: Union[str, None] = "61882a919591"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=False)
|
||||
batch_op.create_foreign_key(
|
||||
None, "room", ["room_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("meeting_room_id_fkey", type_="foreignkey")
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add_search_optimization_indexes
|
||||
|
||||
Revision ID: b1c33bd09963
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-14 17:26:02.117408
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b1c33bd09963"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add indexes for actual search filtering patterns used in frontend
|
||||
# Based on /browse page filters: room_id and source_kind
|
||||
|
||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||
op.create_index(
|
||||
"idx_transcript_room_id_created_at",
|
||||
"transcript",
|
||||
["room_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
# Index for source_kind alone (actively used filter in frontend)
|
||||
op.create_index(
|
||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||
@@ -1,129 +0,0 @@
|
||||
"""add calendar
|
||||
|
||||
Revision ID: d8e204bbf615
|
||||
Revises: d4a1c446458c
|
||||
Create Date: 2025-09-10 19:56:22.295756
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d8e204bbf615"
|
||||
down_revision: Union[str, None] = "d4a1c446458c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"calendar_event",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("room_id", sa.String(), nullable=False),
|
||||
sa.Column("ics_uid", sa.Text(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("attendees", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("location", sa.Text(), nullable=True),
|
||||
sa.Column("ics_raw_data", sa.Text(), nullable=True),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"is_deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["room_id"],
|
||||
["room.id"],
|
||||
name="fk_calendar_event_room_id",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||
)
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_deleted",
|
||||
["is_deleted"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("NOT is_deleted"),
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_room_start", ["room_id", "start_time"], unique=False
|
||||
)
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("calendar_event_id", sa.String(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"calendar_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_meeting_calendar_event", ["calendar_event_id"], unique=False
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_meeting_calendar_event_id",
|
||||
"calendar_event",
|
||||
["calendar_event_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("ics_url", sa.Text(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_fetch_interval", sa.Integer(), server_default="300", nullable=True
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_enabled",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("ics_last_sync", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
batch_op.add_column(sa.Column("ics_last_etag", sa.Text(), nullable=True))
|
||||
batch_op.create_index("idx_room_ics_enabled", ["ics_enabled"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_room_ics_enabled")
|
||||
batch_op.drop_column("ics_last_etag")
|
||||
batch_op.drop_column("ics_last_sync")
|
||||
batch_op.drop_column("ics_enabled")
|
||||
batch_op.drop_column("ics_fetch_interval")
|
||||
batch_op.drop_column("ics_url")
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_meeting_calendar_event_id", type_="foreignkey")
|
||||
batch_op.drop_index("idx_meeting_calendar_event")
|
||||
batch_op.drop_column("calendar_metadata")
|
||||
batch_op.drop_column("calendar_event_id")
|
||||
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_calendar_event_room_start")
|
||||
batch_op.drop_index(
|
||||
"idx_calendar_event_deleted", postgresql_where=sa.text("NOT is_deleted")
|
||||
)
|
||||
|
||||
op.drop_table("calendar_event")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,43 +0,0 @@
|
||||
"""remove_grace_period_fields
|
||||
|
||||
Revision ID: dc035ff72fd5
|
||||
Revises: d8e204bbf615
|
||||
Create Date: 2025-09-11 10:36:45.197588
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc035ff72fd5"
|
||||
down_revision: Union[str, None] = "d8e204bbf615"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove grace period columns from meeting table
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back grace period columns to meeting table
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"last_participant_left_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"grace_period_minutes",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("15"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
"""make meeting room_id nullable but keep foreign key
|
||||
|
||||
Revision ID: def1b5867d4c
|
||||
Revises: 0ce521cda2ee
|
||||
Create Date: 2025-09-11 09:42:18.697264
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "def1b5867d4c"
|
||||
down_revision: Union[str, None] = "0ce521cda2ee"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -12,6 +12,7 @@ dependencies = [
|
||||
"requests>=2.31.0",
|
||||
"aiortc>=1.5.0",
|
||||
"sortedcontainers>=2.4.0",
|
||||
"loguru>=0.7.0",
|
||||
"pydantic-settings>=2.0.2",
|
||||
"structlog>=23.1.0",
|
||||
"uvicorn[standard]>=0.23.1",
|
||||
@@ -26,10 +27,12 @@ dependencies = [
|
||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||
"sentencepiece>=0.1.99",
|
||||
"protobuf>=4.24.3",
|
||||
"profanityfilter>=2.0.6",
|
||||
"celery>=5.3.4",
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"faster-whisper>=0.10.0",
|
||||
"transformers>=4.36.2",
|
||||
"jsonschema>=4.23.0",
|
||||
"openai>=1.59.7",
|
||||
@@ -55,7 +58,6 @@ tests = [
|
||||
"httpx-ws>=0.4.1",
|
||||
"pytest-httpx>=0.23.1",
|
||||
"pytest-celery>=0.0.0",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-docker>=3.2.3",
|
||||
"asgi-lifespan>=2.1.0",
|
||||
]
|
||||
@@ -66,15 +68,6 @@ evaluation = [
|
||||
"tqdm>=4.66.0",
|
||||
"pydantic>=2.1.1",
|
||||
]
|
||||
local = [
|
||||
"pyannote-audio>=3.3.2",
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = [
|
||||
@@ -82,21 +75,6 @@ default-groups = [
|
||||
"tests",
|
||||
"aws",
|
||||
"evaluation",
|
||||
"local",
|
||||
"silero-vad"
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -117,9 +95,6 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
@@ -130,7 +105,7 @@ select = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||
"gpu/modal_deployments/**.py" = ["PLC0415"]
|
||||
"gpu/**.py" = ["PLC0415"]
|
||||
"reflector/tools/**.py" = ["PLC0415"]
|
||||
"migrations/versions/**.py" = ["PLC0415"]
|
||||
"tests/**.py" = ["PLC0415"]
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from reflector.db import get_database
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
coro = run_with_db()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
@@ -67,8 +67,7 @@ def current_user(
|
||||
try:
|
||||
payload = jwtauth.verify_token(token)
|
||||
sub = payload["sub"]
|
||||
email = payload["email"]
|
||||
return UserInfo(sub=sub, email=email)
|
||||
return UserInfo(sub=sub)
|
||||
except JWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -15,7 +15,7 @@ calendar_events = sa.Table(
|
||||
sa.Column(
|
||||
"room_id",
|
||||
sa.String,
|
||||
sa.ForeignKey("room.id", ondelete="CASCADE", name="fk_calendar_event_room_id"),
|
||||
sa.ForeignKey("room.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("ics_uid", sa.Text, nullable=False),
|
||||
@@ -65,6 +65,7 @@ class CalendarEventController:
|
||||
start_after: datetime | None = None,
|
||||
end_before: datetime | None = None,
|
||||
) -> list[CalendarEvent]:
|
||||
"""Get calendar events for a room."""
|
||||
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
|
||||
|
||||
if not include_deleted:
|
||||
@@ -82,9 +83,9 @@ class CalendarEventController:
|
||||
return [CalendarEvent(**result) for result in results]
|
||||
|
||||
async def get_upcoming(
|
||||
self, room_id: str, minutes_ahead: int = 120
|
||||
self, room_id: str, minutes_ahead: int = 30
|
||||
) -> list[CalendarEvent]:
|
||||
"""Get upcoming events for a room within the specified minutes, including currently happening events."""
|
||||
"""Get upcoming events for a room within the specified minutes."""
|
||||
now = datetime.now(timezone.utc)
|
||||
future_time = now + timedelta(minutes=minutes_ahead)
|
||||
|
||||
@@ -94,8 +95,8 @@ class CalendarEventController:
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
calendar_events.c.is_deleted == False,
|
||||
calendar_events.c.start_time >= now,
|
||||
calendar_events.c.start_time <= future_time,
|
||||
calendar_events.c.end_time >= now,
|
||||
)
|
||||
)
|
||||
.order_by(calendar_events.c.start_time.asc())
|
||||
@@ -105,6 +106,7 @@ class CalendarEventController:
|
||||
return [CalendarEvent(**result) for result in results]
|
||||
|
||||
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
|
||||
"""Get a calendar event by its ICS UID."""
|
||||
query = calendar_events.select().where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
@@ -115,9 +117,11 @@ class CalendarEventController:
|
||||
return CalendarEvent(**result) if result else None
|
||||
|
||||
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
|
||||
"""Create or update a calendar event."""
|
||||
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
|
||||
|
||||
if existing:
|
||||
# Update existing event
|
||||
event.id = existing.id
|
||||
event.created_at = existing.created_at
|
||||
event.updated_at = datetime.now(timezone.utc)
|
||||
@@ -128,6 +132,7 @@ class CalendarEventController:
|
||||
.values(**event.model_dump())
|
||||
)
|
||||
else:
|
||||
# Insert new event
|
||||
query = calendar_events.insert().values(**event.model_dump())
|
||||
|
||||
await get_database().execute(query)
|
||||
@@ -139,6 +144,7 @@ class CalendarEventController:
|
||||
"""Soft delete future events that are no longer in the calendar."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# First, get the IDs of events to delete
|
||||
select_query = calendar_events.select().where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
@@ -154,6 +160,7 @@ class CalendarEventController:
|
||||
delete_count = len(to_delete)
|
||||
|
||||
if delete_count > 0:
|
||||
# Now update them
|
||||
update_query = (
|
||||
calendar_events.update()
|
||||
.where(
|
||||
@@ -174,9 +181,13 @@ class CalendarEventController:
|
||||
return delete_count
|
||||
|
||||
async def delete_by_room(self, room_id: str) -> int:
|
||||
"""Hard delete all events for a room (used when room is deleted)."""
|
||||
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
|
||||
result = await get_database().execute(query)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
# Add missing import
|
||||
from datetime import timedelta
|
||||
|
||||
calendar_events_controller = CalendarEventController()
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
@@ -18,12 +19,8 @@ meetings = sa.Table(
|
||||
sa.Column("host_room_url", sa.String),
|
||||
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||
sa.Column(
|
||||
"room_id",
|
||||
sa.String,
|
||||
sa.ForeignKey("room.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("room_id", sa.String),
|
||||
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
|
||||
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
|
||||
@@ -48,13 +45,11 @@ meetings = sa.Table(
|
||||
sa.Column(
|
||||
"calendar_event_id",
|
||||
sa.String,
|
||||
sa.ForeignKey(
|
||||
"calendar_event.id",
|
||||
ondelete="SET NULL",
|
||||
name="fk_meeting_calendar_event_id",
|
||||
),
|
||||
sa.ForeignKey("calendar_event.id", ondelete="SET NULL"),
|
||||
),
|
||||
sa.Column("calendar_metadata", JSONB),
|
||||
sa.Column("last_participant_left_at", sa.DateTime(timezone=True)),
|
||||
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||
)
|
||||
@@ -63,12 +58,7 @@ meeting_consent = sa.Table(
|
||||
"meeting_consent",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column(
|
||||
"meeting_id",
|
||||
sa.String,
|
||||
sa.ForeignKey("meeting.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||
@@ -90,7 +80,8 @@ class Meeting(BaseModel):
|
||||
host_room_url: str
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
room_id: str | None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
is_locked: bool = False
|
||||
room_mode: Literal["normal", "group"] = "normal"
|
||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||
@@ -101,6 +92,8 @@ class Meeting(BaseModel):
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
last_participant_left_at: datetime | None = None
|
||||
grace_period_minutes: int = 15
|
||||
|
||||
|
||||
class MeetingController:
|
||||
@@ -112,10 +105,14 @@ class MeetingController:
|
||||
host_room_url: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: str,
|
||||
room: Room,
|
||||
calendar_event_id: str | None = None,
|
||||
calendar_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new meeting
|
||||
"""
|
||||
meeting = Meeting(
|
||||
id=id,
|
||||
room_name=room_name,
|
||||
@@ -123,6 +120,7 @@ class MeetingController:
|
||||
host_room_url=host_room_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_id=user_id,
|
||||
room_id=room.id,
|
||||
is_locked=room.is_locked,
|
||||
room_mode=room.room_mode,
|
||||
@@ -136,30 +134,27 @@ class MeetingController:
|
||||
return meeting
|
||||
|
||||
async def get_all_active(self) -> list[Meeting]:
|
||||
"""
|
||||
Get active meetings.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.is_active)
|
||||
return await get_database().fetch_all(query)
|
||||
|
||||
async def get_by_room_name(
|
||||
self,
|
||||
room_name: str,
|
||||
) -> Meeting | None:
|
||||
) -> Meeting:
|
||||
"""
|
||||
Get a meeting by room name.
|
||||
For backward compatibility, returns the most recent meeting.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
meetings.select()
|
||||
.where(meetings.c.room_name == room_name)
|
||||
.order_by(end_date.desc())
|
||||
)
|
||||
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_active(self, room: Room, current_time: datetime) -> Meeting | None:
|
||||
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
|
||||
"""
|
||||
Get latest active meeting for a room.
|
||||
For backward compatibility, returns the most recent active meeting.
|
||||
@@ -185,6 +180,10 @@ class MeetingController:
|
||||
async def get_all_active_for_room(
|
||||
self, room: Room, current_time: datetime
|
||||
) -> list[Meeting]:
|
||||
"""
|
||||
Get all active meetings for a room.
|
||||
This supports multiple concurrent meetings per room.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
meetings.select()
|
||||
@@ -220,12 +219,32 @@ class MeetingController:
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by id
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Meeting:
|
||||
"""
|
||||
Get a meeting by ID for HTTP request.
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
meeting = Meeting(**result)
|
||||
if result["user_id"] != user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
|
||||
query = meetings.select().where(
|
||||
meetings.c.calendar_event_id == calendar_event_id
|
||||
@@ -259,9 +278,10 @@ class MeetingConsentController:
|
||||
result = await get_database().fetch_one(query)
|
||||
if result is None:
|
||||
return None
|
||||
return MeetingConsent(**result)
|
||||
return MeetingConsent(**result) if result else None
|
||||
|
||||
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
|
||||
"""Create new consent or update existing one for authenticated users"""
|
||||
if consent.user_id:
|
||||
# For authenticated users, check if consent already exists
|
||||
# not transactional but we're ok with that; the consents ain't deleted anyways
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import Literal
|
||||
@@ -41,8 +40,6 @@ rooms = sqlalchemy.Table(
|
||||
sqlalchemy.Column(
|
||||
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column("webhook_url", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("webhook_secret", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("ics_url", sqlalchemy.Text),
|
||||
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
|
||||
sqlalchemy.Column(
|
||||
@@ -70,8 +67,6 @@ class Room(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
is_shared: bool = False
|
||||
webhook_url: str | None = None
|
||||
webhook_secret: str | None = None
|
||||
ics_url: str | None = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
@@ -125,8 +120,6 @@ class RoomController:
|
||||
recording_type: str,
|
||||
recording_trigger: str,
|
||||
is_shared: bool,
|
||||
webhook_url: str = "",
|
||||
webhook_secret: str = "",
|
||||
ics_url: str | None = None,
|
||||
ics_fetch_interval: int = 300,
|
||||
ics_enabled: bool = False,
|
||||
@@ -134,9 +127,6 @@ class RoomController:
|
||||
"""
|
||||
Add a new room
|
||||
"""
|
||||
if webhook_url and not webhook_secret:
|
||||
webhook_secret = secrets.token_urlsafe(32)
|
||||
|
||||
room = Room(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
@@ -148,8 +138,6 @@ class RoomController:
|
||||
recording_type=recording_type,
|
||||
recording_trigger=recording_trigger,
|
||||
is_shared=is_shared,
|
||||
webhook_url=webhook_url,
|
||||
webhook_secret=webhook_secret,
|
||||
ics_url=ics_url,
|
||||
ics_fetch_interval=ics_fetch_interval,
|
||||
ics_enabled=ics_enabled,
|
||||
@@ -165,9 +153,6 @@ class RoomController:
|
||||
"""
|
||||
Update a room fields with key/values in values
|
||||
"""
|
||||
if values.get("webhook_url") and not values.get("webhook_secret"):
|
||||
values["webhook_secret"] = secrets.token_urlsafe(32)
|
||||
|
||||
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||
try:
|
||||
await get_database().execute(query)
|
||||
@@ -217,13 +202,6 @@ class RoomController:
|
||||
|
||||
return room
|
||||
|
||||
async def get_ics_enabled(self) -> list[Room]:
|
||||
query = rooms.select().where(
|
||||
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||
)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Room(**result) for result in results]
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
@@ -1,38 +1,22 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict, Iterator
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from databases.interfaces import Record as DbRecord
|
||||
from fastapi import HTTPException
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
@@ -40,7 +24,6 @@ SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
|
||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||
search_query_adapter = TypeAdapter(SearchQuery)
|
||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||
SearchOffset = Annotated[
|
||||
SearchOffsetBase, Field(description="Number of results to skip")
|
||||
@@ -49,92 +32,15 @@ SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
WEBVTT_SPEC_HEADER = "WEBVTT"
|
||||
|
||||
WebVTTContent = Annotated[
|
||||
str,
|
||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||
]
|
||||
|
||||
|
||||
class WebVTTProcessor:
|
||||
"""Stateless processor for WebVTT content operations."""
|
||||
|
||||
@staticmethod
|
||||
def parse(raw_content: str) -> WebVTTContent:
|
||||
"""Parse WebVTT content and return it as a string."""
|
||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||
return raw_content
|
||||
|
||||
@staticmethod
|
||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except webvtt.errors.MalformedFileError as e:
|
||||
logger.warning(f"Malformed WebVTT content: {e}")
|
||||
return ""
|
||||
except (UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
return SnippetGenerator.generate(
|
||||
WebVTTProcessor.extract_text(webvtt_content),
|
||||
query,
|
||||
max_snippets=max_snippets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnippetCandidate:
|
||||
"""Represents a candidate snippet with its position."""
|
||||
|
||||
_text: str
|
||||
start: NonNegativeInt
|
||||
_original_text_length: int
|
||||
|
||||
@property
|
||||
def end(self) -> NonNegativeInt:
|
||||
"""Calculate end position from start and raw text length."""
|
||||
return self.start + len(self._text)
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get display text with ellipses added if needed."""
|
||||
result = self._text.strip()
|
||||
if self.start > 0:
|
||||
result = "..." + result
|
||||
if self.end < self._original_text_length:
|
||||
result = result + "..."
|
||||
return result
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
|
||||
query_text: SearchQuery | None = None
|
||||
query_text: SearchQuery
|
||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
@@ -158,18 +64,13 @@ class SearchResult(BaseModel):
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: TranscriptStatus = Field(..., min_length=1)
|
||||
status: str = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -178,157 +79,84 @@ class SearchResult(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SnippetGenerator:
|
||||
"""Stateless generator for text snippets and match operations."""
|
||||
|
||||
@staticmethod
|
||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||
"""Generate all match positions for a query in text."""
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in find_all_matches")
|
||||
return
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in find_all_matches")
|
||||
return
|
||||
|
||||
text_lower = text.lower()
|
||||
query_lower = query.lower()
|
||||
start = 0
|
||||
prev_start = start
|
||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||
yield pos
|
||||
start = pos + len(query_lower)
|
||||
if start <= prev_start:
|
||||
raise ValueError("panic! find_all_matches is not incremental")
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
assert query is not None
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||
) -> SnippetCandidate:
|
||||
"""Create a snippet from a match position."""
|
||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||
|
||||
snippet_text = text[snippet_start:snippet_end]
|
||||
|
||||
return SnippetCandidate(
|
||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_non_overlapping(
|
||||
candidates: Iterator[SnippetCandidate],
|
||||
) -> Iterator[str]:
|
||||
"""Filter out overlapping snippets and return only display text."""
|
||||
last_end = 0
|
||||
for candidate in candidates:
|
||||
display_text = candidate.text()
|
||||
# it means that next overlapping snippets simply don't get included
|
||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||
if candidate.start >= last_end and display_text:
|
||||
yield display_text
|
||||
last_end = candidate.end
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
query: SearchQuery,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from text."""
|
||||
assert query is not None
|
||||
if not text:
|
||||
logger.warning("Empty text for generate_snippets")
|
||||
return []
|
||||
|
||||
candidates = (
|
||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||
)
|
||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||
snippets = list(itertools.islice(filtered, max_snippets))
|
||||
|
||||
# Fallback to first word search if no full matches
|
||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||
if not snippets and " " in query:
|
||||
first_word = query.split()[0]
|
||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: NonEmptyString | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: SearchQuery,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
|
||||
Returns (snippets, total_match_count) tuple.
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
|
||||
assert (
|
||||
summary is not None or webvtt is not None
|
||||
), "At least one source must be present"
|
||||
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
if webvtt:
|
||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||
|
||||
if summary:
|
||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||
|
||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||
|
||||
summary_snippets = (
|
||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||
)
|
||||
|
||||
if len(summary_snippets) >= max_total:
|
||||
return summary_snippets[:max_total], total_matches
|
||||
|
||||
remaining = max_total - len(summary_snippets)
|
||||
webvtt_snippets = (
|
||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||
if webvtt
|
||||
else []
|
||||
)
|
||||
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
|
||||
return snippets
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
@@ -344,72 +172,39 @@ class SearchController:
|
||||
)
|
||||
return [], 0
|
||||
|
||||
base_columns = [
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
transcripts.c.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=rooms.c.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
search_query = None
|
||||
if params.query_text is not None:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
else:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(columns).select_from(
|
||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
|
||||
if params.query_text is not None:
|
||||
# because already initialized based on params.query_text presence above
|
||||
assert search_query is not None
|
||||
base_query = base_query.where(
|
||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(
|
||||
sqlalchemy.or_(
|
||||
transcripts.c.user_id == params.user_id, rooms.c.is_shared
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_query = base_query.where(rooms.c.is_shared)
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
if params.query_text is not None:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await get_database().fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
@@ -417,52 +212,20 @@ class SearchController:
|
||||
)
|
||||
total = await get_database().fetch_val(count_query)
|
||||
|
||||
def _process_result(r: DbRecord) -> SearchResult:
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
webvtt: WebVTTContent | None
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
|
||||
long_summary_r: str | None = r_dict.pop("long_summary", None)
|
||||
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
at_least_one_source = webvtt is not None or long_summary is not None
|
||||
has_query = params.query_text is not None
|
||||
snippets, total_match_count = (
|
||||
SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
if has_query and at_least_one_source
|
||||
else ([], 0)
|
||||
)
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
room_name=room_name,
|
||||
search_snippets=snippets,
|
||||
total_match_count=total_match_count,
|
||||
)
|
||||
|
||||
try:
|
||||
results = [_process_result(r) for r in rs]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal search result data consistency error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
webvtt_processor = WebVTTProcessor()
|
||||
snippet_generator = SnippetGenerator()
|
||||
|
||||
@@ -88,8 +88,6 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
@@ -101,8 +99,7 @@ if is_postgresql():
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
@@ -122,15 +119,6 @@ def generate_transcript_name() -> str:
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
TranscriptStatus = Literal[
|
||||
"idle", "uploaded", "recording", "processing", "error", "ended"
|
||||
]
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
@@ -194,7 +182,7 @@ class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: TranscriptStatus = "idle"
|
||||
status: str = "idle"
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
title: str | None = None
|
||||
@@ -741,27 +729,5 @@ class TranscriptController:
|
||||
transcript.delete_participant(participant_id)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
|
||||
async def set_status(
|
||||
self, transcript_id: str, status: TranscriptStatus
|
||||
) -> TranscriptEvent | None:
|
||||
"""
|
||||
Update the status of a transcript
|
||||
|
||||
Will add an event STATUS + update the status field of transcript
|
||||
"""
|
||||
async with self.transaction():
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
if transcript.status == status:
|
||||
return
|
||||
resp = await self.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await self.update(transcript, {"status": status})
|
||||
return resp
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
@@ -1,439 +0,0 @@
|
||||
"""
|
||||
File-based processing pipeline
|
||||
==============================
|
||||
|
||||
Optimized pipeline for processing complete audio/video files.
|
||||
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import chain, shared_task
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
Transcript,
|
||||
TranscriptStatus,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
PipelineMainBase,
|
||||
broadcast_to_sockets,
|
||||
task_cleanup_consent,
|
||||
task_pipeline_post_to_zulip,
|
||||
)
|
||||
from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.webhook import send_transcript_webhook
|
||||
|
||||
|
||||
class EmptyPipeline:
|
||||
"""Empty pipeline for processors that need a pipeline reference"""
|
||||
|
||||
def __init__(self, logger: structlog.BoundLogger):
|
||||
self.logger = logger
|
||||
|
||||
def get_pref(self, k, d=None):
|
||||
return d
|
||||
|
||||
async def emit(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineMainFile(PipelineMainBase):
|
||||
"""
|
||||
Optimized file processing pipeline.
|
||||
Processes complete audio/video files with parallel execution.
|
||||
"""
|
||||
|
||||
logger: structlog.BoundLogger = None
|
||||
empty_pipeline = None
|
||||
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__(transcript_id=transcript_id)
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
self.logger.error(
|
||||
f"Error in {operation} (task {i}): {result}",
|
||||
transcript_id=self.transcript_id,
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||
async with self.lock_transaction():
|
||||
return await transcripts_controller.set_status(transcript_id, status)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
async with self.transaction():
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
|
||||
# Upload for processing
|
||||
audio_url = await self.upload_audio(audio_path, transcript)
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
transcript.target_language,
|
||||
)
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
await transcripts_controller.set_status(transcript.id, "ended")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
) -> Path:
|
||||
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||
self.logger.info(f"Processing audio file: {file_path}")
|
||||
|
||||
# Check if it's already audio-only
|
||||
container = av.open(str(file_path))
|
||||
has_video = len(container.streams.video) > 0
|
||||
container.close()
|
||||
|
||||
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||
mp3_writer = AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
)
|
||||
|
||||
# Process audio frames and write to transcript location
|
||||
input_container = av.open(str(file_path))
|
||||
for frame in input_container.decode(audio=0):
|
||||
await mp3_writer.push(frame)
|
||||
|
||||
await mp3_writer.flush()
|
||||
input_container.close()
|
||||
|
||||
if has_video:
|
||||
self.logger.info(
|
||||
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
|
||||
return transcript.audio_mp3_filename
|
||||
|
||||
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||
"""Upload audio to storage for processing"""
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
if not storage:
|
||||
raise Exception(
|
||||
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||
)
|
||||
|
||||
self.logger.info("Uploading audio to storage")
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||
await storage.put_file(storage_path, audio_data)
|
||||
|
||||
audio_url = await storage.get_file_url(storage_path)
|
||||
|
||||
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||
return audio_url
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
):
|
||||
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||
self.logger.info(
|
||||
"Starting parallel processing", transcript_id=self.transcript_id
|
||||
)
|
||||
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
)
|
||||
|
||||
transcript_result = results[0]
|
||||
diarization_result = results[1]
|
||||
|
||||
# Handle errors - raise any exception that occurred
|
||||
self._handle_gather_exceptions(results, "parallel processing")
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
# Phase 2: Assemble transcript with diarization
|
||||
self.logger.info(
|
||||
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||
)
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
input_data = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_result, diarization=diarization_result or []
|
||||
)
|
||||
|
||||
# Store result for retrieval
|
||||
diarized_transcript: Transcript | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript captured")
|
||||
|
||||
# Phase 3: Generate topics from diarized transcript
|
||||
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||
|
||||
# Phase 4: Generate title and summaries in parallel
|
||||
self.logger.info(
|
||||
"Generating title and summaries", transcript_id=self.transcript_id
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
self._handle_gather_exceptions(results, "title and summary generation")
|
||||
|
||||
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||
"""Transcribe complete file"""
|
||||
processor = FileTranscriptAutoProcessor()
|
||||
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||
|
||||
# Store result for retrieval
|
||||
result: TranscriptType | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal result
|
||||
result = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not result:
|
||||
raise ValueError("No transcript captured")
|
||||
|
||||
return result
|
||||
|
||||
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||
"""Get diarization for file"""
|
||||
if not settings.DIARIZATION_BACKEND:
|
||||
self.logger.info("Diarization disabled")
|
||||
return None
|
||||
|
||||
processor = FileDiarizationAutoProcessor()
|
||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||
|
||||
# Store result for retrieval
|
||||
result = None
|
||||
|
||||
async def capture_result(diarization_output):
|
||||
nonlocal result
|
||||
result = diarization_output.diarization
|
||||
|
||||
try:
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def detect_topics(
|
||||
self, transcript: TranscriptType, target_language: str
|
||||
) -> list[TitleSummary]:
|
||||
"""Detect topics from complete transcript"""
|
||||
chunk_size = 300
|
||||
topics: list[TitleSummary] = []
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
topics.append(topic)
|
||||
return await self.on_topic(topic)
|
||||
|
||||
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||
topic_detector.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for i in range(0, len(transcript.words), chunk_size):
|
||||
chunk_words = transcript.words[i : i + chunk_size]
|
||||
if not chunk_words:
|
||||
continue
|
||||
|
||||
chunk_transcript = TranscriptType(
|
||||
words=chunk_words, translation=transcript.translation
|
||||
)
|
||||
|
||||
await topic_detector.push(chunk_transcript)
|
||||
|
||||
await topic_detector.flush()
|
||||
return topics
|
||||
|
||||
async def generate_title(self, topics: list[TitleSummary]):
|
||||
"""Generate title from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for title generation")
|
||||
return
|
||||
|
||||
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
on_short_summary=self.on_short_summary,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||
"""Send webhook if this is a room recording with webhook configured"""
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
await pipeline.set_status(transcript_id, "processing")
|
||||
|
||||
# Find the file to process
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
|
||||
if not audio_file:
|
||||
raise Exception("No audio file found to process")
|
||||
|
||||
await pipeline.process(audio_file)
|
||||
|
||||
except Exception:
|
||||
await pipeline.set_status(transcript_id, "error")
|
||||
raise
|
||||
|
||||
# Run post-processing chain: consent cleanup -> zulip -> webhook
|
||||
post_chain = chain(
|
||||
task_cleanup_consent.si(transcript_id=transcript_id),
|
||||
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
|
||||
task_send_webhook_if_needed.si(transcript_id=transcript_id),
|
||||
)
|
||||
post_chain.delay()
|
||||
@@ -22,7 +22,7 @@ from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_database
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -32,7 +32,6 @@ from reflector.db.transcripts import (
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptStatus,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
@@ -41,9 +40,8 @@ from reflector.db.transcripts import (
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioChunkerProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
@@ -70,6 +68,29 @@ from reflector.zulip import (
|
||||
)
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
coro = run_with_db()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
@@ -126,18 +147,15 @@ class StrValue(BaseModel):
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__()
|
||||
self._lock = asyncio.Lock()
|
||||
self.transcript_id = transcript_id
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self._ws_manager = None
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
@property
|
||||
def ws_manager(self) -> WebsocketManager:
|
||||
if self._ws_manager is None:
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
self._lock = asyncio.Lock()
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
@@ -165,16 +183,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
for topic in topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock_transaction(self):
|
||||
# This lock is to prevent multiple processor starting adding
|
||||
# into event array at the same time
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self.lock_transaction():
|
||||
async with self._lock:
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
|
||||
@@ -183,14 +194,14 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
# if it's the first part, update the status of the transcript
|
||||
# but do not set the ended status yet.
|
||||
if isinstance(self, PipelineMainLive):
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
status_mapping = {
|
||||
"started": "recording",
|
||||
"push": "recording",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainFinalSummaries):
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
status_mapping = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
@@ -206,8 +217,22 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
return
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self._lock:
|
||||
return await transcripts_controller.set_status(self.transcript_id, status)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if status == transcript.status:
|
||||
return
|
||||
resp = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return resp
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
@@ -330,6 +355,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
@@ -337,8 +363,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
path=transcript.audio_wav_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
@@ -351,7 +376,6 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
pipeline.describe()
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -370,6 +394,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
@@ -410,6 +435,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
|
||||
@@ -765,7 +792,7 @@ def pipeline_post(*, transcript_id: str):
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
|
||||
return chain.delay()
|
||||
chain.delay()
|
||||
|
||||
|
||||
@get_transcript
|
||||
|
||||
@@ -18,14 +18,22 @@ During its lifecycle, it will emit the following status:
|
||||
import asyncio
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
PipelineMessage = TypeVar("PipelineMessage")
|
||||
|
||||
|
||||
class PipelineRunner(Generic[PipelineMessage]):
|
||||
def __init__(self):
|
||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
@@ -34,8 +42,6 @@ class PipelineRunner(Generic[PipelineMessage]):
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
self.status = "idle"
|
||||
self.pipeline: Pipeline | None = None
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_chunker_auto import AudioChunkerAutoProcessor # noqa: F401
|
||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
@@ -13,13 +11,6 @@ from .base import ( # noqa: F401
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||
from .transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||
)
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
|
||||
@@ -1,78 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioChunkerProcessor(Processor):
|
||||
"""
|
||||
Base class for assembling audio frames into chunks
|
||||
Assemble audio frames into chunks
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = list[av.AudioFrame]
|
||||
|
||||
m_chunk = Histogram(
|
||||
"audio_chunker",
|
||||
"Time spent in AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_call = Counter(
|
||||
"audio_chunker_call",
|
||||
"Number of calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_success = Counter(
|
||||
"audio_chunker_success",
|
||||
"Number of successful calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
m_chunk_failure = Counter(
|
||||
"audio_chunker_failure",
|
||||
"Number of failed calls to AudioChunker.chunk",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_chunk = self.m_chunk.labels(name)
|
||||
self.m_chunk_call = self.m_chunk_call.labels(name)
|
||||
self.m_chunk_success = self.m_chunk_success.labels(name)
|
||||
self.m_chunk_failure = self.m_chunk_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, max_frames=256):
|
||||
super().__init__()
|
||||
self.frames: list[av.AudioFrame] = []
|
||||
self.max_frames = max_frames
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
"""Process incoming audio frame"""
|
||||
# Validate audio format on first frame
|
||||
if len(self.frames) == 0:
|
||||
if data.sample_rate != 16000 or len(data.layout.channels) != 1:
|
||||
raise ValueError(
|
||||
f"AudioChunkerProcessor expects 16kHz mono audio, got {data.sample_rate}Hz "
|
||||
f"with {len(data.layout.channels)} channel(s). "
|
||||
f"Use AudioDownscaleProcessor before this processor."
|
||||
)
|
||||
|
||||
try:
|
||||
self.m_chunk_call.inc()
|
||||
with self.m_chunk.time():
|
||||
result = await self._chunk(data)
|
||||
self.m_chunk_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_chunk_failure.inc()
|
||||
raise
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
"""
|
||||
Process audio frame and return chunk when ready.
|
||||
Subclasses should implement their chunking logic here.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
self.frames.append(data)
|
||||
if len(self.frames) >= self.max_frames:
|
||||
await self.flush()
|
||||
|
||||
async def _flush(self):
|
||||
"""Flush any remaining frames when processing ends"""
|
||||
raise NotImplementedError
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
await self.emit(frames)
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioChunkerAutoProcessor(AudioChunkerProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.AUDIO_CHUNKER_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_chunker_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `AUDIO_CHUNKER_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "AUDIO_CHUNKER_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||
|
||||
|
||||
class AudioChunkerFramesProcessor(AudioChunkerProcessor):
|
||||
"""
|
||||
Simple frame-based audio chunker that emits chunks after a fixed number of frames
|
||||
"""
|
||||
|
||||
def __init__(self, max_frames=256, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_frames = max_frames
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
self.frames.append(data)
|
||||
if len(self.frames) >= self.max_frames:
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
return frames_to_emit
|
||||
|
||||
return None
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
await self.emit(frames)
|
||||
|
||||
|
||||
AudioChunkerAutoProcessor.register("frames", AudioChunkerFramesProcessor)
|
||||
@@ -1,298 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||
|
||||
|
||||
class AudioChunkerSileroProcessor(AudioChunkerProcessor):
|
||||
"""
|
||||
Assemble audio frames into chunks with VAD-based speech detection using Silero VAD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_frames=256,
|
||||
max_frames=1024,
|
||||
use_onnx=True,
|
||||
min_frames=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.block_frames = block_frames
|
||||
self.max_frames = max_frames
|
||||
self.min_frames = min_frames
|
||||
|
||||
# Initialize Silero VAD
|
||||
self._init_vad(use_onnx)
|
||||
|
||||
def _init_vad(self, use_onnx=False):
|
||||
"""Initialize Silero VAD model"""
|
||||
try:
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||
self.logger.info("Silero VAD initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||
self.vad_model = None
|
||||
self.vad_iterator = None
|
||||
|
||||
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||
"""Process audio frame and return chunk when ready"""
|
||||
self.frames.append(data)
|
||||
|
||||
# Check for speech segments every 32 frames (~1 second)
|
||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||
return await self._process_block()
|
||||
|
||||
# Safety fallback - emit if we hit max frames
|
||||
elif len(self.frames) >= self.max_frames:
|
||||
self.logger.warning(
|
||||
f"AudioChunkerSileroProcessor: Reached max frames ({self.max_frames}), "
|
||||
f"emitting first {self.max_frames // 2} frames"
|
||||
)
|
||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||
self.frames = self.frames[self.max_frames // 2 :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _process_block(self) -> Optional[list[av.AudioFrame]]:
|
||||
# Need at least 32 frames for VAD detection (~1 second)
|
||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||
return None
|
||||
|
||||
# Processing block with current buffer size
|
||||
print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||
|
||||
try:
|
||||
# Convert frames to numpy array for VAD
|
||||
audio_array = self._frames_to_numpy(self.frames)
|
||||
|
||||
if audio_array is None:
|
||||
# Fallback: emit all frames if conversion failed
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
return None
|
||||
|
||||
# Find complete speech segments in the buffer
|
||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||
|
||||
if speech_end_frame is None or speech_end_frame <= 0:
|
||||
# No speech found but buffer is getting large
|
||||
if len(self.frames) > 512:
|
||||
# Check if it's all silence and can be discarded
|
||||
# No speech segment found, buffer at {len(self.frames)} frames
|
||||
|
||||
# Could emit silence or discard old frames here
|
||||
# For now, keep first 256 frames and discard older silence
|
||||
if len(self.frames) > 768:
|
||||
self.logger.debug(
|
||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||
)
|
||||
self.frames = self.frames[-256:]
|
||||
return None
|
||||
|
||||
# Calculate segment timing information
|
||||
frames_to_emit = self.frames[:speech_end_frame]
|
||||
|
||||
# Get timing from av.AudioFrame
|
||||
if frames_to_emit:
|
||||
first_frame = frames_to_emit[0]
|
||||
last_frame = frames_to_emit[-1]
|
||||
sample_rate = first_frame.sample_rate
|
||||
|
||||
# Calculate duration
|
||||
total_samples = sum(f.samples for f in frames_to_emit)
|
||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||
|
||||
# Get timestamps if available
|
||||
start_time = (
|
||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||
)
|
||||
end_time = (
|
||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||
)
|
||||
|
||||
# Convert to HH:MM:SS format for logging
|
||||
def format_time(seconds):
|
||||
if not seconds:
|
||||
return "00:00:00"
|
||||
total_seconds = int(float(seconds))
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds % 3600) // 60
|
||||
secs = total_seconds % 60
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
remaining_after = len(self.frames) - speech_end_frame
|
||||
|
||||
# Single structured log line
|
||||
self.logger.info(
|
||||
"Speech segment found",
|
||||
start=start_formatted,
|
||||
end=end_formatted,
|
||||
frames=speech_end_frame,
|
||||
duration=round(duration_seconds, 2),
|
||||
buffer_before=len(self.frames),
|
||||
remaining=remaining_after,
|
||||
)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
self.frames = self.frames[speech_end_frame:]
|
||||
|
||||
# Filter out segments with too few frames
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in VAD processing: {e}")
|
||||
# Fallback to simple chunking
|
||||
if len(self.frames) >= self.block_frames:
|
||||
frames_to_emit = self.frames[: self.block_frames]
|
||||
self.frames = self.frames[self.block_frames :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
return frames_to_emit
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||
if not frames:
|
||||
return None
|
||||
|
||||
try:
|
||||
audio_data = []
|
||||
for frame in frames:
|
||||
frame_array = frame.to_ndarray()
|
||||
|
||||
if len(frame_array.shape) == 2:
|
||||
frame_array = frame_array.flatten()
|
||||
|
||||
audio_data.append(frame_array)
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
combined_audio = np.concatenate(audio_data)
|
||||
|
||||
# Ensure float32 format
|
||||
if combined_audio.dtype == np.int16:
|
||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||
elif combined_audio.dtype != np.float32:
|
||||
combined_audio = combined_audio.astype(np.float32)
|
||||
|
||||
return combined_audio
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||
"""Find complete speech segments and return frame index at segment end"""
|
||||
if self.vad_iterator is None or len(audio_array) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process audio in 512-sample windows for VAD
|
||||
window_size = 512
|
||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||
|
||||
# Track speech state
|
||||
in_speech = False
|
||||
speech_start = None
|
||||
speech_end = None
|
||||
silence_count = 0
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||
|
||||
# Detect if this window has speech
|
||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||
|
||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||
if speech_dict:
|
||||
if not in_speech:
|
||||
# Speech started
|
||||
speech_start = i
|
||||
in_speech = True
|
||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||
silence_count = 0 # Reset silence counter
|
||||
continue
|
||||
|
||||
if not in_speech:
|
||||
continue
|
||||
|
||||
# We're in speech but found silence
|
||||
silence_count += 1
|
||||
if silence_count < min_silence_windows:
|
||||
continue
|
||||
|
||||
# Found end of speech segment
|
||||
speech_end = i - (min_silence_windows - 1) * window_size
|
||||
# Debug: print(f"Speech END at sample {speech_end}")
|
||||
|
||||
# Convert sample position to frame index
|
||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||
frame_index = speech_end // samples_per_frame
|
||||
|
||||
# Ensure we don't exceed buffer
|
||||
frame_index = min(frame_index, len(self.frames))
|
||||
return frame_index
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding speech segment: {e}")
|
||||
return None
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
if len(frames) >= self.min_frames:
|
||||
await self.emit(frames)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring flush segment with {len(frames)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
|
||||
AudioChunkerAutoProcessor.register("silero", AudioChunkerSileroProcessor)
|
||||
@@ -1,7 +1,6 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
@@ -38,21 +37,18 @@ class AudioDiarizationProcessor(Processor):
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||
cls._diarization_remove_overlap(diarization)
|
||||
cls._diarization_remove_segment_without_words(words, diarization)
|
||||
cls._diarization_merge_same_speaker(diarization)
|
||||
cls._diarization_assign_speaker(words, diarization)
|
||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
self._diarization_remove_overlap(diarization)
|
||||
self._diarization_remove_segment_without_words(words, diarization)
|
||||
self._diarization_merge_same_speaker(words, diarization)
|
||||
self._diarization_assign_speaker(words, diarization)
|
||||
|
||||
@staticmethod
|
||||
def iter_words_from_topics(topics: list[TitleSummary]):
|
||||
def iter_words_from_topics(self, topics: TitleSummary):
|
||||
for topic in topics:
|
||||
for word in topic.transcript.words:
|
||||
yield word
|
||||
|
||||
@staticmethod
|
||||
def is_word_continuation(word_prev, word):
|
||||
def is_word_continuation(self, word_prev, word):
|
||||
"""
|
||||
Return True if the word is a continuation of the previous word
|
||||
by checking if the previous word is ending with a punctuation
|
||||
@@ -65,8 +61,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||
"""
|
||||
Remove overlap in diarization results
|
||||
|
||||
@@ -91,9 +86,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_segment_without_words(
|
||||
words: list[Word], diarization: list[DiarizationSegment]
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Remove diarization segments without words
|
||||
@@ -122,8 +116,9 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||
def _diarization_merge_same_speaker(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Merge diarization contigous segments with the same speaker
|
||||
|
||||
@@ -140,10 +135,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@classmethod
|
||||
def _diarization_assign_speaker(
|
||||
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
"""
|
||||
Assign speaker to words based on diarization
|
||||
|
||||
@@ -151,7 +143,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
"""
|
||||
|
||||
word_idx = 0
|
||||
last_speaker = 0
|
||||
last_speaker = None
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
@@ -166,7 +158,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
# If it's a continuation, assign with the last speaker
|
||||
is_continuation = False
|
||||
if word_idx > 0 and word_idx < len(words) - 1:
|
||||
is_continuation = cls.is_word_continuation(
|
||||
is_continuation = self.is_word_continuation(
|
||||
*words[word_idx - 1 : word_idx + 1]
|
||||
)
|
||||
if is_continuation:
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
"""Local diarization processor using pyannote.audio library"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||
pyannote_auth_token: str | None = None,
|
||||
device: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.model_name = model_name
|
||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||
self.device = device
|
||||
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
self.model_name, use_auth_token=self.auth_token
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||
try:
|
||||
# Load audio file (audio_url is assumed to be a local file path)
|
||||
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||
self.logger.info("Running speaker diarization")
|
||||
diarization = self.diarization_pipeline(audio_input)
|
||||
|
||||
# Convert pyannote diarization output to our format
|
||||
segments = []
|
||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||
speaker_id = 0
|
||||
if speaker.startswith("SPEAKER_"):
|
||||
try:
|
||||
speaker_id = int(speaker.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback to hash-based ID if parsing fails
|
||||
speaker_id = hash(speaker) % 1000
|
||||
|
||||
segments.append(
|
||||
{
|
||||
"start": round(segment.start, 3),
|
||||
"end": round(segment.end, 3),
|
||||
"speaker": speaker_id,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||
return segments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Diarization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
from av.audio.resampler import AudioResampler
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||
frame_copy = frame.from_ndarray(
|
||||
frame.to_ndarray(),
|
||||
format=frame.format.name,
|
||||
layout=frame.layout.name,
|
||||
)
|
||||
frame_copy.sample_rate = frame.sample_rate
|
||||
frame_copy.pts = frame.pts
|
||||
frame_copy.time_base = frame.time_base
|
||||
return frame_copy
|
||||
|
||||
|
||||
class AudioDownscaleProcessor(Processor):
|
||||
"""
|
||||
Downscale audio frames to 16kHz mono format
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, target_rate: int = 16000, target_layout: str = "mono", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.target_rate = target_rate
|
||||
self.target_layout = target_layout
|
||||
self.resampler: Optional[AudioResampler] = None
|
||||
self.needs_resampling: Optional[bool] = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if self.needs_resampling is None:
|
||||
self.needs_resampling = (
|
||||
data.sample_rate != self.target_rate
|
||||
or data.layout.name != self.target_layout
|
||||
)
|
||||
|
||||
if self.needs_resampling:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16", layout=self.target_layout, rate=self.target_rate
|
||||
)
|
||||
|
||||
if not self.needs_resampling or not self.resampler:
|
||||
await self.emit(data)
|
||||
return
|
||||
|
||||
resampled_frames = self.resampler.resample(copy_frame(data))
|
||||
for resampled_frame in resampled_frames:
|
||||
await self.emit(resampled_frame)
|
||||
|
||||
async def _flush(self):
|
||||
if self.needs_resampling and self.resampler:
|
||||
final_frames = self.resampler.resample(None)
|
||||
for frame in final_frames:
|
||||
await self.emit(frame)
|
||||
@@ -16,46 +16,37 @@ class AudioMergeProcessor(Processor):
|
||||
INPUT_TYPE = list[av.AudioFrame]
|
||||
OUTPUT_TYPE = AudioFile
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def _push(self, data: list[av.AudioFrame]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# get audio information from first frame
|
||||
frame = data[0]
|
||||
output_channels = len(frame.layout.channels)
|
||||
output_sample_rate = frame.sample_rate
|
||||
output_sample_width = frame.format.bytes
|
||||
channels = len(frame.layout.channels)
|
||||
sample_rate = frame.sample_rate
|
||||
sample_width = frame.format.bytes
|
||||
|
||||
# create audio file
|
||||
uu = uuid4().hex
|
||||
fd = io.BytesIO()
|
||||
|
||||
# Use PyAV to write frames
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||
out_stream.layout = frame.layout.name
|
||||
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
|
||||
fd.seek(0)
|
||||
|
||||
# emit audio file
|
||||
audiofile = AudioFile(
|
||||
name=f"{monotonic_ns()}-{uu}.wav",
|
||||
fd=fd,
|
||||
sample_rate=output_sample_rate,
|
||||
channels=output_channels,
|
||||
sample_width=output_sample_width,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
sample_width=sample_width,
|
||||
timestamp=data[0].pts * data[0].time_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,11 +21,7 @@ from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
modal_api_key: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__()
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
|
||||
@@ -173,7 +173,6 @@ class Processor(Emitter):
|
||||
except Exception:
|
||||
self.m_processor_failure.inc()
|
||||
self.logger.exception("Error in push")
|
||||
raise
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
@@ -241,45 +240,33 @@ class ThreadedProcessor(Processor):
|
||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.queue = asyncio.Queue(maxsize=50)
|
||||
self.task: asyncio.Task | None = None
|
||||
self.queue = asyncio.Queue()
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
async def loop(self):
|
||||
try:
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||
|
||||
async def _ensure_task(self):
|
||||
if self.task is None:
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||
# to be running without having a chance to work on the task here.
|
||||
await asyncio.sleep(0)
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
|
||||
async def _push(self, data):
|
||||
await self._ensure_task()
|
||||
await self.queue.put(data)
|
||||
|
||||
async def _flush(self):
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
|
||||
class FileDiarizationInput(BaseModel):
|
||||
"""Input for file diarization containing audio URL"""
|
||||
|
||||
audio_url: str
|
||||
|
||||
|
||||
class FileDiarizationOutput(BaseModel):
|
||||
"""Output for file diarization containing speaker segments"""
|
||||
|
||||
diarization: list[DiarizationSegment]
|
||||
|
||||
|
||||
class FileDiarizationProcessor(Processor):
|
||||
"""
|
||||
Diarize complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileDiarizationInput
|
||||
OUTPUT_TYPE = FileDiarizationOutput
|
||||
|
||||
async def _push(self, data: FileDiarizationInput):
|
||||
result = await self._diarize(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
raise NotImplementedError
|
||||
@@ -1,33 +0,0 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -1,58 +0,0 @@
|
||||
"""
|
||||
File diarization implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to DIARIZATION_URL:
|
||||
|
||||
```
|
||||
POST /diarize?audio_file_url=...×tamp=0
|
||||
Authorization: Bearer <modal_api_key>
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.DIARIZATION_URL:
|
||||
raise Exception(
|
||||
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||
)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Get speaker diarization for file"""
|
||||
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
diarization_data = response.json()["diarization"]
|
||||
|
||||
return FileDiarizationOutput(diarization=diarization_data)
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||
@@ -1,65 +0,0 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript
|
||||
|
||||
|
||||
class FileTranscriptInput:
|
||||
"""Input for file transcription containing audio URL and language settings"""
|
||||
|
||||
def __init__(self, audio_url: str, language: str = "en"):
|
||||
self.audio_url = audio_url
|
||||
self.language = language
|
||||
|
||||
|
||||
class FileTranscriptProcessor(Processor):
|
||||
"""
|
||||
Transcript complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileTranscriptInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
m_transcript = Histogram(
|
||||
"file_transcript",
|
||||
"Time spent in FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_call = Counter(
|
||||
"file_transcript_call",
|
||||
"Number of calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_success = Counter(
|
||||
"file_transcript_success",
|
||||
"Number of successful calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_failure = Counter(
|
||||
"file_transcript_failure",
|
||||
"Number of failed calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_transcript = self.m_transcript.labels(name)
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: FileTranscriptInput):
|
||||
try:
|
||||
self.m_transcript_call.inc()
|
||||
with self.m_transcript.time():
|
||||
result = await self._transcript(data)
|
||||
self.m_transcript_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
raise NotImplementedError
|
||||
@@ -1,32 +0,0 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
File transcription implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://...",
|
||||
"language": "en",
|
||||
"model": "parakeet-tdt-0.6b-v2",
|
||||
"batch": true
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||
)
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Send full file to Modal for transcription"""
|
||||
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||
|
||||
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={
|
||||
"audio_file_url": data.audio_url,
|
||||
"language": data.language,
|
||||
"batch": True,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
words = [
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
# words come not in order
|
||||
words.sort(key=lambda w: w.start)
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# Register with the auto processor
|
||||
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Processor to assemble transcript with diarization results
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment, Transcript
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerInput:
|
||||
"""Input containing transcript and diarization data"""
|
||||
|
||||
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||
self.transcript = transcript
|
||||
self.diarization = diarization
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||
"""
|
||||
Assemble transcript with diarization results by applying speaker assignments
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||
result = await self._assemble(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||
"""Apply diarization to transcript words"""
|
||||
if not data.diarization:
|
||||
self.logger.info(
|
||||
"No diarization data provided, returning original transcript"
|
||||
)
|
||||
return data.transcript
|
||||
|
||||
# Reuse logic from AudioDiarizationProcessor
|
||||
processor = AudioDiarizationProcessor()
|
||||
words = data.transcript.words
|
||||
processor.assign_speaker(words, data.diarization)
|
||||
|
||||
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||
return data.transcript
|
||||
@@ -2,21 +2,18 @@ import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
from typing import Annotated
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
||||
class DiarizationSegment(TypedDict):
|
||||
"""Type definition for diarization segment containing speaker information"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
profanity_filter.set_censor("*")
|
||||
|
||||
|
||||
class AudioFile(BaseModel):
|
||||
name: str
|
||||
@@ -118,11 +115,21 @@ def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
|
||||
|
||||
class Transcript(BaseModel):
|
||||
translation: str | None = None
|
||||
words: list[Word] = []
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
def raw_text(self):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return "".join([word.text for word in self.words])
|
||||
# Censored text
|
||||
return self._get_censored_text(self.raw_text)
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
@@ -154,6 +161,12 @@ class Transcript(BaseModel):
|
||||
word.start += offset
|
||||
word.end += offset
|
||||
|
||||
def clone(self):
|
||||
words = [
|
||||
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
||||
]
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
return words_to_segments(self.words)
|
||||
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
import redis.asyncio as redis_async
|
||||
import structlog
|
||||
from redis.exceptions import LockError
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
|
||||
@@ -28,12 +21,6 @@ def get_redis_client(db=0):
|
||||
return redis_clients[db]
|
||||
|
||||
|
||||
async def get_async_redis_client(db: int = 0):
|
||||
return await redis_async.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{db}"
|
||||
)
|
||||
|
||||
|
||||
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
|
||||
"""
|
||||
Cache the result of a function in Redis.
|
||||
@@ -62,87 +49,3 @@ def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argid
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RedisAsyncLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
timeout: int = 120,
|
||||
extend_interval: int = 30,
|
||||
skip_if_locked: bool = False,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[float] = None,
|
||||
):
|
||||
self.key = f"async_lock:{key}"
|
||||
self.timeout = timeout
|
||||
self.extend_interval = extend_interval
|
||||
self.skip_if_locked = skip_if_locked
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self._lock = None
|
||||
self._redis = None
|
||||
self._extend_task = None
|
||||
self._acquired = False
|
||||
|
||||
async def _extend_lock_periodically(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.extend_interval)
|
||||
if self._lock:
|
||||
await self._lock.extend(self.timeout, replace_ttl=True)
|
||||
logger.debug("Extended lock", key=self.key)
|
||||
except LockError:
|
||||
logger.warning("Failed to extend lock", key=self.key)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error extending lock", key=self.key, error=str(e))
|
||||
break
|
||||
|
||||
async def __aenter__(self):
|
||||
self._redis = await get_async_redis_client()
|
||||
self._lock = self._redis.lock(
|
||||
self.key,
|
||||
timeout=self.timeout,
|
||||
blocking=self.blocking,
|
||||
blocking_timeout=self.blocking_timeout,
|
||||
)
|
||||
|
||||
self._acquired = await self._lock.acquire()
|
||||
|
||||
if not self._acquired:
|
||||
if self.skip_if_locked:
|
||||
logger.warning(
|
||||
"Lock already acquired by another process, skipping", key=self.key
|
||||
)
|
||||
return self
|
||||
else:
|
||||
raise LockError(f"Failed to acquire lock: {self.key}")
|
||||
|
||||
self._extend_task = asyncio.create_task(self._extend_lock_periodically())
|
||||
logger.info("Acquired lock", key=self.key)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._extend_task:
|
||||
self._extend_task.cancel()
|
||||
try:
|
||||
await self._extend_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._acquired and self._lock:
|
||||
try:
|
||||
await self._lock.release()
|
||||
logger.info("Released lock", key=self.key)
|
||||
except LockError:
|
||||
logger.debug("Lock already released or expired", key=self.key)
|
||||
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
|
||||
@property
|
||||
def acquired(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
@@ -1,78 +1,16 @@
|
||||
"""
|
||||
ICS Calendar Synchronization Service
|
||||
|
||||
This module provides services for fetching, parsing, and synchronizing ICS (iCalendar)
|
||||
calendar feeds with room booking data in the database.
|
||||
|
||||
Key Components:
|
||||
- ICSFetchService: Handles HTTP fetching and parsing of ICS calendar data
|
||||
- ICSSyncService: Manages the synchronization process between ICS feeds and database
|
||||
|
||||
Example Usage:
|
||||
# Sync a room's calendar
|
||||
room = Room(id="room1", name="conference-room", ics_url="https://cal.example.com/room.ics")
|
||||
result = await ics_sync_service.sync_room_calendar(room)
|
||||
|
||||
# Result structure:
|
||||
{
|
||||
"status": "success", # success|unchanged|error|skipped
|
||||
"hash": "abc123...", # MD5 hash of ICS content
|
||||
"events_found": 5, # Events matching this room
|
||||
"total_events": 12, # Total events in calendar within time window
|
||||
"events_created": 2, # New events added to database
|
||||
"events_updated": 3, # Existing events modified
|
||||
"events_deleted": 1 # Events soft-deleted (no longer in calendar)
|
||||
}
|
||||
|
||||
Event Matching:
|
||||
Events are matched to rooms by checking if the room's full URL appears in the
|
||||
event's LOCATION or DESCRIPTION fields. Only events within a 25-hour window
|
||||
(1 hour ago to 24 hours from now) are processed.
|
||||
|
||||
Input: ICS calendar URL (e.g., "https://calendar.google.com/calendar/ical/...")
|
||||
Output: EventData objects with structured calendar information:
|
||||
{
|
||||
"ics_uid": "event123@google.com",
|
||||
"title": "Team Meeting",
|
||||
"description": "Weekly sync meeting",
|
||||
"location": "https://meet.company.com/conference-room",
|
||||
"start_time": datetime(2024, 1, 15, 14, 0, tzinfo=UTC),
|
||||
"end_time": datetime(2024, 1, 15, 15, 0, tzinfo=UTC),
|
||||
"attendees": [
|
||||
{"email": "user@company.com", "name": "John Doe", "role": "ORGANIZER"},
|
||||
{"email": "attendee@company.com", "name": "Jane Smith", "status": "ACCEPTED"}
|
||||
],
|
||||
"ics_raw_data": "BEGIN:VEVENT\nUID:event123@google.com\n..."
|
||||
}
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
import pytz
|
||||
import structlog
|
||||
from icalendar import Calendar, Event
|
||||
from loguru import logger
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.rooms import Room, rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
EVENT_WINDOW_DELTA_START = timedelta(hours=-1)
|
||||
EVENT_WINDOW_DELTA_END = timedelta(hours=24)
|
||||
|
||||
|
||||
class SyncStatus(str, Enum):
|
||||
SUCCESS = "success"
|
||||
UNCHANGED = "unchanged"
|
||||
ERROR = "error"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class AttendeeData(TypedDict, total=False):
|
||||
email: str | None
|
||||
@@ -98,21 +36,6 @@ class SyncStats(TypedDict):
|
||||
events_deleted: int
|
||||
|
||||
|
||||
class SyncResultBase(TypedDict):
|
||||
status: SyncStatus
|
||||
|
||||
|
||||
class SyncResult(SyncResultBase, total=False):
|
||||
hash: str | None
|
||||
events_found: int
|
||||
total_events: int
|
||||
events_created: int
|
||||
events_updated: int
|
||||
events_deleted: int
|
||||
error: str | None
|
||||
reason: str | None
|
||||
|
||||
|
||||
class ICSFetchService:
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient(
|
||||
@@ -130,44 +53,46 @@ class ICSFetchService:
|
||||
|
||||
def extract_room_events(
|
||||
self, calendar: Calendar, room_name: str, room_url: str
|
||||
) -> tuple[list[EventData], int]:
|
||||
) -> list[EventData]:
|
||||
events = []
|
||||
total_events = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now + EVENT_WINDOW_DELTA_START
|
||||
window_end = now + EVENT_WINDOW_DELTA_END
|
||||
window_start = now - timedelta(hours=1)
|
||||
window_end = now + timedelta(hours=24)
|
||||
|
||||
for component in calendar.walk():
|
||||
if component.name != "VEVENT":
|
||||
continue
|
||||
|
||||
status = component.get("STATUS", "").upper()
|
||||
if status == "CANCELLED":
|
||||
continue
|
||||
|
||||
# Count total non-cancelled events in the time window
|
||||
event_data = self._parse_event(component)
|
||||
if event_data and window_start <= event_data["start_time"] <= window_end:
|
||||
total_events += 1
|
||||
if component.name == "VEVENT":
|
||||
# Skip cancelled events
|
||||
status = component.get("STATUS", "").upper()
|
||||
if status == "CANCELLED":
|
||||
continue
|
||||
|
||||
# Check if event matches this room
|
||||
if self._event_matches_room(component, room_name, room_url):
|
||||
events.append(event_data)
|
||||
event_data = self._parse_event(component)
|
||||
|
||||
return events, total_events
|
||||
# Only include events in our time window
|
||||
if (
|
||||
event_data
|
||||
and window_start <= event_data["start_time"] <= window_end
|
||||
):
|
||||
events.append(event_data)
|
||||
|
||||
return events
|
||||
|
||||
def _event_matches_room(self, event: Event, room_name: str, room_url: str) -> bool:
|
||||
location = str(event.get("LOCATION", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
|
||||
# Only match full room URL
|
||||
# XXX leaved here as a patterns, to later be extended with tinyurl or such too
|
||||
# Only match full room URL (with or without protocol)
|
||||
patterns = [
|
||||
room_url,
|
||||
room_url, # Full URL with protocol
|
||||
room_url.replace("https://", ""), # Without https protocol
|
||||
room_url.replace("http://", ""), # Without http protocol
|
||||
]
|
||||
|
||||
# Check location and description for patterns
|
||||
text_to_check = f"{location} {description}".lower()
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern.lower() in text_to_check:
|
||||
return True
|
||||
@@ -175,17 +100,20 @@ class ICSFetchService:
|
||||
return False
|
||||
|
||||
def _parse_event(self, event: Event) -> EventData | None:
|
||||
# Extract basic fields
|
||||
uid = str(event.get("UID", ""))
|
||||
summary = str(event.get("SUMMARY", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
location = str(event.get("LOCATION", ""))
|
||||
|
||||
# Parse dates
|
||||
dtstart = event.get("DTSTART")
|
||||
dtend = event.get("DTEND")
|
||||
|
||||
if not dtstart:
|
||||
return None
|
||||
|
||||
# Convert fields
|
||||
# Convert to datetime
|
||||
start_time = self._normalize_datetime(
|
||||
dtstart.dt if hasattr(dtstart, "dt") else dtstart
|
||||
)
|
||||
@@ -194,6 +122,8 @@ class ICSFetchService:
|
||||
if dtend
|
||||
else start_time + timedelta(hours=1)
|
||||
)
|
||||
|
||||
# Parse attendees
|
||||
attendees = self._parse_attendees(event)
|
||||
|
||||
# Get raw event data for storage
|
||||
@@ -211,135 +141,89 @@ class ICSFetchService:
|
||||
}
|
||||
|
||||
def _normalize_datetime(self, dt) -> datetime:
|
||||
# Ensure datetime is with timezone, if not, assume UTC
|
||||
# Handle date objects (all-day events)
|
||||
if isinstance(dt, date) and not isinstance(dt, datetime):
|
||||
# Convert to datetime at start of day in UTC
|
||||
dt = datetime.combine(dt, datetime.min.time())
|
||||
dt = pytz.UTC.localize(dt)
|
||||
elif isinstance(dt, datetime):
|
||||
# Add UTC timezone if naive
|
||||
if dt.tzinfo is None:
|
||||
dt = pytz.UTC.localize(dt)
|
||||
else:
|
||||
# Convert to UTC
|
||||
dt = dt.astimezone(pytz.UTC)
|
||||
|
||||
return dt
|
||||
|
||||
def _parse_attendees(self, event: Event) -> list[AttendeeData]:
|
||||
# Extracts attendee information from both ATTENDEE and ORGANIZER properties.
|
||||
# Handles malformed comma-separated email addresses in single ATTENDEE fields
|
||||
# by splitting them into separate attendee entries. Returns a list of attendee
|
||||
# data including email, name, status, and role information.
|
||||
final_attendees = []
|
||||
attendees = []
|
||||
|
||||
attendees = event.get("ATTENDEE", [])
|
||||
if not isinstance(attendees, list):
|
||||
attendees = [attendees]
|
||||
for att in attendees:
|
||||
email_str = str(att).replace("mailto:", "") if att else None
|
||||
# Parse ATTENDEE properties
|
||||
for attendee in event.get("ATTENDEE", []):
|
||||
if not isinstance(attendee, list):
|
||||
attendee = [attendee]
|
||||
|
||||
# Handle malformed comma-separated email addresses in a single ATTENDEE field
|
||||
if email_str and "," in email_str:
|
||||
# Split comma-separated emails and create separate attendee entries
|
||||
email_parts = [email.strip() for email in email_str.split(",")]
|
||||
for email in email_parts:
|
||||
if email and "@" in email:
|
||||
clean_email = email.replace("MAILTO:", "").replace(
|
||||
"mailto:", ""
|
||||
)
|
||||
att_data: AttendeeData = {
|
||||
"email": clean_email,
|
||||
"name": att.params.get("CN")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
"status": att.params.get("PARTSTAT")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
"role": att.params.get("ROLE")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
}
|
||||
final_attendees.append(att_data)
|
||||
else:
|
||||
# Normal single attendee
|
||||
for att in attendee:
|
||||
att_data: AttendeeData = {
|
||||
"email": email_str,
|
||||
"email": str(att).replace("mailto:", "") if att else None,
|
||||
"name": att.params.get("CN") if hasattr(att, "params") else None,
|
||||
"status": att.params.get("PARTSTAT")
|
||||
if hasattr(att, "params")
|
||||
else None,
|
||||
"role": att.params.get("ROLE") if hasattr(att, "params") else None,
|
||||
}
|
||||
final_attendees.append(att_data)
|
||||
attendees.append(att_data)
|
||||
|
||||
# Add organizer
|
||||
organizer = event.get("ORGANIZER")
|
||||
if organizer:
|
||||
org_email = (
|
||||
str(organizer).replace("mailto:", "").replace("MAILTO:", "")
|
||||
if organizer
|
||||
else None
|
||||
)
|
||||
org_data: AttendeeData = {
|
||||
"email": org_email,
|
||||
"email": str(organizer).replace("mailto:", "") if organizer else None,
|
||||
"name": organizer.params.get("CN")
|
||||
if hasattr(organizer, "params")
|
||||
else None,
|
||||
"role": "ORGANIZER",
|
||||
}
|
||||
final_attendees.append(org_data)
|
||||
attendees.append(org_data)
|
||||
|
||||
return final_attendees
|
||||
return attendees
|
||||
|
||||
|
||||
class ICSSyncService:
|
||||
def __init__(self):
|
||||
self.fetch_service = ICSFetchService()
|
||||
|
||||
async def sync_room_calendar(self, room: Room) -> SyncResult:
|
||||
async with RedisAsyncLock(
|
||||
f"ics_sync_room:{room.id}", skip_if_locked=True
|
||||
) as lock:
|
||||
if not lock.acquired:
|
||||
logger.warning("ICS sync already in progress for room", room_id=room.id)
|
||||
return {
|
||||
"status": SyncStatus.SKIPPED,
|
||||
"reason": "Sync already in progress",
|
||||
}
|
||||
|
||||
return await self._sync_room_calendar(room)
|
||||
|
||||
async def _sync_room_calendar(self, room: Room) -> SyncResult:
|
||||
async def sync_room_calendar(self, room: Room) -> dict:
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
|
||||
return {"status": "skipped", "reason": "ICS not configured"}
|
||||
|
||||
try:
|
||||
# Check if it's time to sync
|
||||
if not self._should_sync(room):
|
||||
return {"status": SyncStatus.SKIPPED, "reason": "Not time to sync yet"}
|
||||
return {"status": "skipped", "reason": "Not time to sync yet"}
|
||||
|
||||
# Fetch ICS file
|
||||
ics_content = await self.fetch_service.fetch_ics(room.ics_url)
|
||||
calendar = self.fetch_service.parse_ics(ics_content)
|
||||
|
||||
# Check if content changed
|
||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||
if room.ics_last_etag == content_hash:
|
||||
logger.info("No changes in ICS for room", room_id=room.id)
|
||||
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||
events, total_events = self.fetch_service.extract_room_events(
|
||||
calendar, room.name, room_url
|
||||
)
|
||||
return {
|
||||
"status": SyncStatus.UNCHANGED,
|
||||
"hash": content_hash,
|
||||
"events_found": len(events),
|
||||
"total_events": total_events,
|
||||
"events_created": 0,
|
||||
"events_updated": 0,
|
||||
"events_deleted": 0,
|
||||
}
|
||||
logger.info(f"No changes in ICS for room {room.id}")
|
||||
return {"status": "unchanged", "hash": content_hash}
|
||||
|
||||
# Parse calendar
|
||||
calendar = self.fetch_service.parse_ics(ics_content)
|
||||
|
||||
# Build room URL
|
||||
room_url = f"{settings.BASE_URL}/room/{room.name}"
|
||||
|
||||
# Extract matching events
|
||||
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||
events, total_events = self.fetch_service.extract_room_events(
|
||||
events = self.fetch_service.extract_room_events(
|
||||
calendar, room.name, room_url
|
||||
)
|
||||
|
||||
# Sync events to database
|
||||
sync_result = await self._sync_events_to_database(room.id, events)
|
||||
|
||||
# Update room sync metadata
|
||||
@@ -353,16 +237,15 @@ class ICSSyncService:
|
||||
)
|
||||
|
||||
return {
|
||||
"status": SyncStatus.SUCCESS,
|
||||
"status": "success",
|
||||
"hash": content_hash,
|
||||
"events_found": len(events),
|
||||
"total_events": total_events,
|
||||
**sync_result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to sync ICS for room", room_id=room.id, error=str(e))
|
||||
return {"status": SyncStatus.ERROR, "error": str(e)}
|
||||
logger.error(f"Failed to sync ICS for room {room.id}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _should_sync(self, room: Room) -> bool:
|
||||
if not room.ics_last_sync:
|
||||
@@ -377,10 +260,14 @@ class ICSSyncService:
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
# Track current event IDs
|
||||
current_ics_uids = []
|
||||
|
||||
for event_data in events:
|
||||
# Create CalendarEvent object
|
||||
calendar_event = CalendarEvent(room_id=room_id, **event_data)
|
||||
|
||||
# Upsert event
|
||||
existing = await calendar_events_controller.get_by_ics_uid(
|
||||
room_id, event_data["ics_uid"]
|
||||
)
|
||||
@@ -405,4 +292,5 @@ class ICSSyncService:
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
ics_sync_service = ICSSyncService()
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from pydantic.types import PositiveInt
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from reflector.utils.string import NonEmptyString
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
@@ -24,16 +21,11 @@ class Settings(BaseSettings):
|
||||
# local data directory
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# Audio Chunking
|
||||
# backends: silero, frames
|
||||
AUDIO_CHUNKER_BACKEND: str = "frames"
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, modal
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Audio Transcription: modal backend
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
@@ -74,14 +66,10 @@ class Settings(BaseSettings):
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
DIARIZATION_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Diarization: modal backend
|
||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization: local pyannote.audio
|
||||
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
@@ -93,8 +81,9 @@ class Settings(BaseSettings):
|
||||
AUTH_JWT_PUBLIC_KEY: str | None = "authentik.monadical.com_public.pem"
|
||||
AUTH_JWT_AUDIENCE: str | None = None
|
||||
|
||||
# API public mode
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
PUBLIC_DATA_RETENTION_DAYS: PositiveInt = 7
|
||||
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
@@ -122,7 +111,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# Whereby integration
|
||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||
WHEREBY_API_KEY: NonEmptyString | None = None
|
||||
WHEREBY_API_KEY: str | None = None
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Manual cleanup tool for old public data.
|
||||
Uses the same implementation as the Celery worker task.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
from reflector.settings import settings
|
||||
from reflector.worker.cleanup import _cleanup_old_public_data
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
async def cleanup_old_data(days: int = 7):
|
||||
logger.info(
|
||||
"Starting manual cleanup",
|
||||
retention_days=days,
|
||||
public_mode=settings.PUBLIC_MODE,
|
||||
)
|
||||
|
||||
if not settings.PUBLIC_MODE:
|
||||
logger.critical(
|
||||
"WARNING: PUBLIC_MODE is False. "
|
||||
"This tool is intended for public instances only."
|
||||
)
|
||||
raise Exception("Tool intended for public instances only")
|
||||
|
||||
result = await _cleanup_old_public_data(days=days)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Cleanup completed",
|
||||
transcripts_deleted=result.get("transcripts_deleted", 0),
|
||||
meetings_deleted=result.get("meetings_deleted", 0),
|
||||
recordings_deleted=result.get("recordings_deleted", 0),
|
||||
errors_count=len(result.get("errors", [])),
|
||||
)
|
||||
if result.get("errors"):
|
||||
logger.warning(
|
||||
"Errors encountered during cleanup:", errors=result["errors"][:10]
|
||||
)
|
||||
else:
|
||||
logger.info("Cleanup skipped or completed without results")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Clean up old transcripts and meetings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Number of days to keep data (default: 7)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.days < 1:
|
||||
logger.error("Days must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(cleanup_old_data(days=args.days))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,220 +1,105 @@
|
||||
"""
|
||||
Process audio file with diarization support
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
task_pipeline_file_process as task_pipeline_file_process,
|
||||
)
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
pipeline_process as live_pipeline_process,
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
|
||||
|
||||
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
||||
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
|
||||
serialized = []
|
||||
for topic in topics:
|
||||
topic_dict = topic.model_dump()
|
||||
serialized.append(topic_dict)
|
||||
return serialized
|
||||
|
||||
|
||||
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
||||
"""Print debug info about speakers found in topics"""
|
||||
all_speakers = set()
|
||||
for topic_dict in serialized_topics:
|
||||
for word in topic_dict.get("words", []):
|
||||
all_speakers.add(word.get("speaker", 0))
|
||||
|
||||
print(
|
||||
f"Found {len(serialized_topics)} topics with speakers: {all_speakers}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
TranscriptId = str
|
||||
|
||||
|
||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||
# ideally we want to get rid of it at some point
|
||||
async def prepare_entry(
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> TranscriptId:
|
||||
file_path = Path(source_path)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
file_path.name,
|
||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
|
||||
)
|
||||
|
||||
# pipelines expect files as upload.*
|
||||
|
||||
extension = file_path.suffix
|
||||
upload_path = transcript.data_path / f"upload{extension}"
|
||||
upload_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_path, upload_path)
|
||||
logger.info(f"Copied {source_path} to {upload_path}")
|
||||
|
||||
# pipelines expect entity status "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
return transcript.id
|
||||
|
||||
|
||||
# same reason as prepare_entry
|
||||
async def extract_result_from_entry(
|
||||
transcript_id: TranscriptId, output_path: str
|
||||
) -> None:
|
||||
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
|
||||
# assert post_final_transcript.status == "ended"
|
||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||
topics = post_final_transcript.topics
|
||||
if not topics:
|
||||
raise RuntimeError(
|
||||
f"No topics found for transcript {transcript_id} after processing"
|
||||
)
|
||||
|
||||
serialized_topics = serialize_topics(topics)
|
||||
|
||||
if output_path:
|
||||
# Write to JSON file
|
||||
with open(output_path, "w") as f:
|
||||
for topic_dict in serialized_topics:
|
||||
json.dump(topic_dict, f)
|
||||
f.write("\n")
|
||||
print(f"Results written to {output_path}", file=sys.stderr)
|
||||
else:
|
||||
# Write to stdout as JSONL
|
||||
for topic_dict in serialized_topics:
|
||||
print(json.dumps(topic_dict))
|
||||
|
||||
debug_print_speakers(serialized_topics)
|
||||
|
||||
|
||||
async def process_live_pipeline(
|
||||
transcript_id: TranscriptId,
|
||||
async def process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
):
|
||||
"""Process transcript_id with transcription and diarization"""
|
||||
# build pipeline for audio processing
|
||||
processors = [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
print(f"Processing transcript_id {transcript_id}...", file=sys.stderr)
|
||||
await live_pipeline_process(transcript_id=transcript_id)
|
||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||
|
||||
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
|
||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||
assert pre_final_transcript.status != "ended"
|
||||
|
||||
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
||||
result = live_pipeline_post(transcript_id=transcript_id)
|
||||
|
||||
# result.ready() blocks even without await; it mutates result also
|
||||
while not result.ready():
|
||||
print(f"Status: {result.state}")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
|
||||
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
||||
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
# Wait for the Celery task to complete
|
||||
while not result.ready():
|
||||
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
||||
time.sleep(2)
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
|
||||
async def process(
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
pipeline: Literal["live", "file"],
|
||||
output_path: str = None,
|
||||
):
|
||||
from reflector.db import get_database
|
||||
|
||||
database = get_database()
|
||||
# db connect is a part of ceremony
|
||||
await database.connect()
|
||||
# transcription output
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
transcript_id = await prepare_entry(
|
||||
source_path,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
pipeline_handlers = {
|
||||
"live": process_live_pipeline,
|
||||
"file": process_file_pipeline,
|
||||
}
|
||||
|
||||
handler = pipeline_handlers.get(pipeline)
|
||||
if not handler:
|
||||
raise ValueError(f"Unknown pipeline type: {pipeline}")
|
||||
|
||||
await handler(transcript_id)
|
||||
|
||||
await extract_result_from_entry(transcript_id, output_path)
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
logger.info("All done !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with speaker diarization"
|
||||
)
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
required=True,
|
||||
choices=["live", "file"],
|
||||
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
||||
parser.add_argument("--source-language", default="en")
|
||||
parser.add_argument("--target-language", default="en")
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
args = parser.parse_args()
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
# ignore some processor
|
||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
||||
return
|
||||
logger.info(f"Event: {event}")
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
|
||||
asyncio.run(
|
||||
process(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
args.pipeline,
|
||||
args.output,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
|
||||
315
server/reflector/tools/process_with_diarization.py
Normal file
315
server/reflector/tools/process_with_diarization.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file_with_diarization(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
processors += [
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="modal",
|
||||
choices=["modal"],
|
||||
help="Diarization backend to use (default: modal)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set REDIS_HOST to localhost if not provided
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioChunkerProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file_with_diarization(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
@@ -53,7 +53,7 @@ async def run_single_processor(args):
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
# ignore some processor
|
||||
if processor in ("AudioChunkerAutoProcessor", "AudioMergeProcessor"):
|
||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
||||
return
|
||||
print(f"Event: {event}")
|
||||
if output_fd:
|
||||
|
||||
96
server/reflector/tools/test_diarization.py
Normal file
96
server/reflector/tools/test_diarization.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
@vibe-generated
|
||||
Test script for the diarization CLI tool
|
||||
=========================================
|
||||
|
||||
This script helps test the diarization functionality with sample audio files.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
async def test_diarization(audio_file: str):
|
||||
"""Test the diarization functionality"""
|
||||
|
||||
# Import the processing function
|
||||
from process_with_diarization import process_audio_file_with_diarization
|
||||
|
||||
# Collect events
|
||||
events = []
|
||||
|
||||
async def event_callback(event):
|
||||
events.append({"processor": event.processor, "data": event.data})
|
||||
logger.info(f"Event from {event.processor}")
|
||||
|
||||
# Process the audio file
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
|
||||
try:
|
||||
await process_audio_file_with_diarization(
|
||||
audio_file,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
)
|
||||
|
||||
# Analyze results
|
||||
logger.info(f"Processing complete. Received {len(events)} events")
|
||||
|
||||
# Look for diarization results
|
||||
diarized_topics = []
|
||||
for event in events:
|
||||
if "TitleSummary" in event["processor"]:
|
||||
# Check if words have speaker information
|
||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||
words = event["data"].transcript.words
|
||||
if words and hasattr(words[0], "speaker"):
|
||||
speakers = set(
|
||||
w.speaker for w in words if hasattr(w, "speaker")
|
||||
)
|
||||
logger.info(
|
||||
f"Found {len(speakers)} speakers in topic: {event['data'].title}"
|
||||
)
|
||||
diarized_topics.append(event["data"])
|
||||
|
||||
if diarized_topics:
|
||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||
|
||||
# Print sample output
|
||||
sample_topic = diarized_topics[0]
|
||||
logger.info("Sample diarized output:")
|
||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||
else:
|
||||
logger.warning("No diarization results found in output")
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during processing: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python test_diarization.py <audio_file>")
|
||||
sys.exit(1)
|
||||
|
||||
audio_file = sys.argv[1]
|
||||
if not Path(audio_file).exists():
|
||||
print(f"Error: Audio file '{audio_file}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_diarization(audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,23 +0,0 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, TypeAdapter, constr
|
||||
|
||||
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
|
||||
NonEmptyString = Annotated[
|
||||
NonEmptyStringBase,
|
||||
Field(description="A non-empty string", min_length=1),
|
||||
]
|
||||
non_empty_string_adapter = TypeAdapter(NonEmptyString)
|
||||
|
||||
|
||||
def parse_non_empty_string(s: str, error: str | None = None) -> NonEmptyString:
|
||||
try:
|
||||
return non_empty_string_adapter.validate_python(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"{e}: {error}" if error else e) from e
|
||||
|
||||
|
||||
def try_parse_non_empty_string(s: str) -> NonEmptyString | None:
|
||||
if not s:
|
||||
return None
|
||||
return parse_non_empty_string(s)
|
||||
@@ -10,7 +10,6 @@ from reflector.db.meetings import (
|
||||
meeting_consent_controller,
|
||||
meetings_controller,
|
||||
)
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -42,34 +41,3 @@ async def meeting_audio_consent(
|
||||
updated_consent = await meeting_consent_controller.upsert(consent)
|
||||
|
||||
return {"status": "success", "consent_id": updated_consent.id}
|
||||
|
||||
|
||||
@router.patch("/meetings/{meeting_id}/deactivate")
|
||||
async def meeting_deactivate(
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if not meeting.is_active:
|
||||
return {"status": "success", "meeting_id": meeting_id}
|
||||
|
||||
# Only room owner or meeting creator can deactivate
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id and user_id != meeting.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only the room owner can deactivate meetings"
|
||||
)
|
||||
|
||||
await meetings_controller.update_meeting(meeting_id, is_active=False)
|
||||
|
||||
return {"status": "success", "meeting_id": meeting_id}
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import asyncpg.exceptions
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import LockError
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_database
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
from reflector.worker.webhook import test_webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
id: str
|
||||
@@ -43,31 +49,14 @@ class Room(BaseModel):
|
||||
ics_last_etag: Optional[str] = None
|
||||
|
||||
|
||||
class RoomDetails(Room):
|
||||
webhook_url: str | None
|
||||
webhook_secret: str | None
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
id: str
|
||||
room_name: str
|
||||
room_url: str
|
||||
# TODO it's not always present, | None
|
||||
host_room_url: str
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
is_locked: bool = False
|
||||
room_mode: Literal["normal", "group"] = "normal"
|
||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||
recording_trigger: Literal[
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
num_clients: int = 0
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CreateRoom(BaseModel):
|
||||
@@ -80,8 +69,6 @@ class CreateRoom(BaseModel):
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
webhook_url: str
|
||||
webhook_secret: str
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
@@ -97,86 +84,19 @@ class UpdateRoom(BaseModel):
|
||||
recording_type: Optional[str] = None
|
||||
recording_trigger: Optional[str] = None
|
||||
is_shared: Optional[bool] = None
|
||||
webhook_url: Optional[str] = None
|
||||
webhook_secret: Optional[str] = None
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: Optional[int] = None
|
||||
ics_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class CreateRoomMeeting(BaseModel):
|
||||
allow_duplicated: Optional[bool] = False
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WebhookTestResult(BaseModel):
|
||||
success: bool
|
||||
message: str = ""
|
||||
error: str = ""
|
||||
status_code: int | None = None
|
||||
response_preview: str | None = None
|
||||
|
||||
|
||||
class ICSStatus(BaseModel):
|
||||
status: Literal["enabled", "disabled"]
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
last_etag: Optional[str] = None
|
||||
events_count: int = 0
|
||||
|
||||
|
||||
class SyncStatus(str, Enum):
|
||||
success = "success"
|
||||
unchanged = "unchanged"
|
||||
error = "error"
|
||||
skipped = "skipped"
|
||||
|
||||
|
||||
class ICSSyncResult(BaseModel):
|
||||
status: SyncStatus
|
||||
hash: Optional[str] = None
|
||||
events_found: int = 0
|
||||
total_events: int = 0
|
||||
events_created: int = 0
|
||||
events_updated: int = 0
|
||||
events_deleted: int = 0
|
||||
error: Optional[str] = None
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class CalendarEventResponse(BaseModel):
|
||||
id: str
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: Optional[list[dict]] = None
|
||||
location: Optional[str] = None
|
||||
last_synced: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
@router.get("/rooms", response_model=Page[RoomDetails])
|
||||
@router.get("/rooms", response_model=Page[Room])
|
||||
async def rooms_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> list[RoomDetails]:
|
||||
) -> list[Room]:
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
@@ -190,42 +110,6 @@ async def rooms_list(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}", response_model=RoomDetails)
|
||||
async def rooms_get(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
return room
|
||||
|
||||
|
||||
@router.get("/rooms/name/{room_name}", response_model=RoomDetails)
|
||||
async def rooms_get_by_name(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
# Convert to RoomDetails format (add webhook fields if user is owner)
|
||||
room_dict = room.__dict__.copy()
|
||||
if user_id == room.user_id:
|
||||
# User is owner, include webhook details if available
|
||||
room_dict["webhook_url"] = getattr(room, "webhook_url", None)
|
||||
room_dict["webhook_secret"] = getattr(room, "webhook_secret", None)
|
||||
else:
|
||||
# Non-owner, hide webhook details
|
||||
room_dict["webhook_url"] = None
|
||||
room_dict["webhook_secret"] = None
|
||||
|
||||
return RoomDetails(**room_dict)
|
||||
|
||||
|
||||
@router.post("/rooms", response_model=Room)
|
||||
async def rooms_create(
|
||||
room: CreateRoom,
|
||||
@@ -244,15 +128,13 @@ async def rooms_create(
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
is_shared=room.is_shared,
|
||||
webhook_url=room.webhook_url,
|
||||
webhook_secret=room.webhook_secret,
|
||||
ics_url=room.ics_url,
|
||||
ics_fetch_interval=room.ics_fetch_interval,
|
||||
ics_enabled=room.ics_enabled,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/rooms/{room_id}", response_model=RoomDetails)
|
||||
@router.patch("/rooms/{room_id}", response_model=Room)
|
||||
async def rooms_update(
|
||||
room_id: str,
|
||||
info: UpdateRoom,
|
||||
@@ -283,7 +165,6 @@ async def rooms_delete(
|
||||
@router.post("/rooms/{room_name}/meeting", response_model=Meeting)
|
||||
async def rooms_create_meeting(
|
||||
room_name: str,
|
||||
info: CreateRoomMeeting,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
@@ -291,44 +172,52 @@ async def rooms_create_meeting(
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
try:
|
||||
async with RedisAsyncLock(
|
||||
f"create_meeting:{room_name}",
|
||||
timeout=30,
|
||||
extend_interval=10,
|
||||
blocking_timeout=5.0,
|
||||
) as lock:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
|
||||
|
||||
meeting = None
|
||||
if not info.allow_duplicated:
|
||||
meeting = await meetings_controller.get_active(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
if meeting is None:
|
||||
end_date = current_time + timedelta(hours=8)
|
||||
|
||||
whereby_meeting = await create_meeting("", end_date=end_date, room=room)
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
# Now try to save to database
|
||||
try:
|
||||
meeting = await meetings_controller.create(
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
|
||||
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
|
||||
user_id=user_id,
|
||||
room=room,
|
||||
)
|
||||
except (asyncpg.exceptions.UniqueViolationError, sqlite3.IntegrityError):
|
||||
# Another request already created a meeting for this room
|
||||
# Log this race condition occurrence
|
||||
logger.info(
|
||||
"Race condition detected for room %s - fetching existing meeting",
|
||||
room.name,
|
||||
)
|
||||
logger.warning(
|
||||
"Whereby meeting %s was created but not used (resource leak) for room %s",
|
||||
whereby_meeting["meetingId"],
|
||||
room.name,
|
||||
)
|
||||
|
||||
# Fetch the meeting that was created by the other request
|
||||
meeting = await meetings_controller.get_active(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
if meeting is None:
|
||||
end_date = current_time + timedelta(hours=8)
|
||||
|
||||
whereby_meeting = await create_meeting("", end_date=end_date, room=room)
|
||||
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
meeting = await meetings_controller.create(
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=parse_datetime_with_timezone(
|
||||
whereby_meeting["startDate"]
|
||||
),
|
||||
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
|
||||
room=room,
|
||||
# Edge case: meeting was created but expired/deleted between checks
|
||||
logger.error(
|
||||
"Meeting disappeared after race condition for room %s", room.name
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Unable to join meeting - please try again"
|
||||
)
|
||||
except LockError:
|
||||
logger.warning("Failed to acquire lock for room %s within timeout", room_name)
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Meeting creation in progress, please try again"
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
meeting.host_room_url = ""
|
||||
@@ -336,25 +225,22 @@ async def rooms_create_meeting(
|
||||
return meeting
|
||||
|
||||
|
||||
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
|
||||
async def rooms_test_webhook(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""Test webhook configuration by sending a sample payload."""
|
||||
user_id = user["sub"] if user else None
|
||||
class ICSStatus(BaseModel):
|
||||
status: str
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
last_etag: Optional[str] = None
|
||||
events_count: int = 0
|
||||
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id and room.user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authorized to test this room's webhook"
|
||||
)
|
||||
|
||||
result = await test_webhook(room_id)
|
||||
return WebhookTestResult(**result)
|
||||
class ICSSyncResult(BaseModel):
|
||||
status: str
|
||||
hash: Optional[str] = None
|
||||
events_found: int = 0
|
||||
events_created: int = 0
|
||||
events_updated: int = 0
|
||||
events_deleted: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/ics/sync", response_model=ICSSyncResult)
|
||||
@@ -376,6 +262,8 @@ async def rooms_sync_ics(
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
raise HTTPException(status_code=400, detail="ICS not configured for this room")
|
||||
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
|
||||
result = await ics_sync_service.sync_room_calendar(room)
|
||||
|
||||
if result["status"] == "error":
|
||||
@@ -406,6 +294,8 @@ async def rooms_ics_status(
|
||||
if room.ics_enabled and room.ics_last_sync:
|
||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
@@ -419,6 +309,21 @@ async def rooms_ics_status(
|
||||
)
|
||||
|
||||
|
||||
class CalendarEventResponse(BaseModel):
|
||||
id: str
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: Optional[list[dict]] = None
|
||||
location: Optional[str] = None
|
||||
last_synced: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings", response_model=list[CalendarEventResponse])
|
||||
async def rooms_list_meetings(
|
||||
room_name: str,
|
||||
@@ -430,6 +335,8 @@ async def rooms_list_meetings(
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
@@ -448,7 +355,7 @@ async def rooms_list_meetings(
|
||||
async def rooms_list_upcoming_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
minutes_ahead: int = 120,
|
||||
minutes_ahead: int = 30,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
@@ -456,6 +363,8 @@ async def rooms_list_upcoming_meetings(
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
room.id, minutes_ahead=minutes_ahead
|
||||
)
|
||||
@@ -473,6 +382,7 @@ async def rooms_list_active_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""List all active meetings for a room (supports multiple active meetings)"""
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
@@ -492,40 +402,13 @@ async def rooms_list_active_meetings(
|
||||
return meetings
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings/{meeting_id}", response_model=Meeting)
|
||||
async def rooms_get_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""Get a single meeting by ID within a specific room."""
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting.room_id != room.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Meeting does not belong to this room"
|
||||
)
|
||||
|
||||
if user_id != room.user_id and not room.is_shared:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/meetings/{meeting_id}/join", response_model=Meeting)
|
||||
async def rooms_join_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""Join a specific meeting by ID"""
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user