mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
83 Commits
mathieu/fi
...
mathieu/sq
| Author | SHA1 | Date | |
|---|---|---|---|
| b7f8e8ef8d | |||
| 27f19ec6ba | |||
| 2aa99fe846 | |||
| df909363f5 | |||
| ad2accb574 | |||
| a07c621bcd | |||
| f51dae8da3 | |||
| b217c7ba41 | |||
| 0b2152ea75 | |||
| e0c71c5548 | |||
| a883df0d63 | |||
| 1c9e8b9cde | |||
| 27b3b9cdee | |||
| 8ad1270229 | |||
| 617a1c8b32 | |||
| 60cc2b16ae | |||
| 606c5f5059 | |||
| 5e036d17b6 | |||
| 04a9c2f2f7 | |||
| fb5bb39716 | |||
| 4f70a7f593 | |||
| 224e40225d | |||
| 24980de4e0 | |||
| 7f178b5f9e | |||
| 0aaa42528a | |||
| 565a62900f | |||
|
|
27016e6051 | ||
| 6ddfee0b4e | |||
|
|
47716f6e5d | ||
| 1520f88e9e | |||
| 9b90aaa57f | |||
| d21b65e4e8 | |||
| 45d1608950 | |||
| 06639d4d8f | |||
| 0abcebfc94 | |||
|
|
2b723da08b | ||
| 6566e04300 | |||
| 870e860517 | |||
| 396a95d5ce | |||
| 6f680b5795 | |||
| ab859d65a6 | |||
| fa049e8d06 | |||
| 2ce7479967 | |||
| b42f7cfc60 | |||
| c546e69739 | |||
|
|
3f1fe8c9bf | ||
| 5f143fe364 | |||
|
|
79f161436e | ||
|
|
5cba5d310d | ||
| 43ea9349f5 | |||
|
|
b3a8e9739d | ||
|
|
369ecdff13 | ||
| fc363bd49b | |||
|
|
962038ee3f | ||
|
|
3b85ff3bdf | ||
|
|
cde99ca271 | ||
|
|
f81fe9948a | ||
|
|
5a5b323382 | ||
| 02a3938822 | |||
|
|
7f5a4c9ddc | ||
|
|
08d88ec349 | ||
|
|
c4d2825c81 | ||
| 0663700a61 | |||
| dc82f8bb3b | |||
| 457823e1c1 | |||
|
|
695d1a957d | ||
| ccffdba75b | |||
| 84a381220b | |||
| 5f2f0e9317 | |||
| 88ed7cfa78 | |||
| 6f0c7c1a5e | |||
| 9dfd76996f | |||
| 55cc8637c6 | |||
| f5331a2107 | |||
|
|
124ce03bf8 | ||
| 7030e0f236 | |||
| 37f0110892 | |||
| cf2896a7f4 | |||
| aabf2c2572 | |||
| 6a7b08f016 | |||
| e2736563d9 | |||
| 0f54b7782d | |||
| 359280dd34 |
5
.github/workflows/db_migrations.yml
vendored
5
.github/workflows/db_migrations.yml
vendored
@@ -2,6 +2,8 @@ name: Test Database Migrations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/migrations/**"
|
||||
- "server/reflector/db/**"
|
||||
@@ -17,6 +19,9 @@ on:
|
||||
jobs:
|
||||
test-migrations:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: db-ubuntu-latest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
|
||||
45
.github/workflows/test_next_server.yml
vendored
Normal file
45
.github/workflows/test_next_server.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Test Next Server
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "www/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "www/**"
|
||||
|
||||
jobs:
|
||||
test-next-server:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./www
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8
|
||||
|
||||
- name: Setup Node.js cache
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: './www/pnpm-lock.yaml'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test
|
||||
11
.github/workflows/test_server.yml
vendored
11
.github/workflows/test_server.yml
vendored
@@ -5,12 +5,17 @@ on:
|
||||
paths:
|
||||
- "server/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "server/**"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: pytest-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
services:
|
||||
redis:
|
||||
image: redis:6
|
||||
@@ -30,6 +35,9 @@ jobs:
|
||||
|
||||
docker-amd64:
|
||||
runs-on: linux-amd64
|
||||
concurrency:
|
||||
group: docker-amd64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
@@ -45,6 +53,9 @@ jobs:
|
||||
|
||||
docker-arm64:
|
||||
runs-on: linux-arm64
|
||||
concurrency:
|
||||
group: docker-arm64-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,3 +15,6 @@ www/REFACTOR.md
|
||||
www/reload-frontend
|
||||
server/test.sqlite
|
||||
CLAUDE.local.md
|
||||
www/.env.development
|
||||
www/.env.production
|
||||
.playwright-mcp
|
||||
|
||||
1
.gitleaksignore
Normal file
1
.gitleaksignore
Normal file
@@ -0,0 +1 @@
|
||||
b9d891d3424f371642cb032ecfd0e2564470a72c:server/tests/test_transcripts_recording_deletion.py:generic-api-key:15
|
||||
@@ -27,3 +27,8 @@ repos:
|
||||
files: ^server/
|
||||
- id: ruff-format
|
||||
files: ^server/
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
121
CHANGELOG.md
121
CHANGELOG.md
@@ -1,5 +1,126 @@
|
||||
# Changelog
|
||||
|
||||
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* TypeError on not all arguments converted during string formatting in logger ([#667](https://github.com/Monadical-SAS/reflector/issues/667)) ([565a629](https://github.com/Monadical-SAS/reflector/commit/565a62900f5a02fc946b68f9269a42190ed70ab6))
|
||||
|
||||
## [0.13.0](https://github.com/Monadical-SAS/reflector/compare/v0.12.1...v0.13.0) (2025-09-19)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* room form edit with enter ([#662](https://github.com/Monadical-SAS/reflector/issues/662)) ([47716f6](https://github.com/Monadical-SAS/reflector/commit/47716f6e5ddee952609d2fa0ffabdfa865286796))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* invalid cleanup call ([#660](https://github.com/Monadical-SAS/reflector/issues/660)) ([0abcebf](https://github.com/Monadical-SAS/reflector/commit/0abcebfc9491f87f605f21faa3e53996fafedd9a))
|
||||
|
||||
## [0.12.1](https://github.com/Monadical-SAS/reflector/compare/v0.12.0...v0.12.1) (2025-09-17)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* production blocked because having existing meeting with room_id null ([#657](https://github.com/Monadical-SAS/reflector/issues/657)) ([870e860](https://github.com/Monadical-SAS/reflector/commit/870e8605171a27155a9cbee215eeccb9a8d6c0a2))
|
||||
|
||||
## [0.12.0](https://github.com/Monadical-SAS/reflector/compare/v0.11.0...v0.12.0) (2025-09-17)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* calendar integration ([#608](https://github.com/Monadical-SAS/reflector/issues/608)) ([6f680b5](https://github.com/Monadical-SAS/reflector/commit/6f680b57954c688882c4ed49f40f161c52a00a24))
|
||||
* self-hosted gpu api ([#636](https://github.com/Monadical-SAS/reflector/issues/636)) ([ab859d6](https://github.com/Monadical-SAS/reflector/commit/ab859d65a6bded904133a163a081a651b3938d42))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* ignore player hotkeys for text inputs ([#646](https://github.com/Monadical-SAS/reflector/issues/646)) ([fa049e8](https://github.com/Monadical-SAS/reflector/commit/fa049e8d068190ce7ea015fd9fcccb8543f54a3f))
|
||||
|
||||
## [0.11.0](https://github.com/Monadical-SAS/reflector/compare/v0.10.0...v0.11.0) (2025-09-16)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* remove profanity filter that was there for conference ([#652](https://github.com/Monadical-SAS/reflector/issues/652)) ([b42f7cf](https://github.com/Monadical-SAS/reflector/commit/b42f7cfc606783afcee792590efcc78b507468ab))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* zulip and consent handler on the file pipeline ([#645](https://github.com/Monadical-SAS/reflector/issues/645)) ([5f143fe](https://github.com/Monadical-SAS/reflector/commit/5f143fe3640875dcb56c26694254a93189281d17))
|
||||
* zulip stream and topic selection in share dialog ([#644](https://github.com/Monadical-SAS/reflector/issues/644)) ([c546e69](https://github.com/Monadical-SAS/reflector/commit/c546e69739e68bb74fbc877eb62609928e5b8de6))
|
||||
|
||||
## [0.10.0](https://github.com/Monadical-SAS/reflector/compare/v0.9.0...v0.10.0) (2025-09-11)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* replace nextjs-config with environment variables ([#632](https://github.com/Monadical-SAS/reflector/issues/632)) ([369ecdf](https://github.com/Monadical-SAS/reflector/commit/369ecdff13f3862d926a9c0b87df52c9d94c4dde))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* anonymous users transcript permissions ([#621](https://github.com/Monadical-SAS/reflector/issues/621)) ([f81fe99](https://github.com/Monadical-SAS/reflector/commit/f81fe9948a9237b3e0001b2d8ca84f54d76878f9))
|
||||
* auth post ([#624](https://github.com/Monadical-SAS/reflector/issues/624)) ([cde99ca](https://github.com/Monadical-SAS/reflector/commit/cde99ca2716f84ba26798f289047732f0448742e))
|
||||
* auth post ([#626](https://github.com/Monadical-SAS/reflector/issues/626)) ([3b85ff3](https://github.com/Monadical-SAS/reflector/commit/3b85ff3bdf4fb053b103070646811bc990c0e70a))
|
||||
* auth post ([#627](https://github.com/Monadical-SAS/reflector/issues/627)) ([962038e](https://github.com/Monadical-SAS/reflector/commit/962038ee3f2a555dc3c03856be0e4409456e0996))
|
||||
* missing follow_redirects=True on modal endpoint ([#630](https://github.com/Monadical-SAS/reflector/issues/630)) ([fc363bd](https://github.com/Monadical-SAS/reflector/commit/fc363bd49b17b075e64f9186e5e0185abc325ea7))
|
||||
* sync backend and frontend token refresh logic ([#614](https://github.com/Monadical-SAS/reflector/issues/614)) ([5a5b323](https://github.com/Monadical-SAS/reflector/commit/5a5b3233820df9536da75e87ce6184a983d4713a))
|
||||
|
||||
## [0.9.0](https://github.com/Monadical-SAS/reflector/compare/v0.8.2...v0.9.0) (2025-09-06)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* frontend openapi react query ([#606](https://github.com/Monadical-SAS/reflector/issues/606)) ([c4d2825](https://github.com/Monadical-SAS/reflector/commit/c4d2825c81f81ad8835629fbf6ea8c7383f8c31b))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* align whisper transcriber api with parakeet ([#602](https://github.com/Monadical-SAS/reflector/issues/602)) ([0663700](https://github.com/Monadical-SAS/reflector/commit/0663700a615a4af69a03c96c410f049e23ec9443))
|
||||
* kv use tls explicit ([#610](https://github.com/Monadical-SAS/reflector/issues/610)) ([08d88ec](https://github.com/Monadical-SAS/reflector/commit/08d88ec349f38b0d13e0fa4cb73486c8dfd31836))
|
||||
* source kind for file processing ([#601](https://github.com/Monadical-SAS/reflector/issues/601)) ([dc82f8b](https://github.com/Monadical-SAS/reflector/commit/dc82f8bb3bdf3ab3d4088e592a30fd63907319e1))
|
||||
* token refresh locking ([#613](https://github.com/Monadical-SAS/reflector/issues/613)) ([7f5a4c9](https://github.com/Monadical-SAS/reflector/commit/7f5a4c9ddc7fd098860c8bdda2ca3b57f63ded2f))
|
||||
|
||||
## [0.8.2](https://github.com/Monadical-SAS/reflector/compare/v0.8.1...v0.8.2) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* search-logspam ([#593](https://github.com/Monadical-SAS/reflector/issues/593)) ([695d1a9](https://github.com/Monadical-SAS/reflector/commit/695d1a957d4cd862753049f9beed88836cabd5ab))
|
||||
|
||||
## [0.8.1](https://github.com/Monadical-SAS/reflector/compare/v0.8.0...v0.8.1) (2025-08-29)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* make webhook secret/url allowing null ([#590](https://github.com/Monadical-SAS/reflector/issues/590)) ([84a3812](https://github.com/Monadical-SAS/reflector/commit/84a381220bc606231d08d6f71d4babc818fa3c75))
|
||||
|
||||
## [0.8.0](https://github.com/Monadical-SAS/reflector/compare/v0.7.3...v0.8.0) (2025-08-29)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* **cleanup:** add automatic data retention for public instances ([#574](https://github.com/Monadical-SAS/reflector/issues/574)) ([6f0c7c1](https://github.com/Monadical-SAS/reflector/commit/6f0c7c1a5e751713366886c8e764c2009e12ba72))
|
||||
* **rooms:** add webhook for transcript completion ([#578](https://github.com/Monadical-SAS/reflector/issues/578)) ([88ed7cf](https://github.com/Monadical-SAS/reflector/commit/88ed7cfa7804794b9b54cad4c3facc8a98cf85fd))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* file pipeline status reporting and websocket updates ([#589](https://github.com/Monadical-SAS/reflector/issues/589)) ([9dfd769](https://github.com/Monadical-SAS/reflector/commit/9dfd76996f851cc52be54feea078adbc0816dc57))
|
||||
* Igor/evaluation ([#575](https://github.com/Monadical-SAS/reflector/issues/575)) ([124ce03](https://github.com/Monadical-SAS/reflector/commit/124ce03bf86044c18313d27228a25da4bc20c9c5))
|
||||
* optimize parakeet transcription batching algorithm ([#577](https://github.com/Monadical-SAS/reflector/issues/577)) ([7030e0f](https://github.com/Monadical-SAS/reflector/commit/7030e0f23649a8cf6c1eb6d5889684a41ce849ec))
|
||||
|
||||
## [0.7.3](https://github.com/Monadical-SAS/reflector/compare/v0.7.2...v0.7.3) (2025-08-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* cleaned repo, and get git-leaks clean ([359280d](https://github.com/Monadical-SAS/reflector/commit/359280dd340433ba4402ed69034094884c825e67))
|
||||
* restore previous behavior on live pipeline + audio downscaler ([#561](https://github.com/Monadical-SAS/reflector/issues/561)) ([9265d20](https://github.com/Monadical-SAS/reflector/commit/9265d201b590d23c628c5f19251b70f473859043))
|
||||
|
||||
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ pnpm install
|
||||
|
||||
# Copy configuration templates
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
```
|
||||
|
||||
**Development:**
|
||||
|
||||
81
README.md
81
README.md
@@ -1,43 +1,60 @@
|
||||
<div align="center">
|
||||
<img width="100" alt="image" src="https://github.com/user-attachments/assets/66fb367b-2c89-4516-9912-f47ac59c6a7f"/>
|
||||
|
||||
# Reflector
|
||||
|
||||
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
Reflector is an AI-powered audio transcription and meeting analysis platform that provides real-time transcription, speaker diarization, translation and summarization for audio content and live meetings. It works 100% with local models (whisper/parakeet, pyannote, seamless-m4t, and your local llm like phi-4).
|
||||
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/pytests.yml)
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/test_server.yml)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
</div>
|
||||
|
||||
## Screenshots
|
||||
</div>
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/3a976930-56c1-47ef-8c76-55d3864309e3" />
|
||||
<a href="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21f5597c-2930-4899-a154-f7bd61a59e97" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/bfe3bde3-08af-4426-a9a1-11ad5cd63b33" />
|
||||
<a href="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/f6b9399a-5e51-4bae-b807-59128d0a940c" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/7b60c9d0-efe4-474f-a27b-ea13bd0fabdc" />
|
||||
<a href="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/a42ce460-c1fd-4489-a995-270516193897" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4">
|
||||
<img width="700" alt="image" src="https://github.com/user-attachments/assets/21929f6d-c309-42fe-9c11-f1299e50fbd4" />
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## What is Reflector?
|
||||
|
||||
Reflector is a web application that utilizes local models to process audio content, providing:
|
||||
|
||||
- **Real-time Transcription**: Convert speech to text using [Whisper](https://github.com/openai/whisper) (multi-language) or [Parakeet](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2) (English) models
|
||||
- **Speaker Diarization**: Identify and label different speakers using [Pyannote](https://github.com/pyannote/pyannote-audio) 3.1
|
||||
- **Live Translation**: Translate audio content in real-time to many languages with [Facebook Seamless-M4T](https://github.com/facebookresearch/seamless_communication)
|
||||
- **Topic Detection & Summarization**: Extract key topics and generate concise summaries using LLMs
|
||||
- **Meeting Recording**: Create permanent records of meetings with searchable transcripts
|
||||
|
||||
Currently we provide [modal.com](https://modal.com/) gpu template to deploy.
|
||||
|
||||
## Background
|
||||
|
||||
The project architecture consists of three primary components:
|
||||
|
||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
- **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. Most reliable option is Modal deployment
|
||||
- **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
- **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
|
||||
It also uses authentik for authentication if activated, and Vercel for deployment and configuration of the front-end.
|
||||
It also uses authentik for authentication if activated.
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
@@ -72,6 +89,8 @@ Note: We currently do not have instructions for Windows users.
|
||||
|
||||
## Installation
|
||||
|
||||
*Note: we're working toward better installation, theses instructions are not accurate for now*
|
||||
|
||||
### Frontend
|
||||
|
||||
Start with `cd www`.
|
||||
@@ -80,11 +99,10 @@ Start with `cd www`.
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
Then, fill in the environment variables in `.env` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
|
||||
**Run in development mode**
|
||||
|
||||
@@ -149,3 +167,34 @@ You can manually process an audio file by calling the process tool:
|
||||
```bash
|
||||
uv run python -m reflector.tools.process path/to/audio.wav
|
||||
```
|
||||
|
||||
|
||||
## Feature Flags
|
||||
|
||||
Reflector uses environment variable-based feature flags to control application functionality. These flags allow you to enable or disable features without code changes.
|
||||
|
||||
### Available Feature Flags
|
||||
|
||||
| Feature Flag | Environment Variable |
|
||||
|-------------|---------------------|
|
||||
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
|
||||
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
|
||||
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
|
||||
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
|
||||
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
|
||||
|
||||
### Setting Feature Flags
|
||||
|
||||
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Enable user authentication requirement
|
||||
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
||||
|
||||
# Disable browse functionality
|
||||
NEXT_PUBLIC_FEATURE_BROWSE=false
|
||||
|
||||
# Enable Zulip integration
|
||||
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
||||
```
|
||||
|
||||
@@ -6,6 +6,7 @@ services:
|
||||
- 1250:1250
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -16,6 +17,7 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
@@ -26,6 +28,7 @@ services:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
|
||||
33
gpu/modal_deployments/.gitignore
vendored
Normal file
33
gpu/modal_deployments/.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
.env.*
|
||||
*.env
|
||||
*.secret
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Modal local state (if any)
|
||||
modal_mounts/
|
||||
.modal_cache/
|
||||
608
gpu/modal_deployments/reflector_transcriber.py
Normal file
608
gpu/modal_deployments/reflector_transcriber.py
Normal file
@@ -0,0 +1,608 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
MINUTES = 60 # seconds
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/models"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
model_cache = modal.Volume.from_name("models", create_if_missing=True)
|
||||
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
model_cache.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
|
||||
|
||||
model_cache.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
"fastapi==0.115.12",
|
||||
"requests",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
)
|
||||
.run_function(download_model, volumes={CACHE_PATH: model_cache})
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
|
||||
|
||||
Whisper does not require this strictly, but aligning behavior with Parakeet
|
||||
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberWhisperLive:
|
||||
"""Live transcriber class for small audio segments (A10G).
|
||||
|
||||
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
|
||||
"""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
print(f"Model is on device: {self.device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe a single uploaded audio file by filename."""
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
language: str = "en",
|
||||
):
|
||||
"""Transcribe multiple uploaded audio files and return per-file results."""
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start), 2),
|
||||
"end": round(float(w.end), 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=15 * MINUTES,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
)
|
||||
class TranscriberWhisperFile:
|
||||
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
|
||||
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLERATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments as TimeSegment using Silero VAD."""
|
||||
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield TimeSegment(
|
||||
start / float(SAMPLERATE), end / float(SAMPLERATE)
|
||||
)
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
upload_volume.reload()
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
|
||||
# Batch segments up to ~30s windows by merging contiguous VAD segments
|
||||
merged_batches: list[TimeSegment] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for segment in vad_segments(audio_array):
|
||||
seg_start, seg_end = segment.start, segment.end
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append(TimeSegment(batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
|
||||
for segment in merged_batches:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
s_idx = int(start_time * SAMPLERATE)
|
||||
e_idx = int(end_time * SAMPLERATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = pad_audio(segment, SAMPLERATE)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: dict) -> str:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
transcriber_live = TranscriberWhisperLive()
|
||||
transcriber_file = TranscriberWhisperFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(dict):
|
||||
pass
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_filenames: list[str] = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
if batch and len(upload_files) > 1:
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
language=language,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
@@ -14,10 +14,7 @@ SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/cache"
|
||||
VAD_CONFIG = {
|
||||
"max_segment_duration": 30.0,
|
||||
"batch_max_files": 10,
|
||||
"batch_max_duration": 5.0,
|
||||
"min_segment_duration": 0.02,
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
@@ -25,6 +22,37 @@ VAD_CONFIG = {
|
||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
|
||||
class TimeSegment(NamedTuple):
|
||||
"""Represents a time segment with start and end times."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class AudioSegment(NamedTuple):
|
||||
"""Represents an audio segment with timing and audio data."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
audio: any
|
||||
|
||||
|
||||
class TranscriptResult(NamedTuple):
|
||||
"""Represents a transcription result with text and word timings."""
|
||||
|
||||
text: str
|
||||
words: list["WordTiming"]
|
||||
|
||||
|
||||
class WordTiming(TypedDict):
|
||||
"""Represents a word with its timing information."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
app = modal.App("reflector-transcriber-parakeet")
|
||||
|
||||
# Volume for caching model weights
|
||||
@@ -170,12 +198,14 @@ class TranscriberParakeetLive:
|
||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"] + " ",
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
# XXX the space added here is to match the output of whisper
|
||||
# whisper add space to each words, while parakeet don't
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
@@ -211,12 +241,12 @@ class TranscriberParakeetLive:
|
||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||
text = output.text.strip()
|
||||
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"] + " ",
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(word_info["start"], 2),
|
||||
end=round(word_info["end"], 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
@@ -271,7 +301,9 @@ class TranscriberParakeetFile:
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
return audio_array
|
||||
|
||||
def vad_segment_generator(audio_array):
|
||||
def vad_segment_generator(
|
||||
audio_array,
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""Generate speech segments using VAD with start/end sample indices"""
|
||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||
window_size = VAD_CONFIG["window_size"]
|
||||
@@ -297,107 +329,121 @@ class TranscriberParakeetFile:
|
||||
start_time = start / float(SAMPLERATE)
|
||||
end_time = end / float(SAMPLERATE)
|
||||
|
||||
# Extract the actual audio segment
|
||||
audio_segment = audio_array[start:end]
|
||||
|
||||
yield (start_time, end_time, audio_segment)
|
||||
yield TimeSegment(start_time, end_time)
|
||||
start = None
|
||||
|
||||
vad_iterator.reset_states()
|
||||
|
||||
def vad_segment_filter(segments):
|
||||
"""Filter VAD segments by duration and chunk large segments"""
|
||||
min_dur = VAD_CONFIG["min_segment_duration"]
|
||||
max_dur = VAD_CONFIG["max_segment_duration"]
|
||||
def batch_speech_segments(
|
||||
segments: Generator[TimeSegment, None, None], max_duration: int
|
||||
) -> Generator[TimeSegment, None, None]:
|
||||
"""
|
||||
Input segments:
|
||||
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
↓ (max_duration=10)
|
||||
|
||||
# Skip very small segments
|
||||
if segment_duration < min_dur:
|
||||
Output batches:
|
||||
[0-8] [10-19] [20-22]
|
||||
|
||||
Note: silences are kept for better transcription, previous implementation was
|
||||
passing segments separatly, but the output was less accurate.
|
||||
"""
|
||||
batch_start_time = None
|
||||
batch_end_time = None
|
||||
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
# If segment is within max duration, yield as-is
|
||||
if segment_duration <= max_dur:
|
||||
yield (start_time, end_time, audio_segment)
|
||||
total_duration = end_time - batch_start_time
|
||||
|
||||
if total_duration <= max_duration:
|
||||
batch_end_time = end_time
|
||||
continue
|
||||
|
||||
# Chunk large segments into smaller pieces
|
||||
chunk_samples = int(max_dur * SAMPLERATE)
|
||||
current_start = start_time
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
batch_start_time = start_time
|
||||
batch_end_time = end_time
|
||||
|
||||
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
||||
chunk_audio = audio_segment[
|
||||
chunk_offset : chunk_offset + chunk_samples
|
||||
]
|
||||
if len(chunk_audio) == 0:
|
||||
break
|
||||
if batch_start_time is None or batch_end_time is None:
|
||||
return
|
||||
|
||||
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
||||
chunk_end = current_start + chunk_duration
|
||||
yield TimeSegment(batch_start_time, batch_end_time)
|
||||
|
||||
# Only yield chunks that meet minimum duration
|
||||
if chunk_duration >= min_dur:
|
||||
yield (current_start, chunk_end, chunk_audio)
|
||||
def batch_segment_to_audio_segment(
|
||||
segments: Generator[TimeSegment, None, None],
|
||||
audio_array,
|
||||
) -> Generator[AudioSegment, None, None]:
|
||||
"""Extract audio segments and apply padding for Parakeet compatibility.
|
||||
|
||||
current_start = chunk_end
|
||||
Uses pad_audio to ensure segments are at least 0.5s long, preventing
|
||||
Parakeet crashes. This padding may cause slight timing overlaps between
|
||||
segments, which are corrected by enforce_word_timing_constraints.
|
||||
"""
|
||||
for segment in segments:
|
||||
start_time, end_time = segment.start, segment.end
|
||||
start_sample = int(start_time * SAMPLERATE)
|
||||
end_sample = int(end_time * SAMPLERATE)
|
||||
audio_segment = audio_array[start_sample:end_sample]
|
||||
|
||||
def batch_segments(segments, max_files=10, max_duration=5.0):
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
padded_segment = pad_audio(audio_segment, SAMPLERATE)
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
yield AudioSegment(start_time, end_time, padded_segment)
|
||||
|
||||
if segment_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(
|
||||
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
||||
)
|
||||
padding = np.zeros(silence_samples, dtype=np.float32)
|
||||
audio_segment = np.concatenate([audio_segment, padding])
|
||||
segment_duration = VAD_CONFIG["silence_padding"]
|
||||
|
||||
batch.append((start_time, end_time, audio_segment))
|
||||
batch_duration += segment_duration
|
||||
|
||||
if len(batch) >= max_files or batch_duration >= max_duration:
|
||||
yield batch
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def transcribe_batch(model, audio_segments):
|
||||
def transcribe_batch(model, audio_segments: list) -> list:
|
||||
with NoStdStreams():
|
||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||
return outputs
|
||||
|
||||
def enforce_word_timing_constraints(
|
||||
words: list[WordTiming],
|
||||
) -> list[WordTiming]:
|
||||
"""Enforce that word end times don't exceed the start time of the next word.
|
||||
|
||||
Due to silence padding added in batch_segment_to_audio_segment for better
|
||||
transcription accuracy, word timings from different segments may overlap.
|
||||
This function ensures there are no overlaps by adjusting end times.
|
||||
"""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
|
||||
enforced_words = []
|
||||
for i, word in enumerate(words):
|
||||
enforced_word = word.copy()
|
||||
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if enforced_word["end"] > next_start:
|
||||
enforced_word["end"] = next_start
|
||||
|
||||
enforced_words.append(enforced_word)
|
||||
|
||||
return enforced_words
|
||||
|
||||
def emit_results(
|
||||
results,
|
||||
segments_info,
|
||||
batch_index,
|
||||
total_batches,
|
||||
):
|
||||
results: list,
|
||||
segments_info: list[AudioSegment],
|
||||
) -> Generator[TranscriptResult, None, None]:
|
||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||
for i, (output, (start_time, end_time, _)) in enumerate(
|
||||
zip(results, segments_info)
|
||||
):
|
||||
for i, (output, segment) in enumerate(zip(results, segments_info)):
|
||||
start_time, end_time = segment.start, segment.end
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"] + " ",
|
||||
"start": round(
|
||||
words: list[WordTiming] = [
|
||||
WordTiming(
|
||||
word=word_info["word"] + " ",
|
||||
start=round(
|
||||
word_info["start"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
"end": round(
|
||||
word_info["end"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
}
|
||||
end=round(word_info["end"] + start_time + timestamp_offset, 2),
|
||||
)
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
yield text, words
|
||||
yield TranscriptResult(text, words)
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
@@ -407,41 +453,31 @@ class TranscriberParakeetFile:
|
||||
|
||||
audio_array = load_and_convert_audio(file_path)
|
||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||
processed_duration = 0.0
|
||||
|
||||
all_text_parts = []
|
||||
all_words = []
|
||||
all_text_parts: list[str] = []
|
||||
all_words: list[WordTiming] = []
|
||||
|
||||
raw_segments = vad_segment_generator(audio_array)
|
||||
filtered_segments = vad_segment_filter(raw_segments)
|
||||
batches = batch_segments(
|
||||
filtered_segments,
|
||||
VAD_CONFIG["batch_max_files"],
|
||||
speech_segments = batch_speech_segments(
|
||||
raw_segments,
|
||||
VAD_CONFIG["batch_max_duration"],
|
||||
)
|
||||
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
||||
|
||||
batch_index = 0
|
||||
total_batches = max(
|
||||
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
||||
)
|
||||
for batch in audio_segments:
|
||||
audio_segment = batch.audio
|
||||
results = transcribe_batch(self.model, [audio_segment])
|
||||
|
||||
for batch in batches:
|
||||
batch_index += 1
|
||||
audio_segments = [seg[2] for seg in batch]
|
||||
results = transcribe_batch(self.model, audio_segments)
|
||||
|
||||
for text, words in emit_results(
|
||||
for result in emit_results(
|
||||
results,
|
||||
batch,
|
||||
batch_index,
|
||||
total_batches,
|
||||
[batch],
|
||||
):
|
||||
if not text:
|
||||
if not result.text:
|
||||
continue
|
||||
all_text_parts.append(text)
|
||||
all_words.extend(words)
|
||||
all_text_parts.append(result.text)
|
||||
all_words.extend(result.words)
|
||||
|
||||
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
||||
all_words = enforce_word_timing_constraints(all_words)
|
||||
|
||||
combined_text = " ".join(all_text_parts)
|
||||
return {"text": combined_text, "words": all_words}
|
||||
2
gpu/self_hosted/.env.example
Normal file
2
gpu/self_hosted/.env.example
Normal file
@@ -0,0 +1,2 @@
|
||||
REFLECTOR_GPU_APIKEY=
|
||||
HF_TOKEN=
|
||||
38
gpu/self_hosted/.gitignore
vendored
Normal file
38
gpu/self_hosted/.gitignore
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
cache/
|
||||
|
||||
# OS / Editor
|
||||
.DS_Store
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Env and secrets
|
||||
.env
|
||||
*.env
|
||||
*.secret
|
||||
HF_TOKEN
|
||||
REFLECTOR_GPU_APIKEY
|
||||
|
||||
# Virtual env / uv
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
uv/
|
||||
|
||||
# Build / dist
|
||||
build/
|
||||
dist/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
|
||||
# Coverage / test
|
||||
.pytest_cache/
|
||||
.coverage*
|
||||
htmlcov/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
46
gpu/self_hosted/Dockerfile
Normal file
46
gpu/self_hosted/Dockerfile
Normal file
@@ -0,0 +1,46 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
WORKDIR /tmp
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
ffmpeg \
|
||||
curl \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
wget \
|
||||
&& apt-get clean
|
||||
# Add NVIDIA CUDA repo for Debian 12 (bookworm) and install cuDNN 9 for CUDA 12
|
||||
ADD https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb /cuda-keyring.deb
|
||||
RUN dpkg -i /cuda-keyring.deb \
|
||||
&& rm /cuda-keyring.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-12-6 \
|
||||
libcublas-12-6 \
|
||||
libcudnn9-cuda-12 \
|
||||
libcudnn9-dev-cuda-12 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
|
||||
|
||||
COPY ./app /app/app
|
||||
COPY ./main.py /app/
|
||||
COPY ./runserver.sh /app/
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["sh", "/app/runserver.sh"]
|
||||
|
||||
|
||||
73
gpu/self_hosted/README.md
Normal file
73
gpu/self_hosted/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Self-hosted Model API
|
||||
|
||||
Run transcription, translation, and diarization services compatible with Reflector's GPU Model API. Works on CPU or GPU.
|
||||
|
||||
Environment variables
|
||||
|
||||
- REFLECTOR_GPU_APIKEY: Optional Bearer token. If unset, auth is disabled.
|
||||
- HF_TOKEN: Optional. Required for diarization to download pyannote pipelines
|
||||
|
||||
Requirements
|
||||
|
||||
- FFmpeg must be installed and on PATH (used for URL-based and segmented transcription)
|
||||
- Python 3.12+
|
||||
- NVIDIA GPU optional. If available, it will be used automatically
|
||||
|
||||
Local run
|
||||
Set env vars in self_hosted/.env file
|
||||
uv sync
|
||||
|
||||
uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Authentication
|
||||
|
||||
- If REFLECTOR_GPU_APIKEY is set, include header: Authorization: Bearer <key>
|
||||
|
||||
Endpoints
|
||||
|
||||
- POST /v1/audio/transcriptions
|
||||
|
||||
- multipart/form-data
|
||||
- fields: file (single file) OR files[] (multiple files), language, batch (true/false)
|
||||
- response: single { text, words, filename } or { results: [ ... ] }
|
||||
|
||||
- POST /v1/audio/transcriptions-from-url
|
||||
|
||||
- application/json
|
||||
- body: { audio_file_url, language, timestamp_offset }
|
||||
- response: { text, words }
|
||||
|
||||
- POST /translate
|
||||
|
||||
- text: query parameter
|
||||
- body (application/json): { source_language, target_language }
|
||||
- response: { text: { <src>: original, <tgt>: translated } }
|
||||
|
||||
- POST /diarize
|
||||
- query parameters: audio_file_url, timestamp (optional)
|
||||
- requires HF_TOKEN to be set (for pyannote)
|
||||
- response: { diarization: [ { start, end, speaker } ] }
|
||||
|
||||
OpenAPI docs
|
||||
|
||||
- Visit /docs when the server is running
|
||||
|
||||
Docker
|
||||
|
||||
- Not yet provided in this directory. A Dockerfile will be added later. For now, use Local run above
|
||||
|
||||
Conformance tests
|
||||
|
||||
# From this directory
|
||||
|
||||
TRANSCRIPT_URL=http://localhost:8000 \
|
||||
TRANSCRIPT_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_transcript.py
|
||||
|
||||
TRANSLATION_URL=http://localhost:8000 \
|
||||
TRANSLATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_translation.py
|
||||
|
||||
DIARIZATION_URL=http://localhost:8000 \
|
||||
DIARIZATION_API_KEY=dev-key \
|
||||
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_diarization.py
|
||||
19
gpu/self_hosted/app/auth.py
Normal file
19
gpu/self_hosted/app/auth.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
required_key = os.environ.get("REFLECTOR_GPU_APIKEY")
|
||||
if not required_key:
|
||||
return
|
||||
if apikey == required_key:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
12
gpu/self_hosted/app/config.py
Normal file
12
gpu/self_hosted/app/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLE_RATE = 16000
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
# App-level paths
|
||||
UPLOADS_PATH = Path("/tmp/whisper-uploads")
|
||||
30
gpu/self_hosted/app/factory.py
Normal file
30
gpu/self_hosted/app/factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
from .services.transcriber import WhisperService
|
||||
from .services.diarizer import PyannoteDiarizationService
|
||||
from .utils import ensure_dirs
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
ensure_dirs()
|
||||
whisper_service = WhisperService()
|
||||
whisper_service.load()
|
||||
app.state.whisper = whisper_service
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
diarization_service.load()
|
||||
app.state.diarizer = diarization_service
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
return app
|
||||
30
gpu/self_hosted/app/routers/diarization.py
Normal file
30
gpu/self_hosted/app/routers/diarization.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.diarizer import PyannoteDiarizationService
|
||||
from ..utils import download_audio_file
|
||||
|
||||
router = APIRouter(tags=["diarization"])
|
||||
|
||||
|
||||
class DiarizationSegment(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
diarization: List[DiarizationSegment]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth)], response_model=DiarizationResponse
|
||||
)
|
||||
def diarize(request: Request, audio_file_url: str, timestamp: float = 0.0):
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
diarizer: PyannoteDiarizationService = request.app.state.diarizer
|
||||
return diarizer.diarize_file(file_path, timestamp=timestamp)
|
||||
109
gpu/self_hosted/app/routers/transcription.py
Normal file
109
gpu/self_hosted/app/routers/transcription.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
from ..auth import apikey_auth
|
||||
from ..config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
from ..services.transcriber import MODEL_NAME
|
||||
from ..utils import cleanup_uploaded_files, download_audio_file
|
||||
|
||||
router = APIRouter(prefix="/v1/audio", tags=["transcription"])
|
||||
|
||||
|
||||
class WordTiming(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
text: str
|
||||
words: list[WordTiming]
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class TranscriptBatchResponse(BaseModel):
|
||||
results: list[TranscriptResult]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=Union[TranscriptResult, TranscriptBatchResponse],
|
||||
)
|
||||
def transcribe(
|
||||
request: Request,
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
uploaded_paths: list[Path] = []
|
||||
with cleanup_uploaded_files(uploaded_paths):
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1].lower()
|
||||
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = UPLOADS_PATH / unique_filename
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
uploaded_paths.append(file_path)
|
||||
|
||||
if batch and len(upload_files) > 1:
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
return {"results": results}
|
||||
|
||||
results = []
|
||||
for path in uploaded_paths:
|
||||
result = service.transcribe_file(str(path), language=language)
|
||||
result["filename"] = path.name
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcriptions-from-url",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranscriptResult,
|
||||
)
|
||||
def transcribe_from_url(
|
||||
request: Request,
|
||||
audio_file_url: str = Body(..., description="URL of the audio file to transcribe"),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
service = request.app.state.whisper
|
||||
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||
file_path = str(file_path)
|
||||
result = service.transcribe_vad_url_segment(
|
||||
file_path=file_path, timestamp_offset=timestamp_offset, language=language
|
||||
)
|
||||
return result
|
||||
28
gpu/self_hosted/app/routers/translation.py
Normal file
28
gpu/self_hosted/app/routers/translation.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
from ..services.translator import TextTranslatorService
|
||||
|
||||
router = APIRouter(tags=["translation"])
|
||||
|
||||
translator = TextTranslatorService()
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
text: Dict[str, str]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/translate",
|
||||
dependencies=[Depends(apikey_auth)],
|
||||
response_model=TranslationResponse,
|
||||
)
|
||||
def translate(
|
||||
text: str,
|
||||
source_language: str = Body("en"),
|
||||
target_language: str = Body("fr"),
|
||||
):
|
||||
return translator.translate(text, source_language, target_language)
|
||||
42
gpu/self_hosted/app/services/diarizer.py
Normal file
42
gpu/self_hosted/app/services/diarizer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=os.environ.get("HF_TOKEN"),
|
||||
)
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": words}
|
||||
208
gpu/self_hosted/app/services/transcriber.py
Normal file
208
gpu/self_hosted/app/services/transcriber.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import faster_whisper
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import HTTPException
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from ..config import SAMPLE_RATE, VAD_CONFIG
|
||||
|
||||
# Whisper configuration (service-local defaults)
|
||||
MODEL_NAME = "large-v2"
|
||||
# None delegates compute type to runtime: float16 on CUDA, int8 on CPU
|
||||
MODEL_COMPUTE_TYPE = None
|
||||
MODEL_NUM_WORKERS = 1
|
||||
CACHE_PATH = os.path.join(os.path.expanduser("~"), ".cache", "reflector-whisper")
|
||||
from ..utils import NoStdStreams
|
||||
|
||||
|
||||
class WhisperService:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.device = "cpu"
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = MODEL_COMPUTE_TYPE or (
|
||||
"float16" if self.device == "cuda" else "int8"
|
||||
)
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=compute_type,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=CACHE_PATH,
|
||||
)
|
||||
|
||||
def pad_audio(self, audio_array, sample_rate: int = SAMPLE_RATE):
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
def enforce_word_timing_constraints(self, words: list[dict]) -> list[dict]:
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
def transcribe_file(self, file_path: str, language: str = "en") -> dict:
|
||||
input_for_model: str | "object" = file_path
|
||||
try:
|
||||
audio_array, _sample_rate = librosa.load(
|
||||
file_path, sr=SAMPLE_RATE, mono=True
|
||||
)
|
||||
if len(audio_array) / float(SAMPLE_RATE) < VAD_CONFIG["silence_padding"]:
|
||||
input_for_model = self.pad_audio(audio_array, SAMPLE_RATE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
segments, _ = self.model.transcribe(
|
||||
input_for_model,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": word.word,
|
||||
"start": round(float(word.start), 2),
|
||||
"end": round(float(word.end), 2),
|
||||
}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
words = self.enforce_word_timing_constraints(words)
|
||||
return {"text": text, "words": words}
|
||||
|
||||
def transcribe_vad_url_segment(
|
||||
self, file_path: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||
) -> dict:
|
||||
def load_audio_via_ffmpeg(input_path: str, sample_rate: int) -> np.ndarray:
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"ffmpeg failed: {e}")
|
||||
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
return audio
|
||||
|
||||
def vad_segments(
|
||||
audio_array,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
iterator.reset_states()
|
||||
|
||||
audio_array = load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
for seg_start, seg_end in vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
all_text = []
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = self.pad_audio(segment, SAMPLE_RATE)
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
text = "".join(seg.text for seg in segments).strip()
|
||||
words = [
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||
}
|
||||
for seg in segments
|
||||
for w in seg.words
|
||||
]
|
||||
if text:
|
||||
all_text.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
all_words = self.enforce_word_timing_constraints(all_words)
|
||||
return {"text": " ".join(all_text), "words": all_words}
|
||||
44
gpu/self_hosted/app/services/translator.py
Normal file
44
gpu/self_hosted/app/services/translator.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
|
||||
class TextTranslatorService:
|
||||
"""Simple text-to-text translator using HuggingFace MarianMT models.
|
||||
|
||||
This mirrors the modal translator API shape but uses text translation only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
# Pick a default MarianMT model pair if available; fall back to Helsinki-NLP en->fr
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
# Minimal mapping; extend as needed
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
if self._pipeline is None:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
107
gpu/self_hosted/app/utils.py
Normal file
107
gpu/self_hosted/app/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Mapping
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
UPLOADS_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> str:
|
||||
url_path = urlparse(url).path
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return ext
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return "mp3"
|
||||
if "audio/wav" in content_type:
|
||||
return "wav"
|
||||
if "audio/mp4" in content_type:
|
||||
return "mp4"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_uploads(audio_file_url: str) -> tuple[Path, str]:
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path: Path = UPLOADS_PATH / unique_filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return file_path, audio_suffix
|
||||
|
||||
|
||||
@contextmanager
|
||||
def download_audio_file(audio_file_url: str):
|
||||
"""Download an audio file to UPLOADS_PATH and remove it after use.
|
||||
|
||||
Yields (file_path: Path, audio_suffix: str).
|
||||
"""
|
||||
file_path, audio_suffix = download_audio_to_uploads(audio_file_url)
|
||||
try:
|
||||
yield file_path, audio_suffix
|
||||
finally:
|
||||
try:
|
||||
file_path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", file_path, e)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cleanup_uploaded_files(file_paths: list[Path]):
|
||||
"""Ensure provided file paths are removed after use.
|
||||
|
||||
The provided list can be populated inside the context; all present entries
|
||||
at exit will be deleted.
|
||||
"""
|
||||
try:
|
||||
yield file_paths
|
||||
finally:
|
||||
for path in list(file_paths):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error("Error deleting temporary file %s: %s", path, e)
|
||||
10
gpu/self_hosted/compose.yml
Normal file
10
gpu/self_hosted/compose.yml
Normal file
@@ -0,0 +1,10 @@
|
||||
services:
|
||||
reflector_gpu:
|
||||
build:
|
||||
context: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./cache:/root/.cache
|
||||
3
gpu/self_hosted/main.py
Normal file
3
gpu/self_hosted/main.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.factory import create_app
|
||||
|
||||
app = create_app()
|
||||
19
gpu/self_hosted/pyproject.toml
Normal file
19
gpu/self_hosted/pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[project]
|
||||
name = "reflector-gpu"
|
||||
version = "0.1.0"
|
||||
description = "Self-hosted GPU service for speech transcription, diarization, and translation via FastAPI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"torch>=2.3.0",
|
||||
"faster-whisper>=1.1.0",
|
||||
"librosa==0.10.1",
|
||||
"numpy<2",
|
||||
"silero-vad==5.1.0",
|
||||
"transformers>=4.35.0",
|
||||
"sentencepiece",
|
||||
"pyannote.audio==3.1.0",
|
||||
"torchaudio>=2.3.0",
|
||||
]
|
||||
17
gpu/self_hosted/runserver.sh
Normal file
17
gpu/self_hosted/runserver.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
export PATH="/root/.local/bin:$PATH"
|
||||
cd /app
|
||||
|
||||
# Install Python dependencies at runtime (first run or when FORCE_SYNC=1)
|
||||
if [ ! -d "/app/.venv" ] || [ "$FORCE_SYNC" = "1" ]; then
|
||||
echo "[startup] Installing Python dependencies with uv..."
|
||||
uv sync --compile-bytecode --locked
|
||||
else
|
||||
echo "[startup] Using existing virtual environment at /app/.venv"
|
||||
fi
|
||||
|
||||
exec uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
|
||||
3013
gpu/self_hosted/uv.lock
generated
Normal file
3013
gpu/self_hosted/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
118
server/asyncio_loop_analysis.md
Normal file
118
server/asyncio_loop_analysis.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
|
||||
|
||||
## Problem Summary
|
||||
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
|
||||
- `RuntimeError: Task got Future attached to a different loop`
|
||||
- `RuntimeError: Event loop is closed`
|
||||
|
||||
## Root Cause Analysis
|
||||
|
||||
### 1. Multiple Event Loop Creation Points
|
||||
|
||||
The test environment creates event loops at different scopes:
|
||||
|
||||
1. **Session-scoped loop** (conftest.py:27-34):
|
||||
- Created once per test session
|
||||
- Used by session-scoped fixtures
|
||||
- Closed after all tests complete
|
||||
|
||||
2. **Function-scoped loop** (pytest-asyncio default):
|
||||
- Created for each async test function
|
||||
- This is the loop that runs the actual test
|
||||
- Closed immediately after test completes
|
||||
|
||||
3. **AsyncPG internal loop**:
|
||||
- AsyncPG connections store a reference to the loop they were created with
|
||||
- Used for connection lifecycle management
|
||||
|
||||
### 2. Event Loop Lifecycle Mismatch
|
||||
|
||||
The issue occurs because:
|
||||
|
||||
1. **Session fixture creates database connection** on session-scoped loop
|
||||
2. **Test runs** on function-scoped loop (different from session loop)
|
||||
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
|
||||
4. **AsyncPG connection** still references the function-scoped loop which is now closed
|
||||
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
|
||||
|
||||
### 3. Configuration Issues
|
||||
|
||||
Current pytest configuration:
|
||||
- `asyncio_mode = "auto"` in pyproject.toml
|
||||
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
|
||||
- `asyncio_default_test_loop_scope=function` (shown in test output)
|
||||
|
||||
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
|
||||
|
||||
## Solutions
|
||||
|
||||
### Option 1: Align Loop Scopes (Recommended)
|
||||
Change pytest-asyncio configuration to use consistent loop scopes:
|
||||
|
||||
```python
|
||||
# pyproject.toml
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function" # Change from session to function
|
||||
```
|
||||
|
||||
### Option 2: Use Function-Scoped Database Fixture
|
||||
Change the `session` fixture scope from session to function:
|
||||
|
||||
```python
|
||||
@pytest_asyncio.fixture # Remove scope="session"
|
||||
async def session(setup_database):
|
||||
# ... existing code ...
|
||||
```
|
||||
|
||||
### Option 3: Explicit Loop Management
|
||||
Ensure all async operations use the same loop:
|
||||
|
||||
```python
|
||||
@pytest_asyncio.fixture
|
||||
async def session(setup_database, event_loop):
|
||||
# Force using the current event loop
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=NullPool,
|
||||
connect_args={"loop": event_loop} # Pass explicit loop
|
||||
)
|
||||
# ... rest of fixture ...
|
||||
```
|
||||
|
||||
### Option 4: Upgrade pytest-asyncio
|
||||
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
|
||||
|
||||
## Immediate Workaround
|
||||
|
||||
For the test to run cleanly without the teardown error, you can:
|
||||
|
||||
1. Add explicit cleanup in the test:
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_attendee_parsing_bug(session):
|
||||
# ... existing test code ...
|
||||
|
||||
# Explicit cleanup before fixture teardown
|
||||
await session.commit() # or await session.close()
|
||||
```
|
||||
|
||||
2. Or suppress the teardown error (not recommended for production):
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def session(setup_database):
|
||||
# ... existing setup ...
|
||||
try:
|
||||
yield session
|
||||
await session.rollback()
|
||||
except RuntimeError as e:
|
||||
if "Event loop is closed" not in str(e):
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
```
|
||||
|
||||
## Recommendation
|
||||
|
||||
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.
|
||||
95
server/docs/data_retention.md
Normal file
95
server/docs/data_retention.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# Data Retention and Cleanup
|
||||
|
||||
## Overview
|
||||
|
||||
For public instances of Reflector, a data retention policy is automatically enforced to delete anonymous user data after a configurable period (default: 7 days). This ensures compliance with privacy expectations and prevents unbounded storage growth.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `PUBLIC_MODE` (bool): Must be set to `true` to enable automatic cleanup
|
||||
- `PUBLIC_DATA_RETENTION_DAYS` (int): Number of days to retain anonymous data (default: 7)
|
||||
|
||||
### What Gets Deleted
|
||||
|
||||
When data reaches the retention period, the following items are automatically removed:
|
||||
|
||||
1. **Transcripts** from anonymous users (where `user_id` is NULL):
|
||||
- Database records
|
||||
- Local files (audio.wav, audio.mp3, audio.json waveform)
|
||||
- Storage files (cloud storage if configured)
|
||||
|
||||
## Automatic Cleanup
|
||||
|
||||
### Celery Beat Schedule
|
||||
|
||||
When `PUBLIC_MODE=true`, a Celery beat task runs daily at 3 AM to clean up old data:
|
||||
|
||||
```python
|
||||
# Automatically scheduled when PUBLIC_MODE=true
|
||||
"cleanup_old_public_data": {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data",
|
||||
"schedule": crontab(hour=3, minute=0), # Daily at 3 AM
|
||||
}
|
||||
```
|
||||
|
||||
### Running the Worker
|
||||
|
||||
Ensure both Celery worker and beat scheduler are running:
|
||||
|
||||
```bash
|
||||
# Start Celery worker
|
||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
# Start Celery beat scheduler (in another terminal)
|
||||
uv run celery -A reflector.worker.app beat
|
||||
```
|
||||
|
||||
## Manual Cleanup
|
||||
|
||||
For testing or manual intervention, use the cleanup tool:
|
||||
|
||||
```bash
|
||||
# Delete data older than 7 days (default)
|
||||
uv run python -m reflector.tools.cleanup_old_data
|
||||
|
||||
# Delete data older than 30 days
|
||||
uv run python -m reflector.tools.cleanup_old_data --days 30
|
||||
```
|
||||
|
||||
Note: The manual tool uses the same implementation as the Celery worker task to ensure consistency.
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **User Data Deletion**: Only anonymous data (where `user_id` is NULL) is deleted. Authenticated user data is preserved.
|
||||
|
||||
2. **Storage Cleanup**: The system properly cleans up both local files and cloud storage when configured.
|
||||
|
||||
3. **Error Handling**: If individual deletions fail, the cleanup continues and logs errors. Failed deletions are reported in the task output.
|
||||
|
||||
4. **Public Instance Only**: The automatic cleanup task only runs when `PUBLIC_MODE=true` to prevent accidental data loss in private deployments.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the cleanup tests:
|
||||
|
||||
```bash
|
||||
uv run pytest tests/test_cleanup.py -v
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check Celery logs for cleanup task execution:
|
||||
|
||||
```bash
|
||||
# Look for cleanup task logs
|
||||
grep "cleanup_old_public_data" celery.log
|
||||
grep "Starting cleanup of old public data" celery.log
|
||||
```
|
||||
|
||||
Task statistics are logged after each run:
|
||||
- Number of transcripts deleted
|
||||
- Number of meetings deleted
|
||||
- Number of orphaned recordings deleted
|
||||
- Any errors encountered
|
||||
194
server/docs/gpu/api-transcription.md
Normal file
194
server/docs/gpu/api-transcription.md
Normal file
@@ -0,0 +1,194 @@
|
||||
## Reflector GPU Transcription API (Specification)
|
||||
|
||||
This document defines the Reflector GPU transcription API that all implementations must adhere to. Current implementations include NVIDIA Parakeet (NeMo) and Whisper (faster-whisper), both deployed on Modal.com. The API surface and response shapes are OpenAI/Whisper-compatible, so clients can switch implementations by changing only the base URL.
|
||||
|
||||
### Base URL and Authentication
|
||||
|
||||
- Example base URLs (Modal web endpoints):
|
||||
|
||||
- Parakeet: `https://<account>--reflector-transcriber-parakeet-web.modal.run`
|
||||
- Whisper: `https://<account>--reflector-transcriber-web.modal.run`
|
||||
|
||||
- All endpoints are served under `/v1` and require a Bearer token:
|
||||
|
||||
```
|
||||
Authorization: Bearer <REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
Note: To switch implementations, deploy the desired variant and point `TRANSCRIPT_URL` to its base URL. The API is identical.
|
||||
|
||||
### Supported file types
|
||||
|
||||
`mp3, mp4, mpeg, mpga, m4a, wav, webm`
|
||||
|
||||
### Models and languages
|
||||
|
||||
- Parakeet (NVIDIA NeMo): default `nvidia/parakeet-tdt-0.6b-v2`
|
||||
- Language support: only `en`. Other languages return HTTP 400.
|
||||
- Whisper (faster-whisper): default `large-v2` (or deployment-specific)
|
||||
- Language support: multilingual (per Whisper model capabilities).
|
||||
|
||||
Note: The `model` parameter is accepted by all implementations for interface parity. Some backends may treat it as informational.
|
||||
|
||||
### Endpoints
|
||||
|
||||
#### POST /v1/audio/transcriptions
|
||||
|
||||
Transcribe one or more uploaded audio files.
|
||||
|
||||
Request: multipart/form-data
|
||||
|
||||
- `file` (File) — optional. Single file to transcribe.
|
||||
- `files` (File[]) — optional. One or more files to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`.
|
||||
- Parakeet: only `en` is accepted; other values return HTTP 400
|
||||
- Whisper: model-dependent; typically multilingual
|
||||
- `batch` (boolean) — optional, defaults to `false`.
|
||||
|
||||
Notes:
|
||||
|
||||
- Provide either `file` or `files`, not both. If neither is provided, HTTP 400.
|
||||
- `batch` requires `files`; using `batch=true` without `files` returns HTTP 400.
|
||||
- Response shape for multiple files is the same regardless of `batch`.
|
||||
- Files sent to this endpoint are processed in a single pass (no VAD/chunking). This is intended for short clips (roughly ≤ 30s; depends on GPU memory/model). For longer audio, prefer `/v1/audio/transcriptions-from-url` which supports VAD-based chunking.
|
||||
|
||||
Responses
|
||||
|
||||
Single file response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 0.0, "end": 0.5 },
|
||||
{ "word": "world", "start": 0.5, "end": 1.0 }
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
Multiple files response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{"filename": "a1.mp3", "text": "...", "words": [...]},
|
||||
{"filename": "a2.mp3", "text": "...", "words": [...]}]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Word objects always include keys: `word`, `start`, `end`.
|
||||
- Some implementations may include a trailing space in `word` to match Whisper tokenization behavior; clients should trim if needed.
|
||||
|
||||
Example curl (single file):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "file=@/path/to/audio.mp3" \
|
||||
-F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
Example curl (multiple files, batch):
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-F "files=@/path/a1.mp3" -F "files=@/path/a2.mp3" \
|
||||
-F "batch=true" -F "language=en" \
|
||||
"$BASE_URL/v1/audio/transcriptions"
|
||||
```
|
||||
|
||||
#### POST /v1/audio/transcriptions-from-url
|
||||
|
||||
Transcribe a single remote audio file by URL.
|
||||
|
||||
Request: application/json
|
||||
|
||||
Body parameters:
|
||||
|
||||
- `audio_file_url` (string) — required. URL of the audio file to transcribe.
|
||||
- `model` (string) — optional. Defaults to the implementation-specific model (see above).
|
||||
- `language` (string) — optional, defaults to `en`. Parakeet only accepts `en`.
|
||||
- `timestamp_offset` (number) — optional, defaults to `0.0`. Added to each word's `start`/`end` in the response.
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{ "word": "hello", "start": 10.0, "end": 10.5 },
|
||||
{ "word": "world", "start": 10.5, "end": 11.0 }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `timestamp_offset` is added to each word’s `start`/`end` in the response.
|
||||
- Implementations may perform VAD-based chunking and batching for long-form audio; word timings are adjusted accordingly.
|
||||
|
||||
Example curl:
|
||||
|
||||
```bash
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $REFLECTOR_GPU_APIKEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0
|
||||
}' \
|
||||
"$BASE_URL/v1/audio/transcriptions-from-url"
|
||||
```
|
||||
|
||||
### Error handling
|
||||
|
||||
- 400 Bad Request
|
||||
- Parakeet: `language` other than `en`
|
||||
- Missing required parameters (`file`/`files` for upload; `audio_file_url` for URL endpoint)
|
||||
- Unsupported file extension
|
||||
- 401 Unauthorized
|
||||
- Missing or invalid Bearer token
|
||||
- 404 Not Found
|
||||
- `audio_file_url` does not exist
|
||||
|
||||
### Implementation details
|
||||
|
||||
- GPUs: A10G for small-file/live, L40S for large-file URL transcription (subject to deployment)
|
||||
- VAD chunking and segment batching; word timings adjusted and overlapping ends constrained
|
||||
- Pads very short segments (< 0.5s) to avoid model crashes on some backends
|
||||
|
||||
### Server configuration (Reflector API)
|
||||
|
||||
Set the Reflector server to use the Modal backend and point `TRANSCRIPT_URL` to your chosen deployment:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://<account>--reflector-transcriber-parakeet-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=<REFLECTOR_GPU_APIKEY>
|
||||
```
|
||||
|
||||
### Conformance tests
|
||||
|
||||
Use the pytest-based conformance tests to validate any new implementation (including self-hosted) against this spec:
|
||||
|
||||
```
|
||||
TRANSCRIPT_URL=https://<your-deployment-base> \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
|
||||
```
|
||||
212
server/docs/webhook.md
Normal file
212
server/docs/webhook.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# Reflector Webhook Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
Reflector supports webhook notifications to notify external systems when transcript processing is completed. Webhooks can be configured per room and are triggered automatically after a transcript is successfully processed.
|
||||
|
||||
## Configuration
|
||||
|
||||
Webhooks are configured at the room level with two fields:
|
||||
- `webhook_url`: The HTTPS endpoint to receive webhook notifications
|
||||
- `webhook_secret`: Optional secret key for HMAC signature verification (auto-generated if not provided)
|
||||
|
||||
## Events
|
||||
|
||||
### `transcript.completed`
|
||||
|
||||
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
|
||||
|
||||
### `test`
|
||||
|
||||
A test event that can be triggered manually to verify webhook configuration.
|
||||
|
||||
## Webhook Request Format
|
||||
|
||||
### Headers
|
||||
|
||||
All webhook requests include the following headers:
|
||||
|
||||
| Header | Description | Example |
|
||||
|--------|-------------|---------|
|
||||
| `Content-Type` | Always `application/json` | `application/json` |
|
||||
| `User-Agent` | Identifies Reflector as the source | `Reflector-Webhook/1.0` |
|
||||
| `X-Webhook-Event` | The event type | `transcript.completed` or `test` |
|
||||
| `X-Webhook-Retry` | Current retry attempt number | `0`, `1`, `2`... |
|
||||
| `X-Webhook-Signature` | HMAC signature (if secret configured) | `t=1735306800,v1=abc123...` |
|
||||
|
||||
### Signature Verification
|
||||
|
||||
If a webhook secret is configured, Reflector includes an HMAC-SHA256 signature in the `X-Webhook-Signature` header to verify the webhook authenticity.
|
||||
|
||||
The signature format is: `t={timestamp},v1={signature}`
|
||||
|
||||
To verify the signature:
|
||||
1. Extract the timestamp and signature from the header
|
||||
2. Create the signed payload: `{timestamp}.{request_body}`
|
||||
3. Compute HMAC-SHA256 of the signed payload using your webhook secret
|
||||
4. Compare the computed signature with the received signature
|
||||
|
||||
Example verification (Python):
|
||||
```python
|
||||
import hmac
|
||||
import hashlib
|
||||
|
||||
def verify_webhook_signature(payload: bytes, signature_header: str, secret: str) -> bool:
|
||||
# Parse header: "t=1735306800,v1=abc123..."
|
||||
parts = dict(part.split("=") for part in signature_header.split(","))
|
||||
timestamp = parts["t"]
|
||||
received_signature = parts["v1"]
|
||||
|
||||
# Create signed payload
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
|
||||
# Compute expected signature
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
return hmac.compare_digest(expected_signature, received_signature)
|
||||
```
|
||||
|
||||
## Event Payloads
|
||||
|
||||
### `transcript.completed` Event
|
||||
|
||||
This event includes a convenient URL for accessing the transcript:
|
||||
- `frontend_url`: Direct link to view the transcript in the web interface
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "transcript.completed",
|
||||
"event_id": "transcript.completed-abc-123-def-456",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"transcript": {
|
||||
"id": "abc-123-def-456",
|
||||
"room_id": "room-789",
|
||||
"created_at": "2025-08-27T12:00:00Z",
|
||||
"duration": 1800.5,
|
||||
"title": "Q3 Product Planning Meeting",
|
||||
"short_summary": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
|
||||
"long_summary": "The product team met to finalize the Q3 roadmap. Key decisions included...",
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone to today's meeting...",
|
||||
"topics": [
|
||||
{
|
||||
"title": "Introduction and Agenda",
|
||||
"summary": "Meeting kickoff with agenda review",
|
||||
"timestamp": 0.0,
|
||||
"duration": 120.0,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\n<v Speaker 1>Welcome everyone..."
|
||||
},
|
||||
{
|
||||
"title": "Mobile App Features Discussion",
|
||||
"summary": "Team reviewed proposed mobile app features for Q3",
|
||||
"timestamp": 120.0,
|
||||
"duration": 600.0,
|
||||
"webvtt": "WEBVTT\n\n00:02:00.000 --> 00:02:10.000\n<v Speaker 2>Let's talk about the mobile app..."
|
||||
}
|
||||
],
|
||||
"participants": [
|
||||
{
|
||||
"id": "participant-1",
|
||||
"name": "John Doe",
|
||||
"speaker": "Speaker 1"
|
||||
},
|
||||
{
|
||||
"id": "participant-2",
|
||||
"name": "Jane Smith",
|
||||
"speaker": "Speaker 2"
|
||||
}
|
||||
],
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"status": "completed",
|
||||
"frontend_url": "https://app.reflector.com/transcripts/abc-123-def-456"
|
||||
},
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `test` Event
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "test",
|
||||
"event_id": "test.2025-08-27T12:34:56.789012Z",
|
||||
"timestamp": "2025-08-27T12:34:56.789012Z",
|
||||
"message": "This is a test webhook from Reflector",
|
||||
"room": {
|
||||
"id": "room-789",
|
||||
"name": "Product Team Room"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Retry Policy
|
||||
|
||||
Webhooks are delivered with automatic retry logic to handle transient failures. When a webhook delivery fails due to server errors or network issues, Reflector will automatically retry the delivery multiple times over an extended period.
|
||||
|
||||
### Retry Mechanism
|
||||
|
||||
Reflector implements an exponential backoff strategy for webhook retries:
|
||||
|
||||
- **Initial retry delay**: 60 seconds after the first failure
|
||||
- **Exponential backoff**: Each subsequent retry waits approximately twice as long as the previous one
|
||||
- **Maximum retry interval**: 1 hour (backoff is capped at this duration)
|
||||
- **Maximum retry attempts**: 30 attempts total
|
||||
- **Total retry duration**: Retries continue for approximately 24 hours
|
||||
|
||||
### How Retries Work
|
||||
|
||||
When a webhook fails, Reflector will:
|
||||
1. Wait 60 seconds, then retry (attempt #1)
|
||||
2. If it fails again, wait ~2 minutes, then retry (attempt #2)
|
||||
3. Continue doubling the wait time up to a maximum of 1 hour between attempts
|
||||
4. Keep retrying at 1-hour intervals until successful or 30 attempts are exhausted
|
||||
|
||||
The `X-Webhook-Retry` header indicates the current retry attempt number (0 for the initial attempt, 1 for first retry, etc.), allowing your endpoint to track retry attempts.
|
||||
|
||||
### Retry Behavior by HTTP Status Code
|
||||
|
||||
| Status Code | Behavior |
|
||||
|-------------|----------|
|
||||
| 2xx (Success) | No retry, webhook marked as delivered |
|
||||
| 4xx (Client Error) | No retry, request is considered permanently failed |
|
||||
| 5xx (Server Error) | Automatic retry with exponential backoff |
|
||||
| Network/Timeout Error | Automatic retry with exponential backoff |
|
||||
|
||||
**Important Notes:**
|
||||
- Webhooks timeout after 30 seconds. If your endpoint takes longer to respond, it will be considered a timeout error and retried.
|
||||
- During the retry period (~24 hours), you may receive the same webhook multiple times if your endpoint experiences intermittent failures.
|
||||
- There is no mechanism to manually retry failed webhooks after the retry period expires.
|
||||
|
||||
## Testing Webhooks
|
||||
|
||||
You can test your webhook configuration before processing transcripts:
|
||||
|
||||
```http
|
||||
POST /v1/rooms/{room_id}/webhook/test
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"status_code": 200,
|
||||
"message": "Webhook test successful",
|
||||
"response_preview": "OK"
|
||||
}
|
||||
```
|
||||
|
||||
Or in case of failure:
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": "Webhook request timed out (10 seconds)"
|
||||
}
|
||||
```
|
||||
@@ -1,161 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import modal
|
||||
from pydantic import BaseModel
|
||||
|
||||
MODELS_DIR = "/models"
|
||||
|
||||
MODEL_NAME = "large-v2"
|
||||
MODEL_COMPUTE_TYPE: str = "float16"
|
||||
MODEL_NUM_WORKERS: int = 1
|
||||
|
||||
MINUTES = 60 # seconds
|
||||
|
||||
volume = modal.Volume.from_name("models", create_if_missing=True)
|
||||
|
||||
app = modal.App("reflector-transcriber")
|
||||
|
||||
|
||||
def download_model():
|
||||
from faster_whisper import download_model
|
||||
|
||||
volume.reload()
|
||||
|
||||
download_model(MODEL_NAME, cache_dir=MODELS_DIR)
|
||||
|
||||
volume.commit()
|
||||
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.pip_install(
|
||||
"huggingface_hub==0.27.1",
|
||||
"hf-transfer==0.1.9",
|
||||
"torch==2.5.1",
|
||||
"faster-whisper==1.1.1",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
||||
),
|
||||
}
|
||||
)
|
||||
.run_function(download_model, volumes={MODELS_DIR: volume})
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=5 * MINUTES,
|
||||
scaledown_window=5 * MINUTES,
|
||||
allow_concurrent_inputs=6,
|
||||
image=image,
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
class Transcriber:
|
||||
@modal.enter()
|
||||
def enter(self):
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = faster_whisper.WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=self.device,
|
||||
compute_type=MODEL_COMPUTE_TYPE,
|
||||
num_workers=MODEL_NUM_WORKERS,
|
||||
download_root=MODELS_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
audio_data: str,
|
||||
audio_suffix: str,
|
||||
language: str,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
with self.lock:
|
||||
segments, _ = self.model.transcribe(
|
||||
fp.name,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
|
||||
segments = list(segments)
|
||||
text = "".join(segment.text for segment in segments)
|
||||
words = [
|
||||
{"word": word.word, "start": word.start, "end": word.end}
|
||||
for segment in segments
|
||||
for word in segment.words
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=60,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={MODELS_DIR: volume},
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
transcriber = Transcriber()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
supported_file_types = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile,
|
||||
model: str = "whisper-1",
|
||||
language: Annotated[str, Body(...)] = "en",
|
||||
) -> TranscriptResponse:
|
||||
audio_data = file.file.read()
|
||||
audio_suffix = file.filename.split(".")[-1]
|
||||
assert audio_suffix in supported_file_types
|
||||
|
||||
func = transcriber.transcribe_segment.spawn(
|
||||
audio_data=audio_data,
|
||||
audio_suffix=audio_suffix,
|
||||
language=language,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from reflector.db import metadata
|
||||
from reflector.db.base import metadata
|
||||
from reflector.settings import settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add webhook fields to rooms
|
||||
|
||||
Revision ID: 0194f65cd6d3
|
||||
Revises: 5a8907fd1d78
|
||||
Create Date: 2025-08-27 09:03:19.610995
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0194f65cd6d3"
|
||||
down_revision: Union[str, None] = "5a8907fd1d78"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("webhook_url", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("webhook_secret", sa.String(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_column("webhook_secret")
|
||||
batch_op.drop_column("webhook_url")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,36 @@
|
||||
"""remove user_id from meeting table
|
||||
|
||||
Revision ID: 0ce521cda2ee
|
||||
Revises: 6dec9fb5b46c
|
||||
Create Date: 2025-09-10 12:40:55.688899
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ce521cda2ee"
|
||||
down_revision: Union[str, None] = "6dec9fb5b46c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_column("user_id")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""clean up orphaned room_id references in meeting table
|
||||
|
||||
Revision ID: 2ae3db106d4e
|
||||
Revises: def1b5867d4c
|
||||
Create Date: 2025-09-11 10:35:15.759967
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2ae3db106d4e"
|
||||
down_revision: Union[str, None] = "def1b5867d4c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set room_id to NULL for meetings that reference non-existent rooms
|
||||
op.execute("""
|
||||
UPDATE meeting
|
||||
SET room_id = NULL
|
||||
WHERE room_id IS NOT NULL
|
||||
AND room_id NOT IN (SELECT id FROM room WHERE id IS NOT NULL)
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore orphaned references - no operation needed
|
||||
pass
|
||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
|
||||
@@ -36,9 +36,7 @@ def upgrade() -> None:
|
||||
|
||||
# select only the one with duration = 0
|
||||
results = bind.execute(
|
||||
select([transcript.c.id, transcript.c.duration]).where(
|
||||
transcript.c.duration == 0
|
||||
)
|
||||
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0)
|
||||
)
|
||||
|
||||
data_dir = Path(settings.DATA_DIR)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add cascade delete to meeting consent foreign key
|
||||
|
||||
Revision ID: 5a8907fd1d78
|
||||
Revises: 0ab2d7ffaa16
|
||||
Create Date: 2025-08-26 17:26:50.945491
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5a8907fd1d78"
|
||||
down_revision: Union[str, None] = "0ab2d7ffaa16"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"), type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
batch_op.f("meeting_consent_meeting_id_fkey"),
|
||||
"meeting",
|
||||
["meeting_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,53 @@
|
||||
"""remove_one_active_meeting_per_room_constraint
|
||||
|
||||
Revision ID: 6025e9b2bef2
|
||||
Revises: 2ae3db106d4e
|
||||
Create Date: 2025-08-18 18:45:44.418392
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6025e9b2bef2"
|
||||
down_revision: Union[str, None] = "2ae3db106d4e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove the unique constraint that prevents multiple active meetings per room
|
||||
# This is needed to support calendar integration with overlapping meetings
|
||||
# Check if index exists before trying to drop it
|
||||
from alembic import context
|
||||
|
||||
if context.get_context().dialect.name == "postgresql":
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_indexes WHERE indexname = 'idx_one_active_meeting_per_room'"
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
else:
|
||||
# For SQLite, just try to drop it
|
||||
try:
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the unique constraint
|
||||
op.create_index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"meeting",
|
||||
["room_id"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
sqlite_where=sa.text("is_active = 1"),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
"""webhook url and secret null by default
|
||||
|
||||
|
||||
Revision ID: 61882a919591
|
||||
Revises: 0194f65cd6d3
|
||||
Create Date: 2025-08-29 11:46:36.738091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "61882a919591"
|
||||
down_revision: Union[str, None] = "0194f65cd6d3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,35 @@
|
||||
"""make meeting room_id required and add foreign key
|
||||
|
||||
Revision ID: 6dec9fb5b46c
|
||||
Revises: 61882a919591
|
||||
Create Date: 2025-09-10 10:47:06.006819
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6dec9fb5b46c"
|
||||
down_revision: Union[str, None] = "61882a919591"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.create_foreign_key(
|
||||
None, "room", ["room_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("meeting_room_id_fkey", type_="foreignkey")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add_grace_period_fields_to_meeting
|
||||
|
||||
Revision ID: d4a1c446458c
|
||||
Revises: 6025e9b2bef2
|
||||
Create Date: 2025-08-18 18:50:37.768052
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d4a1c446458c"
|
||||
down_revision: Union[str, None] = "6025e9b2bef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add fields to track when participants left for grace period logic
|
||||
op.add_column(
|
||||
"meeting", sa.Column("last_participant_left_at", sa.DateTime(timezone=True))
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
129
server/migrations/versions/d8e204bbf615_add_calendar.py
Normal file
129
server/migrations/versions/d8e204bbf615_add_calendar.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""add calendar
|
||||
|
||||
Revision ID: d8e204bbf615
|
||||
Revises: d4a1c446458c
|
||||
Create Date: 2025-09-10 19:56:22.295756
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d8e204bbf615"
|
||||
down_revision: Union[str, None] = "d4a1c446458c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"calendar_event",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("room_id", sa.String(), nullable=False),
|
||||
sa.Column("ics_uid", sa.Text(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("attendees", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("location", sa.Text(), nullable=True),
|
||||
sa.Column("ics_raw_data", sa.Text(), nullable=True),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"is_deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["room_id"],
|
||||
["room.id"],
|
||||
name="fk_calendar_event_room_id",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||
)
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_deleted",
|
||||
["is_deleted"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("NOT is_deleted"),
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_calendar_event_room_start", ["room_id", "start_time"], unique=False
|
||||
)
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("calendar_event_id", sa.String(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"calendar_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
batch_op.create_index(
|
||||
"idx_meeting_calendar_event", ["calendar_event_id"], unique=False
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_meeting_calendar_event_id",
|
||||
"calendar_event",
|
||||
["calendar_event_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("ics_url", sa.Text(), nullable=True))
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_fetch_interval", sa.Integer(), server_default="300", nullable=True
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"ics_enabled",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("ics_last_sync", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
batch_op.add_column(sa.Column("ics_last_etag", sa.Text(), nullable=True))
|
||||
batch_op.create_index("idx_room_ics_enabled", ["ics_enabled"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_room_ics_enabled")
|
||||
batch_op.drop_column("ics_last_etag")
|
||||
batch_op.drop_column("ics_last_sync")
|
||||
batch_op.drop_column("ics_enabled")
|
||||
batch_op.drop_column("ics_fetch_interval")
|
||||
batch_op.drop_column("ics_url")
|
||||
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_meeting_calendar_event_id", type_="foreignkey")
|
||||
batch_op.drop_index("idx_meeting_calendar_event")
|
||||
batch_op.drop_column("calendar_metadata")
|
||||
batch_op.drop_column("calendar_event_id")
|
||||
|
||||
with op.batch_alter_table("calendar_event", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_calendar_event_room_start")
|
||||
batch_op.drop_index(
|
||||
"idx_calendar_event_deleted", postgresql_where=sa.text("NOT is_deleted")
|
||||
)
|
||||
|
||||
op.drop_table("calendar_event")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,43 @@
|
||||
"""remove_grace_period_fields
|
||||
|
||||
Revision ID: dc035ff72fd5
|
||||
Revises: d8e204bbf615
|
||||
Create Date: 2025-09-11 10:36:45.197588
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc035ff72fd5"
|
||||
down_revision: Union[str, None] = "d8e204bbf615"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove grace period columns from meeting table
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back grace period columns to meeting table
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"last_participant_left_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column(
|
||||
"grace_period_minutes",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("15"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make meeting room_id nullable but keep foreign key
|
||||
|
||||
Revision ID: def1b5867d4c
|
||||
Revises: 0ce521cda2ee
|
||||
Create Date: 2025-09-11 09:42:18.697264
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "def1b5867d4c"
|
||||
down_revision: Union[str, None] = "0ce521cda2ee"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.alter_column("room_id", existing_type=sa.VARCHAR(), nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -12,7 +12,6 @@ dependencies = [
|
||||
"requests>=2.31.0",
|
||||
"aiortc>=1.5.0",
|
||||
"sortedcontainers>=2.4.0",
|
||||
"loguru>=0.7.0",
|
||||
"pydantic-settings>=2.0.2",
|
||||
"structlog>=23.1.0",
|
||||
"uvicorn[standard]>=0.23.1",
|
||||
@@ -20,14 +19,13 @@ dependencies = [
|
||||
"sentry-sdk[fastapi]>=1.29.2",
|
||||
"httpx>=0.24.1",
|
||||
"fastapi-pagination>=0.12.6",
|
||||
"databases[aiosqlite, asyncpg]>=0.7.0",
|
||||
"sqlalchemy<1.5",
|
||||
"sqlalchemy>=2.0.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"alembic>=1.11.3",
|
||||
"nltk>=3.8.1",
|
||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||
"sentencepiece>=0.1.99",
|
||||
"protobuf>=4.24.3",
|
||||
"profanityfilter>=2.0.6",
|
||||
"celery>=5.3.4",
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
@@ -40,6 +38,7 @@ dependencies = [
|
||||
"llama-index-llms-openai-like>=0.4.0",
|
||||
"pytest-env>=1.1.5",
|
||||
"webvtt-py>=0.5.0",
|
||||
"icalendar>=6.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -47,6 +46,7 @@ dev = [
|
||||
"black>=24.1.1",
|
||||
"stamina>=23.1.0",
|
||||
"pyinstrument>=4.6.1",
|
||||
"pytest-async-sqlalchemy>=0.2.0",
|
||||
]
|
||||
tests = [
|
||||
"pytest-cov>=4.1.0",
|
||||
@@ -112,14 +112,17 @@ source = ["reflector"]
|
||||
|
||||
[tool.pytest_env]
|
||||
ENVIRONMENT = "pytest"
|
||||
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_debug = true
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
markers = [
|
||||
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
||||
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
@@ -131,7 +134,7 @@ select = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||
"gpu/**.py" = ["PLC0415"]
|
||||
"gpu/modal_deployments/**.py" = ["PLC0415"]
|
||||
"reflector/tools/**.py" = ["PLC0415"]
|
||||
"migrations/versions/**.py" = ["PLC0415"]
|
||||
"tests/**.py" = ["PLC0415"]
|
||||
|
||||
20
server/reflector/asynctask.py
Normal file
20
server/reflector/asynctask.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_async():
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
coro = run_async()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
@@ -67,7 +67,8 @@ def current_user(
|
||||
try:
|
||||
payload = jwtauth.verify_token(token)
|
||||
sub = payload["sub"]
|
||||
return UserInfo(sub=sub)
|
||||
email = payload["email"]
|
||||
return UserInfo(sub=sub, email=email)
|
||||
except JWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
@@ -1,47 +1,69 @@
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import databases
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from reflector.db.base import Base as Base
|
||||
from reflector.db.base import metadata as metadata
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||
contextvars.ContextVar("database", default=None)
|
||||
)
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def get_database() -> databases.Database:
|
||||
"""Get database instance for current asyncio context"""
|
||||
db = _database_context.get()
|
||||
if db is None:
|
||||
db = databases.Database(settings.DATABASE_URL)
|
||||
_database_context.set(db)
|
||||
return db
|
||||
def get_engine() -> AsyncEngine:
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
# import models
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
global _session_factory
|
||||
if _session_factory is None:
|
||||
_session_factory = async_sessionmaker(
|
||||
get_engine(),
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
# necessary implementation to ease mocking on pytest
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async for session in _get_session():
|
||||
yield session
|
||||
|
||||
|
||||
import reflector.db.calendar_events # noqa
|
||||
import reflector.db.meetings # noqa
|
||||
import reflector.db.recordings # noqa
|
||||
import reflector.db.rooms # noqa
|
||||
import reflector.db.transcripts # noqa
|
||||
|
||||
kwargs = {}
|
||||
if "postgres" not in settings.DATABASE_URL:
|
||||
raise Exception("Only postgres database is supported in reflector")
|
||||
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||
|
||||
|
||||
@subscribers_startup.append
|
||||
async def database_connect(_):
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
get_engine()
|
||||
|
||||
|
||||
@subscribers_shutdown.append
|
||||
async def database_disconnect(_):
|
||||
database = get_database()
|
||||
await database.disconnect()
|
||||
global _engine
|
||||
if _engine:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
|
||||
237
server/reflector/db/base.py
Normal file
237
server/reflector/db/base.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TranscriptModel(Base):
|
||||
__tablename__ = "transcript"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
name: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
status: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
||||
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
title: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
||||
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
reviewed: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
audio_location: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="local"
|
||||
)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
share_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="private"
|
||||
)
|
||||
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
||||
source_kind: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False
|
||||
) # Enum will be handled separately
|
||||
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
||||
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sa.Index("idx_transcript_user_id", "user_id"),
|
||||
sa.Index("idx_transcript_created_at", "created_at"),
|
||||
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sa.Index("idx_transcript_room_id", "room_id"),
|
||||
sa.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
TranscriptModel.search_vector_en = sa.Column(
|
||||
"search_vector_en",
|
||||
TSVECTOR,
|
||||
sa.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RoomModel(Base):
|
||||
__tablename__ = "room"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
|
||||
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
zulip_auto_post: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
is_locked: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
room_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="normal"
|
||||
)
|
||||
recording_type: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="cloud"
|
||||
)
|
||||
recording_trigger: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
||||
)
|
||||
is_shared: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
|
||||
sa.Integer, server_default=sa.text("300")
|
||||
)
|
||||
ics_enabled: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
|
||||
sa.DateTime(timezone=True)
|
||||
)
|
||||
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_room_is_shared", "is_shared"),
|
||||
sa.Index("idx_room_ics_enabled", "ics_enabled"),
|
||||
)
|
||||
|
||||
|
||||
class MeetingModel(Base):
|
||||
__tablename__ = "meeting"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
||||
room_id: Mapped[Optional[str]] = mapped_column(
|
||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
|
||||
)
|
||||
is_locked: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
room_mode: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="normal"
|
||||
)
|
||||
recording_type: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="cloud"
|
||||
)
|
||||
recording_trigger: Mapped[str] = mapped_column(
|
||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
||||
)
|
||||
num_clients: Mapped[int] = mapped_column(
|
||||
sa.Integer, nullable=False, server_default=sa.text("0")
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("true")
|
||||
)
|
||||
calendar_event_id: Mapped[Optional[str]] = mapped_column(
|
||||
sa.String,
|
||||
sa.ForeignKey(
|
||||
"calendar_event.id",
|
||||
ondelete="SET NULL",
|
||||
name="fk_meeting_calendar_event_id",
|
||||
),
|
||||
)
|
||||
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||
)
|
||||
|
||||
|
||||
class MeetingConsentModel(Base):
|
||||
__tablename__ = "meeting_consent"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
meeting_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
||||
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
|
||||
consent_timestamp: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class RecordingModel(Base):
|
||||
__tablename__ = "recording"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
meeting_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
|
||||
|
||||
|
||||
class CalendarEventModel(Base):
|
||||
__tablename__ = "calendar_event"
|
||||
|
||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
||||
room_id: Mapped[str] = mapped_column(
|
||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
title: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
description: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
start_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
end_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
|
||||
location: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
last_synced: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
||||
)
|
||||
|
||||
|
||||
metadata = Base.metadata
|
||||
187
server/reflector/db/calendar_events.py
Normal file
187
server/reflector/db/calendar_events.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.base import CalendarEventModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: list[dict[str, Any]] | None = None
|
||||
location: str | None = None
|
||||
ics_raw_data: str | None = None
|
||||
last_synced: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class CalendarEventController:
|
||||
async def get_upcoming_events(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_id: str,
|
||||
current_time: datetime,
|
||||
buffer_minutes: int = 15,
|
||||
) -> list[CalendarEvent]:
|
||||
buffer_time = current_time + timedelta(minutes=buffer_minutes)
|
||||
|
||||
query = (
|
||||
select(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.start_time <= buffer_time,
|
||||
CalendarEventModel.end_time > current_time,
|
||||
)
|
||||
)
|
||||
.order_by(CalendarEventModel.start_time)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, event_id: str
|
||||
) -> CalendarEvent | None:
|
||||
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def get_by_ics_uid(
|
||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
||||
) -> CalendarEvent | None:
|
||||
query = select(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.ics_uid == ics_uid,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def upsert(
|
||||
self, session: AsyncSession, event: CalendarEvent
|
||||
) -> CalendarEvent:
|
||||
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
|
||||
|
||||
if existing:
|
||||
event.updated_at = datetime.now(timezone.utc)
|
||||
query = (
|
||||
update(CalendarEventModel)
|
||||
.where(CalendarEventModel.id == existing.id)
|
||||
.values(**event.model_dump(exclude={"id"}))
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return event
|
||||
else:
|
||||
new_event = CalendarEventModel(**event.model_dump())
|
||||
session.add(new_event)
|
||||
await session.commit()
|
||||
return event
|
||||
|
||||
async def delete_old_events(
|
||||
self, session: AsyncSession, room_id: str, cutoff_date: datetime
|
||||
) -> int:
|
||||
query = delete(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.end_time < cutoff_date,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
async def delete_events_not_in_list(
|
||||
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
|
||||
) -> int:
|
||||
if not keep_ics_uids:
|
||||
query = delete(CalendarEventModel).where(
|
||||
CalendarEventModel.room_id == room_id
|
||||
)
|
||||
else:
|
||||
query = delete(CalendarEventModel).where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
async def get_by_room(
|
||||
self, session: AsyncSession, room_id: str, include_deleted: bool = True
|
||||
) -> list[CalendarEvent]:
|
||||
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id)
|
||||
if not include_deleted:
|
||||
query = query.where(CalendarEventModel.is_deleted == False)
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_upcoming(
|
||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
||||
) -> list[CalendarEvent]:
|
||||
now = datetime.now(timezone.utc)
|
||||
buffer_time = now + timedelta(minutes=minutes_ahead)
|
||||
|
||||
query = (
|
||||
select(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.start_time <= buffer_time,
|
||||
CalendarEventModel.end_time > now,
|
||||
CalendarEventModel.is_deleted == False,
|
||||
)
|
||||
)
|
||||
.order_by(CalendarEventModel.start_time)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def soft_delete_missing(
|
||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
||||
) -> int:
|
||||
query = (
|
||||
update(CalendarEventModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
CalendarEventModel.room_id == room_id,
|
||||
CalendarEventModel.ics_uid.notin_(current_ics_uids)
|
||||
if current_ics_uids
|
||||
else True,
|
||||
CalendarEventModel.end_time > datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
.values(is_deleted=True)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
calendar_events_controller = CalendarEventController()
|
||||
@@ -1,67 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import MeetingConsentModel, MeetingModel
|
||||
from reflector.db.rooms import Room
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
meetings = sa.Table(
|
||||
"meeting",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("room_name", sa.String),
|
||||
sa.Column("room_url", sa.String),
|
||||
sa.Column("host_room_url", sa.String),
|
||||
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("room_id", sa.String),
|
||||
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
|
||||
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
|
||||
sa.Column(
|
||||
"recording_trigger",
|
||||
sa.String,
|
||||
nullable=False,
|
||||
server_default="automatic-2nd-participant",
|
||||
),
|
||||
sa.Column(
|
||||
"num_clients",
|
||||
sa.Integer,
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"room_id",
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
),
|
||||
)
|
||||
|
||||
meeting_consent = sa.Table(
|
||||
"meeting_consent",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
class MeetingConsent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
meeting_id: str
|
||||
user_id: str | None = None
|
||||
@@ -70,14 +22,15 @@ class MeetingConsent(BaseModel):
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
room_name: str
|
||||
room_url: str
|
||||
host_room_url: str
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_id: str | None
|
||||
is_locked: bool = False
|
||||
room_mode: Literal["normal", "group"] = "normal"
|
||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||
@@ -85,23 +38,25 @@ class Meeting(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
num_clients: int = 0
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MeetingController:
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
id: str,
|
||||
room_name: str,
|
||||
room_url: str,
|
||||
host_room_url: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: str,
|
||||
room: Room,
|
||||
calendar_event_id: str | None = None,
|
||||
calendar_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new meeting
|
||||
"""
|
||||
meeting = Meeting(
|
||||
id=id,
|
||||
room_name=room_name,
|
||||
@@ -109,148 +64,206 @@ class MeetingController:
|
||||
host_room_url=host_room_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_id=user_id,
|
||||
room_id=room.id,
|
||||
is_locked=room.is_locked,
|
||||
room_mode=room.room_mode,
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
calendar_event_id=calendar_event_id,
|
||||
calendar_metadata=calendar_metadata,
|
||||
)
|
||||
query = meetings.insert().values(**meeting.model_dump())
|
||||
await get_database().execute(query)
|
||||
new_meeting = MeetingModel(**meeting.model_dump())
|
||||
session.add(new_meeting)
|
||||
await session.commit()
|
||||
return meeting
|
||||
|
||||
async def get_all_active(self) -> list[Meeting]:
|
||||
"""
|
||||
Get active meetings.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.is_active)
|
||||
return await get_database().fetch_all(query)
|
||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
||||
result = await session.execute(query)
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_room_name(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_name: str,
|
||||
) -> Meeting:
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by room name.
|
||||
For backward compatibility, returns the most recent meeting.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = (
|
||||
select(MeetingModel)
|
||||
.where(MeetingModel.room_name == room_name)
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
|
||||
async def get_active(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get latest active meeting for a room.
|
||||
For backward compatibility, returns the most recent active meeting.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
meetings.select()
|
||||
select(MeetingModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
meetings.c.room_id == room.id,
|
||||
meetings.c.end_date > current_time,
|
||||
meetings.c.is_active,
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
.order_by(end_date.desc())
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
return Meeting(**result)
|
||||
async def get_all_active_for_room(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
) -> list[Meeting]:
|
||||
query = (
|
||||
select(MeetingModel)
|
||||
.where(
|
||||
sa.and_(
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||
async def get_active_by_calendar_event(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room: Room,
|
||||
calendar_event_id: str,
|
||||
current_time: datetime,
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by id
|
||||
Get active meeting for a specific calendar event.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(MeetingModel).where(
|
||||
sa.and_(
|
||||
MeetingModel.room_id == room.id,
|
||||
MeetingModel.calendar_event_id == calendar_event_id,
|
||||
MeetingModel.end_date > current_time,
|
||||
MeetingModel.is_active,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Meeting:
|
||||
"""
|
||||
Get a meeting by ID for HTTP request.
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
||||
) -> Meeting | None:
|
||||
query = select(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
async def get_by_calendar_event(
|
||||
self, session: AsyncSession, calendar_event_id: str
|
||||
) -> Meeting | None:
|
||||
query = select(MeetingModel).where(
|
||||
MeetingModel.calendar_event_id == calendar_event_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
meeting = Meeting(**result)
|
||||
if result["user_id"] != user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||
await get_database().execute(query)
|
||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
||||
query = (
|
||||
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs)
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class MeetingConsentController:
|
||||
async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id
|
||||
async def get_by_meeting_id(
|
||||
self, session: AsyncSession, meeting_id: str
|
||||
) -> list[MeetingConsent]:
|
||||
query = select(MeetingConsentModel).where(
|
||||
MeetingConsentModel.meeting_id == meeting_id
|
||||
)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [MeetingConsent(**result) for result in results]
|
||||
result = await session.execute(query)
|
||||
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_meeting_and_user(
|
||||
self, meeting_id: str, user_id: str
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
||||
) -> MeetingConsent | None:
|
||||
"""Get existing consent for a specific user and meeting"""
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.user_id == user_id,
|
||||
query = select(MeetingConsentModel).where(
|
||||
sa.and_(
|
||||
MeetingConsentModel.meeting_id == meeting_id,
|
||||
MeetingConsentModel.user_id == user_id,
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if result is None:
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return MeetingConsent(**result) if result else None
|
||||
return MeetingConsent.model_validate(row)
|
||||
|
||||
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
|
||||
"""Create new consent or update existing one for authenticated users"""
|
||||
async def upsert(
|
||||
self, session: AsyncSession, consent: MeetingConsent
|
||||
) -> MeetingConsent:
|
||||
if consent.user_id:
|
||||
# For authenticated users, check if consent already exists
|
||||
# not transactional but we're ok with that; the consents ain't deleted anyways
|
||||
existing = await self.get_by_meeting_and_user(
|
||||
consent.meeting_id, consent.user_id
|
||||
session, consent.meeting_id, consent.user_id
|
||||
)
|
||||
if existing:
|
||||
query = (
|
||||
meeting_consent.update()
|
||||
.where(meeting_consent.c.id == existing.id)
|
||||
update(MeetingConsentModel)
|
||||
.where(MeetingConsentModel.id == existing.id)
|
||||
.values(
|
||||
consent_given=consent.consent_given,
|
||||
consent_timestamp=consent.consent_timestamp,
|
||||
)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
existing.consent_given = consent.consent_given
|
||||
existing.consent_timestamp = consent.consent_timestamp
|
||||
return existing
|
||||
|
||||
query = meeting_consent.insert().values(**consent.model_dump())
|
||||
await get_database().execute(query)
|
||||
new_consent = MeetingConsentModel(**consent.model_dump())
|
||||
session.add(new_consent)
|
||||
await session.commit()
|
||||
return consent
|
||||
|
||||
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool:
|
||||
"""Check if any participant denied consent for this meeting"""
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.consent_given.is_(False),
|
||||
query = select(MeetingConsentModel).where(
|
||||
sa.and_(
|
||||
MeetingConsentModel.meeting_id == meeting_id,
|
||||
MeetingConsentModel.consent_given.is_(False),
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
return result is not None
|
||||
)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
meetings_controller = MeetingController()
|
||||
|
||||
@@ -1,61 +1,79 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RecordingModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
recordings = sa.Table(
|
||||
"recording",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("bucket_name", sa.String, nullable=False),
|
||||
sa.Column("object_key", sa.String, nullable=False),
|
||||
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String,
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column("meeting_id", sa.String),
|
||||
sa.Index("idx_recording_meeting_id", "meeting_id"),
|
||||
)
|
||||
|
||||
|
||||
class Recording(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
bucket_name: str
|
||||
meeting_id: str
|
||||
url: str
|
||||
object_key: str
|
||||
recorded_at: datetime
|
||||
status: Literal["pending", "processing", "completed", "failed"] = "pending"
|
||||
meeting_id: str | None = None
|
||||
duration: float | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RecordingController:
|
||||
async def create(self, recording: Recording):
|
||||
query = recordings.insert().values(**recording.model_dump())
|
||||
await get_database().execute(query)
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
meeting_id: str,
|
||||
url: str,
|
||||
object_key: str,
|
||||
duration: float | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
if created_at is None:
|
||||
created_at = datetime.now(timezone.utc)
|
||||
|
||||
recording = Recording(
|
||||
meeting_id=meeting_id,
|
||||
url=url,
|
||||
object_key=object_key,
|
||||
duration=duration,
|
||||
created_at=created_at,
|
||||
)
|
||||
new_recording = RecordingModel(**recording.model_dump())
|
||||
session.add(new_recording)
|
||||
await session.commit()
|
||||
return recording
|
||||
|
||||
async def get_by_id(self, id: str) -> Recording:
|
||||
query = recordings.select().where(recordings.c.id == id)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, recording_id: str
|
||||
) -> Recording | None:
|
||||
"""
|
||||
Get a recording by id
|
||||
"""
|
||||
query = select(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Recording.model_validate(row)
|
||||
|
||||
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||
query = recordings.select().where(
|
||||
recordings.c.bucket_name == bucket_name,
|
||||
recordings.c.object_key == object_key,
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
async def get_by_meeting_id(
|
||||
self, session: AsyncSession, meeting_id: str
|
||||
) -> list[Recording]:
|
||||
"""
|
||||
Get all recordings for a meeting
|
||||
"""
|
||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
return [Recording.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def remove_by_id(self, id: str) -> None:
|
||||
query = recordings.delete().where(recordings.c.id == id)
|
||||
await get_database().execute(query)
|
||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
||||
"""
|
||||
Remove a recording by id
|
||||
"""
|
||||
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
|
||||
recordings_controller = RecordingController()
|
||||
|
||||
@@ -1,50 +1,21 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import Literal
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.sql import false, or_
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import or_
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RoomModel
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
rooms = sqlalchemy.Table(
|
||||
"room",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||
sqlalchemy.Column(
|
||||
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
|
||||
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"recording_trigger",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="automatic-2nd-participant",
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
|
||||
)
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
name: str
|
||||
user_id: str
|
||||
@@ -59,11 +30,19 @@ class Room(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
is_shared: bool = False
|
||||
webhook_url: str | None = None
|
||||
webhook_secret: str | None = None
|
||||
ics_url: str | None = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
ics_last_sync: datetime | None = None
|
||||
ics_last_etag: str | None = None
|
||||
|
||||
|
||||
class RoomController:
|
||||
async def get_all(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
return_query: bool = False,
|
||||
@@ -77,14 +56,14 @@ class RoomController:
|
||||
Parameters:
|
||||
- `order_by`: field to order by, e.g. "-created_at"
|
||||
"""
|
||||
query = rooms.select()
|
||||
query = select(RoomModel)
|
||||
if user_id is not None:
|
||||
query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
|
||||
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared))
|
||||
else:
|
||||
query = query.where(rooms.c.is_shared)
|
||||
query = query.where(RoomModel.is_shared)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(rooms.c, order_by[1:])
|
||||
field = getattr(RoomModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
@@ -92,11 +71,12 @@ class RoomController:
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
result = await session.execute(query)
|
||||
return [Room.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def add(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
user_id: str,
|
||||
zulip_auto_post: bool,
|
||||
@@ -107,10 +87,18 @@ class RoomController:
|
||||
recording_type: str,
|
||||
recording_trigger: str,
|
||||
is_shared: bool,
|
||||
webhook_url: str = "",
|
||||
webhook_secret: str = "",
|
||||
ics_url: str | None = None,
|
||||
ics_fetch_interval: int = 300,
|
||||
ics_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Add a new room
|
||||
"""
|
||||
if webhook_url and not webhook_secret:
|
||||
webhook_secret = secrets.token_urlsafe(32)
|
||||
|
||||
room = Room(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
@@ -122,21 +110,33 @@ class RoomController:
|
||||
recording_type=recording_type,
|
||||
recording_trigger=recording_trigger,
|
||||
is_shared=is_shared,
|
||||
webhook_url=webhook_url,
|
||||
webhook_secret=webhook_secret,
|
||||
ics_url=ics_url,
|
||||
ics_fetch_interval=ics_fetch_interval,
|
||||
ics_enabled=ics_enabled,
|
||||
)
|
||||
query = rooms.insert().values(**room.model_dump())
|
||||
new_room = RoomModel(**room.model_dump())
|
||||
session.add(new_room)
|
||||
try:
|
||||
await get_database().execute(query)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
return room
|
||||
|
||||
async def update(self, room: Room, values: dict, mutate=True):
|
||||
async def update(
|
||||
self, session: AsyncSession, room: Room, values: dict, mutate=True
|
||||
):
|
||||
"""
|
||||
Update a room fields with key/values in values
|
||||
"""
|
||||
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||
if values.get("webhook_url") and not values.get("webhook_secret"):
|
||||
values["webhook_secret"] = secrets.token_urlsafe(32)
|
||||
|
||||
query = update(RoomModel).where(RoomModel.id == room.id).values(**values)
|
||||
try:
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
|
||||
@@ -144,60 +144,79 @@ class RoomController:
|
||||
for key, value in values.items():
|
||||
setattr(room, key, value)
|
||||
|
||||
async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, room_id: str, **kwargs
|
||||
) -> Room | None:
|
||||
"""
|
||||
Get a room by id
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.id == room_id)
|
||||
query = select(RoomModel).where(RoomModel.id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**result)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
|
||||
async def get_by_name(
|
||||
self, session: AsyncSession, room_name: str, **kwargs
|
||||
) -> Room | None:
|
||||
"""
|
||||
Get a room by name
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.name == room_name)
|
||||
query = select(RoomModel).where(RoomModel.name == room_name)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**result)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
|
||||
async def get_by_id_for_http(
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
||||
) -> Room:
|
||||
"""
|
||||
Get a room by ID for HTTP request.
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(RoomModel).where(RoomModel.id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
room = Room(**result)
|
||||
room = Room.model_validate(row)
|
||||
|
||||
return room
|
||||
|
||||
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]:
|
||||
query = select(RoomModel).where(
|
||||
RoomModel.ics_enabled == True, RoomModel.ics_url != None
|
||||
)
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
return [Room(**row.__dict__) for row in results]
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
room_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a room by id
|
||||
"""
|
||||
room = await self.get_by_id(room_id, user_id=user_id)
|
||||
room = await self.get_by_id(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
return
|
||||
if user_id is not None and room.user_id != user_id:
|
||||
return
|
||||
query = rooms.delete().where(rooms.c.id == room_id)
|
||||
await get_database().execute(query)
|
||||
query = delete(RoomModel).where(RoomModel.id == room_id)
|
||||
await session.execute(query)
|
||||
await session.flush()
|
||||
|
||||
|
||||
rooms_controller = RoomController()
|
||||
|
||||
@@ -14,16 +14,17 @@ from pydantic import (
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.db.base import RoomModel, TranscriptModel
|
||||
from reflector.db.transcripts import SourceKind, TranscriptStatus
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
@@ -31,12 +32,13 @@ DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
|
||||
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
|
||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||
search_query_adapter = TypeAdapter(SearchQuery)
|
||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||
SearchOffset = Annotated[
|
||||
SearchOffsetBase, Field(description="Number of results to skip")
|
||||
@@ -88,7 +90,7 @@ class WebVTTProcessor:
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: str,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
@@ -125,7 +127,7 @@ class SnippetCandidate:
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
|
||||
query_text: SearchQuery
|
||||
query_text: SearchQuery | None = None
|
||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
@@ -157,7 +159,7 @@ class SearchResult(BaseModel):
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
status: TranscriptStatus = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
@@ -199,15 +201,13 @@ class SnippetGenerator:
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def count_matches(text: str, query: str) -> NonNegativeInt:
|
||||
def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in count_matches")
|
||||
return ZERO
|
||||
assert query is not None
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
@@ -243,13 +243,14 @@ class SnippetGenerator:
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
query: str,
|
||||
query: SearchQuery,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from text."""
|
||||
if not text or not query:
|
||||
logger.warning("Empty text or query for generate_snippets")
|
||||
assert query is not None
|
||||
if not text:
|
||||
logger.warning("Empty text for generate_snippets")
|
||||
return []
|
||||
|
||||
candidates = (
|
||||
@@ -270,7 +271,7 @@ class SnippetGenerator:
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: str,
|
||||
query: SearchQuery,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
@@ -278,9 +279,9 @@ class SnippetGenerator:
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: str | None,
|
||||
summary: NonEmptyString | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: str,
|
||||
query: SearchQuery,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
@@ -289,6 +290,11 @@ class SnippetGenerator:
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
|
||||
assert (
|
||||
summary is not None or webvtt is not None
|
||||
), "At least one source must be present"
|
||||
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
@@ -323,45 +329,39 @@ class SearchController:
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
cls, session: AsyncSession, params: SearchParameters
|
||||
) -> tuple[list[SearchResult], int]:
|
||||
"""
|
||||
Full-text search for transcripts using PostgreSQL tsvector.
|
||||
Returns (results, total_count).
|
||||
"""
|
||||
|
||||
if not is_postgresql():
|
||||
logger.warning(
|
||||
"Full-text search requires PostgreSQL. Returning empty results."
|
||||
)
|
||||
return [], 0
|
||||
|
||||
base_columns = [
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
transcripts.c.long_summary,
|
||||
TranscriptModel.id,
|
||||
TranscriptModel.title,
|
||||
TranscriptModel.created_at,
|
||||
TranscriptModel.duration,
|
||||
TranscriptModel.status,
|
||||
TranscriptModel.user_id,
|
||||
TranscriptModel.room_id,
|
||||
TranscriptModel.source_kind,
|
||||
TranscriptModel.webvtt,
|
||||
TranscriptModel.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=rooms.c.name,
|
||||
else_=RoomModel.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
|
||||
if params.query_text:
|
||||
search_query = None
|
||||
if params.query_text is not None:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
TranscriptModel.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
@@ -369,58 +369,74 @@ class SearchController:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(columns).select_from(
|
||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||
base_query = (
|
||||
sqlalchemy.select(*columns)
|
||||
.select_from(TranscriptModel)
|
||||
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
if params.query_text is not None:
|
||||
# because already initialized based on params.query_text presence above
|
||||
assert search_query is not None
|
||||
base_query = base_query.where(
|
||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||
TranscriptModel.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(
|
||||
sqlalchemy.or_(
|
||||
transcripts.c.user_id == params.user_id, rooms.c.is_shared
|
||||
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_query = base_query.where(rooms.c.is_shared)
|
||||
base_query = base_query.where(RoomModel.is_shared)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.source_kind == params.source_kind
|
||||
TranscriptModel.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
if params.query_text is not None:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||
order_by = sqlalchemy.desc(TranscriptModel.created_at)
|
||||
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
rs = await get_database().fetch_all(query)
|
||||
result = await session.execute(query)
|
||||
rs = result.mappings().all()
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from(
|
||||
base_query.alias("search_results")
|
||||
)
|
||||
total = await get_database().fetch_val(count_query)
|
||||
count_result = await session.execute(count_query)
|
||||
total = count_result.scalar()
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
def _process_result(r: dict) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
webvtt: WebVTTContent | None
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
long_summary: str | None = r_dict.pop("long_summary", None)
|
||||
|
||||
long_summary_r: str | None = r_dict.pop("long_summary", None)
|
||||
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets, total_match_count = SnippetGenerator.combine_sources(
|
||||
at_least_one_source = webvtt is not None or long_summary is not None
|
||||
has_query = params.query_text is not None
|
||||
snippets, total_match_count = (
|
||||
SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
if has_query and at_least_one_source
|
||||
else ([], 0)
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
|
||||
@@ -7,17 +7,14 @@ from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.sql import false, or_
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import or_
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.base import RoomModel, TranscriptModel
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
@@ -32,96 +29,20 @@ class SourceKind(enum.StrEnum):
|
||||
FILE = enum.auto()
|
||||
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||
sqlalchemy.Column("title", sqlalchemy.String),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("participants", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"audio_location",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="local",
|
||||
),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"share_mode",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"meeting_id",
|
||||
sqlalchemy.String,
|
||||
),
|
||||
sqlalchemy.Column("recording_id", sqlalchemy.String),
|
||||
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
|
||||
sqlalchemy.Column(
|
||||
"source_kind",
|
||||
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
|
||||
nullable=False,
|
||||
),
|
||||
# indicative field: whether associated audio is deleted
|
||||
# the main "audio deleted" is the presence of the audio itself / consents not-given
|
||||
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
|
||||
if is_postgresql():
|
||||
transcripts.append_column(
|
||||
sqlalchemy.Column(
|
||||
"search_vector_en",
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Add GIN index for the search vector
|
||||
transcripts.append_constraint(
|
||||
sqlalchemy.Index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"search_vector_en",
|
||||
postgresql_using="gin",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_transcript_name() -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
TranscriptStatus = Literal[
|
||||
"idle", "uploaded", "recording", "processing", "error", "ended"
|
||||
]
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
@@ -182,10 +103,12 @@ class TranscriptParticipant(BaseModel):
|
||||
class Transcript(BaseModel):
|
||||
"""Full transcript model with all fields."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
status: TranscriptStatus = "idle"
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
title: str | None = None
|
||||
@@ -350,6 +273,7 @@ class Transcript(BaseModel):
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
@@ -374,102 +298,114 @@ class TranscriptController:
|
||||
- `search_term`: filter transcripts by search term
|
||||
"""
|
||||
|
||||
query = transcripts.select().join(
|
||||
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||
query = select(TranscriptModel).join(
|
||||
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
|
||||
)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
|
||||
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
|
||||
)
|
||||
else:
|
||||
query = query.where(rooms.c.is_shared)
|
||||
query = query.where(RoomModel.is_shared)
|
||||
|
||||
if source_kind:
|
||||
query = query.where(transcripts.c.source_kind == source_kind)
|
||||
query = query.where(TranscriptModel.source_kind == source_kind)
|
||||
|
||||
if room_id:
|
||||
query = query.where(transcripts.c.room_id == room_id)
|
||||
query = query.where(TranscriptModel.room_id == room_id)
|
||||
|
||||
if search_term:
|
||||
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
|
||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
||||
|
||||
# Exclude heavy JSON columns from list queries
|
||||
# Get all ORM column attributes except excluded ones
|
||||
transcript_columns = [
|
||||
col for col in transcripts.c if col.name not in exclude_columns
|
||||
getattr(TranscriptModel, col.name)
|
||||
for col in TranscriptModel.__table__.c
|
||||
if col.name not in exclude_columns
|
||||
]
|
||||
|
||||
query = query.with_only_columns(
|
||||
transcript_columns
|
||||
+ [
|
||||
rooms.c.name.label("room_name"),
|
||||
]
|
||||
*transcript_columns,
|
||||
RoomModel.name.label("room_name"),
|
||||
)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
field = getattr(TranscriptModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
query = query.filter(TranscriptModel.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
query = query.filter(TranscriptModel.status != "recording")
|
||||
|
||||
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
result = await session.execute(query)
|
||||
return [dict(row) for row in result.mappings().all()]
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, transcript_id: str, **kwargs
|
||||
) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_recording_id(
|
||||
self, recording_id: str, **kwargs
|
||||
self, session: AsyncSession, recording_id: str, **kwargs
|
||||
) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||
query = select(TranscriptModel).where(
|
||||
TranscriptModel.recording_id == recording_id
|
||||
)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
|
||||
async def get_by_room_id(
|
||||
self, session: AsyncSession, room_id: str, **kwargs
|
||||
) -> list[Transcript]:
|
||||
"""
|
||||
Get transcripts by room_id (direct access without joins)
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.room_id == room_id)
|
||||
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
||||
if "order_by" in kwargs:
|
||||
order_by = kwargs["order_by"]
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
field = getattr(TranscriptModel, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Transcript(**result) for result in results]
|
||||
results = await session.execute(query)
|
||||
return [
|
||||
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
||||
]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript_id: str,
|
||||
user_id: str | None,
|
||||
) -> Transcript:
|
||||
@@ -482,13 +418,14 @@ class TranscriptController:
|
||||
This method checks the share mode of the transcript and the user_id
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
result = await session.execute(query)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# if the transcript is anonymous, share mode is not checked
|
||||
transcript = Transcript(**result)
|
||||
transcript = Transcript.model_validate(row)
|
||||
if transcript.user_id is None:
|
||||
return transcript
|
||||
|
||||
@@ -511,6 +448,7 @@ class TranscriptController:
|
||||
|
||||
async def add(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
source_kind: SourceKind,
|
||||
source_language: str = "en",
|
||||
@@ -535,14 +473,15 @@ class TranscriptController:
|
||||
meeting_id=meeting_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await get_database().execute(query)
|
||||
query = insert(TranscriptModel).values(**transcript.model_dump())
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return transcript
|
||||
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
# using mutate=True is discouraged
|
||||
async def update(
|
||||
self, transcript: Transcript, values: dict, mutate=False
|
||||
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
|
||||
) -> Transcript:
|
||||
"""
|
||||
Update a transcript fields with key/values in values.
|
||||
@@ -551,11 +490,12 @@ class TranscriptController:
|
||||
values = TranscriptController._handle_topics_update(values)
|
||||
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
update(TranscriptModel)
|
||||
.where(TranscriptModel.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
@@ -584,13 +524,14 @@ class TranscriptController:
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a transcript by id
|
||||
"""
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
@@ -610,7 +551,7 @@ class TranscriptController:
|
||||
if transcript.recording_id:
|
||||
try:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
transcript.recording_id
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording:
|
||||
try:
|
||||
@@ -621,33 +562,40 @@ class TranscriptController:
|
||||
exc_info=e,
|
||||
recording_id=transcript.recording_id,
|
||||
)
|
||||
await recordings_controller.remove_by_id(transcript.recording_id)
|
||||
await recordings_controller.remove_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete recording row",
|
||||
exc_info=e,
|
||||
recording_id=transcript.recording_id,
|
||||
)
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await get_database().execute(query)
|
||||
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
async def remove_by_recording_id(self, recording_id: str):
|
||||
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
|
||||
"""
|
||||
Remove a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||
await get_database().execute(query)
|
||||
query = delete(TranscriptModel).where(
|
||||
TranscriptModel.recording_id == recording_id
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async def transaction(self, session: AsyncSession):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with get_database().transaction(isolation="serializable"):
|
||||
async with session.begin():
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
event: str,
|
||||
data: Any,
|
||||
@@ -656,11 +604,12 @@ class TranscriptController:
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(transcript, {"events": transcript.events_dump()})
|
||||
await self.update(session, transcript, {"events": transcript.events_dump()})
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
topic: TranscriptTopic,
|
||||
) -> TranscriptEvent:
|
||||
@@ -668,9 +617,9 @@ class TranscriptController:
|
||||
Upsert topics to a transcript
|
||||
"""
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
await self.update(session, transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
|
||||
"""
|
||||
Move mp3 file to storage
|
||||
"""
|
||||
@@ -694,12 +643,16 @@ class TranscriptController:
|
||||
|
||||
# indicate on the transcript that the audio is now on storage
|
||||
# mutates transcript argument
|
||||
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
|
||||
await self.update(
|
||||
session, transcript, {"audio_location": "storage"}, mutate=True
|
||||
)
|
||||
|
||||
# unlink the local file
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
|
||||
async def download_mp3_from_storage(self, transcript: Transcript):
|
||||
async def download_mp3_from_storage(
|
||||
self, session: AsyncSession, transcript: Transcript
|
||||
):
|
||||
"""
|
||||
Download audio from storage
|
||||
"""
|
||||
@@ -711,6 +664,7 @@ class TranscriptController:
|
||||
|
||||
async def upsert_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
participant: TranscriptParticipant,
|
||||
) -> TranscriptParticipant:
|
||||
@@ -718,11 +672,14 @@ class TranscriptController:
|
||||
Add/update a participant to a transcript
|
||||
"""
|
||||
result = transcript.upsert_participant(participant)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"participants": transcript.participants_dump()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
participant_id: str,
|
||||
):
|
||||
@@ -730,7 +687,32 @@ class TranscriptController:
|
||||
Delete a participant from a transcript
|
||||
"""
|
||||
transcript.delete_participant(participant_id)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"participants": transcript.participants_dump()}
|
||||
)
|
||||
|
||||
async def set_status(
|
||||
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
|
||||
) -> TranscriptEvent | None:
|
||||
"""
|
||||
Update the status of a transcript
|
||||
|
||||
Will add an event STATUS + update the status field of transcript
|
||||
"""
|
||||
async with self.transaction(session):
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
if transcript.status == status:
|
||||
return
|
||||
resp = await self.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await self.update(session, transcript, {"status": status})
|
||||
return resp
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Database utility functions."""
|
||||
|
||||
from reflector.db import get_database
|
||||
|
||||
|
||||
def is_postgresql() -> bool:
|
||||
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||
"postgresql"
|
||||
)
|
||||
@@ -7,18 +7,30 @@ Uses parallel processing for transcription, diarization, and waveform generation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery import chain, shared_task
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
Transcript,
|
||||
TranscriptStatus,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
PipelineMainBase,
|
||||
broadcast_to_sockets,
|
||||
task_cleanup_consent,
|
||||
task_pipeline_post_to_zulip,
|
||||
)
|
||||
from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
@@ -43,6 +55,8 @@ from reflector.processors.types import (
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.session_decorator import with_session
|
||||
from reflector.worker.webhook import send_transcript_webhook
|
||||
|
||||
|
||||
class EmptyPipeline:
|
||||
@@ -83,11 +97,32 @@ class PipelineMainFile(PipelineMainBase):
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_factory()() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, transcript_id, status
|
||||
)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
@@ -97,6 +132,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
session,
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
@@ -105,6 +141,9 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
) -> Path:
|
||||
@@ -165,6 +204,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
session,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
@@ -178,7 +218,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(audio_path)
|
||||
waveform_task = self.generate_waveform(session, audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
@@ -226,7 +266,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(topics),
|
||||
self.generate_summaries(session, topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
@@ -278,9 +318,9 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
@@ -333,13 +373,13 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
@@ -355,13 +395,40 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
@with_session
|
||||
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
||||
"""Send webhook if this is a room recording with webhook configured"""
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def task_pipeline_file_process(session, *, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
await pipeline.set_status(transcript_id, "processing")
|
||||
|
||||
# Find the file to process
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
@@ -370,6 +437,16 @@ async def task_pipeline_file_process(*, transcript_id: str):
|
||||
if not audio_file:
|
||||
raise Exception("No audio file found to process")
|
||||
|
||||
# Run file pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process(audio_file)
|
||||
|
||||
except Exception:
|
||||
await pipeline.set_status(transcript_id, "error")
|
||||
raise
|
||||
|
||||
# Run post-processing chain: consent cleanup -> zulip -> webhook
|
||||
post_chain = chain(
|
||||
task_cleanup_consent.si(transcript_id=transcript_id),
|
||||
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
|
||||
task_send_webhook_if_needed.si(transcript_id=transcript_id),
|
||||
)
|
||||
post_chain.delay()
|
||||
|
||||
@@ -20,9 +20,11 @@ import av
|
||||
import boto3
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -32,6 +34,7 @@ from reflector.db.transcripts import (
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptStatus,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
@@ -61,6 +64,7 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.session_decorator import with_session_and_transcript
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from reflector.zulip import (
|
||||
get_zulip_message,
|
||||
@@ -69,29 +73,6 @@ from reflector.zulip import (
|
||||
)
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
coro = run_with_db()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
@@ -118,9 +99,10 @@ def get_transcript(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
@@ -161,11 +143,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
@@ -188,24 +168,31 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async def lock_transaction(self):
|
||||
# This lock is to prevent multiple processor starting adding
|
||||
# into event array at the same time
|
||||
async with self._lock:
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
# if it's the first part, update the status of the transcript
|
||||
# but do not set the ended status yet.
|
||||
if isinstance(self, PipelineMainLive):
|
||||
status_mapping = {
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
"started": "recording",
|
||||
"push": "recording",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainFinalSummaries):
|
||||
status_mapping = {
|
||||
status_mapping: dict[str, TranscriptStatus] = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
@@ -221,28 +208,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
return
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if status == transcript.status:
|
||||
return
|
||||
resp = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
async with self._lock:
|
||||
async with get_session_factory()() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, self.transcript_id, status
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return resp
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
@@ -259,10 +236,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
@@ -271,16 +249,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
@@ -289,15 +269,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
@@ -308,15 +290,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
@@ -324,29 +308,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
async with self.transaction() as session:
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
session, transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
async with self.transaction() as session:
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
session, transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
@@ -359,7 +344,8 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(
|
||||
@@ -407,7 +393,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
# diarization works only if the file is uploaded to an external storage
|
||||
if transcript.audio_location == "local":
|
||||
@@ -440,7 +427,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
self._transcript = transcript = await self.get_transcript(session)
|
||||
|
||||
# create pipeline
|
||||
processors = self.get_processors()
|
||||
@@ -545,8 +533,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||
logger.info("No storage backend configured, skipping mp3 upload")
|
||||
return
|
||||
@@ -564,7 +551,7 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
@@ -593,20 +580,23 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting consent cleanup")
|
||||
|
||||
consent_denied = False
|
||||
recording = None
|
||||
try:
|
||||
if transcript.recording_id:
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(
|
||||
session, recording.meeting_id
|
||||
)
|
||||
if meeting:
|
||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||
meeting.id
|
||||
session, meeting.id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||
@@ -635,7 +625,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||
|
||||
# non-transactional, files marked for deletion not actually deleted is possible
|
||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
||||
# 2. Delete processed audio from transcript storage S3 bucket
|
||||
if transcript.audio_location == "storage":
|
||||
storage = get_transcripts_storage()
|
||||
@@ -659,15 +649,14 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.info("Consent cleanup done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting post to zulip")
|
||||
|
||||
if not transcript.recording_id:
|
||||
logger.info("Transcript has no recording")
|
||||
return
|
||||
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
@@ -676,12 +665,12 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
if not room:
|
||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||
return
|
||||
@@ -707,7 +696,7 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
room.zulip_stream, room.zulip_topic, message
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"zulip_message_id": response["id"]}
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
|
||||
logger.info("Posted to zulip")
|
||||
@@ -738,8 +727,11 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_upload_mp3(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@@ -762,14 +754,20 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_cleanup_consent(*, transcript_id: str):
|
||||
await cleanup_consent(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
@@ -794,16 +792,18 @@ def pipeline_post(*, transcript_id: str):
|
||||
chain_final_summaries,
|
||||
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
|
||||
|
||||
chain.delay()
|
||||
return chain.delay()
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
try:
|
||||
if transcript.audio_location == "storage":
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
@@ -840,7 +840,9 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
|
||||
@@ -47,6 +47,7 @@ class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
diarization_data = response.json()["diarization"]
|
||||
|
||||
@@ -54,6 +54,7 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
"language": data.language,
|
||||
"batch": True,
|
||||
},
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -67,6 +68,9 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
# words come not in order
|
||||
words.sort(key=lambda w: w.start)
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
|
||||
@@ -4,11 +4,8 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
|
||||
class DiarizationSegment(TypedDict):
|
||||
"""Type definition for diarization segment containing speaker information"""
|
||||
@@ -20,9 +17,6 @@ class DiarizationSegment(TypedDict):
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
profanity_filter.set_censor("*")
|
||||
|
||||
|
||||
class AudioFile(BaseModel):
|
||||
name: str
|
||||
@@ -124,21 +118,11 @@ def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
|
||||
|
||||
class Transcript(BaseModel):
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
def raw_text(self):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
words: list[Word] = []
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return self._get_censored_text(self.raw_text)
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
@@ -170,12 +154,6 @@ class Transcript(BaseModel):
|
||||
word.start += offset
|
||||
word.end += offset
|
||||
|
||||
def clone(self):
|
||||
words = [
|
||||
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
||||
]
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
return words_to_segments(self.words)
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
import redis.asyncio as redis_async
|
||||
import structlog
|
||||
from redis.exceptions import LockError
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
|
||||
@@ -21,6 +28,12 @@ def get_redis_client(db=0):
|
||||
return redis_clients[db]
|
||||
|
||||
|
||||
async def get_async_redis_client(db: int = 0):
|
||||
return await redis_async.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{db}"
|
||||
)
|
||||
|
||||
|
||||
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
|
||||
"""
|
||||
Cache the result of a function in Redis.
|
||||
@@ -49,3 +62,87 @@ def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argid
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RedisAsyncLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
timeout: int = 120,
|
||||
extend_interval: int = 30,
|
||||
skip_if_locked: bool = False,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[float] = None,
|
||||
):
|
||||
self.key = f"async_lock:{key}"
|
||||
self.timeout = timeout
|
||||
self.extend_interval = extend_interval
|
||||
self.skip_if_locked = skip_if_locked
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self._lock = None
|
||||
self._redis = None
|
||||
self._extend_task = None
|
||||
self._acquired = False
|
||||
|
||||
async def _extend_lock_periodically(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.extend_interval)
|
||||
if self._lock:
|
||||
await self._lock.extend(self.timeout, replace_ttl=True)
|
||||
logger.debug("Extended lock", key=self.key)
|
||||
except LockError:
|
||||
logger.warning("Failed to extend lock", key=self.key)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error extending lock", key=self.key, error=str(e))
|
||||
break
|
||||
|
||||
async def __aenter__(self):
|
||||
self._redis = await get_async_redis_client()
|
||||
self._lock = self._redis.lock(
|
||||
self.key,
|
||||
timeout=self.timeout,
|
||||
blocking=self.blocking,
|
||||
blocking_timeout=self.blocking_timeout,
|
||||
)
|
||||
|
||||
self._acquired = await self._lock.acquire()
|
||||
|
||||
if not self._acquired:
|
||||
if self.skip_if_locked:
|
||||
logger.warning(
|
||||
"Lock already acquired by another process, skipping", key=self.key
|
||||
)
|
||||
return self
|
||||
else:
|
||||
raise LockError(f"Failed to acquire lock: {self.key}")
|
||||
|
||||
self._extend_task = asyncio.create_task(self._extend_lock_periodically())
|
||||
logger.info("Acquired lock", key=self.key)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._extend_task:
|
||||
self._extend_task.cancel()
|
||||
try:
|
||||
await self._extend_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._acquired and self._lock:
|
||||
try:
|
||||
await self._lock.release()
|
||||
logger.info("Released lock", key=self.key)
|
||||
except LockError:
|
||||
logger.debug("Lock already released or expired", key=self.key)
|
||||
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
|
||||
@property
|
||||
def acquired(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
412
server/reflector/services/ics_sync.py
Normal file
412
server/reflector/services/ics_sync.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
ICS Calendar Synchronization Service
|
||||
|
||||
This module provides services for fetching, parsing, and synchronizing ICS (iCalendar)
|
||||
calendar feeds with room booking data in the database.
|
||||
|
||||
Key Components:
|
||||
- ICSFetchService: Handles HTTP fetching and parsing of ICS calendar data
|
||||
- ICSSyncService: Manages the synchronization process between ICS feeds and database
|
||||
|
||||
Example Usage:
|
||||
# Sync a room's calendar
|
||||
room = Room(id="room1", name="conference-room", ics_url="https://cal.example.com/room.ics")
|
||||
result = await ics_sync_service.sync_room_calendar(room)
|
||||
|
||||
# Result structure:
|
||||
{
|
||||
"status": "success", # success|unchanged|error|skipped
|
||||
"hash": "abc123...", # MD5 hash of ICS content
|
||||
"events_found": 5, # Events matching this room
|
||||
"total_events": 12, # Total events in calendar within time window
|
||||
"events_created": 2, # New events added to database
|
||||
"events_updated": 3, # Existing events modified
|
||||
"events_deleted": 1 # Events soft-deleted (no longer in calendar)
|
||||
}
|
||||
|
||||
Event Matching:
|
||||
Events are matched to rooms by checking if the room's full URL appears in the
|
||||
event's LOCATION or DESCRIPTION fields. Only events within a 25-hour window
|
||||
(1 hour ago to 24 hours from now) are processed.
|
||||
|
||||
Input: ICS calendar URL (e.g., "https://calendar.google.com/calendar/ical/...")
|
||||
Output: EventData objects with structured calendar information:
|
||||
{
|
||||
"ics_uid": "event123@google.com",
|
||||
"title": "Team Meeting",
|
||||
"description": "Weekly sync meeting",
|
||||
"location": "https://meet.company.com/conference-room",
|
||||
"start_time": datetime(2024, 1, 15, 14, 0, tzinfo=UTC),
|
||||
"end_time": datetime(2024, 1, 15, 15, 0, tzinfo=UTC),
|
||||
"attendees": [
|
||||
{"email": "user@company.com", "name": "John Doe", "role": "ORGANIZER"},
|
||||
{"email": "attendee@company.com", "name": "Jane Smith", "status": "ACCEPTED"}
|
||||
],
|
||||
"ics_raw_data": "BEGIN:VEVENT\nUID:event123@google.com\n..."
|
||||
}
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
import pytz
|
||||
import structlog
|
||||
from icalendar import Calendar, Event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.rooms import Room, rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
EVENT_WINDOW_DELTA_START = timedelta(hours=-1)
|
||||
EVENT_WINDOW_DELTA_END = timedelta(hours=24)
|
||||
|
||||
|
||||
class SyncStatus(str, Enum):
|
||||
SUCCESS = "success"
|
||||
UNCHANGED = "unchanged"
|
||||
ERROR = "error"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class AttendeeData(TypedDict, total=False):
|
||||
email: str | None
|
||||
name: str | None
|
||||
status: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class EventData(TypedDict):
|
||||
ics_uid: str
|
||||
title: str | None
|
||||
description: str | None
|
||||
location: str | None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: list[AttendeeData]
|
||||
ics_raw_data: str
|
||||
|
||||
|
||||
class SyncStats(TypedDict):
|
||||
events_created: int
|
||||
events_updated: int
|
||||
events_deleted: int
|
||||
|
||||
|
||||
class SyncResultBase(TypedDict):
|
||||
status: SyncStatus
|
||||
|
||||
|
||||
class SyncResult(SyncResultBase, total=False):
|
||||
hash: str | None
|
||||
events_found: int
|
||||
total_events: int
|
||||
events_created: int
|
||||
events_updated: int
|
||||
events_deleted: int
|
||||
error: str | None
|
||||
reason: str | None
|
||||
|
||||
|
||||
class ICSFetchService:
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=30.0, headers={"User-Agent": "Reflector/1.0"}
|
||||
)
|
||||
|
||||
async def fetch_ics(self, url: str) -> str:
|
||||
response = await self.client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.text
|
||||
|
||||
def parse_ics(self, ics_content: str) -> Calendar:
|
||||
return Calendar.from_ical(ics_content)
|
||||
|
||||
def extract_room_events(
|
||||
self, calendar: Calendar, room_name: str, room_url: str
|
||||
) -> tuple[list[EventData], int]:
|
||||
events = []
|
||||
total_events = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now + EVENT_WINDOW_DELTA_START
|
||||
window_end = now + EVENT_WINDOW_DELTA_END
|
||||
|
||||
for component in calendar.walk():
|
||||
if component.name != "VEVENT":
|
||||
continue
|
||||
|
||||
status = component.get("STATUS", "").upper()
|
||||
if status == "CANCELLED":
|
||||
continue
|
||||
|
||||
# Count total non-cancelled events in the time window
|
||||
event_data = self._parse_event(component)
|
||||
if event_data and window_start <= event_data["start_time"] <= window_end:
|
||||
total_events += 1
|
||||
|
||||
# Check if event matches this room
|
||||
if self._event_matches_room(component, room_name, room_url):
|
||||
events.append(event_data)
|
||||
|
||||
return events, total_events
|
||||
|
||||
def _event_matches_room(self, event: Event, room_name: str, room_url: str) -> bool:
|
||||
location = str(event.get("LOCATION", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
|
||||
# Only match full room URL
|
||||
# XXX leaved here as a patterns, to later be extended with tinyurl or such too
|
||||
patterns = [
|
||||
room_url,
|
||||
]
|
||||
|
||||
# Check location and description for patterns
|
||||
text_to_check = f"{location} {description}".lower()
|
||||
for pattern in patterns:
|
||||
if pattern.lower() in text_to_check:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _parse_event(self, event: Event) -> EventData | None:
|
||||
uid = str(event.get("UID", ""))
|
||||
summary = str(event.get("SUMMARY", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
location = str(event.get("LOCATION", ""))
|
||||
dtstart = event.get("DTSTART")
|
||||
dtend = event.get("DTEND")
|
||||
|
||||
if not dtstart:
|
||||
return None
|
||||
|
||||
# Convert fields
|
||||
start_time = self._normalize_datetime(
|
||||
dtstart.dt if hasattr(dtstart, "dt") else dtstart
|
||||
)
|
||||
end_time = (
|
||||
self._normalize_datetime(dtend.dt if hasattr(dtend, "dt") else dtend)
|
||||
if dtend
|
||||
else start_time + timedelta(hours=1)
|
||||
)
|
||||
attendees = self._parse_attendees(event)
|
||||
|
||||
# Get raw event data for storage
|
||||
raw_data = event.to_ical().decode("utf-8")
|
||||
|
||||
return {
|
||||
"ics_uid": uid,
|
||||
"title": summary,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"attendees": attendees,
|
||||
"ics_raw_data": raw_data,
|
||||
}
|
||||
|
||||
def _normalize_datetime(self, dt) -> datetime:
|
||||
# Ensure datetime is with timezone, if not, assume UTC
|
||||
if isinstance(dt, date) and not isinstance(dt, datetime):
|
||||
dt = datetime.combine(dt, datetime.min.time())
|
||||
dt = pytz.UTC.localize(dt)
|
||||
elif isinstance(dt, datetime):
|
||||
if dt.tzinfo is None:
|
||||
dt = pytz.UTC.localize(dt)
|
||||
else:
|
||||
dt = dt.astimezone(pytz.UTC)
|
||||
|
||||
return dt
|
||||
|
||||
def _parse_attendees(self, event: Event) -> list[AttendeeData]:
|
||||
# Extracts attendee information from both ATTENDEE and ORGANIZER properties.
|
||||
# Handles malformed comma-separated email addresses in single ATTENDEE fields
|
||||
# by splitting them into separate attendee entries. Returns a list of attendee
|
||||
# data including email, name, status, and role information.
|
||||
final_attendees = []
|
||||
|
||||
attendees = event.get("ATTENDEE", [])
|
||||
if not isinstance(attendees, list):
|
||||
attendees = [attendees]
|
||||
for att in attendees:
|
||||
email_str = str(att).replace("mailto:", "") if att else None
|
||||
|
||||
# Handle malformed comma-separated email addresses in a single ATTENDEE field
|
||||
if email_str and "," in email_str:
|
||||
# Split comma-separated emails and create separate attendee entries
|
||||
email_parts = [email.strip() for email in email_str.split(",")]
|
||||
for email in email_parts:
|
||||
if email and "@" in email:
|
||||
clean_email = email.replace("MAILTO:", "").replace(
|
||||
"mailto:", ""
|
||||
)
|
||||
att_data: AttendeeData = {
|
||||
"email": clean_email,
|
||||
"name": att.params.get("CN")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
"status": att.params.get("PARTSTAT")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
"role": att.params.get("ROLE")
|
||||
if hasattr(att, "params") and email == email_parts[0]
|
||||
else None,
|
||||
}
|
||||
final_attendees.append(att_data)
|
||||
else:
|
||||
# Normal single attendee
|
||||
att_data: AttendeeData = {
|
||||
"email": email_str,
|
||||
"name": att.params.get("CN") if hasattr(att, "params") else None,
|
||||
"status": att.params.get("PARTSTAT")
|
||||
if hasattr(att, "params")
|
||||
else None,
|
||||
"role": att.params.get("ROLE") if hasattr(att, "params") else None,
|
||||
}
|
||||
final_attendees.append(att_data)
|
||||
|
||||
# Add organizer
|
||||
organizer = event.get("ORGANIZER")
|
||||
if organizer:
|
||||
org_email = (
|
||||
str(organizer).replace("mailto:", "").replace("MAILTO:", "")
|
||||
if organizer
|
||||
else None
|
||||
)
|
||||
org_data: AttendeeData = {
|
||||
"email": org_email,
|
||||
"name": organizer.params.get("CN")
|
||||
if hasattr(organizer, "params")
|
||||
else None,
|
||||
"role": "ORGANIZER",
|
||||
}
|
||||
final_attendees.append(org_data)
|
||||
|
||||
return final_attendees
|
||||
|
||||
|
||||
class ICSSyncService:
|
||||
def __init__(self):
|
||||
self.fetch_service = ICSFetchService()
|
||||
|
||||
async def sync_room_calendar(self, session: AsyncSession, room: Room) -> SyncResult:
|
||||
async with RedisAsyncLock(
|
||||
f"ics_sync_room:{room.id}", skip_if_locked=True
|
||||
) as lock:
|
||||
if not lock.acquired:
|
||||
logger.warning("ICS sync already in progress for room", room_id=room.id)
|
||||
return {
|
||||
"status": SyncStatus.SKIPPED,
|
||||
"reason": "Sync already in progress",
|
||||
}
|
||||
|
||||
return await self._sync_room_calendar(session, room)
|
||||
|
||||
async def _sync_room_calendar(
|
||||
self, session: AsyncSession, room: Room
|
||||
) -> SyncResult:
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
|
||||
|
||||
try:
|
||||
if not self._should_sync(room):
|
||||
return {"status": SyncStatus.SKIPPED, "reason": "Not time to sync yet"}
|
||||
|
||||
ics_content = await self.fetch_service.fetch_ics(room.ics_url)
|
||||
calendar = self.fetch_service.parse_ics(ics_content)
|
||||
|
||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||
if room.ics_last_etag == content_hash:
|
||||
logger.info("No changes in ICS for room", room_id=room.id)
|
||||
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||
events, total_events = self.fetch_service.extract_room_events(
|
||||
calendar, room.name, room_url
|
||||
)
|
||||
return {
|
||||
"status": SyncStatus.UNCHANGED,
|
||||
"hash": content_hash,
|
||||
"events_found": len(events),
|
||||
"total_events": total_events,
|
||||
"events_created": 0,
|
||||
"events_updated": 0,
|
||||
"events_deleted": 0,
|
||||
}
|
||||
|
||||
# Extract matching events
|
||||
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||
events, total_events = self.fetch_service.extract_room_events(
|
||||
calendar, room.name, room_url
|
||||
)
|
||||
sync_result = await self._sync_events_to_database(session, room.id, events)
|
||||
|
||||
# Update room sync metadata
|
||||
await rooms_controller.update(
|
||||
session,
|
||||
room,
|
||||
{
|
||||
"ics_last_sync": datetime.now(timezone.utc),
|
||||
"ics_last_etag": content_hash,
|
||||
},
|
||||
mutate=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": SyncStatus.SUCCESS,
|
||||
"hash": content_hash,
|
||||
"events_found": len(events),
|
||||
"total_events": total_events,
|
||||
**sync_result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to sync ICS for room", room_id=room.id, error=str(e))
|
||||
return {"status": SyncStatus.ERROR, "error": str(e)}
|
||||
|
||||
def _should_sync(self, room: Room) -> bool:
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
|
||||
async def _sync_events_to_database(
|
||||
self, session: AsyncSession, room_id: str, events: list[EventData]
|
||||
) -> SyncStats:
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
current_ics_uids = []
|
||||
|
||||
for event_data in events:
|
||||
calendar_event = CalendarEvent(room_id=room_id, **event_data)
|
||||
existing = await calendar_events_controller.get_by_ics_uid(
|
||||
session, room_id, event_data["ics_uid"]
|
||||
)
|
||||
|
||||
if existing:
|
||||
updated += 1
|
||||
else:
|
||||
created += 1
|
||||
|
||||
await calendar_events_controller.upsert(session, calendar_event)
|
||||
current_ics_uids.append(event_data["ics_uid"])
|
||||
|
||||
# Soft delete events that are no longer in calendar
|
||||
deleted = await calendar_events_controller.soft_delete_missing(
|
||||
session, room_id, current_ics_uids
|
||||
)
|
||||
|
||||
return {
|
||||
"events_created": created,
|
||||
"events_updated": updated,
|
||||
"events_deleted": deleted,
|
||||
}
|
||||
|
||||
|
||||
ics_sync_service = ICSSyncService()
|
||||
@@ -1,5 +1,8 @@
|
||||
from pydantic.types import PositiveInt
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from reflector.utils.string import NonEmptyString
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
@@ -90,9 +93,8 @@ class Settings(BaseSettings):
|
||||
AUTH_JWT_PUBLIC_KEY: str | None = "authentik.monadical.com_public.pem"
|
||||
AUTH_JWT_AUDIENCE: str | None = None
|
||||
|
||||
# API public mode
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
PUBLIC_DATA_RETENTION_DAYS: PositiveInt = 7
|
||||
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
@@ -120,7 +122,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# Whereby integration
|
||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||
WHEREBY_API_KEY: str | None = None
|
||||
WHEREBY_API_KEY: NonEmptyString | None = None
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||
|
||||
72
server/reflector/tools/cleanup_old_data.py
Normal file
72
server/reflector/tools/cleanup_old_data.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Manual cleanup tool for old public data.
|
||||
Uses the same implementation as the Celery worker task.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
from reflector.settings import settings
|
||||
from reflector.worker.cleanup import _cleanup_old_public_data
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
async def cleanup_old_data(days: int = 7):
|
||||
logger.info(
|
||||
"Starting manual cleanup",
|
||||
retention_days=days,
|
||||
public_mode=settings.PUBLIC_MODE,
|
||||
)
|
||||
|
||||
if not settings.PUBLIC_MODE:
|
||||
logger.critical(
|
||||
"WARNING: PUBLIC_MODE is False. "
|
||||
"This tool is intended for public instances only."
|
||||
)
|
||||
raise Exception("Tool intended for public instances only")
|
||||
|
||||
result = await _cleanup_old_public_data(days=days)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Cleanup completed",
|
||||
transcripts_deleted=result.get("transcripts_deleted", 0),
|
||||
meetings_deleted=result.get("meetings_deleted", 0),
|
||||
recordings_deleted=result.get("recordings_deleted", 0),
|
||||
errors_count=len(result.get("errors", [])),
|
||||
)
|
||||
if result.get("errors"):
|
||||
logger.warning(
|
||||
"Errors encountered during cleanup:", errors=result["errors"][:10]
|
||||
)
|
||||
else:
|
||||
logger.info("Cleanup skipped or completed without results")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Clean up old transcripts and meetings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Number of days to keep data (default: 7)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.days < 1:
|
||||
logger.error("Days must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(cleanup_old_data(days=args.days))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import get_database, transcripts
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcripts = await transcripts_controller.get_all(session)
|
||||
|
||||
def export_transcript(transcript, output_dir):
|
||||
for topic in transcript.topics:
|
||||
|
||||
@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import get_database, transcripts
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcripts = await transcripts_controller.get_all(session)
|
||||
|
||||
def export_transcript(transcript):
|
||||
tid = transcript.id
|
||||
|
||||
@@ -1,294 +1,210 @@
|
||||
"""
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
import av
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
task_pipeline_file_process as task_pipeline_file_process,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
from reflector.pipelines.main_live_pipeline import pipeline_post as live_pipeline_post
|
||||
from reflector.pipelines.main_live_pipeline import (
|
||||
pipeline_process as live_pipeline_process,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
def serialize_topics(topics: List[TranscriptTopic]) -> List[Dict[str, Any]]:
|
||||
"""Convert TranscriptTopic objects to JSON-serializable dicts"""
|
||||
serialized = []
|
||||
for topic in topics:
|
||||
topic_dict = topic.model_dump()
|
||||
serialized.append(topic_dict)
|
||||
return serialized
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="pyannote",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
def debug_print_speakers(serialized_topics: List[Dict[str, Any]]) -> None:
|
||||
"""Print debug info about speakers found in topics"""
|
||||
all_speakers = set()
|
||||
for topic_dict in serialized_topics:
|
||||
for word in topic_dict.get("words", []):
|
||||
all_speakers.add(word.get("speaker", 0))
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
print(
|
||||
f"Found {len(serialized_topics)} topics with speakers: {all_speakers}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
TranscriptId = str
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
filename: str,
|
||||
event_callback,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
try:
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||
# ideally we want to get rid of it at some point
|
||||
async def prepare_entry(
|
||||
session: AsyncSession,
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> TranscriptId:
|
||||
file_path = Path(source_path)
|
||||
|
||||
await database.connect()
|
||||
try:
|
||||
# Create a temporary transcript for processing
|
||||
transcript = await transcripts_controller.add(
|
||||
"",
|
||||
session,
|
||||
file_path.name,
|
||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Process the file
|
||||
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||
await pipeline.process(Path(filename))
|
||||
logger.info(
|
||||
f"Created empty transcript {transcript.id} for file {file_path.name} because technically we need an empty transcript before we start transcript"
|
||||
)
|
||||
|
||||
# pipelines expect files as upload.*
|
||||
|
||||
extension = file_path.suffix
|
||||
upload_path = transcript.data_path / f"upload{extension}"
|
||||
upload_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_path, upload_path)
|
||||
logger.info(f"Copied {source_path} to {upload_path}")
|
||||
|
||||
# pipelines expect entity status "uploaded"
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
return transcript.id
|
||||
|
||||
|
||||
# same reason as prepare_entry
|
||||
async def extract_result_from_entry(
|
||||
session: AsyncSession,
|
||||
transcript_id: TranscriptId,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
post_final_transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
|
||||
# assert post_final_transcript.status == "ended"
|
||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||
topics = post_final_transcript.topics
|
||||
if not topics:
|
||||
raise RuntimeError(
|
||||
f"No topics found for transcript {transcript_id} after processing"
|
||||
)
|
||||
|
||||
serialized_topics = serialize_topics(topics)
|
||||
|
||||
if output_path:
|
||||
# Write to JSON file
|
||||
with open(output_path, "w") as f:
|
||||
for topic_dict in serialized_topics:
|
||||
json.dump(topic_dict, f)
|
||||
f.write("\n")
|
||||
print(f"Results written to {output_path}", file=sys.stderr)
|
||||
else:
|
||||
# Write to stdout as JSONL
|
||||
for topic_dict in serialized_topics:
|
||||
print(json.dumps(topic_dict))
|
||||
|
||||
debug_print_speakers(serialized_topics)
|
||||
|
||||
|
||||
async def process_live_pipeline(
|
||||
session: AsyncSession,
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process transcript_id with transcription and diarization"""
|
||||
|
||||
print(f"Processing transcript_id {transcript_id}...", file=sys.stderr)
|
||||
await live_pipeline_process(transcript_id=transcript_id)
|
||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||
|
||||
pre_final_transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
|
||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||
assert pre_final_transcript.status != "ended"
|
||||
|
||||
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
|
||||
result = live_pipeline_post(transcript_id=transcript_id)
|
||||
|
||||
# result.ready() blocks even without await; it mutates result also
|
||||
while not result.ready():
|
||||
print(f"Status: {result.state}")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
|
||||
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
|
||||
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
# Wait for the Celery task to complete
|
||||
while not result.ready():
|
||||
print(f"File pipeline status: {result.state}", file=sys.stderr)
|
||||
time.sleep(2)
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
except ImportError as e:
|
||||
logger.error(f"File pipeline not available: {e}")
|
||||
logger.info("Falling back to stream pipeline")
|
||||
# Fall back to stream pipeline
|
||||
await process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
enable_diarization=enable_diarization,
|
||||
diarization_backend=diarization_backend,
|
||||
|
||||
async def process(
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
pipeline: Literal["live", "file"],
|
||||
output_path: str = None,
|
||||
):
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcript_id = await prepare_entry(
|
||||
session,
|
||||
source_path,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
pipeline_handlers = {
|
||||
"live": lambda tid: process_live_pipeline(session, tid),
|
||||
"file": process_file_pipeline,
|
||||
}
|
||||
|
||||
handler = pipeline_handlers.get(pipeline)
|
||||
if not handler:
|
||||
raise ValueError(f"Unknown pipeline type: {pipeline}")
|
||||
|
||||
await handler(transcript_id)
|
||||
|
||||
await extract_result_from_entry(session, transcript_id, output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
description="Process audio files with speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use streaming pipeline (original frame-based processing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
"--pipeline",
|
||||
required=True,
|
||||
choices=["live", "file"],
|
||||
help="Pipeline type to use for processing (live: streaming/incremental, file: batch/parallel)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
@@ -297,82 +213,14 @@ if __name__ == "__main__":
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="pyannote",
|
||||
choices=["pyannote", "modal"],
|
||||
help="Diarization backend to use (default: pyannote)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioDownscaleProcessor",
|
||||
"AudioChunkerAutoProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
if args.stream:
|
||||
# Use original streaming pipeline
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
process(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
args.pipeline,
|
||||
args.output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use optimized file pipeline (default)
|
||||
asyncio.run(
|
||||
process_file_pipeline(
|
||||
args.source,
|
||||
event_callback,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerAutoProcessor,
|
||||
AudioDownscaleProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file_with_diarization(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioDownscaleProcessor(),
|
||||
AudioChunkerAutoProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
processors += [
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="modal",
|
||||
choices=["modal"],
|
||||
help="Diarization backend to use (default: modal)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set REDIS_HOST to localhost if not provided
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioDownscaleProcessor",
|
||||
"AudioChunkerAutoProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file_with_diarization(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
@@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
@vibe-generated
|
||||
Test script for the diarization CLI tool
|
||||
=========================================
|
||||
|
||||
This script helps test the diarization functionality with sample audio files.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
async def test_diarization(audio_file: str):
|
||||
"""Test the diarization functionality"""
|
||||
|
||||
# Import the processing function
|
||||
from process_with_diarization import process_audio_file_with_diarization
|
||||
|
||||
# Collect events
|
||||
events = []
|
||||
|
||||
async def event_callback(event):
|
||||
events.append({"processor": event.processor, "data": event.data})
|
||||
logger.info(f"Event from {event.processor}")
|
||||
|
||||
# Process the audio file
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
|
||||
try:
|
||||
await process_audio_file_with_diarization(
|
||||
audio_file,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
)
|
||||
|
||||
# Analyze results
|
||||
logger.info(f"Processing complete. Received {len(events)} events")
|
||||
|
||||
# Look for diarization results
|
||||
diarized_topics = []
|
||||
for event in events:
|
||||
if "TitleSummary" in event["processor"]:
|
||||
# Check if words have speaker information
|
||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||
words = event["data"].transcript.words
|
||||
if words and hasattr(words[0], "speaker"):
|
||||
speakers = set(
|
||||
w.speaker for w in words if hasattr(w, "speaker")
|
||||
)
|
||||
logger.info(
|
||||
f"Found {len(speakers)} speakers in topic: {event['data'].title}"
|
||||
)
|
||||
diarized_topics.append(event["data"])
|
||||
|
||||
if diarized_topics:
|
||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||
|
||||
# Print sample output
|
||||
sample_topic = diarized_topics[0]
|
||||
logger.info("Sample diarized output:")
|
||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||
else:
|
||||
logger.warning("No diarization results found in output")
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during processing: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python test_diarization.py <audio_file>")
|
||||
sys.exit(1)
|
||||
|
||||
audio_file = sys.argv[1]
|
||||
if not Path(audio_file).exists():
|
||||
print(f"Error: Audio file '{audio_file}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_diarization(audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
server/reflector/utils/string.py
Normal file
23
server/reflector/utils/string.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, TypeAdapter, constr
|
||||
|
||||
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
|
||||
NonEmptyString = Annotated[
|
||||
NonEmptyStringBase,
|
||||
Field(description="A non-empty string", min_length=1),
|
||||
]
|
||||
non_empty_string_adapter = TypeAdapter(NonEmptyString)
|
||||
|
||||
|
||||
def parse_non_empty_string(s: str, error: str | None = None) -> NonEmptyString:
|
||||
try:
|
||||
return non_empty_string_adapter.validate_python(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"{e}: {error}" if error else e) from e
|
||||
|
||||
|
||||
def try_parse_non_empty_string(s: str) -> NonEmptyString | None:
|
||||
if not s:
|
||||
return None
|
||||
return parse_non_empty_string(s)
|
||||
@@ -10,6 +10,7 @@ from reflector.db.meetings import (
|
||||
meeting_consent_controller,
|
||||
meetings_controller,
|
||||
)
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -41,3 +42,34 @@ async def meeting_audio_consent(
|
||||
updated_consent = await meeting_consent_controller.upsert(consent)
|
||||
|
||||
return {"status": "success", "consent_id": updated_consent.id}
|
||||
|
||||
|
||||
@router.patch("/meetings/{meeting_id}/deactivate")
|
||||
async def meeting_deactivate(
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if not meeting.is_active:
|
||||
return {"status": "success", "meeting_id": meeting_id}
|
||||
|
||||
# Only room owner or meeting creator can deactivate
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id and user_id != meeting.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only the room owner can deactivate meetings"
|
||||
)
|
||||
|
||||
await meetings_controller.update_meeting(meeting_id, is_active=False)
|
||||
|
||||
return {"status": "success", "meeting_id": meeting_id}
|
||||
|
||||
@@ -1,33 +1,28 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
import asyncpg.exceptions
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import LockError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_database
|
||||
from reflector.db import get_session
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
from reflector.worker.webhook import test_webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
id: str
|
||||
@@ -42,16 +37,38 @@ class Room(BaseModel):
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
ics_last_sync: Optional[datetime] = None
|
||||
ics_last_etag: Optional[str] = None
|
||||
|
||||
|
||||
class RoomDetails(Room):
|
||||
webhook_url: str | None
|
||||
webhook_secret: str | None
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
id: str
|
||||
room_name: str
|
||||
room_url: str
|
||||
# TODO it's not always present, | None
|
||||
host_room_url: str
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
is_locked: bool = False
|
||||
room_mode: Literal["normal", "group"] = "normal"
|
||||
recording_type: Literal["none", "local", "cloud"] = "cloud"
|
||||
recording_trigger: Literal[
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
num_clients: int = 0
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CreateRoom(BaseModel):
|
||||
@@ -64,49 +81,163 @@ class CreateRoom(BaseModel):
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
webhook_url: str
|
||||
webhook_secret: str
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
|
||||
|
||||
class UpdateRoom(BaseModel):
|
||||
name: str
|
||||
zulip_auto_post: bool
|
||||
zulip_stream: str
|
||||
zulip_topic: str
|
||||
is_locked: bool
|
||||
room_mode: str
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
name: Optional[str] = None
|
||||
zulip_auto_post: Optional[bool] = None
|
||||
zulip_stream: Optional[str] = None
|
||||
zulip_topic: Optional[str] = None
|
||||
is_locked: Optional[bool] = None
|
||||
room_mode: Optional[str] = None
|
||||
recording_type: Optional[str] = None
|
||||
recording_trigger: Optional[str] = None
|
||||
is_shared: Optional[bool] = None
|
||||
webhook_url: Optional[str] = None
|
||||
webhook_secret: Optional[str] = None
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: Optional[int] = None
|
||||
ics_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class CreateRoomMeeting(BaseModel):
|
||||
allow_duplicated: Optional[bool] = False
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
@router.get("/rooms", response_model=Page[Room])
|
||||
class WebhookTestResult(BaseModel):
|
||||
success: bool
|
||||
message: str = ""
|
||||
error: str = ""
|
||||
status_code: int | None = None
|
||||
response_preview: str | None = None
|
||||
|
||||
|
||||
class ICSStatus(BaseModel):
|
||||
status: Literal["enabled", "disabled"]
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
last_etag: Optional[str] = None
|
||||
events_count: int = 0
|
||||
|
||||
|
||||
class SyncStatus(str, Enum):
|
||||
success = "success"
|
||||
unchanged = "unchanged"
|
||||
error = "error"
|
||||
skipped = "skipped"
|
||||
|
||||
|
||||
class ICSSyncResult(BaseModel):
|
||||
status: SyncStatus
|
||||
hash: Optional[str] = None
|
||||
events_found: int = 0
|
||||
total_events: int = 0
|
||||
events_created: int = 0
|
||||
events_updated: int = 0
|
||||
events_deleted: int = 0
|
||||
error: Optional[str] = None
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class CalendarEventResponse(BaseModel):
|
||||
id: str
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: Optional[list[dict]] = None
|
||||
location: Optional[str] = None
|
||||
last_synced: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
@router.get("/rooms", response_model=Page[RoomDetails])
|
||||
async def rooms_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> list[Room]:
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[RoomDetails]:
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await apaginate(
|
||||
get_database(),
|
||||
await rooms_controller.get_all(
|
||||
user_id=user_id, order_by="-created_at", return_query=True
|
||||
),
|
||||
query = await rooms_controller.get_all(
|
||||
session, user_id=user_id, order_by="-created_at", return_query=True
|
||||
)
|
||||
return await paginate(session, query)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}", response_model=RoomDetails)
|
||||
async def rooms_get(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
return room
|
||||
|
||||
|
||||
@router.get("/rooms/name/{room_name}", response_model=RoomDetails)
|
||||
async def rooms_get_by_name(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
# Convert to RoomDetails format (add webhook fields if user is owner)
|
||||
room_dict = room.__dict__.copy()
|
||||
if user_id == room.user_id:
|
||||
# User is owner, include webhook details if available
|
||||
room_dict["webhook_url"] = getattr(room, "webhook_url", None)
|
||||
room_dict["webhook_secret"] = getattr(room, "webhook_secret", None)
|
||||
else:
|
||||
# Non-owner, hide webhook details
|
||||
room_dict["webhook_url"] = None
|
||||
room_dict["webhook_secret"] = None
|
||||
|
||||
return RoomDetails(**room_dict)
|
||||
|
||||
|
||||
@router.post("/rooms", response_model=Room)
|
||||
async def rooms_create(
|
||||
room: CreateRoom,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await rooms_controller.add(
|
||||
session,
|
||||
name=room.name,
|
||||
user_id=user_id,
|
||||
zulip_auto_post=room.zulip_auto_post,
|
||||
@@ -117,21 +248,27 @@ async def rooms_create(
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
is_shared=room.is_shared,
|
||||
webhook_url=room.webhook_url,
|
||||
webhook_secret=room.webhook_secret,
|
||||
ics_url=room.ics_url,
|
||||
ics_fetch_interval=room.ics_fetch_interval,
|
||||
ics_enabled=room.ics_enabled,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/rooms/{room_id}", response_model=Room)
|
||||
@router.patch("/rooms/{room_id}", response_model=RoomDetails)
|
||||
async def rooms_update(
|
||||
room_id: str,
|
||||
info: UpdateRoom,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
values = info.dict(exclude_unset=True)
|
||||
await rooms_controller.update(room, values)
|
||||
await rooms_controller.update(session, room, values)
|
||||
return room
|
||||
|
||||
|
||||
@@ -139,73 +276,297 @@ async def rooms_update(
|
||||
async def rooms_delete(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id(room_id, user_id=user_id)
|
||||
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
await rooms_controller.remove_by_id(room.id, user_id=user_id)
|
||||
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/meeting", response_model=Meeting)
|
||||
async def rooms_create_meeting(
|
||||
room_name: str,
|
||||
info: CreateRoomMeeting,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
try:
|
||||
async with RedisAsyncLock(
|
||||
f"create_meeting:{room_name}",
|
||||
timeout=30,
|
||||
extend_interval=10,
|
||||
blocking_timeout=5.0,
|
||||
) as lock:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
|
||||
|
||||
meeting = None
|
||||
if not info.allow_duplicated:
|
||||
meeting = await meetings_controller.get_active(
|
||||
session, room=room, current_time=current_time
|
||||
)
|
||||
|
||||
if meeting is None:
|
||||
end_date = current_time + timedelta(hours=8)
|
||||
|
||||
whereby_meeting = await create_meeting("", end_date=end_date, room=room)
|
||||
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
# Now try to save to database
|
||||
try:
|
||||
meeting = await meetings_controller.create(
|
||||
session,
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
|
||||
start_date=parse_datetime_with_timezone(
|
||||
whereby_meeting["startDate"]
|
||||
),
|
||||
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
|
||||
user_id=user_id,
|
||||
room=room,
|
||||
)
|
||||
except (asyncpg.exceptions.UniqueViolationError, sqlite3.IntegrityError):
|
||||
# Another request already created a meeting for this room
|
||||
# Log this race condition occurrence
|
||||
logger.info(
|
||||
"Race condition detected for room %s - fetching existing meeting",
|
||||
room.name,
|
||||
)
|
||||
logger.warning(
|
||||
"Whereby meeting %s was created but not used (resource leak) for room %s",
|
||||
whereby_meeting["meetingId"],
|
||||
room.name,
|
||||
)
|
||||
|
||||
# Fetch the meeting that was created by the other request
|
||||
meeting = await meetings_controller.get_active(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
if meeting is None:
|
||||
# Edge case: meeting was created but expired/deleted between checks
|
||||
logger.error(
|
||||
"Meeting disappeared after race condition for room %s", room.name
|
||||
)
|
||||
except LockError:
|
||||
logger.warning("Failed to acquire lock for room %s within timeout", room_name)
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Unable to join meeting - please try again"
|
||||
status_code=503, detail="Meeting creation in progress, please try again"
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
|
||||
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
|
||||
async def rooms_test_webhook(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Test webhook configuration by sending a sample payload."""
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id and room.user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authorized to test this room's webhook"
|
||||
)
|
||||
|
||||
result = await test_webhook(room_id)
|
||||
return WebhookTestResult(**result)
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/ics/sync", response_model=ICSSyncResult)
|
||||
async def rooms_sync_ics(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only room owner can trigger ICS sync"
|
||||
)
|
||||
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
raise HTTPException(status_code=400, detail="ICS not configured for this room")
|
||||
|
||||
result = await ics_sync_service.sync_room_calendar(session, room)
|
||||
|
||||
if result["status"] == "error":
|
||||
raise HTTPException(
|
||||
status_code=500, detail=result.get("error", "Unknown error")
|
||||
)
|
||||
|
||||
return ICSSyncResult(**result)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/ics/status", response_model=ICSStatus)
|
||||
async def rooms_ics_status(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only room owner can view ICS status"
|
||||
)
|
||||
|
||||
next_sync = None
|
||||
if room.ics_enabled and room.ics_last_sync:
|
||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
session, room.id, include_deleted=False
|
||||
)
|
||||
|
||||
return ICSStatus(
|
||||
status="enabled" if room.ics_enabled else "disabled",
|
||||
last_sync=room.ics_last_sync,
|
||||
next_sync=next_sync,
|
||||
last_etag=room.ics_last_etag,
|
||||
events_count=len(events),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings", response_model=list[CalendarEventResponse])
|
||||
async def rooms_list_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
session, room.id, include_deleted=False
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
for event in events:
|
||||
event.description = None
|
||||
event.attendees = None
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_name}/meetings/upcoming", response_model=list[CalendarEventResponse]
|
||||
)
|
||||
async def rooms_list_upcoming_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
minutes_ahead: int = 120,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
session, room.id, minutes_ahead=minutes_ahead
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
for event in events:
|
||||
event.description = None
|
||||
event.attendees = None
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings/active", response_model=list[Meeting])
|
||||
async def rooms_list_active_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meetings = await meetings_controller.get_all_active_for_room(
|
||||
session, room=room, current_time=current_time
|
||||
)
|
||||
|
||||
# Hide host URLs from non-owners
|
||||
if user_id != room.user_id:
|
||||
for meeting in meetings:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meetings
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings/{meeting_id}", response_model=Meeting)
|
||||
async def rooms_get_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get a single meeting by ID within a specific room."""
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting.room_id != room.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Meeting does not belong to this room"
|
||||
)
|
||||
|
||||
if user_id != room.user_id and not room.is_shared:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/meetings/{meeting_id}/join", response_model=Meeting)
|
||||
async def rooms_join_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting.room_id != room.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Meeting does not belong to this room"
|
||||
)
|
||||
|
||||
if not meeting.is_active:
|
||||
raise HTTPException(status_code=400, detail="Meeting is not active")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
if meeting.end_date <= current_time:
|
||||
raise HTTPException(status_code=400, detail="Meeting has ended")
|
||||
|
||||
# Hide host URL from non-owners
|
||||
if user_id != room.user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
@@ -3,12 +3,13 @@ from typing import Annotated, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_database
|
||||
from reflector.db import get_session
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.search import (
|
||||
@@ -19,14 +20,15 @@ from reflector.db.search import (
|
||||
SearchOffsetBase,
|
||||
SearchParameters,
|
||||
SearchQuery,
|
||||
SearchQueryBase,
|
||||
SearchResult,
|
||||
SearchTotal,
|
||||
search_controller,
|
||||
search_query_adapter,
|
||||
)
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptParticipant,
|
||||
TranscriptStatus,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
@@ -63,7 +65,7 @@ class GetTranscriptMinimal(BaseModel):
|
||||
id: str
|
||||
user_id: str | None
|
||||
name: str
|
||||
status: str
|
||||
status: TranscriptStatus
|
||||
locked: bool
|
||||
duration: float
|
||||
title: str | None
|
||||
@@ -96,6 +98,7 @@ class CreateTranscript(BaseModel):
|
||||
name: str
|
||||
source_language: str = Field("en")
|
||||
target_language: str = Field("en")
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class UpdateTranscript(BaseModel):
|
||||
@@ -114,7 +117,19 @@ class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
|
||||
SearchQueryParamBase = constr(min_length=0, strip_whitespace=True)
|
||||
SearchQueryParam = Annotated[
|
||||
SearchQueryParamBase, Query(description="Search query text")
|
||||
]
|
||||
|
||||
|
||||
# http and api standards accept "q="; we would like to handle it as the absence of query, not as "empty string query"
|
||||
def parse_search_query_param(q: SearchQueryParam) -> SearchQuery | None:
|
||||
if q == "":
|
||||
return None
|
||||
return search_query_adapter.validate_python(q)
|
||||
|
||||
|
||||
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
|
||||
SearchOffsetParam = Annotated[
|
||||
SearchOffsetBase, Query(description="Number of results to skip")
|
||||
@@ -124,7 +139,7 @@ SearchOffsetParam = Annotated[
|
||||
class SearchResponse(BaseModel):
|
||||
results: list[SearchResult]
|
||||
total: SearchTotal
|
||||
query: SearchQuery
|
||||
query: SearchQuery | None = None
|
||||
limit: SearchLimit
|
||||
offset: SearchOffset
|
||||
|
||||
@@ -135,24 +150,25 @@ async def transcripts_list(
|
||||
source_kind: SourceKind | None = None,
|
||||
room_id: str | None = None,
|
||||
search_term: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await apaginate(
|
||||
get_database(),
|
||||
await transcripts_controller.get_all(
|
||||
query = await transcripts_controller.get_all(
|
||||
session,
|
||||
user_id=user_id,
|
||||
source_kind=SourceKind(source_kind) if source_kind else None,
|
||||
room_id=room_id,
|
||||
search_term=search_term,
|
||||
order_by="-created_at",
|
||||
return_query=True,
|
||||
),
|
||||
)
|
||||
|
||||
return await paginate(session, query)
|
||||
|
||||
|
||||
@router.get("/transcripts/search", response_model=SearchResponse)
|
||||
async def transcripts_search(
|
||||
@@ -164,6 +180,7 @@ async def transcripts_search(
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
] = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Full-text search across transcript titles and content.
|
||||
@@ -174,7 +191,7 @@ async def transcripts_search(
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
search_params = SearchParameters(
|
||||
query_text=q,
|
||||
query_text=parse_search_query_param(q),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
user_id=user_id,
|
||||
@@ -182,7 +199,7 @@ async def transcripts_search(
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
results, total = await search_controller.search_transcripts(search_params)
|
||||
results, total = await search_controller.search_transcripts(session, search_params)
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
@@ -197,11 +214,13 @@ async def transcripts_search(
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.add(
|
||||
session,
|
||||
info.name,
|
||||
source_kind=SourceKind.LIVE,
|
||||
source_kind=info.source_kind or SourceKind.LIVE,
|
||||
source_language=info.source_language,
|
||||
target_language=info.target_language,
|
||||
user_id=user_id,
|
||||
@@ -319,10 +338,11 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@@ -331,15 +351,16 @@ async def transcript_update(
|
||||
transcript_id: str,
|
||||
info: UpdateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = info.dict(exclude_unset=True)
|
||||
updated_transcript = await transcripts_controller.update(transcript, values)
|
||||
updated_transcript = await transcripts_controller.update(
|
||||
session, transcript, values
|
||||
)
|
||||
return updated_transcript
|
||||
|
||||
|
||||
@@ -347,19 +368,20 @@ async def transcript_update(
|
||||
async def transcript_delete(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if transcript.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
meeting = await meetings_controller.get_by_id(session, transcript.meeting_id)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
if room.is_shared:
|
||||
user_id = None
|
||||
|
||||
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id, user_id=user_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@@ -370,10 +392,11 @@ async def transcript_delete(
|
||||
async def transcript_get_topics(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# convert to GetTranscriptTopic
|
||||
@@ -389,10 +412,11 @@ async def transcript_get_topics(
|
||||
async def transcript_get_topics_with_words(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# convert to GetTranscriptTopicWithWords
|
||||
@@ -410,10 +434,11 @@ async def transcript_get_topics_with_words_per_speaker(
|
||||
transcript_id: str,
|
||||
topic_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# get the topic from the transcript
|
||||
@@ -432,10 +457,11 @@ async def transcript_post_to_zulip(
|
||||
topic: str,
|
||||
include_topics: bool,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -455,5 +481,5 @@ async def transcript_post_to_zulip(
|
||||
if not message_updated:
|
||||
response = await send_message_to_zulip(stream, topic, content)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"zulip_message_id": response["id"]}
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
|
||||
@@ -9,8 +9,10 @@ from typing import Annotated, Optional
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from jose import jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
@@ -32,6 +34,7 @@ async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
token: str | None = None,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
@@ -48,7 +51,7 @@ async def transcript_get_audio_mp3(
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
@@ -86,7 +89,7 @@ async def transcript_get_audio_mp3(
|
||||
|
||||
return range_requests_response(
|
||||
request,
|
||||
transcript.audio_mp3_filename,
|
||||
transcript.audio_mp3_filename.as_posix(),
|
||||
content_type="audio/mpeg",
|
||||
content_disposition=f"attachment; filename={filename}",
|
||||
)
|
||||
@@ -96,13 +99,18 @@ async def transcript_get_audio_mp3(
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_waveform_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return transcript.audio_waveform
|
||||
audio_waveform = transcript.audio_waveform
|
||||
if not audio_waveform:
|
||||
raise HTTPException(status_code=404, detail="Audio waveform not found")
|
||||
|
||||
return audio_waveform
|
||||
|
||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
||||
from reflector.views.types import DeletionStatus
|
||||
|
||||
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
|
||||
async def transcript_get_participants(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Participant]:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.participants is None:
|
||||
@@ -57,10 +60,11 @@ async def transcript_add_participant(
|
||||
transcript_id: str,
|
||||
participant: CreateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
@@ -73,7 +77,7 @@ async def transcript_add_participant(
|
||||
)
|
||||
|
||||
obj = await transcripts_controller.upsert_participant(
|
||||
transcript, TranscriptParticipant(**participant.dict())
|
||||
session, transcript, TranscriptParticipant(**participant.dict())
|
||||
)
|
||||
return Participant.model_validate(obj)
|
||||
|
||||
@@ -83,10 +87,11 @@ async def transcript_get_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
for p in transcript.participants:
|
||||
@@ -102,10 +107,11 @@ async def transcript_update_participant(
|
||||
participant_id: str,
|
||||
participant: UpdateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
@@ -130,7 +136,7 @@ async def transcript_update_participant(
|
||||
fields = participant.dict(exclude_unset=True)
|
||||
obj = obj.copy(update=fields)
|
||||
|
||||
await transcripts_controller.upsert_participant(transcript, obj)
|
||||
await transcripts_controller.upsert_participant(session, transcript, obj)
|
||||
return Participant.model_validate(obj)
|
||||
|
||||
|
||||
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> DeletionStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||
await transcripts_controller.delete_participant(session, transcript, participant_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
@@ -3,10 +3,12 @@ from typing import Annotated, Optional
|
||||
import celery
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_process
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
|
||||
async def transcript_process(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
@@ -34,13 +37,13 @@ async def transcript_process(
|
||||
)
|
||||
|
||||
if task_is_scheduled_or_active(
|
||||
"reflector.pipelines.main_live_pipeline.task_pipeline_process",
|
||||
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
|
||||
transcript_id=transcript_id,
|
||||
):
|
||||
return ProcessStatus(status="already running")
|
||||
|
||||
# schedule a background task process the file
|
||||
task_pipeline_process.delay(transcript_id=transcript_id)
|
||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
return ProcessStatus(status="ok")
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
router = APIRouter()
|
||||
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
|
||||
transcript_id: str,
|
||||
assignment: SpeakerAssignment,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
@@ -79,7 +82,9 @@ async def transcript_assign_speaker(
|
||||
# if the participant does not have a speaker, create one
|
||||
if participant.speaker is None:
|
||||
participant.speaker = transcript.find_empty_speaker()
|
||||
await transcripts_controller.upsert_participant(transcript, participant)
|
||||
await transcripts_controller.upsert_participant(
|
||||
session, transcript, participant
|
||||
)
|
||||
|
||||
speaker = participant.speaker
|
||||
|
||||
@@ -100,6 +105,7 @@ async def transcript_assign_speaker(
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
@@ -114,10 +120,11 @@ async def transcript_merge_speaker(
|
||||
transcript_id: str,
|
||||
merge: SpeakerMerge,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
@@ -163,6 +170,7 @@ async def transcript_merge_speaker(
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
|
||||
@@ -3,10 +3,12 @@ from typing import Annotated, Optional
|
||||
import av
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_process
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -22,10 +24,11 @@ async def transcript_record_upload(
|
||||
total_chunks: int,
|
||||
chunk: UploadFile,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
@@ -89,9 +92,9 @@ async def transcript_record_upload(
|
||||
container.close()
|
||||
|
||||
# set the status to "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_process.delay(transcript_id=transcript_id)
|
||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
return UploadStatus(status="ok")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
|
||||
@@ -24,7 +24,7 @@ async def transcript_events_websocket(
|
||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
# user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
|
||||
@@ -68,8 +68,7 @@ async def whereby_webhook(event: WherebyWebhookEvent, request: Request):
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if event.type in ["room.client.joined", "room.client.left"]:
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, num_clients=event.data["numClients"]
|
||||
)
|
||||
update_data = {"num_clients": event.data["numClients"]}
|
||||
await meetings_controller.update_meeting(meeting.id, **update_data)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1,18 +1,60 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.db.rooms import Room
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.string import parse_non_empty_string
|
||||
|
||||
HEADERS = {
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_headers():
|
||||
api_key = parse_non_empty_string(
|
||||
settings.WHEREBY_API_KEY, "WHEREBY_API_KEY value is required."
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {settings.WHEREBY_API_KEY}",
|
||||
}
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
|
||||
TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
def _get_whereby_s3_auth():
|
||||
errors = []
|
||||
try:
|
||||
bucket_name = parse_non_empty_string(
|
||||
settings.RECORDING_STORAGE_AWS_BUCKET_NAME,
|
||||
"RECORDING_STORAGE_AWS_BUCKET_NAME value is required.",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
try:
|
||||
key_id = parse_non_empty_string(
|
||||
settings.AWS_WHEREBY_ACCESS_KEY_ID,
|
||||
"AWS_WHEREBY_ACCESS_KEY_ID value is required.",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
try:
|
||||
key_secret = parse_non_empty_string(
|
||||
settings.AWS_WHEREBY_ACCESS_KEY_SECRET,
|
||||
"AWS_WHEREBY_ACCESS_KEY_SECRET value is required.",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
if len(errors) > 0:
|
||||
raise Exception(
|
||||
f"Failed to get Whereby auth settings: {', '.join(str(e) for e in errors)}"
|
||||
)
|
||||
return bucket_name, key_id, key_secret
|
||||
|
||||
|
||||
async def create_meeting(room_name_prefix: str, end_date: datetime, room: Room):
|
||||
s3_bucket_name, s3_key_id, s3_key_secret = _get_whereby_s3_auth()
|
||||
data = {
|
||||
"isLocked": room.is_locked,
|
||||
"roomNamePrefix": room_name_prefix,
|
||||
@@ -23,23 +65,26 @@ async def create_meeting(room_name_prefix: str, end_date: datetime, room: Room):
|
||||
"type": room.recording_type,
|
||||
"destination": {
|
||||
"provider": "s3",
|
||||
"bucket": settings.RECORDING_STORAGE_AWS_BUCKET_NAME,
|
||||
"accessKeyId": settings.AWS_WHEREBY_ACCESS_KEY_ID,
|
||||
"accessKeySecret": settings.AWS_WHEREBY_ACCESS_KEY_SECRET,
|
||||
"bucket": s3_bucket_name,
|
||||
"accessKeyId": s3_key_id,
|
||||
"accessKeySecret": s3_key_secret,
|
||||
"fileFormat": "mp4",
|
||||
},
|
||||
"startTrigger": room.recording_trigger,
|
||||
},
|
||||
"fields": ["hostRoomUrl"],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.WHEREBY_API_URL}/meetings",
|
||||
headers=HEADERS,
|
||||
headers=_get_headers(),
|
||||
json=data,
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
if response.status_code == 403:
|
||||
logger.warning(
|
||||
f"Failed to create meeting: access denied on Whereby: {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -48,7 +93,7 @@ async def get_room_sessions(room_name: str):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.WHEREBY_API_URL}/insights/room-sessions?roomName={room_name}",
|
||||
headers=HEADERS,
|
||||
headers=_get_headers(),
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -19,6 +19,8 @@ else:
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.worker.healthcheck",
|
||||
"reflector.worker.process",
|
||||
"reflector.worker.cleanup",
|
||||
"reflector.worker.ics_sync",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -36,8 +38,26 @@ else:
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
},
|
||||
"sync_all_ics_calendars": {
|
||||
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
|
||||
"schedule": 60.0, # Run every minute to check which rooms need sync
|
||||
},
|
||||
"create_upcoming_meetings": {
|
||||
"task": "reflector.worker.ics_sync.create_upcoming_meetings",
|
||||
"schedule": 30.0, # Run every 30 seconds to create upcoming meetings
|
||||
},
|
||||
}
|
||||
|
||||
if settings.PUBLIC_MODE:
|
||||
app.conf.beat_schedule["cleanup_old_public_data"] = {
|
||||
"task": "reflector.worker.cleanup.cleanup_old_public_data_task",
|
||||
"schedule": crontab(hour=3, minute=0),
|
||||
}
|
||||
logger.info(
|
||||
"Public mode cleanup enabled",
|
||||
retention_days=settings.PUBLIC_DATA_RETENTION_DAYS,
|
||||
)
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
|
||||
166
server/reflector/worker/cleanup.py
Normal file
166
server/reflector/worker/cleanup.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Main task for cleanup old public data.
|
||||
|
||||
Deletes old anonymous transcripts and their associated meetings/recordings.
|
||||
Transcripts are the main entry point - any associated data is also removed.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from pydantic.types import PositiveInt
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_recordings_storage
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class CleanupStats(TypedDict):
|
||||
"""Statistics for cleanup operation."""
|
||||
|
||||
transcripts_deleted: int
|
||||
meetings_deleted: int
|
||||
recordings_deleted: int
|
||||
errors: list[str]
|
||||
|
||||
|
||||
async def delete_single_transcript(
|
||||
session: AsyncSession, transcript_data: dict, stats: CleanupStats
|
||||
):
|
||||
transcript_id = transcript_data["id"]
|
||||
meeting_id = transcript_data["meeting_id"]
|
||||
recording_id = transcript_data["recording_id"]
|
||||
|
||||
try:
|
||||
if meeting_id:
|
||||
await session.execute(
|
||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
)
|
||||
stats["meetings_deleted"] += 1
|
||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
||||
|
||||
if recording_id:
|
||||
result = await session.execute(
|
||||
select(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
recording = result.mappings().first()
|
||||
if recording:
|
||||
try:
|
||||
await get_recordings_storage().delete_file(recording["object_key"])
|
||||
except Exception as storage_error:
|
||||
logger.warning(
|
||||
"Failed to delete recording from storage",
|
||||
recording_id=recording_id,
|
||||
object_key=recording["object_key"],
|
||||
error=str(storage_error),
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
stats["recordings_deleted"] += 1
|
||||
logger.info("Deleted associated recording", recording_id=recording_id)
|
||||
|
||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
||||
stats["transcripts_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted transcript",
|
||||
transcript_id=transcript_id,
|
||||
created_at=transcript_data["created_at"].isoformat(),
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
stats["errors"].append(error_msg)
|
||||
|
||||
|
||||
async def cleanup_old_transcripts(
|
||||
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
|
||||
):
|
||||
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
||||
query = select(
|
||||
TranscriptModel.id,
|
||||
TranscriptModel.meeting_id,
|
||||
TranscriptModel.recording_id,
|
||||
TranscriptModel.created_at,
|
||||
).where(
|
||||
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
old_transcripts = result.mappings().all()
|
||||
|
||||
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
|
||||
|
||||
for transcript_data in old_transcripts:
|
||||
try:
|
||||
await delete_single_transcript(session, transcript_data, stats)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
stats["errors"].append(error_msg)
|
||||
|
||||
|
||||
def log_cleanup_results(stats: CleanupStats):
|
||||
logger.info(
|
||||
"Cleanup completed",
|
||||
transcripts_deleted=stats["transcripts_deleted"],
|
||||
meetings_deleted=stats["meetings_deleted"],
|
||||
recordings_deleted=stats["recordings_deleted"],
|
||||
errors_count=len(stats["errors"]),
|
||||
)
|
||||
|
||||
if stats["errors"]:
|
||||
logger.warning(
|
||||
"Cleanup completed with errors",
|
||||
errors=stats["errors"][:10],
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_old_public_data(
|
||||
session: AsyncSession,
|
||||
days: PositiveInt | None = None,
|
||||
) -> CleanupStats | None:
|
||||
if days is None:
|
||||
days = settings.PUBLIC_DATA_RETENTION_DAYS
|
||||
|
||||
if not settings.PUBLIC_MODE:
|
||||
logger.info("Skipping cleanup - not a public instance")
|
||||
return None
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
logger.info(
|
||||
"Starting cleanup of old public data",
|
||||
cutoff_date=cutoff_date.isoformat(),
|
||||
)
|
||||
|
||||
stats: CleanupStats = {
|
||||
"transcripts_deleted": 0,
|
||||
"meetings_deleted": 0,
|
||||
"recordings_deleted": 0,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
await cleanup_old_transcripts(session, cutoff_date, stats)
|
||||
|
||||
log_cleanup_results(stats)
|
||||
return stats
|
||||
|
||||
|
||||
@shared_task(
|
||||
autoretry_for=(Exception,),
|
||||
retry_kwargs={"max_retries": 3, "countdown": 300},
|
||||
)
|
||||
@asynctask
|
||||
@with_session
|
||||
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
|
||||
await cleanup_old_public_data(session, days=days)
|
||||
186
server/reflector/worker/ics_sync.py
Normal file
186
server/reflector/worker/ics_sync.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def sync_room_ics(session: AsyncSession, room_id: str):
|
||||
try:
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||
return
|
||||
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
logger.debug("ICS not enabled for room", room_id=room_id)
|
||||
return
|
||||
|
||||
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
|
||||
result = await ics_sync_service.sync_room_calendar(session, room)
|
||||
|
||||
if result["status"] == SyncStatus.SUCCESS:
|
||||
logger.info(
|
||||
"ICS sync completed successfully",
|
||||
room_id=room_id,
|
||||
events_found=result.get("events_found", 0),
|
||||
events_created=result.get("events_created", 0),
|
||||
events_updated=result.get("events_updated", 0),
|
||||
events_deleted=result.get("events_deleted", 0),
|
||||
)
|
||||
elif result["status"] == SyncStatus.UNCHANGED:
|
||||
logger.debug("ICS content unchanged", room_id=room_id)
|
||||
elif result["status"] == SyncStatus.ERROR:
|
||||
logger.error("ICS sync failed", room_id=room_id, error=result.get("error"))
|
||||
else:
|
||||
logger.debug(
|
||||
"ICS sync skipped", room_id=room_id, reason=result.get("reason")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during ICS sync", room_id=room_id, error=str(e))
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def sync_all_ics_calendars(session: AsyncSession):
|
||||
try:
|
||||
logger.info("Starting sync for all ICS-enabled rooms")
|
||||
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
if not _should_sync(room):
|
||||
logger.debug("Skipping room, not time to sync yet", room_id=room.id)
|
||||
continue
|
||||
|
||||
sync_room_ics.delay(room.id)
|
||||
|
||||
logger.info("Queued sync tasks for all eligible rooms")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in sync_all_ics_calendars", error=str(e))
|
||||
|
||||
|
||||
def _should_sync(room) -> bool:
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
|
||||
|
||||
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
||||
|
||||
|
||||
async def create_upcoming_meetings_for_event(
|
||||
session: AsyncSession, event, create_window, room_id, room
|
||||
):
|
||||
if event.start_time <= create_window:
|
||||
return
|
||||
existing_meeting = await meetings_controller.get_by_calendar_event(
|
||||
session, event.id
|
||||
)
|
||||
|
||||
if existing_meeting:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Pre-creating meeting for calendar event",
|
||||
room_id=room_id,
|
||||
event_id=event.id,
|
||||
event_title=event.title,
|
||||
)
|
||||
|
||||
try:
|
||||
end_date = event.end_time or (event.start_time + MEETING_DEFAULT_DURATION)
|
||||
|
||||
whereby_meeting = await create_meeting(
|
||||
"",
|
||||
end_date=end_date,
|
||||
room=room,
|
||||
)
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
meeting = await meetings_controller.create(
|
||||
session,
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
|
||||
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
|
||||
room=room,
|
||||
calendar_event_id=event.id,
|
||||
calendar_metadata={
|
||||
"title": event.title,
|
||||
"description": event.description,
|
||||
"attendees": event.attendees,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Meeting pre-created successfully",
|
||||
meeting_id=meeting.id,
|
||||
event_id=event.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to pre-create meeting",
|
||||
room_id=room_id,
|
||||
event_id=event.id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def create_upcoming_meetings(session: AsyncSession):
|
||||
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
||||
if not lock.acquired:
|
||||
logger.warning(
|
||||
"Another worker is already creating upcoming meetings, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting creation of upcoming meetings")
|
||||
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
now = datetime.now(timezone.utc)
|
||||
create_window = now - timedelta(minutes=6)
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
session,
|
||||
room.id,
|
||||
minutes_ahead=7,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
await create_upcoming_meetings_for_event(
|
||||
session, event, create_window, room.id, room
|
||||
)
|
||||
logger.info("Completed pre-creation check for upcoming meetings")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in create_upcoming_meetings", error=str(e))
|
||||
@@ -9,6 +9,8 @@ import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from pydantic import ValidationError
|
||||
from redis.exceptions import LockError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.recordings import Recording, recordings_controller
|
||||
@@ -16,8 +18,10 @@ from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.redis_cache import get_redis_client
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import get_room_sessions
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
@@ -73,30 +77,39 @@ def process_messages():
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def process_recording(bucket_name: str, object_key: str):
|
||||
@with_session
|
||||
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
|
||||
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
||||
|
||||
# extract a guid and a datetime from the object key
|
||||
room_name = f"/{object_key[:36]}"
|
||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||
|
||||
meeting = await meetings_controller.get_by_room_name(room_name)
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
||||
if not meeting:
|
||||
logger.warning("Room not found, may be deleted ?", room_name=room_name)
|
||||
return
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
session, bucket_name, object_key
|
||||
)
|
||||
if not recording:
|
||||
recording = await recordings_controller.create(
|
||||
session,
|
||||
Recording(
|
||||
bucket_name=bucket_name,
|
||||
object_key=object_key,
|
||||
recorded_at=recorded_at,
|
||||
meeting_id=meeting.id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_recording_id(recording.id)
|
||||
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
|
||||
if transcript:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
@@ -104,6 +117,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
)
|
||||
else:
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
"",
|
||||
source_kind=SourceKind.ROOM,
|
||||
source_language="en",
|
||||
@@ -139,37 +153,110 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def process_meetings():
|
||||
@with_session
|
||||
async def process_meetings(session: AsyncSession):
|
||||
"""
|
||||
Checks which meetings are still active and deactivates those that have ended.
|
||||
|
||||
Deactivation logic:
|
||||
- Active sessions: Keep meeting active regardless of scheduled time
|
||||
- No active sessions:
|
||||
* Calendar meetings:
|
||||
- If previously used (had sessions): Deactivate immediately
|
||||
- If never used: Keep active until scheduled end time, then deactivate
|
||||
* On-the-fly meetings: Deactivate immediately (created when someone joins,
|
||||
so no sessions means everyone left)
|
||||
|
||||
Uses distributed locking to prevent race conditions when multiple workers
|
||||
process the same meeting simultaneously.
|
||||
"""
|
||||
logger.info("Processing meetings")
|
||||
meetings = await meetings_controller.get_all_active()
|
||||
meetings = await meetings_controller.get_all_active(session)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
redis_client = get_redis_client()
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for meeting in meetings:
|
||||
is_active = False
|
||||
logger_ = logger.bind(meeting_id=meeting.id, room_name=meeting.room_name)
|
||||
lock_key = f"meeting_process_lock:{meeting.id}"
|
||||
lock = redis_client.lock(lock_key, timeout=120)
|
||||
|
||||
try:
|
||||
if not lock.acquire(blocking=False):
|
||||
logger_.debug("Meeting is being processed by another worker, skipping")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Process the meeting
|
||||
should_deactivate = False
|
||||
end_date = meeting.end_date
|
||||
if end_date.tzinfo is None:
|
||||
end_date = end_date.replace(tzinfo=timezone.utc)
|
||||
if end_date > datetime.now(timezone.utc):
|
||||
|
||||
# This API call could be slow, extend lock if needed
|
||||
response = await get_room_sessions(meeting.room_name)
|
||||
|
||||
try:
|
||||
# Extend lock after slow operation to ensure we still hold it
|
||||
lock.extend(120, replace_ttl=True)
|
||||
except LockError:
|
||||
logger_.warning("Lost lock for meeting, skipping")
|
||||
continue
|
||||
|
||||
room_sessions = response.get("results", [])
|
||||
is_active = not room_sessions or any(
|
||||
has_active_sessions = room_sessions and any(
|
||||
rs["endedAt"] is None for rs in room_sessions
|
||||
)
|
||||
if not is_active:
|
||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||
logger.info("Meeting %s is deactivated", meeting.id)
|
||||
has_had_sessions = bool(room_sessions)
|
||||
|
||||
logger.info("Processed meetings")
|
||||
if has_active_sessions:
|
||||
logger_.debug("Meeting still has active sessions, keep it")
|
||||
elif has_had_sessions:
|
||||
should_deactivate = True
|
||||
logger_.info("Meeting ended - all participants left")
|
||||
elif current_time > end_date:
|
||||
should_deactivate = True
|
||||
logger_.info(
|
||||
"Meeting deactivated - scheduled time ended with no participants",
|
||||
)
|
||||
else:
|
||||
logger_.debug("Meeting not yet started, keep it")
|
||||
|
||||
if should_deactivate:
|
||||
await meetings_controller.update_meeting(
|
||||
session, meeting.id, is_active=False
|
||||
)
|
||||
logger_.info("Meeting is deactivated")
|
||||
|
||||
processed_count += 1
|
||||
|
||||
except Exception:
|
||||
logger_.error("Error processing meeting", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
lock.release()
|
||||
except LockError:
|
||||
pass # Lock already released or expired
|
||||
|
||||
logger.info(
|
||||
"Processed meetings finished",
|
||||
processed_count=processed_count,
|
||||
skipped_count=skipped_count,
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def reprocess_failed_recordings():
|
||||
@with_session
|
||||
async def reprocess_failed_recordings(session: AsyncSession):
|
||||
"""
|
||||
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
||||
If not, requeue them for processing.
|
||||
@@ -200,7 +287,7 @@ async def reprocess_failed_recordings():
|
||||
continue
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
bucket_name, object_key
|
||||
session, bucket_name, object_key
|
||||
)
|
||||
if not recording:
|
||||
logger.info(f"Queueing recording for processing: {object_key}")
|
||||
@@ -211,10 +298,12 @@ async def reprocess_failed_recordings():
|
||||
transcript = None
|
||||
try:
|
||||
transcript = await transcripts_controller.get_by_recording_id(
|
||||
recording.id
|
||||
session, recording.id
|
||||
)
|
||||
except ValidationError:
|
||||
await transcripts_controller.remove_by_recording_id(recording.id)
|
||||
await transcripts_controller.remove_by_recording_id(
|
||||
session, recording.id
|
||||
)
|
||||
logger.warning(
|
||||
f"Removed invalid transcript for recording: {recording.id}"
|
||||
)
|
||||
|
||||
109
server/reflector/worker/session_decorator.py
Normal file
109
server/reflector/worker/session_decorator.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Session management decorator for async worker tasks.
|
||||
|
||||
This decorator ensures that all worker tasks have a properly managed database session
|
||||
that stays open for the entire duration of the task execution.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from celery import current_task
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.logger import logger
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def with_session(func: F) -> F:
|
||||
"""
|
||||
Decorator that provides an AsyncSession as the first argument to the decorated function.
|
||||
|
||||
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
|
||||
proper session management throughout the task execution.
|
||||
|
||||
Example:
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def my_task(session: AsyncSession, arg1: str, arg2: int):
|
||||
# session is automatically provided and managed
|
||||
result = await some_controller.get_by_id(session, arg1)
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
# Pass session as first argument to the decorated function
|
||||
return await func(session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_session_and_transcript(func: F) -> F:
|
||||
"""
|
||||
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
|
||||
|
||||
This decorator:
|
||||
1. Extracts transcript_id from kwargs
|
||||
2. Creates and manages a database session
|
||||
3. Fetches the transcript using the session
|
||||
4. Creates an enhanced logger with Celery task context
|
||||
5. Passes session, transcript, and logger to the decorated function
|
||||
|
||||
This should be used AFTER the @asynctask decorator on Celery tasks.
|
||||
|
||||
Example:
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session_and_transcript
|
||||
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
|
||||
# session, transcript, and logger are automatically provided
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id", None)
|
||||
if not transcript_id:
|
||||
raise ValueError(
|
||||
"transcript_id is required for @with_session_and_transcript"
|
||||
)
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
# Fetch the transcript
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Create enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
if current_task:
|
||||
tlogger = tlogger.bind(
|
||||
task_id=current_task.request.id,
|
||||
task_name=current_task.name,
|
||||
worker_hostname=current_task.request.hostname,
|
||||
task_retries=current_task.request.retries,
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Pass session, transcript, and logger to the decorated function
|
||||
return await func(
|
||||
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
||||
)
|
||||
except Exception:
|
||||
tlogger.exception("Error in task execution")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
262
server/reflector/worker/webhook.py
Normal file
262
server/reflector/worker/webhook.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Webhook task for sending transcript notifications."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.webvtt import topics_to_webvtt
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
|
||||
def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> str:
|
||||
"""Generate HMAC signature for webhook payload."""
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
hmac_obj = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
)
|
||||
return hmac_obj.hexdigest()
|
||||
|
||||
|
||||
@shared_task(
|
||||
bind=True,
|
||||
max_retries=30,
|
||||
default_retry_delay=60,
|
||||
retry_backoff=True,
|
||||
retry_backoff_max=3600, # Max 1 hour between retries
|
||||
)
|
||||
@asynctask
|
||||
@with_session
|
||||
async def send_transcript_webhook(
|
||||
self,
|
||||
transcript_id: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
session: AsyncSession,
|
||||
):
|
||||
log = logger.bind(
|
||||
transcript_id=transcript_id,
|
||||
room_id=room_id,
|
||||
retry_count=self.request.retries,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch transcript and room
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
log.error("Transcript not found, skipping webhook")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
log.error("Room not found, skipping webhook")
|
||||
return
|
||||
|
||||
if not room.webhook_url:
|
||||
log.info("No webhook URL configured for room, skipping")
|
||||
return
|
||||
|
||||
# Generate WebVTT content from topics
|
||||
topics_data = []
|
||||
|
||||
if transcript.topics:
|
||||
# Build topics data with diarized content per topic
|
||||
for topic in transcript.topics:
|
||||
topic_webvtt = topics_to_webvtt([topic]) if topic.words else ""
|
||||
topics_data.append(
|
||||
{
|
||||
"title": topic.title,
|
||||
"summary": topic.summary,
|
||||
"timestamp": topic.timestamp,
|
||||
"duration": topic.duration,
|
||||
"webvtt": topic_webvtt,
|
||||
}
|
||||
)
|
||||
|
||||
# Build webhook payload
|
||||
frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}"
|
||||
participants = [
|
||||
{"id": p.id, "name": p.name, "speaker": p.speaker}
|
||||
for p in (transcript.participants or [])
|
||||
]
|
||||
payload_data = {
|
||||
"event": "transcript.completed",
|
||||
"event_id": event_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"transcript": {
|
||||
"id": transcript.id,
|
||||
"room_id": transcript.room_id,
|
||||
"created_at": transcript.created_at.isoformat(),
|
||||
"duration": transcript.duration,
|
||||
"title": transcript.title,
|
||||
"short_summary": transcript.short_summary,
|
||||
"long_summary": transcript.long_summary,
|
||||
"webvtt": transcript.webvtt,
|
||||
"topics": topics_data,
|
||||
"participants": participants,
|
||||
"source_language": transcript.source_language,
|
||||
"target_language": transcript.target_language,
|
||||
"status": transcript.status,
|
||||
"frontend_url": frontend_url,
|
||||
},
|
||||
"room": {
|
||||
"id": room.id,
|
||||
"name": room.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Convert to JSON
|
||||
payload_json = json.dumps(payload_data, separators=(",", ":"))
|
||||
payload_bytes = payload_json.encode("utf-8")
|
||||
|
||||
# Generate signature if secret is configured
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "Reflector-Webhook/1.0",
|
||||
"X-Webhook-Event": "transcript.completed",
|
||||
"X-Webhook-Retry": str(self.request.retries),
|
||||
}
|
||||
|
||||
if room.webhook_secret:
|
||||
timestamp = str(int(datetime.now(timezone.utc).timestamp()))
|
||||
signature = generate_webhook_signature(
|
||||
payload_bytes, room.webhook_secret, timestamp
|
||||
)
|
||||
headers["X-Webhook-Signature"] = f"t={timestamp},v1={signature}"
|
||||
|
||||
# Send webhook with timeout
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
log.info(
|
||||
"Sending webhook",
|
||||
url=room.webhook_url,
|
||||
payload_size=len(payload_bytes),
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
room.webhook_url,
|
||||
content=payload_bytes,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
log.info(
|
||||
"Webhook sent successfully",
|
||||
status_code=response.status_code,
|
||||
response_size=len(response.content),
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
log.error(
|
||||
"Webhook failed with HTTP error",
|
||||
status_code=e.response.status_code,
|
||||
response_text=e.response.text[:500], # First 500 chars
|
||||
)
|
||||
|
||||
# Don't retry on client errors (4xx)
|
||||
if 400 <= e.response.status_code < 500:
|
||||
log.error("Client error, not retrying")
|
||||
return
|
||||
|
||||
# Retry on server errors (5xx)
|
||||
raise self.retry(exc=e)
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
# Retry on network errors
|
||||
log.error("Webhook failed with connection error", error=str(e))
|
||||
raise self.retry(exc=e)
|
||||
|
||||
except Exception as e:
|
||||
# Retry on unexpected errors
|
||||
log.exception("Unexpected error in webhook task", error=str(e))
|
||||
raise self.retry(exc=e)
|
||||
|
||||
|
||||
async def test_webhook(room_id: str) -> dict:
|
||||
"""
|
||||
Test webhook configuration by sending a sample payload.
|
||||
Returns immediately with success/failure status.
|
||||
This is the shared implementation used by both the API endpoint and Celery task.
|
||||
"""
|
||||
try:
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
if not room:
|
||||
return {"success": False, "error": "Room not found"}
|
||||
|
||||
if not room.webhook_url:
|
||||
return {"success": False, "error": "No webhook URL configured"}
|
||||
|
||||
now = (datetime.now(timezone.utc).isoformat(),)
|
||||
payload_data = {
|
||||
"event": "test",
|
||||
"event_id": uuid.uuid4().hex,
|
||||
"timestamp": now,
|
||||
"message": "This is a test webhook from Reflector",
|
||||
"room": {
|
||||
"id": room.id,
|
||||
"name": room.name,
|
||||
},
|
||||
}
|
||||
|
||||
payload_json = json.dumps(payload_data, separators=(",", ":"))
|
||||
payload_bytes = payload_json.encode("utf-8")
|
||||
|
||||
# Generate headers with signature
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "Reflector-Webhook/1.0",
|
||||
"X-Webhook-Event": "test",
|
||||
}
|
||||
|
||||
if room.webhook_secret:
|
||||
timestamp = str(int(datetime.now(timezone.utc).timestamp()))
|
||||
signature = generate_webhook_signature(
|
||||
payload_bytes, room.webhook_secret, timestamp
|
||||
)
|
||||
headers["X-Webhook-Signature"] = f"t={timestamp},v1={signature}"
|
||||
|
||||
# Send test webhook with short timeout
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
room.webhook_url,
|
||||
content=payload_bytes,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": response.is_success,
|
||||
"status_code": response.status_code,
|
||||
"message": f"Webhook test {'successful' if response.is_success else 'failed'}",
|
||||
"response_preview": response.text if response.text else None,
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Webhook request timed out (10 seconds)",
|
||||
}
|
||||
except httpx.ConnectError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Could not connect to webhook URL: {str(e)}",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unexpected error: {str(e)}",
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user