mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 267b7401ea | |||
| aea9de393c | |||
| dc177af3ff | |||
| 5bd8233657 | |||
| 28ac031ff6 | |||
| 1878834ce6 | |||
| f5b82d44e3 | |||
| ad56165b54 | |||
| 4ee19ed015 | |||
| 406164033d | |||
| 81d316cb56 | |||
| db3beae5cd | |||
|
|
03b9a18c1b | ||
|
|
7e3027adb6 | ||
|
|
27b43d85ab | ||
| 2289a1a231 | |||
| d0e130eb13 | |||
| 24fabe3e86 | |||
| 6fedbbe63f | |||
| b39175cdc9 | |||
| 2a2af5fff2 | |||
| ad44492cae | |||
| 901a239952 | |||
| d77b5611f8 | |||
| fc38345d65 | |||
| 5a1d662dc4 | |||
| 033bd4bc48 | |||
| 0eb670ca19 |
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -1,19 +1,21 @@
|
||||
## ⚠️ Insert the PR TITLE replacing this text ⚠️
|
||||
<!--- Provide a general summary of your changes in the Title above -->
|
||||
|
||||
⚠️ Describe your PR replacing this text. Post screenshots or videos whenever possible. ⚠️
|
||||
## Description
|
||||
<!--- Describe your changes in detail -->
|
||||
|
||||
### Checklist
|
||||
## Related Issue
|
||||
<!--- This project only accepts pull requests related to open issues -->
|
||||
<!--- If suggesting a new feature or change, please discuss it in an issue first -->
|
||||
<!--- If fixing a bug, there should be an issue describing it with steps to reproduce -->
|
||||
<!--- Please link to the issue here: -->
|
||||
|
||||
- [ ] My branch is updated with main (mandatory)
|
||||
- [ ] I wrote unit tests for this (if applies)
|
||||
- [ ] I have included migrations and tested them locally (if applies)
|
||||
- [ ] I have manually tested this feature locally
|
||||
## Motivation and Context
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
<!--- If it fixes an open issue, please link to the issue here. -->
|
||||
|
||||
> IMPORTANT: Remember that you are responsible for merging this PR after it's been reviewed, and once deployed
|
||||
> you should perform manual testing to make sure everything went smoothly.
|
||||
|
||||
### Urgency
|
||||
|
||||
- [ ] Urgent (deploy ASAP)
|
||||
- [ ] Non-urgent (deploying in next release is ok)
|
||||
## How Has This Been Tested?
|
||||
<!--- Please describe in detail how you tested your changes. -->
|
||||
<!--- Include details of your testing environment, and the tests you ran to -->
|
||||
<!--- see how your change affects other areas of the code, etc. -->
|
||||
|
||||
## Screenshots (if appropriate):
|
||||
|
||||
19
.github/workflows/conventional_commit_pr.yml
vendored
19
.github/workflows/conventional_commit_pr.yml
vendored
@@ -1,19 +0,0 @@
|
||||
name: Conventional commit PR
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
cog_check_job:
|
||||
runs-on: ubuntu-latest
|
||||
name: check conventional commit compliance
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
# pick the pr HEAD instead of the merge commit
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Conventional commit check
|
||||
uses: cocogitto/cocogitto-action@v3
|
||||
with:
|
||||
check-latest-tag-only: true
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,3 +11,6 @@ ngrok.log
|
||||
restart-dev.sh
|
||||
*.log
|
||||
data/
|
||||
www/REFACTOR.md
|
||||
www/reload-frontend
|
||||
server/test.sqlite
|
||||
|
||||
@@ -15,25 +15,16 @@ repos:
|
||||
hooks:
|
||||
- id: debug-statements
|
||||
- id: trailing-whitespace
|
||||
exclude: ^server/trials
|
||||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.1.1
|
||||
hooks:
|
||||
- id: black
|
||||
files: ^server/(reflector|tests)/
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
files: ^server/(gpu|evaluate|reflector)/
|
||||
args: [ "--profile", "black", "--filter-files" ]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.5
|
||||
rev: v0.8.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
files: ^server/(reflector|tests)/
|
||||
args:
|
||||
- --fix
|
||||
- --select
|
||||
- I,F401
|
||||
files: ^server/
|
||||
- id: ruff-format
|
||||
files: ^server/
|
||||
|
||||
87
CHANGELOG.md
87
CHANGELOG.md
@@ -1,5 +1,92 @@
|
||||
# Changelog
|
||||
|
||||
## [0.6.0](https://github.com/Monadical-SAS/reflector/compare/v0.5.0...v0.6.0) (2025-08-05)
|
||||
|
||||
|
||||
### ⚠ BREAKING CHANGES
|
||||
|
||||
* Configuration keys have changed. Update your .env file:
|
||||
- TRANSCRIPT_MODAL_API_KEY → TRANSCRIPT_API_KEY
|
||||
- LLM_MODAL_API_KEY → (removed, use TRANSCRIPT_API_KEY)
|
||||
- Add DIARIZATION_API_KEY and TRANSLATE_API_KEY if using those services
|
||||
|
||||
### Features
|
||||
|
||||
* implement service-specific Modal API keys with auto processor pattern ([#528](https://github.com/Monadical-SAS/reflector/issues/528)) ([650befb](https://github.com/Monadical-SAS/reflector/commit/650befb291c47a1f49e94a01ab37d8fdfcd2b65d))
|
||||
* use llamaindex everywhere ([#525](https://github.com/Monadical-SAS/reflector/issues/525)) ([3141d17](https://github.com/Monadical-SAS/reflector/commit/3141d172bc4d3b3d533370c8e6e351ea762169bf))
|
||||
|
||||
|
||||
### Miscellaneous Chores
|
||||
|
||||
* **main:** release 0.6.0 ([ecdbf00](https://github.com/Monadical-SAS/reflector/commit/ecdbf003ea2476c3e95fd231adaeb852f2943df0))
|
||||
|
||||
## [0.5.0](https://github.com/Monadical-SAS/reflector/compare/v0.4.0...v0.5.0) (2025-07-31)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* new summary using phi-4 and llama-index ([#519](https://github.com/Monadical-SAS/reflector/issues/519)) ([1bf9ce0](https://github.com/Monadical-SAS/reflector/commit/1bf9ce07c12f87f89e68a1dbb3b2c96c5ee62466))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* remove unused settings and utils files ([#522](https://github.com/Monadical-SAS/reflector/issues/522)) ([2af4790](https://github.com/Monadical-SAS/reflector/commit/2af4790e4be9e588f282fbc1bb171c88a03d6479))
|
||||
|
||||
## [0.4.0](https://github.com/Monadical-SAS/reflector/compare/v0.3.2...v0.4.0) (2025-07-25)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* Diarization cli ([#509](https://github.com/Monadical-SAS/reflector/issues/509)) ([ffc8003](https://github.com/Monadical-SAS/reflector/commit/ffc8003e6dad236930a27d0fe3e2f2adfb793890))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* remove faulty import Meeting ([#512](https://github.com/Monadical-SAS/reflector/issues/512)) ([0e68c79](https://github.com/Monadical-SAS/reflector/commit/0e68c798434e1b481f9482cc3a4702ea00365df4))
|
||||
* room concurrency (theoretically) ([#511](https://github.com/Monadical-SAS/reflector/issues/511)) ([7bb3676](https://github.com/Monadical-SAS/reflector/commit/7bb367653afeb2778cff697a0eb217abf0b81b84))
|
||||
|
||||
## [0.3.2](https://github.com/Monadical-SAS/reflector/compare/v0.3.1...v0.3.2) (2025-07-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* match font size for the filter sidebar ([#507](https://github.com/Monadical-SAS/reflector/issues/507)) ([4b8ba5d](https://github.com/Monadical-SAS/reflector/commit/4b8ba5db1733557e27b098ad3d1cdecadf97ae52))
|
||||
* whereby consent not displaying ([#505](https://github.com/Monadical-SAS/reflector/issues/505)) ([1120552](https://github.com/Monadical-SAS/reflector/commit/1120552c2c83d084d3a39272ad49b6aeda1af98f))
|
||||
|
||||
## [0.3.1](https://github.com/Monadical-SAS/reflector/compare/v0.3.0...v0.3.1) (2025-07-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* remove fief out of the source code ([#502](https://github.com/Monadical-SAS/reflector/issues/502)) ([890dd15](https://github.com/Monadical-SAS/reflector/commit/890dd15ba5a2be10dbb841e9aeb75d377885f4af))
|
||||
* remove primary color for room action menu ([#504](https://github.com/Monadical-SAS/reflector/issues/504)) ([2e33f89](https://github.com/Monadical-SAS/reflector/commit/2e33f89c0f9e5fbaafa80e8d2ae9788450ea2f31))
|
||||
|
||||
## [0.3.0](https://github.com/Monadical-SAS/reflector/compare/v0.2.1...v0.3.0) (2025-07-21)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* migrate from chakra 2 to chakra 3 ([#500](https://github.com/Monadical-SAS/reflector/issues/500)) ([a858464](https://github.com/Monadical-SAS/reflector/commit/a858464c7a80e5497acf801d933bf04092f8b526))
|
||||
|
||||
## [0.2.1](https://github.com/Monadical-SAS/reflector/compare/v0.2.0...v0.2.1) (2025-07-18)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* separate browsing page into different components, limit to 10 by default ([#498](https://github.com/Monadical-SAS/reflector/issues/498)) ([c752da6](https://github.com/Monadical-SAS/reflector/commit/c752da6b97c96318aff079a5b2a6eceadfbfcad1))
|
||||
|
||||
## [0.2.0](https://github.com/Monadical-SAS/reflector/compare/0.1.1...v0.2.0) (2025-07-17)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* improve transcript listing with room_id ([#496](https://github.com/Monadical-SAS/reflector/issues/496)) ([d2b5de5](https://github.com/Monadical-SAS/reflector/commit/d2b5de543fc0617fc220caa6a8a290e4040cb10b))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* don't attempt to load waveform/mp3 if audio was deleted ([#495](https://github.com/Monadical-SAS/reflector/issues/495)) ([f4578a7](https://github.com/Monadical-SAS/reflector/commit/f4578a743fd0f20312fbd242fa9cccdfaeb20a9e))
|
||||
|
||||
## [0.1.1](https://github.com/Monadical-SAS/reflector/compare/0.1.0...v0.1.1) (2025-07-17)
|
||||
|
||||
|
||||
|
||||
10
CLAUDE.md
10
CLAUDE.md
@@ -144,9 +144,11 @@ All endpoints prefixed `/v1/`:
|
||||
**Backend** (`server/.env`):
|
||||
- `DATABASE_URL` - Database connection string
|
||||
- `REDIS_URL` - Redis broker for Celery
|
||||
- `MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` - Modal.com GPU processing
|
||||
- `TRANSCRIPT_BACKEND=modal` + `TRANSCRIPT_MODAL_API_KEY` - Modal.com transcription
|
||||
- `DIARIZATION_BACKEND=modal` + `DIARIZATION_MODAL_API_KEY` - Modal.com diarization
|
||||
- `TRANSLATION_BACKEND=modal` + `TRANSLATION_MODAL_API_KEY` - Modal.com translation
|
||||
- `WHEREBY_API_KEY` - Video platform integration
|
||||
- `REFLECTOR_AUTH_BACKEND` - Authentication method (none, fief, jwt)
|
||||
- `REFLECTOR_AUTH_BACKEND` - Authentication method (none, jwt)
|
||||
|
||||
**Frontend** (`www/.env`):
|
||||
- `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration
|
||||
@@ -172,3 +174,7 @@ Modal.com integration for scalable ML processing:
|
||||
- **Audio Routing**: Use BlackHole (Mac) for merging multiple audio sources
|
||||
- **WebRTC**: Ensure proper CORS configuration for cross-origin streaming
|
||||
- **Database**: Run `uv run alembic upgrade head` after pulling schema changes
|
||||
|
||||
## Pipeline/worker related info
|
||||
|
||||
If you need to do any worker/pipeline related work, search for "Pipeline" classes and their "create" or "build" methods to find the main processor sequence. Look for task orchestration patterns (like "chord", "group", or "chain") to identify the post-processing flow with parallel execution chains. This will give you abstract vision on how processing pipeling is organized.
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
|
||||
[](https://github.com/monadical-sas/cubbi/actions/workflows/pytests.yml)
|
||||
[](https://opensource.org/licenses/AGPL-v3)
|
||||
[](https://github.com/monadical-sas/reflector/actions/workflows/pytests.yml)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
</div>
|
||||
|
||||
## Screenshots
|
||||
@@ -74,7 +74,7 @@ Note: We currently do not have instructions for Windows users.
|
||||
|
||||
### Frontend
|
||||
|
||||
Start with `cd backend`.
|
||||
Start with `cd www`.
|
||||
|
||||
**Installation**
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=***REMOVED***
|
||||
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=***REMOVED***
|
||||
|
||||
AUTH_BACKEND=fief
|
||||
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
|
||||
AUTH_FIEF_CLIENT_ID=***REMOVED***
|
||||
AUTH_FIEF_CLIENT_SECRET=<ask in zulip> <-----------------------------------------------------------------------------------------
|
||||
|
||||
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
||||
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||
|
||||
BASE_URL=https://xxxxx.ngrok.app
|
||||
DIARIZATION_ENABLED=false
|
||||
|
||||
SQS_POLLING_TIMEOUT_SECONDS=60
|
||||
1
server/.gitignore
vendored
1
server/.gitignore
vendored
@@ -180,3 +180,4 @@ reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
|
||||
@@ -20,3 +20,23 @@ Polls SQS every 60 seconds via /server/reflector/worker/process.py:24-62:
|
||||
# Every 60 seconds, check for new recordings
|
||||
sqs = boto3.client("sqs", ...)
|
||||
response = sqs.receive_message(QueueUrl=queue_url, ...)
|
||||
|
||||
# Requeue
|
||||
|
||||
```bash
|
||||
uv run /app/requeue_uploaded_file.py TRANSCRIPT_ID
|
||||
```
|
||||
|
||||
## Pipeline Management
|
||||
|
||||
### Continue stuck pipeline from final summaries (identify_participants) step:
|
||||
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeline_final_summaries; result = task_pipeline_final_summaries.delay(transcript_id='TRANSCRIPT_ID'); print(f'Task queued: {result.id}')"
|
||||
```
|
||||
|
||||
### Run full post-processing pipeline (continues to completion):
|
||||
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||
```
|
||||
|
||||
@@ -7,11 +7,9 @@
|
||||
## User authentication
|
||||
## =======================================================
|
||||
|
||||
## Using fief (fief.dev)
|
||||
AUTH_BACKEND=fief
|
||||
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
|
||||
AUTH_FIEF_CLIENT_ID=***REMOVED***
|
||||
AUTH_FIEF_CLIENT_SECRET=<ask in zulip>
|
||||
## Using jwt/authentik
|
||||
AUTH_BACKEND=jwt
|
||||
AUTH_JWT_AUDIENCE=
|
||||
|
||||
## =======================================================
|
||||
## Transcription backend
|
||||
@@ -22,24 +20,24 @@ AUTH_FIEF_CLIENT_SECRET=<ask in zulip>
|
||||
|
||||
## Using local whisper
|
||||
#TRANSCRIPT_BACKEND=whisper
|
||||
#WHISPER_MODEL_SIZE=tiny
|
||||
|
||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||
#TRANSCRIPT_BACKEND=modal
|
||||
#TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-web.modal.run
|
||||
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
|
||||
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=***REMOVED***
|
||||
TRANSCRIPT_MODAL_API_KEY=
|
||||
|
||||
## =======================================================
|
||||
## Transcription backend
|
||||
## Translation backend
|
||||
##
|
||||
## Only available in modal atm
|
||||
## =======================================================
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
#TRANSLATION_MODAL_API_KEY=xxxxx
|
||||
|
||||
## =======================================================
|
||||
## LLM backend
|
||||
@@ -49,28 +47,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
## llm backend implementation
|
||||
## =======================================================
|
||||
|
||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=***REMOVED***
|
||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
||||
|
||||
|
||||
## Using OpenAI
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_OPENAI_KEY=xxx
|
||||
#LLM_OPENAI_MODEL=gpt-3.5-turbo
|
||||
|
||||
## Using GPT4ALL
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_URL=http://localhost:4891/v1/completions
|
||||
#LLM_OPENAI_MODEL="GPT4All Falcon"
|
||||
|
||||
## Default LLM MODEL NAME
|
||||
#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
|
||||
|
||||
## Cache directory to store models
|
||||
CACHE_DIR=data
|
||||
## Context size for summary generation (tokens)
|
||||
# LLM_MODEL=microsoft/phi-4
|
||||
LLM_CONTEXT_WINDOW=16000
|
||||
LLM_URL=
|
||||
LLM_API_KEY=sk-
|
||||
|
||||
## =======================================================
|
||||
## Diarization
|
||||
@@ -79,7 +60,9 @@ CACHE_DIR=data
|
||||
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
|
||||
## =======================================================
|
||||
DIARIZATION_ENABLED=false
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||
#DIARIZATION_MODAL_API_KEY=xxxxx
|
||||
|
||||
|
||||
## =======================================================
|
||||
@@ -88,4 +71,3 @@ DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||
|
||||
## Sentry DSN configuration
|
||||
#SENTRY_DSN=
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_llm.py` - LLM API
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
@@ -23,16 +24,20 @@ $ modal deploy reflector_llm.py
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
```
|
||||
|
||||
Then in your reflector api configuration `.env`, you can set theses keys:
|
||||
Then in your reflector api configuration `.env`, you can set these keys:
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://xxxx--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
DIARIZATION_BACKEND=modal
|
||||
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
|
||||
TRANSLATION_BACKEND=modal
|
||||
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
|
||||
app = App(name="reflector-llm")
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
llm_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"transformers",
|
||||
"torch",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"jsonformer==0.12.0",
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
.run_function(migrate_cache_llm)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A100",
|
||||
timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=15,
|
||||
image=llm_image,
|
||||
)
|
||||
class LLM:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
model = model.cuda()
|
||||
|
||||
print("Warmup llm done")
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@exit()
|
||||
def exit():
|
||||
print("Exit llm")
|
||||
|
||||
@method()
|
||||
def generate(
|
||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
||||
) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
with self.lock:
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens,
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(
|
||||
output[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
response = response[len(prompt) :]
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60 * 10,
|
||||
timeout=60 * 5,
|
||||
allow_concurrent_inputs=45,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
app = FastAPI()
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(
|
||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
||||
|
||||
app = App(name="reflector-llm-zephyr")
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
llm_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"transformers==4.34.0",
|
||||
"torch",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"jsonformer==0.12.0",
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
.run_function(migrate_cache_llm)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=10,
|
||||
image=llm_image,
|
||||
)
|
||||
class LLM:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
||||
)
|
||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
model = model.cuda()
|
||||
|
||||
print("Warmup llm done")
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@exit()
|
||||
def exit():
|
||||
print("Exit llm")
|
||||
|
||||
@method()
|
||||
def generate(
|
||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
||||
) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
with self.lock:
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens,
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(
|
||||
output[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
response = response[len(prompt) :]
|
||||
response = {"long_summary": response}
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60 * 10,
|
||||
timeout=60 * 5,
|
||||
allow_concurrent_inputs=30,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
app = FastAPI()
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(
|
||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -1,171 +0,0 @@
|
||||
# # Run an OpenAI-Compatible vLLM Server
|
||||
|
||||
import modal
|
||||
|
||||
MODELS_DIR = "/llamas"
|
||||
MODEL_NAME = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
N_GPU = 1
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(
|
||||
MODEL_NAME,
|
||||
local_dir=f"{MODELS_DIR}/{MODEL_NAME}",
|
||||
ignore_patterns=[
|
||||
"*.pt",
|
||||
"*.bin",
|
||||
"*.pth",
|
||||
"original/*",
|
||||
], # Ensure safetensors
|
||||
)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
def move_cache():
|
||||
from transformers.utils import move_cache as transformers_move_cache
|
||||
|
||||
transformers_move_cache()
|
||||
|
||||
|
||||
vllm_image = (
|
||||
modal.Image.debian_slim(python_version="3.10")
|
||||
.pip_install("vllm==0.5.3post1")
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.pip_install(
|
||||
# "accelerate==0.34.2",
|
||||
"einops==0.8.0",
|
||||
"hf-transfer~=0.1",
|
||||
)
|
||||
.run_function(download_llm)
|
||||
.run_function(move_cache)
|
||||
.pip_install(
|
||||
"bitsandbytes>=0.42.9",
|
||||
)
|
||||
)
|
||||
|
||||
app = modal.App("reflector-vllm-hermes3")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=vllm_image,
|
||||
gpu=modal.gpu.A100(count=N_GPU, size="40GB"),
|
||||
timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=100,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def serve():
|
||||
import os
|
||||
|
||||
import fastapi
|
||||
import vllm.entrypoints.openai.api_server as api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
TOKEN = os.environ["REFLECTOR_GPU_APIKEY"]
|
||||
|
||||
# create a fastAPI app that uses vLLM's OpenAI-compatible router
|
||||
web_app = fastapi.FastAPI(
|
||||
title=f"OpenAI-compatible {MODEL_NAME} server",
|
||||
description="Run an OpenAI-compatible LLM server with vLLM on modal.com",
|
||||
version="0.0.1",
|
||||
docs_url="/docs",
|
||||
)
|
||||
|
||||
# security: CORS middleware for external requests
|
||||
http_bearer = fastapi.security.HTTPBearer(
|
||||
scheme_name="Bearer Token",
|
||||
description="See code for authentication details.",
|
||||
)
|
||||
web_app.add_middleware(
|
||||
fastapi.middleware.cors.CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# security: inject dependency on authed routes
|
||||
async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
|
||||
if api_key.credentials != TOKEN:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
return {"username": "authenticated_user"}
|
||||
|
||||
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
|
||||
|
||||
# wrap vllm's router in auth router
|
||||
router.include_router(api_server.router)
|
||||
# add authed vllm to our fastAPI app
|
||||
web_app.include_router(router)
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODELS_DIR + "/" + MODEL_NAME,
|
||||
tensor_parallel_size=N_GPU,
|
||||
gpu_memory_utilization=0.90,
|
||||
# max_model_len=8096,
|
||||
enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s)
|
||||
# --- 4 bits load
|
||||
# quantization="bitsandbytes",
|
||||
# load_format="bitsandbytes",
|
||||
)
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER
|
||||
)
|
||||
|
||||
model_config = get_model_config(engine)
|
||||
|
||||
request_logger = RequestLogger(max_log_len=2048)
|
||||
|
||||
api_server.openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config=model_config,
|
||||
served_model_names=[MODEL_NAME],
|
||||
chat_template=None,
|
||||
response_role="assistant",
|
||||
lora_modules=[],
|
||||
prompt_adapters=[],
|
||||
request_logger=request_logger,
|
||||
)
|
||||
api_server.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
model_config=model_config,
|
||||
served_model_names=[MODEL_NAME],
|
||||
lora_modules=[],
|
||||
prompt_adapters=[],
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
return web_app
|
||||
|
||||
|
||||
def get_model_config(engine):
|
||||
import asyncio
|
||||
|
||||
try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1
|
||||
event_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
event_loop = None
|
||||
|
||||
if event_loop is not None and event_loop.is_running():
|
||||
# If the current is instanced by Ray Serve,
|
||||
# there is already a running event loop
|
||||
model_config = event_loop.run_until_complete(engine.get_model_config())
|
||||
else:
|
||||
# When using single vLLM without engine_use_ray
|
||||
model_config = asyncio.run(engine.get_model_config())
|
||||
|
||||
return model_config
|
||||
@@ -1,16 +0,0 @@
|
||||
LOAD DATABASE
|
||||
FROM sqlite:///app/reflector.sqlite3
|
||||
INTO pgsql://reflector:reflector@postgres:5432/reflector
|
||||
WITH
|
||||
include drop,
|
||||
create tables,
|
||||
create indexes,
|
||||
reset sequences,
|
||||
preserve index names,
|
||||
prefetch rows = 10
|
||||
SET
|
||||
work_mem to '512MB',
|
||||
maintenance_work_mem to '1024MB'
|
||||
CAST
|
||||
column transcript.duration to float using (lambda (val) (when val (format nil "~f" val)))
|
||||
;
|
||||
@@ -1,9 +1,10 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from reflector.db import metadata
|
||||
from reflector.settings import settings
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -8,7 +8,6 @@ Create Date: 2024-09-24 16:12:56.944133
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -5,11 +5,11 @@ Revises: f819277e5169
|
||||
Create Date: 2023-11-07 11:12:21.614198
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0fea6d96b096"
|
||||
|
||||
@@ -5,26 +5,26 @@ Revises: 0fea6d96b096
|
||||
Create Date: 2023-11-30 15:56:03.341466
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '125031f7cb78'
|
||||
down_revision: Union[str, None] = '0fea6d96b096'
|
||||
revision: str = "125031f7cb78"
|
||||
down_revision: Union[str, None] = "0fea6d96b096"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
|
||||
op.add_column("transcript", sa.Column("participants", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'participants')
|
||||
op.drop_column("transcript", "participants")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: f819277e5169
|
||||
Create Date: 2025-06-17 14:00:03.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -19,16 +20,16 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'meeting_consent',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('meeting_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('consent_given', sa.Boolean(), nullable=False),
|
||||
sa.Column('consent_timestamp', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['meeting_id'], ['meeting.id']),
|
||||
"meeting_consent",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("meeting_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("consent_given", sa.Boolean(), nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["meeting_id"], ["meeting.id"]),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('meeting_consent')
|
||||
op.drop_table("meeting_consent")
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 20250617140003
|
||||
Create Date: 2025-06-18 14:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
@@ -5,36 +5,40 @@ Revises: ccd68dc784ff
|
||||
Create Date: 2025-07-15 16:53:40.397394
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2cf0b60a9d34'
|
||||
down_revision: Union[str, None] = 'ccd68dc784ff'
|
||||
revision: str = "2cf0b60a9d34"
|
||||
down_revision: Union[str, None] = "ccd68dc784ff"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('transcript', schema=None) as batch_op:
|
||||
batch_op.alter_column('duration',
|
||||
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"duration",
|
||||
existing_type=sa.INTEGER(),
|
||||
type_=sa.Float(),
|
||||
existing_nullable=True)
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('transcript', schema=None) as batch_op:
|
||||
batch_op.alter_column('duration',
|
||||
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"duration",
|
||||
existing_type=sa.Float(),
|
||||
type_=sa.INTEGER(),
|
||||
existing_nullable=True)
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,17 +5,17 @@ Revises: 9920ecfe2735
|
||||
Create Date: 2023-11-02 19:53:09.116240
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from alembic import op
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38a927dcb099'
|
||||
down_revision: Union[str, None] = '9920ecfe2735'
|
||||
revision: str = "38a927dcb099"
|
||||
down_revision: Union[str, None] = "9920ecfe2735"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ Revises: 38a927dcb099
|
||||
Create Date: 2023-11-10 18:12:17.886522
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from alembic import op
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "4814901632bc"
|
||||
@@ -24,9 +24,11 @@ def upgrade() -> None:
|
||||
# for all the transcripts, calculate the duration from the mp3
|
||||
# and update the duration column
|
||||
from pathlib import Path
|
||||
from reflector.settings import settings
|
||||
|
||||
import av
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
bind = op.get_bind()
|
||||
transcript = table(
|
||||
"transcript", column("id", sa.String), column("duration", sa.Float)
|
||||
|
||||
@@ -5,14 +5,11 @@ Revises:
|
||||
Create Date: 2023-08-29 10:54:45.142974
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '543ed284d69a'
|
||||
revision: str = "543ed284d69a"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2025-06-27 09:04:21.006823
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "62dea3db63a5"
|
||||
|
||||
@@ -5,26 +5,28 @@ Revises: 62dea3db63a5
|
||||
Create Date: 2024-09-06 14:02:06.649665
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '764ce6db4388'
|
||||
down_revision: Union[str, None] = '62dea3db63a5'
|
||||
revision: str = "764ce6db4388"
|
||||
down_revision: Union[str, None] = "62dea3db63a5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('zulip_message_id', sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"transcript", sa.Column("zulip_message_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'zulip_message_id')
|
||||
op.drop_column("transcript", "zulip_message_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,8 +9,6 @@ Create Date: 2025-07-15 19:30:19.876332
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "88d292678ba2"
|
||||
@@ -21,7 +19,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
import json
|
||||
import re
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
# Get database connection
|
||||
@@ -58,7 +56,9 @@ def upgrade() -> None:
|
||||
fixed_events = json.dumps(jevents)
|
||||
assert "NaN" not in fixed_events
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
print(f"Warning: Invalid JSON for transcript {transcript_id}, skipping: {e}")
|
||||
print(
|
||||
f"Warning: Invalid JSON for transcript {transcript_id}, skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Update the record with fixed JSON
|
||||
|
||||
@@ -5,13 +5,13 @@ Revises: 99365b0cd87b
|
||||
Create Date: 2023-11-02 18:55:17.019498
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from alembic import op
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9920ecfe2735"
|
||||
|
||||
@@ -8,8 +8,8 @@ Create Date: 2023-09-01 20:19:47.216334
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "99365b0cd87b"
|
||||
|
||||
@@ -9,8 +9,6 @@ Create Date: 2025-07-15 20:09:40.253018
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a9c9c229ee36"
|
||||
|
||||
@@ -5,30 +5,34 @@ Revises: 6ea59639f30e
|
||||
Create Date: 2025-01-28 10:06:50.446233
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b0e5f7876032'
|
||||
down_revision: Union[str, None] = '6ea59639f30e'
|
||||
revision: str = "b0e5f7876032"
|
||||
down_revision: Union[str, None] = "6ea59639f30e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('meeting', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_active', sa.Boolean(), server_default=sa.text('1'), nullable=False))
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), server_default=sa.text("1"), nullable=False
|
||||
)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('meeting', schema=None) as batch_op:
|
||||
batch_op.drop_column('is_active')
|
||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_active")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2025-06-27 08:57:16.306940
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b3df9681cae9"
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2024-10-11 13:45:28.914902
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b469348df210"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add_unique_constraint_one_active_meeting_per_room
|
||||
|
||||
Revision ID: b7df9609542c
|
||||
Revises: d7fbb74b673b
|
||||
Create Date: 2025-07-25 16:27:06.959868
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b7df9609542c"
|
||||
down_revision: Union[str, None] = "d7fbb74b673b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a partial unique index that ensures only one active meeting per room
|
||||
# This works for both PostgreSQL and SQLite
|
||||
op.create_index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"meeting",
|
||||
["room_id"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
sqlite_where=sa.text("is_active = 1"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
@@ -5,25 +5,31 @@ Revises: 125031f7cb78
|
||||
Create Date: 2023-12-13 15:37:51.303970
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b9348748bbbc'
|
||||
down_revision: Union[str, None] = '125031f7cb78'
|
||||
revision: str = "b9348748bbbc"
|
||||
down_revision: Union[str, None] = "125031f7cb78"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('reviewed', sa.Boolean(), server_default=sa.text('0'), nullable=False))
|
||||
op.add_column(
|
||||
"transcript",
|
||||
sa.Column(
|
||||
"reviewed", sa.Boolean(), server_default=sa.text("0"), nullable=False
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'reviewed')
|
||||
op.drop_column("transcript", "reviewed")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,8 +9,6 @@ Create Date: 2025-07-15 11:48:42.854741
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ccd68dc784ff"
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2025-06-27 09:27:25.302152
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d3ff3a39297f"
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Add room_id to transcript
|
||||
|
||||
Revision ID: d7fbb74b673b
|
||||
Revises: a9c9c229ee36
|
||||
Create Date: 2025-07-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d7fbb74b673b"
|
||||
down_revision: Union[str, None] = "a9c9c229ee36"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add room_id column to transcript table
|
||||
op.add_column("transcript", sa.Column("room_id", sa.String(), nullable=True))
|
||||
|
||||
# Add index for room_id for better query performance
|
||||
op.create_index("idx_transcript_room_id", "transcript", ["room_id"])
|
||||
|
||||
# Populate room_id for existing ROOM-type transcripts
|
||||
# This joins through recording -> meeting -> room to get the room_id
|
||||
op.execute("""
|
||||
UPDATE transcript AS t
|
||||
SET room_id = r.id
|
||||
FROM recording rec
|
||||
JOIN meeting m ON rec.meeting_id = m.id
|
||||
JOIN room r ON m.room_id = r.id
|
||||
WHERE t.recording_id = rec.id
|
||||
AND t.source_kind = 'room'
|
||||
AND t.room_id IS NULL
|
||||
""")
|
||||
|
||||
# Fix missing meeting_id for ROOM-type transcripts
|
||||
# The meeting_id field exists but was never populated
|
||||
op.execute("""
|
||||
UPDATE transcript AS t
|
||||
SET meeting_id = rec.meeting_id
|
||||
FROM recording rec
|
||||
WHERE t.recording_id = rec.id
|
||||
AND t.source_kind = 'room'
|
||||
AND t.meeting_id IS NULL
|
||||
AND rec.meeting_id IS NOT NULL
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index first
|
||||
op.drop_index("idx_transcript_room_id", "transcript")
|
||||
|
||||
# Drop the room_id column
|
||||
op.drop_column("transcript", "room_id")
|
||||
@@ -5,11 +5,11 @@ Revises: 4814901632bc
|
||||
Create Date: 2023-11-16 10:29:09.351664
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f819277e5169"
|
||||
|
||||
@@ -22,7 +22,6 @@ dependencies = [
|
||||
"fastapi-pagination>=0.12.6",
|
||||
"databases[aiosqlite, asyncpg]>=0.7.0",
|
||||
"sqlalchemy<1.5",
|
||||
"fief-client[fastapi]>=0.17.0",
|
||||
"alembic>=1.11.3",
|
||||
"nltk>=3.8.1",
|
||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||
@@ -39,6 +38,9 @@ dependencies = [
|
||||
"jsonschema>=4.23.0",
|
||||
"openai>=1.59.7",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"llama-index>=0.12.52",
|
||||
"llama-index-llms-openai-like>=0.4.0",
|
||||
"pytest-env>=1.1.5",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -82,6 +84,10 @@ packages = ["reflector"]
|
||||
[tool.coverage.run]
|
||||
source = ["reflector"]
|
||||
|
||||
[tool.pytest_env]
|
||||
ENVIRONMENT = "pytest"
|
||||
DATABASE_URL = "sqlite:///test.sqlite"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi_pagination import add_pagination
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.logger import logger
|
||||
from reflector.metrics import metrics_init
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from reflector.settings import settings
|
||||
from reflector.logger import logger
|
||||
import importlib
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
logger.info(f"User authentication using {settings.AUTH_BACKEND}")
|
||||
module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}"
|
||||
auth_module = importlib.import_module(module_name)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo
|
||||
from fief_client.integrations.fastapi import FiefAuth
|
||||
from reflector.settings import settings
|
||||
|
||||
fief = FiefAsync(
|
||||
settings.AUTH_FIEF_URL,
|
||||
settings.AUTH_FIEF_CLIENT_ID,
|
||||
settings.AUTH_FIEF_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
scheme = OAuth2AuthorizationCodeBearer(
|
||||
f"{settings.AUTH_FIEF_URL}/authorize",
|
||||
f"{settings.AUTH_FIEF_URL}/api/token",
|
||||
scopes={"openid": "openid", "offline_access": "offline_access"},
|
||||
auto_error=False,
|
||||
)
|
||||
|
||||
auth = FiefAuth(fief, scheme)
|
||||
|
||||
UserInfo = FiefUserInfo
|
||||
AccessTokenInfo = FiefAccessTokenInfo
|
||||
authenticated = auth.authenticated()
|
||||
current_user = auth.current_user()
|
||||
current_user_optional = auth.current_user(optional=True)
|
||||
@@ -4,6 +4,7 @@ from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
from typing import NoReturn
|
||||
|
||||
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.stream_client import StreamClient
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
async def main() -> NoReturn:
|
||||
@@ -51,7 +51,7 @@ async def main() -> NoReturn:
|
||||
|
||||
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info(f'{"Flushing metrics"}')
|
||||
logger.info(f"{'Flushing metrics'}")
|
||||
loop.stop()
|
||||
|
||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Literal
|
||||
import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db.rooms import Room
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from reflector.db import database
|
||||
from reflector.db.meetings import meetings
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
users_to_migrate = [
|
||||
["123@lifex.pink", "63b727f5-485d-449f-b528-563d779b11ef", None],
|
||||
["ana@monadical.com", "1bae2e4d-5c04-49c2-932f-a86266a6ca13", None],
|
||||
["cspencer@sprocket.org", "614ed0be-392e-488c-bd19-6a9730fd0e9e", None],
|
||||
["daniel.f.lopez.j@gmail.com", "ca9561bd-c989-4a1e-8877-7081cf62ae7f", None],
|
||||
["jenalee@monadical.com", "c7c1e79e-b068-4b28-a9f4-29d98b1697ed", None],
|
||||
["jennifer@rootandseed.com", "f5321727-7546-4b2b-b69d-095a931ef0c4", None],
|
||||
["jose@monadical.com", "221f079c-7ce0-4677-90b7-0359b6315e27", None],
|
||||
["labenclayton@gmail.com", "40078cd0-543c-40e4-9c2e-5ce57a686428", None],
|
||||
["mathieu@monadical.com", "c7a36151-851e-4afa-9fab-aaca834bfd30", None],
|
||||
["michal.flak.96@gmail.com", "3096eb5e-b590-41fc-a0d1-d152c1895402", None],
|
||||
["sara@monadical.com", "31ab0cfe-5d2c-4c7a-84de-a29494714c99", None],
|
||||
["sara@monadical.com", "b871e5f0-754e-447f-9c3d-19f629f0082b", None],
|
||||
["sebastian@monadical.com", "f024f9d0-15d0-480f-8529-43959fc8b639", None],
|
||||
["sergey@monadical.com", "5c4798eb-b9ab-4721-a540-bd96fc434156", None],
|
||||
["sergey@monadical.com", "9dd8a6b4-247e-48fe-b1fb-4c84dd3c01bc", None],
|
||||
["transient.tran@gmail.com", "617ba2d3-09b6-4b1f-a435-a7f41c3ce060", None],
|
||||
]
|
||||
|
||||
|
||||
async def migrate_user(email, user_id):
|
||||
# if the email match the email in the users_to_migrate list
|
||||
# reassign all transcripts/rooms/meetings to the new user_id
|
||||
|
||||
user_ids = [user[1] for user in users_to_migrate if user[0] == email]
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# do not migrate back
|
||||
if user_id in user_ids:
|
||||
return
|
||||
|
||||
for old_user_id in user_ids:
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.user_id == old_user_id)
|
||||
.values(user_id=user_id)
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
query = (
|
||||
rooms.update().where(rooms.c.user_id == old_user_id).values(user_id=user_id)
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
query = (
|
||||
meetings.update()
|
||||
.where(meetings.c.user_id == old_user_id)
|
||||
.values(user_id=user_id)
|
||||
)
|
||||
await database.execute(query)
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Literal
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.utils import generate_uuid4
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
rooms = sqlalchemy.Table(
|
||||
"room",
|
||||
|
||||
@@ -10,13 +10,14 @@ from typing import Any, Literal
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils import generate_uuid4
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
|
||||
class SourceKind(enum.StrEnum):
|
||||
@@ -74,10 +75,12 @@ transcripts = sqlalchemy.Table(
|
||||
# the main "audio deleted" is the presence of the audio itself / consents not-given
|
||||
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +170,7 @@ class Transcript(BaseModel):
|
||||
zulip_message_id: int | None = None
|
||||
source_kind: SourceKind
|
||||
audio_deleted: bool | None = None
|
||||
room_id: str | None = None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -331,17 +335,10 @@ class TranscriptController:
|
||||
- `room_id`: filter transcripts by room ID
|
||||
- `search_term`: filter transcripts by search term
|
||||
"""
|
||||
from reflector.db.meetings import meetings
|
||||
from reflector.db.recordings import recordings
|
||||
from reflector.db.rooms import rooms
|
||||
|
||||
query = (
|
||||
transcripts.select()
|
||||
.join(
|
||||
recordings, transcripts.c.recording_id == recordings.c.id, isouter=True
|
||||
)
|
||||
.join(meetings, recordings.c.meeting_id == meetings.c.id, isouter=True)
|
||||
.join(rooms, meetings.c.room_id == rooms.c.id, isouter=True)
|
||||
query = transcripts.select().join(
|
||||
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||
)
|
||||
|
||||
if user_id:
|
||||
@@ -355,7 +352,7 @@ class TranscriptController:
|
||||
query = query.where(transcripts.c.source_kind == source_kind)
|
||||
|
||||
if room_id:
|
||||
query = query.where(rooms.c.id == room_id)
|
||||
query = query.where(transcripts.c.room_id == room_id)
|
||||
|
||||
if search_term:
|
||||
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
|
||||
@@ -368,7 +365,6 @@ class TranscriptController:
|
||||
query = query.with_only_columns(
|
||||
transcript_columns
|
||||
+ [
|
||||
rooms.c.id.label("room_id"),
|
||||
rooms.c.name.label("room_name"),
|
||||
]
|
||||
)
|
||||
@@ -419,6 +415,22 @@ class TranscriptController:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
|
||||
"""
|
||||
Get transcripts by room_id (direct access without joins)
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.room_id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
if "order_by" in kwargs:
|
||||
order_by = kwargs["order_by"]
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await database.fetch_all(query)
|
||||
return [Transcript(**result) for result in results]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
transcript_id: str,
|
||||
@@ -469,6 +481,8 @@ class TranscriptController:
|
||||
user_id: str | None = None,
|
||||
recording_id: str | None = None,
|
||||
share_mode: str = "private",
|
||||
meeting_id: str | None = None,
|
||||
room_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Add a new transcript
|
||||
@@ -481,6 +495,8 @@ class TranscriptController:
|
||||
user_id=user_id,
|
||||
recording_id=recording_id,
|
||||
share_mode=share_mode,
|
||||
meeting_id=meeting_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
|
||||
83
server/reflector/llm.py
Normal file
83
server/reflector/llm.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.output_parsers import PydanticOutputParser
|
||||
from llama_index.core.program import LLMTextCompletionProgram
|
||||
from llama_index.core.response_synthesizers import TreeSummarize
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
||||
Based on the following analysis, provide the information in the requested JSON format:
|
||||
|
||||
Analysis:
|
||||
{analysis}
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
|
||||
self.settings_obj = settings
|
||||
self.model_name = settings.LLM_MODEL
|
||||
self.url = settings.LLM_URL
|
||||
self.api_key = settings.LLM_API_KEY
|
||||
self.context_window = settings.LLM_CONTEXT_WINDOW
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Configure llamaindex Settings
|
||||
self._configure_llamaindex()
|
||||
|
||||
def _configure_llamaindex(self):
|
||||
"""Configure llamaindex Settings with OpenAILike LLM"""
|
||||
Settings.llm = OpenAILike(
|
||||
model=self.model_name,
|
||||
api_base=self.url,
|
||||
api_key=self.api_key,
|
||||
context_window=self.context_window,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
async def get_response(
|
||||
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||
) -> str:
|
||||
"""Get a text response using TreeSummarize for non-function-calling models"""
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||
return str(response).strip()
|
||||
|
||||
async def get_structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
texts: list[str],
|
||||
output_cls: Type[T],
|
||||
tone_name: str | None = None,
|
||||
) -> T:
|
||||
"""Get structured output from LLM for non-function-calling models"""
|
||||
summarizer = TreeSummarize(verbose=True)
|
||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||
|
||||
output_parser = PydanticOutputParser(output_cls)
|
||||
|
||||
program = LLMTextCompletionProgram.from_defaults(
|
||||
output_parser=output_parser,
|
||||
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
format_instructions = output_parser.format(
|
||||
"Please structure the above information in the following JSON format:"
|
||||
)
|
||||
|
||||
output = await program.acall(
|
||||
analysis=str(response), format_instructions=format_instructions
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1,2 +0,0 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from .llm_params import LLMTaskParams # noqa: F401
|
||||
@@ -1,338 +0,0 @@
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from transformers import GenerationConfig
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_nltk_downloaded = False
|
||||
_registry = {}
|
||||
m_generate = Histogram(
|
||||
"llm_generate",
|
||||
"Time spent in LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_call = Counter(
|
||||
"llm_generate_call",
|
||||
"Number of calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_success = Counter(
|
||||
"llm_generate_success",
|
||||
"Number of successful calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_failure = Counter(
|
||||
"llm_generate_failure",
|
||||
"Number of failed calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
if not cls._nltk_downloaded:
|
||||
nltk.download("punkt_tab")
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger_eng")
|
||||
cls._nltk_downloaded = True
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
if name is None:
|
||||
name = settings.LLM_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
cls.ensure_nltk()
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get the currently set model name
|
||||
"""
|
||||
return self._get_model_name()
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
pass
|
||||
|
||||
def set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Update the model name with the provided model name
|
||||
"""
|
||||
return self._set_model_name(model_name)
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
"""
|
||||
Return the LLM Prompt template
|
||||
"""
|
||||
return """
|
||||
### Human:
|
||||
{instruct}
|
||||
|
||||
{text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
name = self.__class__.__name__
|
||||
self.m_generate = self.m_generate.labels(name)
|
||||
self.m_generate_call = self.m_generate_call.labels(name)
|
||||
self.m_generate_success = self.m_generate_success.labels(name)
|
||||
self.m_generate_failure = self.m_generate_failure.labels(name)
|
||||
self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer instance used by LLM
|
||||
"""
|
||||
return self._get_tokenizer()
|
||||
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
gen_schema: dict | None = None,
|
||||
gen_cfg: GenerationConfig | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
|
||||
if gen_cfg:
|
||||
gen_cfg = gen_cfg.to_dict()
|
||||
self.m_generate_call.inc()
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._generate)(
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
raise
|
||||
|
||||
logger.debug("LLM result [raw]", result=repr(result))
|
||||
if isinstance(result, str):
|
||||
result = self._parse_json(result)
|
||||
logger.debug("LLM result [parsed]", result=repr(result))
|
||||
|
||||
return result
|
||||
|
||||
async def completion(
|
||||
self, messages: list, logger: reflector_logger, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Use /v1/chat/completion Open-AI compatible endpoint from the URL
|
||||
It's up to the user to validate anything or transform the result
|
||||
"""
|
||||
logger.info("LLM completions", messages=messages)
|
||||
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._completion)(messages=messages, **kwargs)
|
||||
self.m_generate_success.inc()
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
raise
|
||||
|
||||
logger.debug("LLM completion result", result=repr(result))
|
||||
return result
|
||||
|
||||
def ensure_casing(self, title: str) -> str:
|
||||
"""
|
||||
LLM takes care of word casing, but in rare cases this
|
||||
can falter. This is a fallback to ensure the casing of
|
||||
topics is in a proper format.
|
||||
|
||||
We select nouns, verbs and adjectives and check if camel
|
||||
casing is present and fix it, if not. Will not perform
|
||||
any other changes.
|
||||
"""
|
||||
tokens = nltk.word_tokenize(title)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
camel_cased = []
|
||||
|
||||
whitelisted_pos_tags = [
|
||||
"NN",
|
||||
"NNS",
|
||||
"NNP",
|
||||
"NNPS", # Noun POS
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ", # Verb POS
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS", # Adjective POS
|
||||
]
|
||||
|
||||
# If at all there is an exception, do not block other reflector
|
||||
# processes. Return the LLM generated title, at the least.
|
||||
try:
|
||||
for word, pos in pos_tags:
|
||||
if pos in whitelisted_pos_tags and word[0].islower():
|
||||
camel_cased.append(word[0].upper() + word[1:])
|
||||
else:
|
||||
camel_cased.append(word)
|
||||
modified_title = self.detokenizer.detokenize(camel_cased)
|
||||
|
||||
# Irrespective of casing changes, the starting letter
|
||||
# of title is always upper-cased
|
||||
title = modified_title[0].upper() + modified_title[1:]
|
||||
except Exception as e:
|
||||
reflector_logger.info(
|
||||
f"Failed to ensure casing on {title=} with exception : {str(e)}"
|
||||
)
|
||||
|
||||
return title
|
||||
|
||||
def trim_title(self, title: str) -> str:
|
||||
"""
|
||||
List of manual trimming to the title.
|
||||
|
||||
Longer titles are prone to run into A prefix of phrases that don't
|
||||
really add any descriptive information and in some cases, this
|
||||
behaviour can be repeated for several consecutive topics. Trim the
|
||||
titles to maintain quality of titles.
|
||||
"""
|
||||
phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
|
||||
try:
|
||||
pattern = (
|
||||
r"\b(?:"
|
||||
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
|
||||
+ r")\b"
|
||||
)
|
||||
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
|
||||
except Exception as e:
|
||||
reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
|
||||
return title
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _completion(
|
||||
self, messages: list, logger: reflector_logger, **kwargs
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
result = result.strip()
|
||||
# try detecting code block if exist
|
||||
# starts with ```json\n, ends with ```
|
||||
# or starts with ```\n, ends with ```
|
||||
# or starts with \n```javascript\n, ends with ```
|
||||
|
||||
regex = r"```(json|javascript|)?(.*)```"
|
||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
||||
if matches:
|
||||
result = matches[0][1]
|
||||
|
||||
else:
|
||||
# maybe the prompt has been started with ```json
|
||||
# so if text ends with ```, just remove it and use it as json
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
|
||||
return json.loads(result.strip())
|
||||
|
||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
||||
"""
|
||||
Choose the token size to set as the threshold to pack the LLM calls
|
||||
"""
|
||||
buffer_token_size = 100
|
||||
default_output_tokens = 1000
|
||||
context_window = self.tokenizer.model_max_length
|
||||
tokens = self.tokenizer.tokenize(
|
||||
self.create_prompt(instruct=task_params.instruct, text="")
|
||||
)
|
||||
threshold = context_window - len(tokens) - buffer_token_size
|
||||
if task_params.gen_cfg:
|
||||
threshold -= task_params.gen_cfg.max_new_tokens
|
||||
else:
|
||||
threshold -= default_output_tokens
|
||||
return threshold
|
||||
|
||||
def split_corpus(
|
||||
self,
|
||||
corpus: str,
|
||||
task_params: TaskParams,
|
||||
token_threshold: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
||||
restrictions.
|
||||
|
||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
||||
"""
|
||||
if not token_threshold:
|
||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
||||
|
||||
accumulated_tokens = []
|
||||
accumulated_sentences = []
|
||||
accumulated_token_count = 0
|
||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
||||
|
||||
for sentence in corpus_sentences:
|
||||
tokens = self.tokenizer.tokenize(sentence)
|
||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
||||
accumulated_token_count += len(tokens)
|
||||
accumulated_tokens.extend(tokens)
|
||||
accumulated_sentences.append(sentence)
|
||||
else:
|
||||
yield "".join(accumulated_sentences)
|
||||
accumulated_token_count = len(tokens)
|
||||
accumulated_tokens = tokens
|
||||
accumulated_sentences = [sentence]
|
||||
|
||||
if accumulated_tokens:
|
||||
yield " ".join(accumulated_sentences)
|
||||
|
||||
def create_prompt(self, instruct: str, text: str) -> str:
|
||||
"""
|
||||
Create a consumable prompt based on the prompt template
|
||||
"""
|
||||
return self.template.format(instruct=instruct, text=text)
|
||||
@@ -1,151 +0,0 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.llm_url = settings.LLM_URL + "/llm"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
||||
|
||||
@property
|
||||
def supported_models(self):
|
||||
"""
|
||||
List of currently supported models on this GPU platform
|
||||
"""
|
||||
# TODO: Query the specific GPU platform
|
||||
# Replace this with a HTTP call
|
||||
return [
|
||||
"lmsys/vicuna-13b-v1.5",
|
||||
"HuggingFaceH4/zephyr-7b-alpha",
|
||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
]
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
|
||||
# Handing over generation of the final summary to Zephyr model
|
||||
# but replacing the Vicuna model will happen after more testing
|
||||
# TODO: Create a mapping of model names and cloud deployments
|
||||
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
|
||||
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
||||
kwargs.setdefault("temperature", 0.3)
|
||||
kwargs.setdefault("max_tokens", 2048)
|
||||
kwargs.setdefault("stream", False)
|
||||
kwargs.setdefault("repetition_penalty", 1)
|
||||
kwargs.setdefault("top_p", 1)
|
||||
kwargs.setdefault("top_k", -1)
|
||||
kwargs.setdefault("min_p", 0.05)
|
||||
data = {"messages": messages, "model": self.model_name, **kwargs}
|
||||
|
||||
if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
|
||||
self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Set the model name
|
||||
"""
|
||||
# Abort, if the model is not supported
|
||||
if model_name not in self.supported_models:
|
||||
reflector_logger.info(
|
||||
f"Attempted to change {model_name=}, but is not supported."
|
||||
f"Setting model and tokenizer failed !"
|
||||
)
|
||||
return False
|
||||
# Abort, if the model is already set
|
||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
||||
reflector_logger.info("No change in model. Setting model skipped.")
|
||||
return False
|
||||
# Update model name and tokenizer
|
||||
self.model_name = model_name
|
||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, cache_dir=settings.CACHE_DIR
|
||||
)
|
||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
||||
return True
|
||||
|
||||
def _get_tokenizer(self) -> AutoTokenizer:
|
||||
"""
|
||||
Return the currently used LLM tokenizer
|
||||
"""
|
||||
return self.llm_tokenizer
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""
|
||||
Return the current model name from the instance details
|
||||
"""
|
||||
return self.model_name
|
||||
|
||||
|
||||
LLM.register("modal", ModalLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = ModalLLM()
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_schema = {
|
||||
"type": "object",
|
||||
"properties": {"response": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
||||
result = await llm.generate(
|
||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
||||
)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,29 +0,0 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OobaboogaLLM(LLM):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload.update(gen_cfg)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=json_payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
LLM.register("oobabooga", OobaboogaLLM)
|
||||
@@ -1,48 +0,0 @@
|
||||
import httpx
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
def __init__(self, model_name: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = settings.LLM_OPENAI_KEY
|
||||
self.openai_url = settings.LLM_URL
|
||||
self.openai_model = settings.LLM_OPENAI_MODEL
|
||||
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
gen_schema: dict | None,
|
||||
gen_cfg: GenerationConfig | None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.openai_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"model": self.openai_model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.openai_temperature,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
|
||||
LLM.register("openai", OpenAILLM)
|
||||
@@ -1,219 +0,0 @@
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
||||
instruct: str
|
||||
gen_cfg: Optional[GenerationConfig] = None
|
||||
gen_schema: Optional[dict] = None
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LLMTaskParams")
|
||||
|
||||
|
||||
class LLMTaskParams:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task, klass) -> None:
|
||||
cls._registry[task] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, task: str) -> T:
|
||||
return cls._registry[task]()
|
||||
|
||||
@property
|
||||
def task_params(self) -> TaskParams | None:
|
||||
"""
|
||||
Fetch the task related parameters
|
||||
"""
|
||||
return self._get_task_params()
|
||||
|
||||
def _get_task_params(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FinalLongSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"long_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalShortSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"short_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalTitleParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
||||
)
|
||||
self._instruct = """
|
||||
Combine the following individual titles into one single short title that
|
||||
condenses the essence of all titles.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class TopicParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
|
||||
)
|
||||
self._instruct = """
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
For the title field, generate a very detailed and self-explanatory
|
||||
title for the given text. Let the title be as descriptive as possible.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
two sentences.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class BulletedSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
early_stopping=True,
|
||||
)
|
||||
self._instruct = """
|
||||
Given a meeting transcript, extract the key things discussed in the
|
||||
form of a list.
|
||||
|
||||
While generating the response, follow the constraints mentioned below.
|
||||
|
||||
Summary constraints:
|
||||
i) Do not add new content, except to fix spelling or punctuation.
|
||||
ii) Do not add any prefixes or numbering in the response.
|
||||
iii) The summarization should be as information dense as possible.
|
||||
iv) Do not add any additional sections like Note, Conclusion, etc. in
|
||||
the response.
|
||||
|
||||
Response format:
|
||||
i) The response should be in the form of a bulleted list.
|
||||
ii) Iteratively merge all the relevant paragraphs together to keep the
|
||||
number of paragraphs to a minimum.
|
||||
iii) Remove any unfinished sentences from the final response.
|
||||
iv) Do not include narrative or reporting clauses.
|
||||
v) Use "*" as the bullet icon.
|
||||
"""
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class MergedSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=600,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
early_stopping=True,
|
||||
)
|
||||
self._instruct = """
|
||||
Given the key points of a meeting, summarize the points to describe the
|
||||
meeting in the form of paragraphs.
|
||||
"""
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
LLMTaskParams.register("topic", TopicParams)
|
||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
||||
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
|
||||
LLMTaskParams.register("merged_summary", MergedSummaryParams)
|
||||
@@ -16,8 +16,10 @@ import functools
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import boto3
|
||||
from celery import chord, group, shared_task
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -45,7 +47,7 @@ from reflector.processors import (
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
@@ -61,7 +63,6 @@ from reflector.zulip import (
|
||||
send_message_to_zulip,
|
||||
update_zulip_message,
|
||||
)
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@@ -111,16 +112,29 @@ def get_transcript(func):
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
|
||||
# Enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
if current_task:
|
||||
tlogger = tlogger.bind(
|
||||
task_id=current_task.request.id,
|
||||
task_name=current_task.name,
|
||||
worker_hostname=current_task.request.hostname,
|
||||
task_retries=current_task.request.retries,
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
return await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", exc_info=exc)
|
||||
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
@@ -347,7 +361,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
|
||||
@@ -18,6 +18,7 @@ During its lifecycle, it will emit the following status:
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||
from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401
|
||||
from .transcript_translator_auto import TranscriptTranslatorAutoProcessor # noqa: F401
|
||||
from .types import ( # noqa: F401
|
||||
AudioFile,
|
||||
FinalLongSummary,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from reflector.processors.base import Processor
|
||||
import av
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioChunkerProcessor(Processor):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import httpx
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
@@ -9,12 +10,17 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.DIARIZATION_URL:
|
||||
raise Exception(
|
||||
"DIARIZATION_URL required to use AudioDiarizationModalProcessor"
|
||||
)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
self.modal_api_key = modal_api_key
|
||||
self.headers = {}
|
||||
if self.modal_api_key:
|
||||
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
# Gather diarization data
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
import io
|
||||
from time import monotonic_ns
|
||||
from uuid import uuid4
|
||||
import io
|
||||
|
||||
import av
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
|
||||
|
||||
class AudioMergeProcessor(Processor):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile, Transcript
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ API will be a POST request to TRANSCRIPT_URL:
|
||||
"""
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
@@ -20,16 +21,20 @@ from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__()
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
"TRANSCRIPT_URL required to use AudioTranscriptModalProcessor"
|
||||
)
|
||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.api_key = settings.TRANSCRIPT_MODAL_API_KEY
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with AsyncOpenAI(
|
||||
base_url=self.transcript_url,
|
||||
api_key=self.api_key,
|
||||
api_key=self.modal_api_key,
|
||||
timeout=self.timeout,
|
||||
) as client:
|
||||
self.logger.debug(f"Try to transcribe audio {data.name}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@ from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class TranscriptFinalSummaryProcessor(Processor):
|
||||
@@ -16,14 +17,14 @@ class TranscriptFinalSummaryProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = transcript
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance(model_name="NousResearch/Hermes-3-Llama-3.1-8B")
|
||||
self.llm = LLM(settings=settings)
|
||||
self.builder = None
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_summary_builder(self, text) -> SummaryBuilder:
|
||||
builder = SummaryBuilder(self.llm)
|
||||
builder = SummaryBuilder(self.llm, logger=self.logger)
|
||||
builder.set_transcript(text)
|
||||
await builder.identify_participants()
|
||||
await builder.generate_summary()
|
||||
|
||||
@@ -1,67 +1,72 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from textwrap import dedent
|
||||
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalTitle, TitleSummary
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
TITLE_PROMPT = dedent(
|
||||
"""
|
||||
Generate a concise title for this meeting based on the following topic titles.
|
||||
Ignore casual conversation, greetings, or administrative matters.
|
||||
|
||||
The title must:
|
||||
- Be maximum 10 words
|
||||
- Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
|
||||
- Avoid generic terms like "Team Meeting" or "Discussion"
|
||||
|
||||
If multiple unrelated topics were discussed, prioritize the most significant one.
|
||||
or create a compound title (e.g., "Product Launch and Budget Planning").
|
||||
|
||||
<topics_discussed>
|
||||
{titles}
|
||||
</topics_discussed>
|
||||
|
||||
Do not explain, just output the meeting title as a single line.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
|
||||
class TranscriptFinalTitleProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
Generate a final title from topic titles using LlamaIndex
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalTitle
|
||||
TASK = "final_title"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_title(self, text: str) -> dict:
|
||||
async def get_title(self, accumulated_titles: str) -> str:
|
||||
"""
|
||||
Generate a title for the whole recording
|
||||
Generate a title for the whole recording using LLM
|
||||
"""
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
prompt = TITLE_PROMPT.format(titles=accumulated_titles)
|
||||
response = await self.llm.get_response(
|
||||
prompt,
|
||||
[accumulated_titles],
|
||||
tone_name="Title generator",
|
||||
)
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return title_result
|
||||
else:
|
||||
accumulated_titles = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_titles += title_result["summary"]
|
||||
self.logger.info(f"Generated title response: {response}")
|
||||
|
||||
return await self.get_title(accumulated_titles)
|
||||
return response
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
||||
title_result = await self.get_title(accumulated_titles)
|
||||
final_title = self.llm.trim_title(title_result["title"])
|
||||
final_title = self.llm.ensure_casing(final_title)
|
||||
accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
|
||||
title = await self.get_title(accumulated_titles)
|
||||
title = clean_title(title)
|
||||
|
||||
final_title = FinalTitle(title=final_title)
|
||||
final_title = FinalTitle(title=title)
|
||||
await self.emit(final_title)
|
||||
|
||||
@@ -1,7 +1,41 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from textwrap import dedent
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
TOPIC_PROMPT = dedent(
|
||||
"""
|
||||
Analyze the following transcript segment and extract the main topic being discussed.
|
||||
Focus on the substantive content and ignore small talk or administrative chatter.
|
||||
|
||||
Create a title that:
|
||||
- Captures the specific subject matter being discussed
|
||||
- Is descriptive and self-explanatory
|
||||
- Uses professional language
|
||||
- Is specific rather than generic
|
||||
|
||||
For the summary:
|
||||
- Summarize the key points in maximum two sentences
|
||||
- Focus on what was discussed, decided, or accomplished
|
||||
- Be concise but informative
|
||||
|
||||
<transcript>
|
||||
{text}
|
||||
</transcript>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
|
||||
class TopicResponse(BaseModel):
|
||||
"""Structured response for topic detection"""
|
||||
|
||||
title: str = Field(description="A descriptive title for the topic being discussed")
|
||||
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
TASK = "topic"
|
||||
|
||||
def __init__(
|
||||
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
||||
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
if self.transcript is None:
|
||||
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
return
|
||||
await self.flush()
|
||||
|
||||
async def get_topic(self, text: str) -> dict:
|
||||
async def get_topic(self, text: str) -> TopicResponse:
|
||||
"""
|
||||
Generate a topic and description for a transcription excerpt
|
||||
Generate a topic and description for a transcription excerpt using LLM
|
||||
"""
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
||||
topic_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
prompt = TOPIC_PROMPT.format(text=text)
|
||||
response = await self.llm.get_structured_response(
|
||||
prompt, [text], TopicResponse, tone_name="Topic analyzer"
|
||||
)
|
||||
return topic_result
|
||||
return response
|
||||
|
||||
async def _flush(self):
|
||||
if not self.transcript:
|
||||
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
|
||||
topic_result = await self.get_topic(text=text)
|
||||
title = self.llm.trim_title(topic_result["title"])
|
||||
title = self.llm.ensure_casing(title)
|
||||
title = clean_title(topic_result.title)
|
||||
|
||||
summary = TitleSummary(
|
||||
title=title,
|
||||
summary=topic_result["summary"],
|
||||
summary=topic_result.summary,
|
||||
timestamp=self.transcript.timestamp,
|
||||
duration=self.transcript.duration,
|
||||
transcript=self.transcript,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import httpx
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript, TranslationLanguages
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from reflector.processors.types import Transcript
|
||||
|
||||
|
||||
class TranscriptTranslatorProcessor(Processor):
|
||||
@@ -12,60 +9,27 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = Transcript
|
||||
TASK = "translate"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
self.transcript = data
|
||||
await self.flush()
|
||||
|
||||
async def get_translation(self, text: str) -> str | None:
|
||||
# FIXME this should be a processor after, as each user may want
|
||||
# different languages
|
||||
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
if source_language == target_language:
|
||||
return
|
||||
|
||||
languages = TranslationLanguages()
|
||||
# Only way to set the target should be the UI element like dropdown.
|
||||
# Hence, this assert should never fail.
|
||||
assert languages.is_supported(target_language)
|
||||
self.logger.debug(f"Try to translate {text=}")
|
||||
json_payload = {
|
||||
"text": text,
|
||||
"source_language": source_language,
|
||||
"target_language": target_language,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.translate_url + "/translate",
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["text"]
|
||||
|
||||
# Sanity check for translation status in the result
|
||||
if target_language in result:
|
||||
translation = result[target_language]
|
||||
self.logger.debug(f"Translation response: {text=}, {translation=}")
|
||||
return translation
|
||||
async def _translate(self, text: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _flush(self):
|
||||
if not self.transcript:
|
||||
return
|
||||
self.transcript.translation = await self.get_translation(
|
||||
text=self.transcript.text
|
||||
)
|
||||
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
if source_language == target_language:
|
||||
self.transcript.translation = None
|
||||
else:
|
||||
self.transcript.translation = await self._translate(self.transcript.text)
|
||||
|
||||
await self.emit(self.transcript)
|
||||
|
||||
32
server/reflector/processors/transcript_translator_auto.py
Normal file
32
server/reflector/processors/transcript_translator_auto.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class TranscriptTranslatorAutoProcessor(TranscriptTranslatorProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSLATION_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.transcript_translator_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSLATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSLATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
66
server/reflector/processors/transcript_translator_modal.py
Normal file
66
server/reflector/processors/transcript_translator_modal.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import httpx
|
||||
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.types import TranslationLanguages
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class TranscriptTranslatorModalProcessor(TranscriptTranslatorProcessor):
|
||||
"""
|
||||
Translate the transcript into the target language using Modal.com
|
||||
"""
|
||||
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.TRANSLATE_URL:
|
||||
raise Exception(
|
||||
"TRANSLATE_URL is required for TranscriptTranslatorModalProcessor"
|
||||
)
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
self.headers = {}
|
||||
if self.modal_api_key:
|
||||
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async def _translate(self, text: str) -> str | None:
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
|
||||
languages = TranslationLanguages()
|
||||
# Only way to set the target should be the UI element like dropdown.
|
||||
# Hence, this assert should never fail.
|
||||
assert languages.is_supported(target_language)
|
||||
self.logger.debug(f"Try to translate {text=}")
|
||||
json_payload = {
|
||||
"text": text,
|
||||
"source_language": source_language,
|
||||
"target_language": target_language,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.translate_url + "/translate",
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
logger=self.logger,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["text"]
|
||||
|
||||
# Sanity check for translation status in the result
|
||||
if target_language in result:
|
||||
translation = result[target_language]
|
||||
else:
|
||||
translation = None
|
||||
self.logger.debug(f"Translation response: {text=}, {translation=}")
|
||||
return translation
|
||||
|
||||
|
||||
TranscriptTranslatorAutoProcessor.register("modal", TranscriptTranslatorModalProcessor)
|
||||
@@ -0,0 +1,14 @@
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
|
||||
|
||||
class TranscriptTranslatorPassthroughProcessor(TranscriptTranslatorProcessor):
|
||||
async def _translate(self, text: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
TranscriptTranslatorAutoProcessor.register(
|
||||
"passthrough", TranscriptTranslatorPassthroughProcessor
|
||||
)
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
@@ -2,6 +2,7 @@ import functools
|
||||
import json
|
||||
|
||||
import redis
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
@@ -8,49 +8,24 @@ class Settings(BaseSettings):
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
OPENMP_KMP_DUPLICATE_LIB_OK: bool = False
|
||||
|
||||
# CORS
|
||||
UI_BASE_URL: str = "http://localhost:3000"
|
||||
CORS_ORIGIN: str = "*"
|
||||
CORS_ALLOW_CREDENTIALS: bool = False
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||
|
||||
# local data directory (audio for no)
|
||||
# local data directory
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# Whisper
|
||||
WHISPER_MODEL_SIZE: str = "tiny"
|
||||
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"
|
||||
|
||||
# Summarizer
|
||||
SUMMARIZER_MODEL: str = "facebook/bart-large-cnn"
|
||||
SUMMARIZER_INPUT_ENCODING_MAX_LENGTH: int = 1024
|
||||
SUMMARIZER_MAX_LENGTH: int = 2048
|
||||
SUMMARIZER_BEAM_SIZE: int = 6
|
||||
SUMMARIZER_MAX_CHUNK_LENGTH: int = 1024
|
||||
SUMMARIZER_USING_CHUNKS: bool = True
|
||||
|
||||
# Audio
|
||||
AUDIO_BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME: str = "aggregator"
|
||||
AUDIO_AV_FOUNDATION_DEVICE_ID: int = 1
|
||||
AUDIO_CHANNELS: int = 2
|
||||
AUDIO_SAMPLING_RATE: int = 48000
|
||||
AUDIO_SAMPLING_WIDTH: int = 2
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, modal
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
|
||||
# Translate into the target language
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription modal.com configuration
|
||||
# Audio Transcription: modal backend
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Audio transcription storage
|
||||
@@ -62,42 +37,34 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Translate into the target language
|
||||
TRANSLATION_BACKEND: str = "passthrough"
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Translation: modal backend
|
||||
TRANSLATE_MODAL_API_KEY: str | None = None
|
||||
|
||||
# LLM
|
||||
# available backend: openai, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
LLM_MODEL: str = "microsoft/phi-4"
|
||||
LLM_URL: str | None = None
|
||||
LLM_HOST: str = "localhost"
|
||||
LLM_PORT: int = 7860
|
||||
LLM_OPENAI_KEY: str | None = None
|
||||
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
||||
LLM_OPENAI_TEMPERATURE: float = 0.7
|
||||
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
|
||||
LLM_MAX_TOKENS: int = 1024
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
ZEPHYR_LLM_URL: str | None = None
|
||||
HERMES_3_8B_LLM_URL: str | None = None
|
||||
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
LLM_API_KEY: str | None = None
|
||||
LLM_CONTEXT_WINDOW: int = 16000
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
# Diarization: modal backend
|
||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
# User authentication (none, fief)
|
||||
# User authentication (none, jwt)
|
||||
AUTH_BACKEND: str = "none"
|
||||
|
||||
# User authentication using fief
|
||||
AUTH_FIEF_URL: str | None = None
|
||||
AUTH_FIEF_CLIENT_ID: str | None = None
|
||||
AUTH_FIEF_CLIENT_SECRET: str | None = None
|
||||
|
||||
# User authentication using JWT
|
||||
AUTH_JWT_ALGORITHM: str = "RS256"
|
||||
AUTH_JWT_PUBLIC_KEY: str | None = "authentik.monadical.com_public.pem"
|
||||
@@ -107,12 +74,6 @@ class Settings(BaseSettings):
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
|
||||
# Default LLM model name
|
||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
||||
|
||||
# Cache directory for all model storage
|
||||
CACHE_DIR: str = "./data"
|
||||
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
|
||||
@@ -137,24 +98,20 @@ class Settings(BaseSettings):
|
||||
# Healthcheck
|
||||
HEALTHCHECK_URL: str | None = None
|
||||
|
||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||
|
||||
# Whereby integration
|
||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||
|
||||
WHEREBY_API_KEY: str | None = None
|
||||
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
AWS_WHEREBY_S3_BUCKET: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||
|
||||
# Zulip integration
|
||||
ZULIP_REALM: str | None = None
|
||||
ZULIP_API_KEY: str | None = None
|
||||
ZULIP_BOT_EMAIL: str | None = None
|
||||
|
||||
UI_BASE_URL: str = "http://localhost:3000"
|
||||
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import aioboto3
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.storage.base import FileResult, Storage
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from os import environ
|
||||
|
||||
import httpx
|
||||
import stamina
|
||||
@@ -8,7 +9,6 @@ from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaPlayer, MediaRelay
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class StreamClient:
|
||||
@@ -43,8 +43,9 @@ class StreamClient:
|
||||
else:
|
||||
if self.relay is None:
|
||||
self.relay = MediaRelay()
|
||||
audio_device_id = int(environ.get("AUDIO_AV_FOUNDATION_DEVICE_ID", 1))
|
||||
self.player = MediaPlayer(
|
||||
f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}",
|
||||
f":{audio_device_id}",
|
||||
format="avfoundation",
|
||||
options={"channels": "2"},
|
||||
)
|
||||
@@ -126,7 +127,7 @@ class StreamClient:
|
||||
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
self.reader = self.worker(f'{"worker"}', self.queue)
|
||||
self.reader = self.worker(f"{'worker'}", self.queue)
|
||||
|
||||
def get_reader(self):
|
||||
return self.reader
|
||||
|
||||
@@ -36,9 +36,13 @@ async def export_db(filename: str) -> None:
|
||||
if entry["event"] == "TRANSCRIPT":
|
||||
yield tid, "event_transcript", idx, "text", entry["data"]["text"]
|
||||
if entry["data"].get("translation") is not None:
|
||||
yield tid, "event_transcript", idx, "translation", entry[
|
||||
"data"
|
||||
].get("translation", None)
|
||||
yield (
|
||||
tid,
|
||||
"event_transcript",
|
||||
idx,
|
||||
"translation",
|
||||
entry["data"].get("translation", None),
|
||||
)
|
||||
|
||||
def export_transcripts(transcripts):
|
||||
for transcript in transcripts:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
@@ -12,7 +13,7 @@ from reflector.processors import (
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
|
||||
@@ -30,7 +31,7 @@ async def process_audio_file(
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
|
||||
316
server/reflector/tools/process_with_diarization.py
Normal file
316
server/reflector/tools/process_with_diarization.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
|
||||
|
||||
async def process_audio_file_with_diarization(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
processors += [
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
# Import diarization processor
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
# Create diarization processor
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
diarization_processor.on(event_callback)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="modal",
|
||||
choices=["modal"],
|
||||
help="Diarization backend to use (default: modal)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set REDIS_HOST to localhost if not provided
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
logger.info("REDIS_HOST not set, defaulting to localhost")
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioChunkerProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
asyncio.run(
|
||||
process_audio_file_with_diarization(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
logger.info(f"Output written to {args.output}")
|
||||
96
server/reflector/tools/test_diarization.py
Normal file
96
server/reflector/tools/test_diarization.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
@vibe-generated
|
||||
Test script for the diarization CLI tool
|
||||
=========================================
|
||||
|
||||
This script helps test the diarization functionality with sample audio files.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
async def test_diarization(audio_file: str):
|
||||
"""Test the diarization functionality"""
|
||||
|
||||
# Import the processing function
|
||||
from process_with_diarization import process_audio_file_with_diarization
|
||||
|
||||
# Collect events
|
||||
events = []
|
||||
|
||||
async def event_callback(event):
|
||||
events.append({"processor": event.processor, "data": event.data})
|
||||
logger.info(f"Event from {event.processor}")
|
||||
|
||||
# Process the audio file
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
|
||||
try:
|
||||
await process_audio_file_with_diarization(
|
||||
audio_file,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
)
|
||||
|
||||
# Analyze results
|
||||
logger.info(f"Processing complete. Received {len(events)} events")
|
||||
|
||||
# Look for diarization results
|
||||
diarized_topics = []
|
||||
for event in events:
|
||||
if "TitleSummary" in event["processor"]:
|
||||
# Check if words have speaker information
|
||||
if hasattr(event["data"], "transcript") and event["data"].transcript:
|
||||
words = event["data"].transcript.words
|
||||
if words and hasattr(words[0], "speaker"):
|
||||
speakers = set(
|
||||
w.speaker for w in words if hasattr(w, "speaker")
|
||||
)
|
||||
logger.info(
|
||||
f"Found {len(speakers)} speakers in topic: {event['data'].title}"
|
||||
)
|
||||
diarized_topics.append(event["data"])
|
||||
|
||||
if diarized_topics:
|
||||
logger.info(f"Successfully diarized {len(diarized_topics)} topics")
|
||||
|
||||
# Print sample output
|
||||
sample_topic = diarized_topics[0]
|
||||
logger.info("Sample diarized output:")
|
||||
for i, word in enumerate(sample_topic.transcript.words[:10]):
|
||||
logger.info(f" Word {i}: '{word.text}' - Speaker {word.speaker}")
|
||||
else:
|
||||
logger.warning("No diarization results found in output")
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during processing: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python test_diarization.py <audio_file>")
|
||||
sys.exit(1)
|
||||
|
||||
audio_file = sys.argv[1]
|
||||
if not Path(audio_file).exists():
|
||||
print(f"Error: Audio file '{audio_file}' not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_diarization(audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Utility file for file handling related functions, including file downloads and
|
||||
uploads to cloud storage
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, NoReturn
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
from .log_utils import LOGGER
|
||||
from .run_utils import SECRETS
|
||||
|
||||
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
|
||||
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
|
||||
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"],
|
||||
)
|
||||
|
||||
|
||||
def upload_files(files_to_upload: List[str]) -> NoReturn:
|
||||
"""
|
||||
Upload a list of files to the configured S3 bucket
|
||||
:param files_to_upload: List of files to upload
|
||||
:return: None
|
||||
"""
|
||||
for key in files_to_upload:
|
||||
LOGGER.info("Uploading file " + key)
|
||||
try:
|
||||
s3.upload_file(key, BUCKET_NAME, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
print(exception.response)
|
||||
|
||||
|
||||
def download_files(files_to_download: List[str]) -> NoReturn:
|
||||
"""
|
||||
Download a list of files from the configured S3 bucket
|
||||
:param files_to_download: List of files to download
|
||||
:return: None
|
||||
"""
|
||||
for key in files_to_download:
|
||||
LOGGER.info("Downloading file " + key)
|
||||
try:
|
||||
s3.download_file(BUCKET_NAME, key, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
if exception.response["Error"]["Code"] == "404":
|
||||
print("The object does not exist.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if sys.argv[1] == "download":
|
||||
download_files([sys.argv[2]])
|
||||
elif sys.argv[1] == "upload":
|
||||
upload_files([sys.argv[2]])
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Utility function to format the artefacts created during Reflector run
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f:
|
||||
outputs = f.read()
|
||||
|
||||
outputs = json.loads(outputs)
|
||||
|
||||
transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8")
|
||||
title_desc_file = open(
|
||||
"../artefacts/meeting_title_description.txt", "a", encoding="utf-8"
|
||||
)
|
||||
summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8")
|
||||
|
||||
for item in outputs["topics"]:
|
||||
transcript_file.write(item["transcript"])
|
||||
summary_file.write(item["description"])
|
||||
|
||||
title_desc_file.write("TITLE: \n")
|
||||
title_desc_file.write(item["title"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("DESCRIPTION: \n")
|
||||
title_desc_file.write(item["description"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("TRANSCRIPT: \n")
|
||||
title_desc_file.write(item["transcript"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("---------------------------------------- \n\n")
|
||||
|
||||
transcript_file.close()
|
||||
title_desc_file.close()
|
||||
summary_file.close()
|
||||
@@ -1,8 +1,10 @@
|
||||
from reflector.logger import logger
|
||||
from time import monotonic
|
||||
from httpx import HTTPStatusError, Response
|
||||
from random import random
|
||||
import asyncio
|
||||
from random import random
|
||||
from time import monotonic
|
||||
|
||||
from httpx import HTTPStatusError, Response
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
class RetryException(Exception):
|
||||
@@ -34,6 +36,7 @@ def retry(fn):
|
||||
),
|
||||
)
|
||||
retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", (Exception,))
|
||||
retry_logger = kwargs.pop("logger", logger)
|
||||
|
||||
result = None
|
||||
last_exception = None
|
||||
@@ -58,17 +61,33 @@ def retry(fn):
|
||||
if result:
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
logger.exception(e)
|
||||
retry_logger.exception(e)
|
||||
status_code = e.response.status_code
|
||||
logger.debug(f"HTTP status {status_code} - {e}")
|
||||
|
||||
# Log detailed error information including response body
|
||||
try:
|
||||
response_text = e.response.text
|
||||
response_headers = dict(e.response.headers)
|
||||
retry_logger.error(
|
||||
f"HTTP {status_code} error for {e.request.method} {e.request.url}\n"
|
||||
f"Response headers: {response_headers}\n"
|
||||
f"Response body: {response_text}"
|
||||
)
|
||||
|
||||
except Exception as log_error:
|
||||
retry_logger.warning(
|
||||
f"Failed to log detailed error info: {log_error}"
|
||||
)
|
||||
retry_logger.debug(f"HTTP status {status_code} - {e}")
|
||||
|
||||
if status_code in retry_httpx_status_stop:
|
||||
message = f"HTTP status {status_code} is in retry_httpx_status_stop"
|
||||
raise RetryHTTPException(message) from e
|
||||
except retry_ignore_exc_types as e:
|
||||
logger.exception(e)
|
||||
retry_logger.exception(e)
|
||||
last_exception = e
|
||||
|
||||
logger.debug(
|
||||
retry_logger.debug(
|
||||
f"Retrying {fn_name} - in {retry_backoff_interval:.1f}s "
|
||||
f"({monotonic() - start:.1f}s / {retry_timeout:.1f}s)"
|
||||
)
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
Utility file for server side asynchronous task running and config objects
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import ContextManager, Generic, TypeVar
|
||||
|
||||
|
||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||
"""
|
||||
Run the function in an executor, unblocking the main loop
|
||||
:param func: Function to be run in executor
|
||||
:param args: function parameters
|
||||
:param executor: executor instance [Thread | Process]
|
||||
:param kwargs: Additional parameters
|
||||
:return: Future of function result upon completion
|
||||
"""
|
||||
callback = partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_in_executor(executor, callback)
|
||||
|
||||
|
||||
# Genetic type template
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Mutex(Generic[T]):
|
||||
"""
|
||||
Mutex class to implement lock/release of a shared
|
||||
protected variable
|
||||
"""
|
||||
|
||||
def __init__(self, value: T):
|
||||
"""
|
||||
Create an instance of Mutex wrapper for the given resource
|
||||
:param value: Shared resources to be thread protected
|
||||
"""
|
||||
self.__value = value
|
||||
self.__lock = Lock()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock(self) -> ContextManager[T]:
|
||||
"""
|
||||
Lock the resource with a mutex to be used within a context block
|
||||
The lock is automatically released on context exit
|
||||
:return: Shared resource
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
try:
|
||||
yield self.__value
|
||||
finally:
|
||||
self.__lock.release()
|
||||
150
server/reflector/utils/s3_temp_file.py
Normal file
150
server/reflector/utils/s3_temp_file.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
@vibe-generated
|
||||
S3 Temporary File Context Manager
|
||||
|
||||
Provides automatic cleanup of S3 files with retry logic and proper error handling.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.storage.base import Storage
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class S3TemporaryFile:
|
||||
"""
|
||||
Async context manager for temporary S3 files with automatic cleanup.
|
||||
|
||||
Ensures that uploaded files are deleted even if exceptions occur during processing.
|
||||
Uses retry logic for all S3 operations to handle transient failures.
|
||||
|
||||
Example:
|
||||
async with S3TemporaryFile(storage, "temp/audio.wav") as s3_file:
|
||||
url = await s3_file.upload(audio_data)
|
||||
# Use url for processing
|
||||
# File is automatically cleaned up here
|
||||
"""
|
||||
|
||||
def __init__(self, storage: Storage, filepath: str):
|
||||
"""
|
||||
Initialize the temporary file context.
|
||||
|
||||
Args:
|
||||
storage: Storage instance for S3 operations
|
||||
filepath: S3 key/path for the temporary file
|
||||
"""
|
||||
self.storage = storage
|
||||
self.filepath = filepath
|
||||
self.uploaded = False
|
||||
self._url: Optional[str] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Exit the context manager and clean up the file.
|
||||
|
||||
Cleanup is attempted even if an exception occurred during processing.
|
||||
Cleanup failures are logged but don't raise exceptions.
|
||||
"""
|
||||
if self.uploaded:
|
||||
try:
|
||||
await self._delete_with_retry()
|
||||
logger.info(f"Successfully cleaned up S3 file: {self.filepath}")
|
||||
except Exception as e:
|
||||
# Log the error but don't raise - we don't want cleanup failures
|
||||
# to mask the original exception
|
||||
logger.warning(
|
||||
f"Failed to cleanup S3 file {self.filepath} after retries: {e}"
|
||||
)
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
async def upload(self, data: bytes) -> str:
|
||||
"""
|
||||
Upload data to S3 and return the public URL.
|
||||
|
||||
Args:
|
||||
data: File data to upload
|
||||
|
||||
Returns:
|
||||
Public URL for the uploaded file
|
||||
|
||||
Raises:
|
||||
Exception: If upload or URL generation fails after retries
|
||||
"""
|
||||
await self._upload_with_retry(data)
|
||||
self.uploaded = True
|
||||
self._url = await self._get_url_with_retry()
|
||||
return self._url
|
||||
|
||||
@property
|
||||
def url(self) -> Optional[str]:
|
||||
"""Get the URL of the uploaded file, if available."""
|
||||
return self._url
|
||||
|
||||
async def _upload_with_retry(self, data: bytes):
|
||||
"""Upload file to S3 with retry logic."""
|
||||
|
||||
async def upload():
|
||||
await self.storage.put_file(self.filepath, data)
|
||||
logger.debug(f"Successfully uploaded file to S3: {self.filepath}")
|
||||
return True # Return something to indicate success
|
||||
|
||||
await retry(upload)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
async def _get_url_with_retry(self) -> str:
|
||||
"""Get public URL for the file with retry logic."""
|
||||
|
||||
async def get_url():
|
||||
url = await self.storage.get_file_url(self.filepath)
|
||||
logger.debug(f"Generated public URL for S3 file: {self.filepath}")
|
||||
return url
|
||||
|
||||
return await retry(get_url)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
async def _delete_with_retry(self):
|
||||
"""Delete file from S3 with retry logic."""
|
||||
|
||||
async def delete():
|
||||
await self.storage.delete_file(self.filepath)
|
||||
logger.debug(f"Successfully deleted S3 file: {self.filepath}")
|
||||
return True # Return something to indicate success
|
||||
|
||||
await retry(delete)(
|
||||
retry_attempts=3,
|
||||
retry_timeout=30.0,
|
||||
retry_backoff_interval=0.5,
|
||||
retry_backoff_max=5.0,
|
||||
)
|
||||
|
||||
|
||||
# Convenience function for simpler usage
|
||||
async def temporary_s3_file(storage: Storage, filepath: str):
|
||||
"""
|
||||
Create a temporary S3 file context manager.
|
||||
|
||||
This is a convenience wrapper around S3TemporaryFile for simpler usage.
|
||||
|
||||
Args:
|
||||
storage: Storage instance for S3 operations
|
||||
filepath: S3 key/path for the temporary file
|
||||
|
||||
Example:
|
||||
async with temporary_s3_file(storage, "temp/audio.wav") as s3_file:
|
||||
url = await s3_file.upload(audio_data)
|
||||
# Use url for processing
|
||||
"""
|
||||
return S3TemporaryFile(storage, filepath)
|
||||
33
server/reflector/utils/text.py
Normal file
33
server/reflector/utils/text.py
Normal file
@@ -0,0 +1,33 @@
|
||||
def clean_title(title: str) -> str:
|
||||
"""
|
||||
Clean and format a title string for consistent capitalization.
|
||||
|
||||
Rules:
|
||||
- Strip surrounding quotes (single or double)
|
||||
- Capitalize the first word
|
||||
- Capitalize words longer than 3 characters
|
||||
- Keep words with 3 or fewer characters lowercase (except first word)
|
||||
|
||||
Args:
|
||||
title: The title string to clean
|
||||
|
||||
Returns:
|
||||
The cleaned title with consistent capitalization
|
||||
|
||||
Examples:
|
||||
>>> clean_title("hello world")
|
||||
"Hello World"
|
||||
>>> clean_title("meeting with the team")
|
||||
"Meeting With the Team"
|
||||
>>> clean_title("'Title with quotes'")
|
||||
"Title With Quotes"
|
||||
"""
|
||||
title = title.strip("\"'")
|
||||
words = title.split()
|
||||
if words:
|
||||
words = [
|
||||
word.capitalize() if i == 0 or len(word) > 3 else word.lower()
|
||||
for i, word in enumerate(words)
|
||||
]
|
||||
title = " ".join(words)
|
||||
return title
|
||||
@@ -1,264 +0,0 @@
|
||||
"""
|
||||
Utility file for all text processing related functionalities
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
from log_utils import LOGGER
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.tokenize import word_tokenize
|
||||
from run_utils import CONFIG
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
def preprocess_sentence(sentence: str) -> str:
|
||||
"""
|
||||
Filter out undesirable tokens from thr sentence
|
||||
:param sentence:
|
||||
:return:
|
||||
"""
|
||||
stop_words = set(stopwords.words("english"))
|
||||
tokens = word_tokenize(sentence.lower())
|
||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def compute_similarity(sent1: str, sent2: str) -> float:
|
||||
"""
|
||||
Compute the similarity
|
||||
"""
|
||||
tfidf_vectorizer = TfidfVectorizer()
|
||||
if sent1 is not None and sent2 is not None:
|
||||
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||
return 0.0
|
||||
|
||||
|
||||
def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[str]:
|
||||
"""
|
||||
Filter sentences that are similar beyond a set threshold
|
||||
:param sentences:
|
||||
:param threshold:
|
||||
:return:
|
||||
"""
|
||||
num_sentences = len(sentences)
|
||||
removed_indices = set()
|
||||
|
||||
for i in range(num_sentences):
|
||||
if i not in removed_indices:
|
||||
for j in range(i + 1, num_sentences):
|
||||
if j not in removed_indices:
|
||||
l_i = len(sentences[i])
|
||||
l_j = len(sentences[j])
|
||||
if l_i == 0 or l_j == 0:
|
||||
if l_i == 0:
|
||||
removed_indices.add(i)
|
||||
if l_j == 0:
|
||||
removed_indices.add(j)
|
||||
else:
|
||||
sentence1 = preprocess_sentence(sentences[i])
|
||||
sentence2 = preprocess_sentence(sentences[j])
|
||||
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||
similarity = compute_similarity(sentence1, sentence2)
|
||||
|
||||
if similarity >= threshold:
|
||||
removed_indices.add(max(i, j))
|
||||
|
||||
filtered_sentences = [
|
||||
sentences[i] for i in range(num_sentences) if i not in removed_indices
|
||||
]
|
||||
return filtered_sentences
|
||||
|
||||
|
||||
def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
|
||||
"""
|
||||
Remove repetitive sentences
|
||||
:param chunk:
|
||||
:return:
|
||||
"""
|
||||
chunk_text = chunk["text"]
|
||||
sentences = nltk.sent_tokenize(chunk_text)
|
||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||
return nonduplicate_sentences
|
||||
|
||||
|
||||
def remove_whisper_repetitive_hallucination(
|
||||
nonduplicate_sentences: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Remove sentences that are repeated as a result of Whisper
|
||||
hallucinations
|
||||
:param nonduplicate_sentences:
|
||||
:return:
|
||||
"""
|
||||
chunk_sentences = []
|
||||
|
||||
for sent in nonduplicate_sentences:
|
||||
temp_result = ""
|
||||
seen = {}
|
||||
words = nltk.word_tokenize(sent)
|
||||
n_gram_filter = 3
|
||||
for i in range(len(words)):
|
||||
if (
|
||||
str(words[i : i + n_gram_filter]) in seen
|
||||
and seen[str(words[i : i + n_gram_filter])]
|
||||
== words[i + 1 : i + n_gram_filter + 2]
|
||||
):
|
||||
pass
|
||||
else:
|
||||
seen[str(words[i : i + n_gram_filter])] = words[
|
||||
i + 1 : i + n_gram_filter + 2
|
||||
]
|
||||
temp_result += words[i]
|
||||
temp_result += " "
|
||||
chunk_sentences.append(temp_result)
|
||||
return chunk_sentences
|
||||
|
||||
|
||||
def post_process_transcription(whisper_result: dict) -> dict:
|
||||
"""
|
||||
Parent function to perform post-processing on the transcription result
|
||||
:param whisper_result:
|
||||
:return:
|
||||
"""
|
||||
transcript_text = ""
|
||||
for chunk in whisper_result["chunks"]:
|
||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||
chunk_sentences = remove_whisper_repetitive_hallucination(
|
||||
nonduplicate_sentences
|
||||
)
|
||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||
transcript_text += chunk["text"]
|
||||
whisper_result["text"] = transcript_text
|
||||
return whisper_result
|
||||
|
||||
|
||||
def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
|
||||
"""
|
||||
Summarize each chunk using a summarizer model
|
||||
:param chunks:
|
||||
:param tokenizer:
|
||||
:param model:
|
||||
:return:
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summaries = []
|
||||
for c in chunks:
|
||||
input_ids = tokenizer.encode(c, return_tensors="pt")
|
||||
input_ids = input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
summary_ids = model.generate(
|
||||
input_ids,
|
||||
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
||||
length_penalty=2.0,
|
||||
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
||||
early_stopping=True,
|
||||
)
|
||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
:param text: Text to be chunked
|
||||
:param max_chunk_length: length of chunk
|
||||
:return: chunked texts
|
||||
"""
|
||||
sentences = nltk.sent_tokenize(text)
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
for sentence in sentences:
|
||||
if len(current_chunk) + len(sentence) < max_chunk_length:
|
||||
current_chunk += f" {sentence.strip()}"
|
||||
else:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = f"{sentence.strip()}"
|
||||
chunks.append(current_chunk.strip())
|
||||
return chunks
|
||||
|
||||
|
||||
def summarize(
|
||||
transcript_text: str,
|
||||
timestamp: datetime.datetime.timestamp,
|
||||
real_time: bool = False,
|
||||
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
|
||||
):
|
||||
"""
|
||||
Summarize the given text either as a whole or as chunks as needed
|
||||
:param transcript_text:
|
||||
:param timestamp:
|
||||
:param real_time:
|
||||
:param chunk_summarize:
|
||||
:return:
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
summary_model = "facebook/bart-large-cnn"
|
||||
|
||||
# Summarize the generated transcript using the BART model
|
||||
LOGGER.info(f"Loading BART model: {summary_model}")
|
||||
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
||||
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||
model = model.to(device)
|
||||
|
||||
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
if real_time:
|
||||
output_file = "real_time_" + output_file
|
||||
|
||||
if chunk_summarize != "YES":
|
||||
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
inputs = tokenizer.batch_encode_plus(
|
||||
[transcript_text],
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
||||
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
||||
summaries = model.generate(
|
||||
inputs["input_ids"],
|
||||
num_beams=num_beans,
|
||||
length_penalty=2.0,
|
||||
max_length=max_length,
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
decoded_summaries = [
|
||||
tokenizer.decode(
|
||||
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
for summary in summaries
|
||||
]
|
||||
summary = " ".join(decoded_summaries)
|
||||
with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
|
||||
file.write(summary.strip() + "\n")
|
||||
else:
|
||||
LOGGER.info("Breaking transcript into smaller chunks")
|
||||
chunks = chunk_text(transcript_text)
|
||||
|
||||
LOGGER.info(
|
||||
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
|
||||
)
|
||||
|
||||
LOGGER.info(f"Writing summary text to: {output_file}")
|
||||
with open(output_file, "w") as f:
|
||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||
for summary in summaries:
|
||||
f.write(summary.strip() + " ")
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
Utility file for all visualization related functions
|
||||
"""
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
from typing import NoReturn
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import scattertext as st
|
||||
import spacy
|
||||
from nltk.corpus import stopwords
|
||||
from wordcloud import STOPWORDS, WordCloud
|
||||
|
||||
en = spacy.load("en_core_web_md")
|
||||
spacy_stopwords = en.Defaults.stop_words
|
||||
|
||||
STOPWORDS = (
|
||||
set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
||||
)
|
||||
|
||||
|
||||
def create_wordcloud(
|
||||
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Create a basic word cloud visualization of transcribed text
|
||||
:return: None. The wordcloud image is saved locally
|
||||
"""
|
||||
filename = "transcript"
|
||||
if real_time:
|
||||
filename = (
|
||||
"real_time_"
|
||||
+ filename
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
else:
|
||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
|
||||
with open("./artefacts/" + filename, "r") as f:
|
||||
transcription_text = f.read()
|
||||
|
||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||
|
||||
wordcloud = WordCloud(
|
||||
height=800,
|
||||
width=800,
|
||||
background_color="white",
|
||||
stopwords=STOPWORDS,
|
||||
min_font_size=8,
|
||||
).generate(transcription_text)
|
||||
|
||||
# Plot wordcloud and save image
|
||||
plt.figure(facecolor=None)
|
||||
plt.imshow(wordcloud, interpolation="bilinear")
|
||||
plt.axis("off")
|
||||
plt.tight_layout(pad=0)
|
||||
|
||||
wordcloud = "wordcloud"
|
||||
if real_time:
|
||||
wordcloud = (
|
||||
"real_time_"
|
||||
+ wordcloud
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".png"
|
||||
)
|
||||
else:
|
||||
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||
|
||||
plt.savefig("./artefacts/" + wordcloud)
|
||||
|
||||
|
||||
def create_talk_diff_scatter_viz(
|
||||
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Perform agenda vs transcription diff to see covered topics.
|
||||
Create a scatter plot of words in topics.
|
||||
:return: None. Saved locally.
|
||||
"""
|
||||
spacy_model = "en_core_web_md"
|
||||
nlp = spacy.load(spacy_model)
|
||||
nlp.add_pipe("sentencizer")
|
||||
|
||||
agenda_topics = []
|
||||
agenda = []
|
||||
# Load the agenda
|
||||
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
|
||||
for line in f.readlines():
|
||||
if line.strip():
|
||||
agenda.append(line.strip())
|
||||
agenda_topics.append(line.split(":")[0])
|
||||
|
||||
# Load the transcription with timestamp
|
||||
if real_time:
|
||||
filename = (
|
||||
"./artefacts/real_time_transcript_with_timestamp_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
else:
|
||||
filename = (
|
||||
"./artefacts/transcript_with_timestamp_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
with open(filename) as file:
|
||||
transcription_timestamp_text = file.read()
|
||||
|
||||
res = ast.literal_eval(transcription_timestamp_text)
|
||||
chunks = res["chunks"]
|
||||
|
||||
# create df for processing
|
||||
df = pd.DataFrame.from_dict(res["chunks"])
|
||||
|
||||
covered_items = {}
|
||||
# ts: timestamp
|
||||
# Map each timestamped chunk with top1 and top2 matched agenda
|
||||
ts_to_topic_mapping_top_1 = {}
|
||||
ts_to_topic_mapping_top_2 = {}
|
||||
|
||||
# Also create a mapping of the different timestamps
|
||||
# in which each topic was covered
|
||||
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
||||
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
||||
|
||||
similarity_threshold = 0.7
|
||||
|
||||
for c in chunks:
|
||||
doc_transcription = nlp(c["text"])
|
||||
topic_similarities = []
|
||||
for item in range(len(agenda)):
|
||||
item_doc = nlp(agenda[item])
|
||||
# if not doc_transcription or not all
|
||||
# (token.has_vector for token in doc_transcription):
|
||||
if not doc_transcription:
|
||||
continue
|
||||
similarity = doc_transcription.similarity(item_doc)
|
||||
topic_similarities.append((item, similarity))
|
||||
topic_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
for i in range(2):
|
||||
if topic_similarities[i][1] >= similarity_threshold:
|
||||
covered_items[agenda[topic_similarities[i][0]]] = True
|
||||
# top1 match
|
||||
if i == 0:
|
||||
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[
|
||||
topic_similarities[i][0]
|
||||
]
|
||||
topic_to_ts_mapping_top_1[
|
||||
agenda_topics[topic_similarities[i][0]]
|
||||
].append(c["timestamp"])
|
||||
# top2 match
|
||||
else:
|
||||
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[
|
||||
topic_similarities[i][0]
|
||||
]
|
||||
topic_to_ts_mapping_top_2[
|
||||
agenda_topics[topic_similarities[i][0]]
|
||||
].append(c["timestamp"])
|
||||
|
||||
def create_new_columns(record: dict) -> dict:
|
||||
"""
|
||||
Accumulate the mapping information into the df
|
||||
:param record:
|
||||
:return:
|
||||
"""
|
||||
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[
|
||||
record["timestamp"]
|
||||
]
|
||||
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[
|
||||
record["timestamp"]
|
||||
]
|
||||
return record
|
||||
|
||||
df = df.apply(create_new_columns, axis=1)
|
||||
|
||||
# Count the number of items covered and calculate the percentage
|
||||
num_covered_items = sum(covered_items.values())
|
||||
percentage_covered = num_covered_items / len(agenda) * 100
|
||||
|
||||
# Print the results
|
||||
print("💬 Agenda items covered in the transcription:")
|
||||
for item in agenda:
|
||||
if item in covered_items and covered_items[item]:
|
||||
print("✅ ", item)
|
||||
else:
|
||||
print("❌ ", item)
|
||||
print("📊 Coverage: {:.2f}%".format(percentage_covered))
|
||||
|
||||
# Save df, mappings for further experimentation
|
||||
df_name = "df"
|
||||
if real_time:
|
||||
df_name = (
|
||||
"real_time_"
|
||||
+ df_name
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".pkl"
|
||||
)
|
||||
else:
|
||||
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||
df.to_pickle("./artefacts/" + df_name)
|
||||
|
||||
my_mappings = [
|
||||
ts_to_topic_mapping_top_1,
|
||||
ts_to_topic_mapping_top_2,
|
||||
topic_to_ts_mapping_top_1,
|
||||
topic_to_ts_mapping_top_2,
|
||||
]
|
||||
|
||||
mappings_name = "mappings"
|
||||
if real_time:
|
||||
mappings_name = (
|
||||
"real_time_"
|
||||
+ mappings_name
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".pkl"
|
||||
)
|
||||
else:
|
||||
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
|
||||
|
||||
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
||||
|
||||
# pick the 2 most matched topic to be used for plotting
|
||||
topic_times = collections.defaultdict(int)
|
||||
for key in ts_to_topic_mapping_top_1.keys():
|
||||
if key[0] is None or key[1] is None:
|
||||
continue
|
||||
duration = key[1] - key[0]
|
||||
topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
||||
|
||||
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
if len(topic_times) > 1:
|
||||
cat_1 = topic_times[0][0]
|
||||
cat_1_name = topic_times[0][0]
|
||||
cat_2_name = topic_times[1][0]
|
||||
|
||||
# Scatter plot of topics
|
||||
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||
corpus = (
|
||||
st.CorpusFromParsedDocuments(
|
||||
df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse"
|
||||
)
|
||||
.build()
|
||||
.get_unigram_corpus()
|
||||
.compact(st.AssociationCompactor(2000))
|
||||
)
|
||||
html = st.produce_scattertext_explorer(
|
||||
corpus,
|
||||
category=cat_1,
|
||||
category_name=cat_1_name,
|
||||
not_category_name=cat_2_name,
|
||||
minimum_term_frequency=0,
|
||||
pmi_threshold_coefficient=0,
|
||||
width_in_pixels=1000,
|
||||
transform=st.Scalers.dense_rank,
|
||||
)
|
||||
if real_time:
|
||||
with open(
|
||||
"./artefacts/real_time_scatter_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".html",
|
||||
"w",
|
||||
) as file:
|
||||
file.write(html)
|
||||
else:
|
||||
with open(
|
||||
"./artefacts/scatter_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".html",
|
||||
"w",
|
||||
) as file:
|
||||
file.write(html)
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db.meetings import (
|
||||
MeetingConsent,
|
||||
meeting_consent_controller,
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Optional, Literal
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
import asyncpg.exceptions
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from pydantic import BaseModel
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -149,19 +155,47 @@ async def rooms_create_meeting(
|
||||
|
||||
if meeting is None:
|
||||
end_date = current_time + timedelta(hours=8)
|
||||
meeting = await create_meeting("", end_date=end_date, room=room)
|
||||
await upload_logo(meeting["roomName"], "./images/logo.png")
|
||||
|
||||
whereby_meeting = await create_meeting("", end_date=end_date, room=room)
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
# Now try to save to database
|
||||
try:
|
||||
meeting = await meetings_controller.create(
|
||||
id=meeting["meetingId"],
|
||||
room_name=meeting["roomName"],
|
||||
room_url=meeting["roomUrl"],
|
||||
host_room_url=meeting["hostRoomUrl"],
|
||||
start_date=datetime.fromisoformat(meeting["startDate"]),
|
||||
end_date=datetime.fromisoformat(meeting["endDate"]),
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
|
||||
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
|
||||
user_id=user_id,
|
||||
room=room,
|
||||
)
|
||||
except (asyncpg.exceptions.UniqueViolationError, sqlite3.IntegrityError):
|
||||
# Another request already created a meeting for this room
|
||||
# Log this race condition occurrence
|
||||
logger.info(
|
||||
"Race condition detected for room %s - fetching existing meeting",
|
||||
room.name,
|
||||
)
|
||||
logger.warning(
|
||||
"Whereby meeting %s was created but not used (resource leak) for room %s",
|
||||
whereby_meeting["meetingId"],
|
||||
room.name,
|
||||
)
|
||||
|
||||
# Fetch the meeting that was created by the other request
|
||||
meeting = await meetings_controller.get_active(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
if meeting is None:
|
||||
# Edge case: meeting was created but expired/deleted between checks
|
||||
logger.error(
|
||||
"Meeting disappeared after race condition for room %s", room.name
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Unable to join meeting - please try again"
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user