mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge branch 'main' into reenable-non-latin-languages
This commit is contained in:
2
.github/workflows/deploy.yml
vendored
2
.github/workflows/deploy.yml
vendored
@@ -3,7 +3,7 @@ name: Deploy to Amazon ECS
|
||||
on: [workflow_dispatch]
|
||||
|
||||
env:
|
||||
# 384658522150.dkr.ecr.us-east-1.amazonaws.com/reflector
|
||||
# 950402358378.dkr.ecr.us-east-1.amazonaws.com/reflector
|
||||
AWS_REGION: us-east-1
|
||||
ECR_REPOSITORY: reflector
|
||||
|
||||
|
||||
13
.github/workflows/test_server.yml
vendored
13
.github/workflows/test_server.yml
vendored
@@ -2,15 +2,20 @@ name: Unittests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'www/**'
|
||||
paths:
|
||||
- 'server/**'
|
||||
push:
|
||||
paths-ignore:
|
||||
- 'www/**'
|
||||
paths:
|
||||
- 'server/**'
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:6
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Install poetry
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,3 +2,6 @@
|
||||
server/.env
|
||||
.env
|
||||
server/exportdanswer
|
||||
.vercel
|
||||
.env*.local
|
||||
dump.rdb
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11.6
|
||||
39
.vscode/launch.json
vendored
Normal file
39
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"type": "aws-sam",
|
||||
"request": "direct-invoke",
|
||||
"name": "lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)",
|
||||
"invokeTarget": {
|
||||
"target": "template",
|
||||
"templatePath": "${workspaceFolder}/aws/lambda-nodejs18.x/template.yaml",
|
||||
"logicalId": "HelloWorldFunction"
|
||||
},
|
||||
"lambda": {
|
||||
"payload": {},
|
||||
"environmentVariables": {},
|
||||
"runtime": "nodejs18.x"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "aws-sam",
|
||||
"request": "direct-invoke",
|
||||
"name": "API lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)",
|
||||
"invokeTarget": {
|
||||
"target": "api",
|
||||
"templatePath": "${workspaceFolder}/aws/lambda-nodejs18.x/template.yaml",
|
||||
"logicalId": "HelloWorldFunction"
|
||||
},
|
||||
"api": {
|
||||
"path": "/hello",
|
||||
"httpMethod": "get",
|
||||
"payload": {
|
||||
"json": {}
|
||||
}
|
||||
},
|
||||
"lambda": {
|
||||
"runtime": "nodejs18.x"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
113
README.md
113
README.md
@@ -1,12 +1,14 @@
|
||||
# Reflector
|
||||
|
||||
Reflector is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
|
||||
|
||||
The project architecture consists of three primary components:
|
||||
|
||||
* **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
* **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
|
||||
It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
@@ -21,7 +23,12 @@ The project architecture consists of three primary components:
|
||||
- [OpenAPI Code Generation](#openapi-code-generation)
|
||||
- [Back-End](#back-end)
|
||||
- [Installation](#installation-1)
|
||||
- [Start the project](#start-the-project)
|
||||
- [Start the API/Backend](#start-the-apibackend)
|
||||
- [Redis (Mac)](#redis-mac)
|
||||
- [Redis (Windows)](#redis-windows)
|
||||
- [Update the database schema (run on first install, and after each pull containing a migration)](#update-the-database-schema-run-on-first-install-and-after-each-pull-containing-a-migration)
|
||||
- [Main Server](#main-server)
|
||||
- [Crontab (optional)](#crontab-optional)
|
||||
- [Using docker](#using-docker)
|
||||
- [Using local GPT4All](#using-local-gpt4all)
|
||||
- [Using local files](#using-local-files)
|
||||
@@ -31,12 +38,18 @@ The project architecture consists of three primary components:
|
||||
|
||||
### Contribution Guidelines
|
||||
|
||||
All new contributions should be made in a separate branch. Before any code is merged into `master`, it requires a code review.
|
||||
All new contributions should be made in a separate branch. Before any code is merged into `main`, it requires a code review.
|
||||
|
||||
### How to Install Blackhole (Mac Only)
|
||||
To record both your voice and the meeting you're taking part in, you need :
|
||||
- For an in-person meeting, make sure your microphone is in range of all participants.
|
||||
- If using several miscrophones, make sure to merge the audio feeds into one with an external tool.
|
||||
- For an online meeting, if you do not use headphones, your microphone should be able to pick up both your voice and the audio feed of the meeting.
|
||||
- If you want to use headphones, you need to merge the audio feeds with an external tool.
|
||||
|
||||
|
||||
This is an external tool for merging the audio feeds as explained in the previous section of this document.
|
||||
Note: We currently do not have instructions for Windows users.
|
||||
|
||||
* Install [Blackhole](https://github.com/ExistentialAudio/BlackHole)-2ch (2 ch is enough) by 1 of 2 options listed.
|
||||
* Setup ["Aggregate device"](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and local microphone input.
|
||||
* Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
|
||||
@@ -59,8 +72,12 @@ To install the application, run:
|
||||
|
||||
```bash
|
||||
yarn install
|
||||
cp .env_template .env
|
||||
cp config-template.ts config.ts
|
||||
```
|
||||
|
||||
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
|
||||
|
||||
### Run the Application
|
||||
|
||||
To run the application in development mode, run:
|
||||
@@ -69,7 +86,7 @@ To run the application in development mode, run:
|
||||
yarn dev
|
||||
```
|
||||
|
||||
Then open [http://localhost:3000](http://localhost:3000) to view it in the browser.
|
||||
Then (after completing server setup and starting it) open [http://localhost:3000](http://localhost:3000) to view it in the browser.
|
||||
|
||||
### OpenAPI Code Generation
|
||||
|
||||
@@ -87,35 +104,76 @@ Start with `cd server`.
|
||||
|
||||
### Installation
|
||||
|
||||
Download [Python 3.11 from the official website](https://www.python.org/downloads/) and ensure you have version 3.11 by running `python --version`.
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
python --version # It should say 3.11
|
||||
pip install poetry
|
||||
poetry install --no-root
|
||||
cp .env_template .env
|
||||
```
|
||||
|
||||
Then create an `.env` with:
|
||||
Then fill `.env` with the omitted values (ask in Zulip). At the moment of this writing, the only value omitted is `AUTH_FIEF_CLIENT_SECRET`.
|
||||
|
||||
```
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=<omitted>
|
||||
### Start the API/Backend
|
||||
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=<omitted>
|
||||
|
||||
AUTH_BACKEND=fief
|
||||
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
|
||||
AUTH_FIEF_CLIENT_ID=KQzRsNgoY<omitted>
|
||||
AUTH_FIEF_CLIENT_SECRET=<omitted>
|
||||
```
|
||||
|
||||
### Start the project
|
||||
|
||||
Use:
|
||||
Start the background worker:
|
||||
|
||||
```bash
|
||||
poetry run python3 -m reflector.app
|
||||
poetry run celery -A reflector.worker.app worker --loglevel=info
|
||||
```
|
||||
|
||||
### Redis (Mac)
|
||||
|
||||
```bash
|
||||
yarn add redis
|
||||
redis-server
|
||||
```
|
||||
|
||||
### Redis (Windows)
|
||||
|
||||
**Option 1**
|
||||
|
||||
```bash
|
||||
docker compose up -d redis
|
||||
```
|
||||
|
||||
**Option 2**
|
||||
|
||||
Install:
|
||||
- [Git for Windows](https://gitforwindows.org/)
|
||||
- [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install)
|
||||
- Install your preferred Linux distribution via the Microsoft Store (e.g., Ubuntu).
|
||||
|
||||
Open your Linux distribution and update the package list:
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install redis-server
|
||||
redis-server
|
||||
```
|
||||
|
||||
## Update the database schema (run on first install, and after each pull containing a migration)
|
||||
|
||||
```bash
|
||||
poetry run alembic heads
|
||||
```
|
||||
|
||||
## Main Server
|
||||
|
||||
Start the server:
|
||||
|
||||
```bash
|
||||
poetry run python -m reflector.app
|
||||
```
|
||||
|
||||
### Crontab (optional)
|
||||
|
||||
For crontab (only healthcheck for now), start the celery beat (you don't need it on your local dev environment):
|
||||
|
||||
```bash
|
||||
poetry run celery -A reflector.worker.app beat
|
||||
```
|
||||
|
||||
#### Using docker
|
||||
@@ -141,4 +199,5 @@ poetry run python -m reflector.tools.process path/to/audio.wav
|
||||
|
||||
## AI Models
|
||||
|
||||
*(Documentation for this section is pending.)*
|
||||
*(Documentation for this section is pending.)*
|
||||
|
||||
|
||||
211
aws/lambda-nodejs18.x/.gitignore
vendored
Normal file
211
aws/lambda-nodejs18.x/.gitignore
vendored
Normal file
@@ -0,0 +1,211 @@
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/osx,node,linux,windows,sam
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=osx,node,linux,windows,sam
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### Node ###
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
lerna-debug.log*
|
||||
|
||||
# Diagnostic reports (https://nodejs.org/api/report.html)
|
||||
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
||||
|
||||
# Runtime data
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
*.pid.lock
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
||||
# Coverage directory used by tools like istanbul
|
||||
coverage
|
||||
*.lcov
|
||||
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
||||
.grunt
|
||||
|
||||
# Bower dependency directory (https://bower.io/)
|
||||
bower_components
|
||||
|
||||
# node-waf configuration
|
||||
.lock-wscript
|
||||
|
||||
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
||||
build/Release
|
||||
|
||||
# Dependency directories
|
||||
node_modules/
|
||||
jspm_packages/
|
||||
|
||||
# TypeScript v1 declaration files
|
||||
typings/
|
||||
|
||||
# TypeScript cache
|
||||
*.tsbuildinfo
|
||||
|
||||
# Optional npm cache directory
|
||||
.npm
|
||||
|
||||
# Optional eslint cache
|
||||
.eslintcache
|
||||
|
||||
# Optional stylelint cache
|
||||
.stylelintcache
|
||||
|
||||
# Microbundle cache
|
||||
.rpt2_cache/
|
||||
.rts2_cache_cjs/
|
||||
.rts2_cache_es/
|
||||
.rts2_cache_umd/
|
||||
|
||||
# Optional REPL history
|
||||
.node_repl_history
|
||||
|
||||
# Output of 'npm pack'
|
||||
*.tgz
|
||||
|
||||
# Yarn Integrity file
|
||||
.yarn-integrity
|
||||
|
||||
# dotenv environment variables file
|
||||
.env
|
||||
.env.test
|
||||
.env*.local
|
||||
|
||||
# parcel-bundler cache (https://parceljs.org/)
|
||||
.cache
|
||||
.parcel-cache
|
||||
|
||||
# Next.js build output
|
||||
.next
|
||||
|
||||
# Nuxt.js build / generate output
|
||||
.nuxt
|
||||
dist
|
||||
|
||||
# Storybook build outputs
|
||||
.out
|
||||
.storybook-out
|
||||
storybook-static
|
||||
|
||||
# rollup.js default build output
|
||||
dist/
|
||||
|
||||
# Gatsby files
|
||||
.cache/
|
||||
# Comment in the public line in if your project uses Gatsby and not Next.js
|
||||
# https://nextjs.org/blog/next-9-1#public-directory-support
|
||||
# public
|
||||
|
||||
# vuepress build output
|
||||
.vuepress/dist
|
||||
|
||||
# Serverless directories
|
||||
.serverless/
|
||||
|
||||
# FuseBox cache
|
||||
.fusebox/
|
||||
|
||||
# DynamoDB Local files
|
||||
.dynamodb/
|
||||
|
||||
# TernJS port file
|
||||
.tern-port
|
||||
|
||||
# Stores VSCode versions used for testing VSCode extensions
|
||||
.vscode-test
|
||||
|
||||
# Temporary folders
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
### OSX ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### SAM ###
|
||||
# Ignore build directories for the AWS Serverless Application Model (SAM)
|
||||
# Info: https://aws.amazon.com/serverless/sam/
|
||||
# Docs: https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-reference.html
|
||||
|
||||
**/.aws-sam
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
.aws-sam
|
||||
|
||||
UpdateZulipStreams/node_modules
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/osx,node,linux,windows,sam
|
||||
38
aws/lambda-nodejs18.x/README.TOOLKIT.md
Normal file
38
aws/lambda-nodejs18.x/README.TOOLKIT.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Developing AWS SAM Applications with the AWS Toolkit For Visual Studio Code
|
||||
|
||||
This project contains source code and supporting files for a serverless application that you can locally run, debug, and deploy to AWS with the AWS Toolkit For Visual Studio Code.
|
||||
|
||||
A "SAM" (serverless application model) project is a project that contains a template.yaml file which is understood by AWS tooling (such as SAM CLI, and the AWS Toolkit For Visual Studio Code).
|
||||
|
||||
## Writing and Debugging Serverless Applications
|
||||
|
||||
The code for this application will differ based on the runtime, but the path to a handler can be found in the [`template.yaml`](./template.yaml) file through a resource's `CodeUri` and `Handler` fields.
|
||||
|
||||
AWS Toolkit For Visual Studio Code supports local debugging for serverless applications through VS Code's debugger. Since this application was created by the AWS Toolkit, launch configurations for all included handlers have been generated and can be found in the menu next to the Run button:
|
||||
|
||||
* lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)
|
||||
* API lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)
|
||||
|
||||
You can debug the Lambda handlers locally by adding a breakpoint to the source file, then running the launch configuration. This works by using Docker on your local machine.
|
||||
|
||||
Invocation parameters, including payloads and request parameters, can be edited either by the `Edit SAM Debug Configuration` command (through the Command Palette or CodeLens) or by editing the `launch.json` file.
|
||||
|
||||
AWS Lambda functions not defined in the [`template.yaml`](./template.yaml) file can be invoked and debugged by creating a launch configuration through the CodeLens over the function declaration, or with the `Add SAM Debug Configuration` command.
|
||||
|
||||
## Deploying Serverless Applications
|
||||
|
||||
You can deploy a serverless application by invoking the `AWS: Deploy SAM application` command through the Command Palette or by right-clicking the Lambda node in the AWS Explorer and entering the deployment region, a valid S3 bucket from the region, and the name of a CloudFormation stack to deploy to. You can monitor your deployment's progress through the `AWS Toolkit` Output Channel.
|
||||
|
||||
## Interacting With Deployed Serverless Applications
|
||||
|
||||
A successfully-deployed serverless application can be found in the AWS Explorer under region and CloudFormation node that the serverless application was deployed to.
|
||||
|
||||
In the AWS Explorer, you can invoke _remote_ AWS Lambda Functions by right-clicking the Lambda node and selecting "Invoke on AWS".
|
||||
|
||||
Similarly, if the Function declaration contained an API Gateway event, the API Gateway API can be found in the API Gateway node under the region node the serverless application was deployed to, and can be invoked via right-clicking the API node and selecting "Invoke on AWS".
|
||||
|
||||
## Resources
|
||||
|
||||
General information about this SAM project can be found in the [`README.md`](./README.md) file in this folder.
|
||||
|
||||
More information about using the AWS Toolkit For Visual Studio Code with serverless applications can be found [in the AWS documentation](https://docs.aws.amazon.com/toolkit-for-vscode/latest/userguide/serverless-apps.html) .
|
||||
127
aws/lambda-nodejs18.x/README.md
Normal file
127
aws/lambda-nodejs18.x/README.md
Normal file
@@ -0,0 +1,127 @@
|
||||
# lambda-nodejs18.x
|
||||
|
||||
This project contains source code and supporting files for a serverless application that you can deploy with the SAM CLI. It includes the following files and folders.
|
||||
|
||||
- hello-world - Code for the application's Lambda function.
|
||||
- events - Invocation events that you can use to invoke the function.
|
||||
- hello-world/tests - Unit tests for the application code.
|
||||
- template.yaml - A template that defines the application's AWS resources.
|
||||
|
||||
The application uses several AWS resources, including Lambda functions and an API Gateway API. These resources are defined in the `template.yaml` file in this project. You can update the template to add AWS resources through the same deployment process that updates your application code.
|
||||
|
||||
If you prefer to use an integrated development environment (IDE) to build and test your application, you can use the AWS Toolkit.
|
||||
The AWS Toolkit is an open source plug-in for popular IDEs that uses the SAM CLI to build and deploy serverless applications on AWS. The AWS Toolkit also adds a simplified step-through debugging experience for Lambda function code. See the following links to get started.
|
||||
|
||||
* [CLion](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [GoLand](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [IntelliJ](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [WebStorm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [Rider](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [PhpStorm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [PyCharm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [RubyMine](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [DataGrip](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
|
||||
* [VS Code](https://docs.aws.amazon.com/toolkit-for-vscode/latest/userguide/welcome.html)
|
||||
* [Visual Studio](https://docs.aws.amazon.com/toolkit-for-visual-studio/latest/user-guide/welcome.html)
|
||||
|
||||
## Deploy the sample application
|
||||
|
||||
The Serverless Application Model Command Line Interface (SAM CLI) is an extension of the AWS CLI that adds functionality for building and testing Lambda applications. It uses Docker to run your functions in an Amazon Linux environment that matches Lambda. It can also emulate your application's build environment and API.
|
||||
|
||||
To use the SAM CLI, you need the following tools.
|
||||
|
||||
* SAM CLI - [Install the SAM CLI](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-install.html)
|
||||
* Node.js - [Install Node.js 18](https://nodejs.org/en/), including the NPM package management tool.
|
||||
* Docker - [Install Docker community edition](https://hub.docker.com/search/?type=edition&offering=community)
|
||||
|
||||
To build and deploy your application for the first time, run the following in your shell:
|
||||
|
||||
```bash
|
||||
sam build
|
||||
sam deploy --guided
|
||||
```
|
||||
|
||||
The first command will build the source of your application. The second command will package and deploy your application to AWS, with a series of prompts:
|
||||
|
||||
* **Stack Name**: The name of the stack to deploy to CloudFormation. This should be unique to your account and region, and a good starting point would be something matching your project name.
|
||||
* **AWS Region**: The AWS region you want to deploy your app to.
|
||||
* **Confirm changes before deploy**: If set to yes, any change sets will be shown to you before execution for manual review. If set to no, the AWS SAM CLI will automatically deploy application changes.
|
||||
* **Allow SAM CLI IAM role creation**: Many AWS SAM templates, including this example, create AWS IAM roles required for the AWS Lambda function(s) included to access AWS services. By default, these are scoped down to minimum required permissions. To deploy an AWS CloudFormation stack which creates or modifies IAM roles, the `CAPABILITY_IAM` value for `capabilities` must be provided. If permission isn't provided through this prompt, to deploy this example you must explicitly pass `--capabilities CAPABILITY_IAM` to the `sam deploy` command.
|
||||
* **Save arguments to samconfig.toml**: If set to yes, your choices will be saved to a configuration file inside the project, so that in the future you can just re-run `sam deploy` without parameters to deploy changes to your application.
|
||||
|
||||
You can find your API Gateway Endpoint URL in the output values displayed after deployment.
|
||||
|
||||
## Use the SAM CLI to build and test locally
|
||||
|
||||
Build your application with the `sam build` command.
|
||||
|
||||
```bash
|
||||
lambda-nodejs18.x$ sam build
|
||||
```
|
||||
|
||||
The SAM CLI installs dependencies defined in `hello-world/package.json`, creates a deployment package, and saves it in the `.aws-sam/build` folder.
|
||||
|
||||
Test a single function by invoking it directly with a test event. An event is a JSON document that represents the input that the function receives from the event source. Test events are included in the `events` folder in this project.
|
||||
|
||||
Run functions locally and invoke them with the `sam local invoke` command.
|
||||
|
||||
```bash
|
||||
lambda-nodejs18.x$ sam local invoke HelloWorldFunction --event events/event.json
|
||||
```
|
||||
|
||||
The SAM CLI can also emulate your application's API. Use the `sam local start-api` to run the API locally on port 3000.
|
||||
|
||||
```bash
|
||||
lambda-nodejs18.x$ sam local start-api
|
||||
lambda-nodejs18.x$ curl http://localhost:3000/
|
||||
```
|
||||
|
||||
The SAM CLI reads the application template to determine the API's routes and the functions that they invoke. The `Events` property on each function's definition includes the route and method for each path.
|
||||
|
||||
```yaml
|
||||
Events:
|
||||
HelloWorld:
|
||||
Type: Api
|
||||
Properties:
|
||||
Path: /hello
|
||||
Method: get
|
||||
```
|
||||
|
||||
## Add a resource to your application
|
||||
The application template uses AWS Serverless Application Model (AWS SAM) to define application resources. AWS SAM is an extension of AWS CloudFormation with a simpler syntax for configuring common serverless application resources such as functions, triggers, and APIs. For resources not included in [the SAM specification](https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md), you can use standard [AWS CloudFormation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-template-resource-type-ref.html) resource types.
|
||||
|
||||
## Fetch, tail, and filter Lambda function logs
|
||||
|
||||
To simplify troubleshooting, SAM CLI has a command called `sam logs`. `sam logs` lets you fetch logs generated by your deployed Lambda function from the command line. In addition to printing the logs on the terminal, this command has several nifty features to help you quickly find the bug.
|
||||
|
||||
`NOTE`: This command works for all AWS Lambda functions; not just the ones you deploy using SAM.
|
||||
|
||||
```bash
|
||||
lambda-nodejs18.x$ sam logs -n HelloWorldFunction --stack-name lambda-nodejs18.x --tail
|
||||
```
|
||||
|
||||
You can find more information and examples about filtering Lambda function logs in the [SAM CLI Documentation](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-logging.html).
|
||||
|
||||
## Unit tests
|
||||
|
||||
Tests are defined in the `hello-world/tests` folder in this project. Use NPM to install the [Mocha test framework](https://mochajs.org/) and run unit tests.
|
||||
|
||||
```bash
|
||||
lambda-nodejs18.x$ cd hello-world
|
||||
hello-world$ npm install
|
||||
hello-world$ npm run test
|
||||
```
|
||||
|
||||
## Cleanup
|
||||
|
||||
To delete the sample application that you created, use the AWS CLI. Assuming you used your project name for the stack name, you can run the following:
|
||||
|
||||
```bash
|
||||
sam delete --stack-name lambda-nodejs18.x
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
See the [AWS SAM developer guide](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/what-is-sam.html) for an introduction to SAM specification, the SAM CLI, and serverless application concepts.
|
||||
|
||||
Next, you can use AWS Serverless Application Repository to deploy ready to use Apps that go beyond hello world samples and learn how authors developed their applications: [AWS Serverless Application Repository main page](https://aws.amazon.com/serverless/serverlessrepo/)
|
||||
71
aws/lambda-nodejs18.x/UpdateZulipStreams/index.js
Normal file
71
aws/lambda-nodejs18.x/UpdateZulipStreams/index.js
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
const AWS = require('aws-sdk');
|
||||
const s3 = new AWS.S3();
|
||||
const axios = require('axios');
|
||||
|
||||
async function getTopics(stream_id) {
|
||||
const response = await axios.get(`https://${process.env.ZULIP_REALM}/api/v1/users/me/${stream_id}/topics`, {
|
||||
auth: {
|
||||
username: process.env.ZULIP_BOT_EMAIL || "?",
|
||||
password: process.env.ZULIP_API_KEY || "?"
|
||||
}
|
||||
});
|
||||
return response.data.topics.map(topic => topic.name);
|
||||
}
|
||||
|
||||
async function getStreams() {
|
||||
|
||||
const response = await axios.get(`https://${process.env.ZULIP_REALM}/api/v1/streams`, {
|
||||
auth: {
|
||||
username: process.env.ZULIP_BOT_EMAIL || "?",
|
||||
password: process.env.ZULIP_API_KEY || "?"
|
||||
}
|
||||
});
|
||||
|
||||
const streams = [];
|
||||
for (const stream of response.data.streams) {
|
||||
console.log("Loading topics for " + stream.name);
|
||||
const topics = await getTopics(stream.stream_id);
|
||||
streams.push({ id: stream.stream_id, name: stream.name, topics });
|
||||
}
|
||||
|
||||
return streams;
|
||||
|
||||
}
|
||||
|
||||
|
||||
const handler = async (event) => {
|
||||
const streams = await getStreams();
|
||||
|
||||
// Convert the streams to JSON
|
||||
const json_data = JSON.stringify(streams);
|
||||
|
||||
const bucketName = process.env.S3BUCKET_NAME;
|
||||
const fileName = process.env.S3BUCKET_FILE_NAME;
|
||||
|
||||
// Parameters for S3 upload
|
||||
const params = {
|
||||
Bucket: bucketName,
|
||||
Key: fileName,
|
||||
Body: json_data,
|
||||
ContentType: 'application/json'
|
||||
};
|
||||
|
||||
try {
|
||||
// Write the JSON data to S3
|
||||
await s3.putObject(params).promise();
|
||||
return {
|
||||
statusCode: 200,
|
||||
body: JSON.stringify('File written to S3 successfully')
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error writing to S3:', error);
|
||||
return {
|
||||
statusCode: 500,
|
||||
body: JSON.stringify('Error writing file to S3')
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = { handler };
|
||||
1462
aws/lambda-nodejs18.x/UpdateZulipStreams/package-lock.json
generated
Normal file
1462
aws/lambda-nodejs18.x/UpdateZulipStreams/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
16
aws/lambda-nodejs18.x/UpdateZulipStreams/package.json
Normal file
16
aws/lambda-nodejs18.x/UpdateZulipStreams/package.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"name": "updatezulipstreams",
|
||||
"version": "1.0.0",
|
||||
"description": "Updates the JSON with the zulip streams and topics on S3",
|
||||
"main": "index.js",
|
||||
"author": "Andreas Bonini",
|
||||
"license": "All rights reserved",
|
||||
"dependencies": {
|
||||
"aws-sdk": "^2.1498.0",
|
||||
"axios": "^1.6.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"chai": "^4.3.6",
|
||||
"mocha": "^10.1.0"
|
||||
}
|
||||
}
|
||||
62
aws/lambda-nodejs18.x/events/event.json
Normal file
62
aws/lambda-nodejs18.x/events/event.json
Normal file
@@ -0,0 +1,62 @@
|
||||
{
|
||||
"body": "{\"message\": \"hello world\"}",
|
||||
"resource": "/{proxy+}",
|
||||
"path": "/path/to/resource",
|
||||
"httpMethod": "POST",
|
||||
"isBase64Encoded": false,
|
||||
"queryStringParameters": {
|
||||
"foo": "bar"
|
||||
},
|
||||
"pathParameters": {
|
||||
"proxy": "/path/to/resource"
|
||||
},
|
||||
"stageVariables": {
|
||||
"baz": "qux"
|
||||
},
|
||||
"headers": {
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate, sdch",
|
||||
"Accept-Language": "en-US,en;q=0.8",
|
||||
"Cache-Control": "max-age=0",
|
||||
"CloudFront-Forwarded-Proto": "https",
|
||||
"CloudFront-Is-Desktop-Viewer": "true",
|
||||
"CloudFront-Is-Mobile-Viewer": "false",
|
||||
"CloudFront-Is-SmartTV-Viewer": "false",
|
||||
"CloudFront-Is-Tablet-Viewer": "false",
|
||||
"CloudFront-Viewer-Country": "US",
|
||||
"Host": "1234567890.execute-api.us-east-1.amazonaws.com",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"User-Agent": "Custom User Agent String",
|
||||
"Via": "1.1 08f323deadbeefa7af34d5feb414ce27.cloudfront.net (CloudFront)",
|
||||
"X-Amz-Cf-Id": "cDehVQoZnx43VYQb9j2-nvCh-9z396Uhbp027Y2JvkCPNLmGJHqlaA==",
|
||||
"X-Forwarded-For": "127.0.0.1, 127.0.0.2",
|
||||
"X-Forwarded-Port": "443",
|
||||
"X-Forwarded-Proto": "https"
|
||||
},
|
||||
"requestContext": {
|
||||
"accountId": "123456789012",
|
||||
"resourceId": "123456",
|
||||
"stage": "prod",
|
||||
"requestId": "c6af9ac6-7b61-11e6-9a41-93e8deadbeef",
|
||||
"requestTime": "09/Apr/2015:12:34:56 +0000",
|
||||
"requestTimeEpoch": 1428582896000,
|
||||
"identity": {
|
||||
"cognitoIdentityPoolId": null,
|
||||
"accountId": null,
|
||||
"cognitoIdentityId": null,
|
||||
"caller": null,
|
||||
"accessKey": null,
|
||||
"sourceIp": "127.0.0.1",
|
||||
"cognitoAuthenticationType": null,
|
||||
"cognitoAuthenticationProvider": null,
|
||||
"userArn": null,
|
||||
"userAgent": "Custom User Agent String",
|
||||
"user": null
|
||||
},
|
||||
"path": "/prod/path/to/resource",
|
||||
"resourcePath": "/{proxy+}",
|
||||
"httpMethod": "POST",
|
||||
"apiId": "1234567890",
|
||||
"protocol": "HTTP/1.1"
|
||||
}
|
||||
}
|
||||
31
aws/lambda-nodejs18.x/samconfig.toml
Normal file
31
aws/lambda-nodejs18.x/samconfig.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
# More information about the configuration file can be found here:
|
||||
# https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-config.html
|
||||
version = 0.1
|
||||
|
||||
[default]
|
||||
[default.global.parameters]
|
||||
stack_name = "lambda-nodejs18.x"
|
||||
|
||||
[default.build.parameters]
|
||||
cached = true
|
||||
parallel = true
|
||||
|
||||
[default.validate.parameters]
|
||||
lint = true
|
||||
|
||||
[default.deploy.parameters]
|
||||
capabilities = "CAPABILITY_IAM"
|
||||
confirm_changeset = true
|
||||
resolve_s3 = true
|
||||
|
||||
[default.package.parameters]
|
||||
resolve_s3 = true
|
||||
|
||||
[default.sync.parameters]
|
||||
watch = true
|
||||
|
||||
[default.local_start_api.parameters]
|
||||
warm_containers = "EAGER"
|
||||
|
||||
[default.local_start_lambda.parameters]
|
||||
warm_containers = "EAGER"
|
||||
41
aws/lambda-nodejs18.x/template.yaml
Normal file
41
aws/lambda-nodejs18.x/template.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
AWSTemplateFormatVersion: '2010-09-09'
|
||||
Transform: AWS::Serverless-2016-10-31
|
||||
Description: >
|
||||
lambda-nodejs18.x
|
||||
|
||||
Sample SAM Template for lambda-nodejs18.x
|
||||
|
||||
# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst
|
||||
Globals:
|
||||
Function:
|
||||
Timeout: 3
|
||||
|
||||
Resources:
|
||||
HelloWorldFunction:
|
||||
Type: AWS::Serverless::Function # More info about Function Resource: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#awsserverlessfunction
|
||||
Properties:
|
||||
CodeUri: hello-world/
|
||||
Handler: app.lambdaHandler
|
||||
Runtime: nodejs18.x
|
||||
Architectures:
|
||||
- arm64
|
||||
Events:
|
||||
HelloWorld:
|
||||
Type: Api # More info about API Event Source: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#api
|
||||
Properties:
|
||||
Path: /hello
|
||||
Method: get
|
||||
|
||||
Outputs:
|
||||
# ServerlessRestApi is an implicit API created out of Events key under Serverless::Function
|
||||
# Find out more about other implicit resources you can reference within SAM
|
||||
# https://github.com/awslabs/serverless-application-model/blob/master/docs/internals/generated_resources.rst#api
|
||||
HelloWorldApi:
|
||||
Description: "API Gateway endpoint URL for Prod stage for Hello World function"
|
||||
Value: !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/hello/"
|
||||
HelloWorldFunction:
|
||||
Description: "Hello World Lambda Function ARN"
|
||||
Value: !GetAtt HelloWorldFunction.Arn
|
||||
HelloWorldFunctionIamRole:
|
||||
Description: "Implicit IAM Role created for Hello World function"
|
||||
Value: !GetAtt HelloWorldFunctionRole.Arn
|
||||
@@ -5,10 +5,19 @@ services:
|
||||
context: server
|
||||
ports:
|
||||
- 1250:1250
|
||||
environment:
|
||||
LLM_URL: "${LLM_URL}"
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment: ENTRYPOINT=server
|
||||
worker:
|
||||
build:
|
||||
context: server
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment: ENTRYPOINT=worker
|
||||
redis:
|
||||
image: redis:7.2
|
||||
ports:
|
||||
- 6379:6379
|
||||
web:
|
||||
build:
|
||||
context: www
|
||||
@@ -17,4 +26,3 @@ services:
|
||||
|
||||
volumes:
|
||||
model-cache:
|
||||
|
||||
|
||||
20
server/.env_template
Normal file
20
server/.env_template
Normal file
@@ -0,0 +1,20 @@
|
||||
TRANSCRIPT_BACKEND=modal
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
||||
TRANSCRIPT_MODAL_API_KEY=***REMOVED***
|
||||
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=<ask in zulip>
|
||||
|
||||
AUTH_BACKEND=fief
|
||||
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
|
||||
AUTH_FIEF_CLIENT_ID=***REMOVED***
|
||||
AUTH_FIEF_CLIENT_SECRET=<ask in zulip> <-----------------------------------------------------------------------------------------
|
||||
|
||||
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
||||
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||
|
||||
BASE_URL=https://xxxxx.ngrok.app
|
||||
DIARIZATION_ENABLED=false
|
||||
|
||||
2
server/.gitignore
vendored
2
server/.gitignore
vendored
@@ -178,3 +178,5 @@ audio_*.wav
|
||||
# ignore local database
|
||||
reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.11
|
||||
3.11.6
|
||||
|
||||
@@ -5,11 +5,23 @@ services:
|
||||
context: .
|
||||
ports:
|
||||
- 1250:1250
|
||||
environment:
|
||||
LLM_URL: "${LLM_URL}"
|
||||
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment:
|
||||
ENTRYPOINT: server
|
||||
REDIS_HOST: redis
|
||||
worker:
|
||||
build:
|
||||
context: .
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment:
|
||||
ENTRYPOINT: worker
|
||||
REDIS_HOST: redis
|
||||
redis:
|
||||
image: redis:7.2
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
volumes:
|
||||
model-cache:
|
||||
|
||||
@@ -51,17 +51,6 @@
|
||||
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
|
||||
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||
|
||||
## Using serverless banana.dev (require reflector-gpu-banana deployed)
|
||||
## XXX this service is buggy do not use at the moment
|
||||
## XXX it also require the audio to be saved to S3
|
||||
#TRANSCRIPT_BACKEND=banana
|
||||
#TRANSCRIPT_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
|
||||
#TRANSCRIPT_BANANA_API_KEY=xxx
|
||||
#TRANSCRIPT_BANANA_MODEL_KEY=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_BUCKET_NAME="reflector-bucket/chunks"
|
||||
|
||||
## =======================================================
|
||||
## LLM backend
|
||||
##
|
||||
@@ -78,13 +67,6 @@
|
||||
#LLM_URL=https://xxxxxx--reflector-llm-web.modal.run
|
||||
#LLM_MODAL_API_KEY=xxx
|
||||
|
||||
## Using serverless banana.dev (require reflector-gpu-banana deployed)
|
||||
## XXX this service is buggy do not use at the moment
|
||||
#LLM_BACKEND=banana
|
||||
#LLM_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
|
||||
#LLM_BANANA_API_KEY=xxxxx
|
||||
#LLM_BANANA_MODEL_KEY=xxxxx
|
||||
|
||||
## Using OpenAI
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_OPENAI_KEY=xxx
|
||||
|
||||
188
server/gpu/modal/reflector_diarizer.py
Normal file
188
server/gpu/modal/reflector_diarizer.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Reflector GPU backend - diarizer
|
||||
===================================
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import modal.gpu
|
||||
from modal import Image, Secret, Stub, asgi_app, method
|
||||
from pydantic import BaseModel
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
|
||||
stub = Stub(name="reflector-diarizer")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
def download_pyannote_audio():
|
||||
from pyannote.audio import Pipeline
|
||||
Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token="***REMOVED***"
|
||||
)
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio",
|
||||
"requests",
|
||||
"onnx",
|
||||
"torchaudio",
|
||||
"onnxruntime-gpu",
|
||||
"torch==2.0.0",
|
||||
"transformers==4.34.0",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"numpy",
|
||||
"huggingface_hub",
|
||||
"hf-transfer"
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.run_function(download_pyannote_audio)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
|
||||
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@stub.cls(
|
||||
gpu=modal.gpu.A100(memory=40),
|
||||
timeout=60 * 30,
|
||||
container_idle_timeout=60,
|
||||
allow_concurrent_inputs=1,
|
||||
image=diarizer_image,
|
||||
)
|
||||
class Diarizer:
|
||||
def __enter__(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
cache_dir=MODEL_DIR
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@method()
|
||||
def diarize(
|
||||
self,
|
||||
audio_data: str,
|
||||
audio_suffix: str,
|
||||
timestamp: float
|
||||
):
|
||||
import tempfile
|
||||
|
||||
import torchaudio
|
||||
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
print("Diarizing audio")
|
||||
waveform, sample_rate = torchaudio.load(fp.name)
|
||||
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {
|
||||
"diarization": words
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@stub.function(
|
||||
timeout=60 * 10,
|
||||
container_idle_timeout=60 * 3,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
image=diarizer_image
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def validate_audio_file(audio_file_url: str):
|
||||
# Check if the audio file exists
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="The audio file does not exist."
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
|
||||
def diarize(
|
||||
audio_file_url: str,
|
||||
timestamp: float = 0.0
|
||||
) -> HTTPException | DiarizationResponse:
|
||||
# Currently the uploaded files are in mp3 format
|
||||
audio_suffix = "mp3"
|
||||
|
||||
print("Downloading audio file")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
print("Audio file downloaded successfully")
|
||||
|
||||
func = diarizerstub.diarize.spawn(
|
||||
audio_data=response.content,
|
||||
audio_suffix=audio_suffix,
|
||||
timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -81,7 +81,8 @@ class LLM:
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
@@ -96,7 +97,8 @@ class LLM:
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
|
||||
@@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
||||
|
||||
stub = Stub(name="reflector-llm-zephyr")
|
||||
|
||||
@@ -81,7 +81,8 @@ class LLM:
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
@@ -96,7 +97,8 @@ class LLM:
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
@@ -95,7 +95,8 @@ class Transcriber:
|
||||
device=self.device,
|
||||
compute_type=WHISPER_COMPUTE_TYPE,
|
||||
num_workers=WHISPER_NUM_WORKERS,
|
||||
download_root=WHISPER_MODEL_DIR
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
@method()
|
||||
|
||||
33
server/migrations/versions/0fea6d96b096_add_share_mode.py
Normal file
33
server/migrations/versions/0fea6d96b096_add_share_mode.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""add share_mode
|
||||
|
||||
Revision ID: 0fea6d96b096
|
||||
Revises: f819277e5169
|
||||
Create Date: 2023-11-07 11:12:21.614198
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0fea6d96b096"
|
||||
down_revision: Union[str, None] = "f819277e5169"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"transcript",
|
||||
sa.Column("share_mode", sa.String(), server_default="private", nullable=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("transcript", "share_mode")
|
||||
# ### end Alembic commands ###
|
||||
30
server/migrations/versions/125031f7cb78_participants.py
Normal file
30
server/migrations/versions/125031f7cb78_participants.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""participants
|
||||
|
||||
Revision ID: 125031f7cb78
|
||||
Revises: 0fea6d96b096
|
||||
Create Date: 2023-11-30 15:56:03.341466
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '125031f7cb78'
|
||||
down_revision: Union[str, None] = '0fea6d96b096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'participants')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,80 @@
|
||||
"""rename back text to transcript
|
||||
|
||||
Revision ID: 38a927dcb099
|
||||
Revises: 9920ecfe2735
|
||||
Create Date: 2023-11-02 19:53:09.116240
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38a927dcb099'
|
||||
down_revision: Union[str, None] = '9920ecfe2735'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "text" in topic:
|
||||
# Rename key 'text' back to 'transcript'
|
||||
topic["transcript"] = topic.pop("text")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "transcript" in topic:
|
||||
# Rename key 'transcript' to 'text'
|
||||
topic["text"] = topic.pop("transcript")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""fix duration
|
||||
|
||||
Revision ID: 4814901632bc
|
||||
Revises: 38a927dcb099
|
||||
Create Date: 2023-11-10 18:12:17.886522
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "4814901632bc"
|
||||
down_revision: Union[str, None] = "38a927dcb099"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# for all the transcripts, calculate the duration from the mp3
|
||||
# and update the duration column
|
||||
from pathlib import Path
|
||||
from reflector.settings import settings
|
||||
import av
|
||||
|
||||
bind = op.get_bind()
|
||||
transcript = table(
|
||||
"transcript", column("id", sa.String), column("duration", sa.Float)
|
||||
)
|
||||
|
||||
# select only the one with duration = 0
|
||||
results = bind.execute(
|
||||
select([transcript.c.id, transcript.c.duration]).where(
|
||||
transcript.c.duration == 0
|
||||
)
|
||||
)
|
||||
|
||||
data_dir = Path(settings.DATA_DIR)
|
||||
for row in results:
|
||||
audio_path = data_dir / row["id"] / "audio.mp3"
|
||||
if not audio_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"Processing {audio_path}")
|
||||
container = av.open(audio_path.as_posix())
|
||||
print(container.duration)
|
||||
duration = round(float(container.duration / av.time_base), 2)
|
||||
print(f"Duration: {duration}")
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == row["id"])
|
||||
.values(duration=duration)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to process {audio_path}: {e}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Migration transcript to text field in transcripts table
|
||||
|
||||
Revision ID: 9920ecfe2735
|
||||
Revises: 99365b0cd87b
|
||||
Create Date: 2023-11-02 18:55:17.019498
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9920ecfe2735"
|
||||
down_revision: Union[str, None] = "99365b0cd87b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "transcript" in topic:
|
||||
# Rename key 'transcript' to 'text'
|
||||
topic["text"] = topic.pop("transcript")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "text" in topic:
|
||||
# Rename key 'text' back to 'transcript'
|
||||
topic["transcript"] = topic.pop("text")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
29
server/migrations/versions/b9348748bbbc_reviewed.py
Normal file
29
server/migrations/versions/b9348748bbbc_reviewed.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""reviewed
|
||||
|
||||
Revision ID: b9348748bbbc
|
||||
Revises: 125031f7cb78
|
||||
Create Date: 2023-12-13 15:37:51.303970
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b9348748bbbc'
|
||||
down_revision: Union[str, None] = '125031f7cb78'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('reviewed', sa.Boolean(), server_default=sa.text('0'), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'reviewed')
|
||||
# ### end Alembic commands ###
|
||||
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""audio_location
|
||||
|
||||
Revision ID: f819277e5169
|
||||
Revises: 4814901632bc
|
||||
Create Date: 2023-11-16 10:29:09.351664
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f819277e5169"
|
||||
down_revision: Union[str, None] = "4814901632bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"transcript",
|
||||
sa.Column(
|
||||
"audio_location", sa.String(), server_default="local", nullable=False
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("transcript", "audio_location")
|
||||
# ### end Alembic commands ###
|
||||
1208
server/poetry.lock
generated
1208
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,12 +8,11 @@ packages = []
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
aiohttp = "^3.8.5"
|
||||
aiohttp = "^3.9.0"
|
||||
aiohttp-cors = "^0.7.0"
|
||||
av = "^10.0.0"
|
||||
requests = "^2.31.0"
|
||||
aiortc = "^1.5.0"
|
||||
faster-whisper = "^0.7.1"
|
||||
sortedcontainers = "^2.4.0"
|
||||
loguru = "^0.7.0"
|
||||
pydantic-settings = "^2.0.2"
|
||||
@@ -28,16 +27,22 @@ sqlalchemy = "<1.5"
|
||||
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
|
||||
alembic = "^1.11.3"
|
||||
nltk = "^3.8.1"
|
||||
transformers = "^4.32.1"
|
||||
prometheus-fastapi-instrumentator = "^6.1.0"
|
||||
sentencepiece = "^0.1.99"
|
||||
protobuf = "^4.24.3"
|
||||
profanityfilter = "^2.0.6"
|
||||
celery = "^5.3.4"
|
||||
redis = "^5.0.1"
|
||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
python-multipart = "^0.0.6"
|
||||
faster-whisper = "^0.10.0"
|
||||
transformers = "^4.36.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
stamina = "^23.1.0"
|
||||
pyinstrument = "^4.6.1"
|
||||
|
||||
|
||||
[tool.poetry.group.tests.dependencies]
|
||||
@@ -47,6 +52,7 @@ pytest-asyncio = "^0.21.1"
|
||||
pytest = "^7.4.0"
|
||||
httpx-ws = "^0.4.1"
|
||||
pytest-httpx = "^0.23.1"
|
||||
pytest-celery = "^0.0.0"
|
||||
|
||||
|
||||
[tool.poetry.group.aws.dependencies]
|
||||
|
||||
@@ -13,6 +13,14 @@ from reflector.metrics import metrics_init
|
||||
from reflector.settings import settings
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
||||
from reflector.views.transcripts_participants import (
|
||||
router as transcripts_participants_router,
|
||||
)
|
||||
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
|
||||
from reflector.views.transcripts_upload import router as transcripts_upload_router
|
||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
||||
from reflector.views.user import router as user_router
|
||||
|
||||
try:
|
||||
@@ -41,7 +49,6 @@ if settings.SENTRY_DSN:
|
||||
else:
|
||||
logger.info("Sentry disabled")
|
||||
|
||||
|
||||
# build app
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
@@ -61,9 +68,18 @@ metrics_init(app, instrumentator)
|
||||
# register views
|
||||
app.include_router(rtc_offer_router)
|
||||
app.include_router(transcripts_router, prefix="/v1")
|
||||
app.include_router(transcripts_audio_router, prefix="/v1")
|
||||
app.include_router(transcripts_participants_router, prefix="/v1")
|
||||
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
from reflector.worker import app as celery_app # noqa
|
||||
|
||||
|
||||
# simpler openapi id
|
||||
def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
@@ -84,7 +100,10 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
version = None
|
||||
if route.path.startswith("/v"):
|
||||
version = route.path.split("/")[1]
|
||||
opid = f"{version}_{route.name}"
|
||||
if route.operation_id is not None:
|
||||
opid = f"{version}_{route.operation_id}"
|
||||
else:
|
||||
opid = f"{version}_{route.name}"
|
||||
else:
|
||||
opid = route.name
|
||||
|
||||
@@ -94,11 +113,28 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
"Please rename the route or the view function."
|
||||
)
|
||||
route.operation_id = opid
|
||||
ensure_uniq_operation_ids.add(route.name)
|
||||
ensure_uniq_operation_ids.add(opid)
|
||||
|
||||
|
||||
use_route_names_as_operation_ids(app)
|
||||
|
||||
if settings.PROFILING:
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pyinstrument import Profiler
|
||||
|
||||
@app.middleware("http")
|
||||
async def profile_request(request: Request, call_next):
|
||||
profiling = request.query_params.get("profile", False)
|
||||
if profiling:
|
||||
profiler = Profiler(async_mode="enabled")
|
||||
profiler.start()
|
||||
await call_next(request)
|
||||
profiler.stop()
|
||||
return HTMLResponse(profiler.output_html())
|
||||
else:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -1,32 +1,13 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
database = databases.Database(settings.DATABASE_URL)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
)
|
||||
# import models
|
||||
import reflector.db.transcripts # noqa
|
||||
|
||||
engine = sqlalchemy.create_engine(
|
||||
settings.DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
|
||||
507
server/reflector/db/transcripts.py
Normal file
507
server/reflector/db/transcripts.py
Normal file
@@ -0,0 +1,507 @@
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from sqlalchemy.sql import false
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("participants", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column(
|
||||
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column(
|
||||
"audio_location",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="local",
|
||||
),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"share_mode",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def generate_uuid4() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name() -> str:
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
def get_storage() -> Storage:
|
||||
return Storage.get_instance(
|
||||
name=settings.TRANSCRIPT_STORAGE_BACKEND,
|
||||
settings_prefix="TRANSCRIPT_STORAGE_",
|
||||
)
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptSegmentTopic(BaseModel):
|
||||
speaker: int
|
||||
text: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
duration: float | None = 0
|
||||
transcript: str | None = None
|
||||
words: list[ProcessorWord] = []
|
||||
|
||||
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptDuration(BaseModel):
|
||||
duration: float
|
||||
|
||||
|
||||
class TranscriptWaveform(BaseModel):
|
||||
waveform: list[float]
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
|
||||
class TranscriptParticipant(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
speaker: int | None
|
||||
name: str
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
participants: list[TranscriptParticipant] | None = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
share_mode: Literal["private", "semi-private", "public"] = "private"
|
||||
audio_location: str = "local"
|
||||
reviewed: bool = False
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
|
||||
if index is not None:
|
||||
self.topics[index] = topic
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def upsert_participant(self, participant: TranscriptParticipant):
|
||||
index = next(
|
||||
(i for i, p in enumerate(self.participants) if p.id == participant.id),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
self.participants[index] = participant
|
||||
else:
|
||||
self.participants.append(participant)
|
||||
return participant
|
||||
|
||||
def delete_participant(self, participant_id: str):
|
||||
index = next(
|
||||
(i for i, p in enumerate(self.participants) if p.id == participant_id),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
del self.participants[index]
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def participants_dump(self, mode="json"):
|
||||
return [participant.model_dump(mode=mode) for participant in self.participants]
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_wav_filename(self):
|
||||
return self.data_path / "audio.wav"
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def storage_audio_path(self):
|
||||
return f"{self.id}/audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
async def get_audio_url(self) -> str:
|
||||
if self.audio_location == "local":
|
||||
return self._generate_local_audio_link()
|
||||
elif self.audio_location == "storage":
|
||||
return await self._generate_storage_audio_link()
|
||||
raise Exception(f"Unknown audio location {self.audio_location}")
|
||||
|
||||
async def _generate_storage_audio_link(self) -> str:
|
||||
return await get_storage().get_file_url(self.storage_audio_path)
|
||||
|
||||
def _generate_local_audio_link(self) -> str:
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from datetime import timedelta
|
||||
|
||||
from reflector.app import app
|
||||
from reflector.views.transcripts import create_access_token
|
||||
|
||||
path = app.url_path_for(
|
||||
"transcript_get_audio_mp3",
|
||||
transcript_id=self.id,
|
||||
)
|
||||
url = f"{settings.BASE_URL}{path}"
|
||||
if self.user_id:
|
||||
# we pass token only if the user_id is set
|
||||
# otherwise, the audio is public
|
||||
token = create_access_token(
|
||||
{"sub": self.user_id},
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
url += f"?token={token}"
|
||||
return url
|
||||
|
||||
def find_empty_speaker(self) -> int:
|
||||
"""
|
||||
Find an empty speaker seat
|
||||
"""
|
||||
speakers = set(
|
||||
word.speaker
|
||||
for topic in self.topics
|
||||
for word in topic.words
|
||||
if word.speaker is not None
|
||||
)
|
||||
i = 0
|
||||
while True:
|
||||
if i not in speakers:
|
||||
return i
|
||||
i += 1
|
||||
raise Exception("No empty speaker found")
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
filter_recording: bool | None = False,
|
||||
return_query: bool = False,
|
||||
) -> list[Transcript]:
|
||||
"""
|
||||
Get all transcripts
|
||||
|
||||
If `user_id` is specified, only return transcripts that belong to the user.
|
||||
Otherwise, return all anonymous transcripts.
|
||||
|
||||
Parameters:
|
||||
- `order_by`: field to order by, e.g. "-created_at"
|
||||
- `filter_empty`: filter out empty transcripts
|
||||
- `filter_recording`: filter out transcripts that are currently recording
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
transcript_id: str,
|
||||
user_id: str | None,
|
||||
) -> Transcript:
|
||||
"""
|
||||
Get a transcript by ID for HTTP request.
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
If the user is not allowed to access the transcript, it will raise a 403 error.
|
||||
|
||||
This method checks the share mode of the transcript and the user_id
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# if the transcript is anonymous, share mode is not checked
|
||||
transcript = Transcript(**result)
|
||||
if transcript.user_id is None:
|
||||
return transcript
|
||||
|
||||
if transcript.share_mode == "private":
|
||||
# in private mode, only the owner can access the transcript
|
||||
if transcript.user_id == user_id:
|
||||
return transcript
|
||||
|
||||
elif transcript.share_mode == "semi-private":
|
||||
# in semi-private mode, only the owner and the users with the link
|
||||
# can access the transcript
|
||||
if user_id is not None:
|
||||
return transcript
|
||||
|
||||
elif transcript.share_mode == "public":
|
||||
# in public mode, everyone can access the transcript
|
||||
return transcript
|
||||
|
||||
raise HTTPException(status_code=403, detail="Transcript access denied")
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Add a new transcript
|
||||
"""
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict, mutate=True):
|
||||
"""
|
||||
Update a transcript fields with key/values in values
|
||||
"""
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
transcript_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a transcript by id
|
||||
"""
|
||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with database.transaction(isolation="serializable"):
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
event: str,
|
||||
data: Any,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"events": transcript.events_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
topic: TranscriptTopic,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"topics": transcript.topics_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
"""
|
||||
Move mp3 file to storage
|
||||
"""
|
||||
|
||||
# store the audio on external storage
|
||||
await get_storage().put_file(
|
||||
transcript.storage_audio_path,
|
||||
transcript.audio_mp3_filename.read_bytes(),
|
||||
)
|
||||
|
||||
# indicate on the transcript that the audio is now on storage
|
||||
await self.update(transcript, {"audio_location": "storage"})
|
||||
|
||||
# unlink the local file
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
|
||||
async def upsert_participant(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
participant: TranscriptParticipant,
|
||||
) -> TranscriptParticipant:
|
||||
"""
|
||||
Add/update a participant to a transcript
|
||||
"""
|
||||
result = transcript.upsert_participant(participant)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"participants": transcript.participants_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_participant(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
participant_id: str,
|
||||
):
|
||||
"""
|
||||
Delete a participant from a transcript
|
||||
"""
|
||||
transcript.delete_participant(participant_id)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"participants": transcript.participants_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
@@ -1,54 +0,0 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=300, # as per their sdk
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -47,6 +47,7 @@ class ModalLLM(LLM):
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
|
||||
680
server/reflector/pipelines/main_live_pipeline.py
Normal file
680
server/reflector/pipelines/main_live_pipeline.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Main reflector pipeline for live streaming
|
||||
==========================================
|
||||
|
||||
This is the default pipeline used in the API.
|
||||
|
||||
It is decoupled to:
|
||||
- PipelineMainLive: have limited processing during live
|
||||
- PipelineMainPost: do heavy lifting after the live
|
||||
|
||||
It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from celery import chord, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptDuration,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
BroadcastProcessor,
|
||||
Pipeline,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
from reflector.processors.types import (
|
||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||
)
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
coro = f(*args, **kwargs)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
concerning this transcript
|
||||
"""
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
resp = await func(self, *args, **kwargs)
|
||||
if resp is None:
|
||||
return
|
||||
await self.ws_manager.send_json(
|
||||
room_id=self.ws_room_id,
|
||||
message=resp.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
try:
|
||||
return await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
self._lock = asyncio.Lock()
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||
return [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self._lock:
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
# if it's the first part, update the status of the transcript
|
||||
# but do not set the ended status yet.
|
||||
if isinstance(self, PipelineMainLive):
|
||||
status_mapping = {
|
||||
"started": "recording",
|
||||
"push": "recording",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainFinalSummaries):
|
||||
status_mapping = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
"ended": "ended",
|
||||
}
|
||||
else:
|
||||
# intermediate pipeline don't update status
|
||||
return
|
||||
|
||||
# mutate to model status
|
||||
status = status_mapping.get(status)
|
||||
if not status:
|
||||
return
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if status == transcript.status:
|
||||
return
|
||||
resp = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return resp
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_topic(self, data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
transcript=data.transcript.text,
|
||||
words=data.transcript.words,
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_short_summary(self, data):
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
"""
|
||||
Main pipeline for live streaming, attach to RTC connection
|
||||
Any long post process should be done in the post pipeline
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(
|
||||
path=transcript.audio_wav_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
|
||||
return pipeline
|
||||
|
||||
async def on_ended(self):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
"""
|
||||
Diarize the audio and update topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
# diarization works only if the file is uploaded to an external storage
|
||||
if transcript.audio_location == "local":
|
||||
pipeline.logger.info("Audio is local, skipping diarization")
|
||||
return
|
||||
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
audio_url = await transcript.get_audio_url()
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url,
|
||||
topics=topics,
|
||||
)
|
||||
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Diarization pipeline created")
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainFromTopics(PipelineMainBase):
|
||||
"""
|
||||
Pseudo class for generating a pipeline from topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
|
||||
# create pipeline
|
||||
processors = self.get_processors()
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
|
||||
|
||||
# push topics
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
for topic in topics:
|
||||
self.push(topic)
|
||||
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate title from the topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class PipelineMainFinalSummaries(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate summaries from the topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PipelineMainWaveform(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate waveform
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=self._transcript.audio_wav_filename,
|
||||
waveform_path=self._transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting remove upload")
|
||||
uploads = transcript.data_path.glob("upload.*")
|
||||
for upload in uploads:
|
||||
upload.unlink(missing_ok=True)
|
||||
logger.info("Remove upload done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting waveform")
|
||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Waveform done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
wav_filename = transcript.audio_wav_filename
|
||||
if not wav_filename.exists():
|
||||
logger.warning("Wav file not found, may be already converted")
|
||||
return
|
||||
|
||||
# Convert to mp3
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
|
||||
import av
|
||||
|
||||
with av.open(wav_filename.as_posix()) as in_container:
|
||||
in_stream = in_container.streams.audio[0]
|
||||
with av.open(mp3_filename.as_posix(), "w") as out_container:
|
||||
out_stream = out_container.add_stream("mp3")
|
||||
for frame in in_container.decode(in_stream):
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Delete the wav file
|
||||
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||
logger.info("No storage backend configured, skipping mp3 upload")
|
||||
return
|
||||
|
||||
logger.info("Starting upload mp3")
|
||||
|
||||
# If the audio mp3 is not available, just skip
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
if not mp3_filename.exists():
|
||||
logger.warning("Mp3 file not found, may be already uploaded")
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting title and short summary")
|
||||
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Title and short summary done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Celery tasks that can be called from the API
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_remove_upload(*, transcript_id: str):
|
||||
await pipeline_remove_upload(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_waveform(*, transcript_id: str):
|
||||
await pipeline_waveform(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_diarization(*, transcript_id: str):
|
||||
await pipeline_diarization(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
|
||||
await pipeline_title_and_short_summary(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
await pipeline_summaries(transcript_id=transcript_id)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
"""
|
||||
Run the post pipeline
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_title_preview = task_pipeline_title_and_short_summary.si(
|
||||
transcript_id=transcript_id
|
||||
)
|
||||
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||
transcript_id=transcript_id
|
||||
)
|
||||
|
||||
chain = chord(
|
||||
group(chain_mp3_and_diarize, chain_title_preview),
|
||||
chain_final_summaries,
|
||||
)
|
||||
chain.delay()
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload(transcript: Transcript, logger: Logger):
|
||||
import av
|
||||
|
||||
try:
|
||||
# open audio
|
||||
upload_filename = next(transcript.data_path.glob("upload.*"))
|
||||
container = av.open(upload_filename.as_posix())
|
||||
|
||||
# create pipeline
|
||||
pipeline = PipelineMainLive(transcript_id=transcript.id)
|
||||
pipeline.start()
|
||||
|
||||
# push audio to pipeline
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
pipeline.flush()
|
||||
|
||||
logger.info("Waiting for the pipeline to end")
|
||||
await pipeline.join()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info("Pipeline ended")
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload(*, transcript_id: str):
|
||||
return await pipeline_upload(transcript_id=transcript_id)
|
||||
152
server/reflector/pipelines/runner.py
Normal file
152
server/reflector/pipelines/runner.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Pipeline Runner
|
||||
===============
|
||||
|
||||
Pipeline runner designed to be executed in a asyncio task.
|
||||
|
||||
It is meant to be subclassed, and implement a create() method
|
||||
that expose/return a Pipeline instance.
|
||||
|
||||
During its lifecycle, it will emit the following status:
|
||||
- started: the pipeline has been started
|
||||
- push: the pipeline received at least one data
|
||||
- flush: the pipeline is flushing
|
||||
- ended: the pipeline has ended
|
||||
- error: the pipeline has ended with an error
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
|
||||
class PipelineRunner(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
self._is_first_push = True
|
||||
self._logger = logger.bind(
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create(self) -> Pipeline:
|
||||
"""
|
||||
Create the pipeline if not specified earlier.
|
||||
Should be implemented in a subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the pipeline as a coroutine task
|
||||
"""
|
||||
self._task = asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
async def join(self):
|
||||
"""
|
||||
Wait for the pipeline to finish
|
||||
"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
def start_sync(self):
|
||||
"""
|
||||
Start the pipeline synchronously (for non-asyncio apps)
|
||||
"""
|
||||
coro = self.run()
|
||||
asyncio.run(coro)
|
||||
|
||||
def push(self, data):
|
||||
"""
|
||||
Push data to the pipeline
|
||||
"""
|
||||
self._add_cmd("PUSH", data)
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush the pipeline
|
||||
"""
|
||||
self._add_cmd("FLUSH", None)
|
||||
|
||||
async def on_status(self, status):
|
||||
"""
|
||||
Called when the status of the pipeline changes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_ended(self):
|
||||
"""
|
||||
Called when the pipeline ends
|
||||
"""
|
||||
pass
|
||||
|
||||
def _add_cmd(self, cmd: str, data):
|
||||
"""
|
||||
Enqueue a command to be executed in the runner.
|
||||
Currently supported commands: PUSH, FLUSH
|
||||
"""
|
||||
self._q_cmd.put_nowait([cmd, data])
|
||||
|
||||
async def _set_status(self, status):
|
||||
self._logger.debug("Runner status updated", status=status)
|
||||
self.status = status
|
||||
if self.on_status:
|
||||
try:
|
||||
await self.on_status(status)
|
||||
except Exception:
|
||||
self._logger.exception("Runer error while setting status")
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
# create the pipeline if not yet done
|
||||
await self._set_status("init")
|
||||
self._is_first_push = True
|
||||
if not self.pipeline:
|
||||
self.pipeline = await self.create()
|
||||
|
||||
if not self.pipeline:
|
||||
# no pipeline created in create, just finish it then.
|
||||
await self._set_status("ended")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
return
|
||||
|
||||
# start the loop
|
||||
await self._set_status("started")
|
||||
while not self._ev_done.is_set():
|
||||
cmd, data = await self._q_cmd.get()
|
||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||
if func:
|
||||
await func(data)
|
||||
else:
|
||||
raise Exception(f"Unknown command {cmd}")
|
||||
except Exception:
|
||||
self._logger.exception("Runner error")
|
||||
await self._set_status("error")
|
||||
self._ev_done.set()
|
||||
raise
|
||||
|
||||
async def cmd_push(self, data):
|
||||
if self._is_first_push:
|
||||
await self._set_status("push")
|
||||
self._is_first_push = False
|
||||
await self.pipeline.push(data)
|
||||
|
||||
async def cmd_flush(self, data):
|
||||
await self._set_status("flush")
|
||||
await self.pipeline.flush()
|
||||
await self._set_status("ended")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
@@ -1,9 +1,16 @@
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
BroadcastProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .transcript_final_long_summary import ( # noqa: F401
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
)
|
||||
|
||||
181
server/reflector/processors/audio_diarization.py
Normal file
181
server/reflector/processors/audio_diarization.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
|
||||
|
||||
|
||||
class AudioDiarizationProcessor(Processor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
async def _push(self, data: AudioDiarizationInput):
|
||||
try:
|
||||
self.logger.info("Diarization started", audio_file_url=data.audio_url)
|
||||
diarization = await self._diarize(data)
|
||||
self.logger.info("Diarization finished")
|
||||
except Exception:
|
||||
self.logger.exception("Diarization failed after retrying")
|
||||
raise
|
||||
|
||||
# now reapply speaker to topics (if any)
|
||||
# topics is a list[BaseModel] with an attribute words
|
||||
# words is a list[BaseModel] with text, start and speaker attribute
|
||||
|
||||
# create a view of words based on topics
|
||||
# the current algorithm is using words index, we cannot use a generator
|
||||
words = list(self.iter_words_from_topics(data.topics))
|
||||
|
||||
# assign speaker to words (mutate the words list)
|
||||
self.assign_speaker(words, diarization)
|
||||
|
||||
# emit them
|
||||
for topic in data.topics:
|
||||
await self.emit(topic)
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
|
||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
self._diarization_remove_overlap(diarization)
|
||||
self._diarization_remove_segment_without_words(words, diarization)
|
||||
self._diarization_merge_same_speaker(words, diarization)
|
||||
self._diarization_assign_speaker(words, diarization)
|
||||
|
||||
def iter_words_from_topics(self, topics: TitleSummary):
|
||||
for topic in topics:
|
||||
for word in topic.transcript.words:
|
||||
yield word
|
||||
|
||||
def is_word_continuation(self, word_prev, word):
|
||||
"""
|
||||
Return True if the word is a continuation of the previous word
|
||||
by checking if the previous word is ending with a punctuation
|
||||
or if the current word is starting with a capital letter
|
||||
"""
|
||||
# is word_prev ending with a punctuation ?
|
||||
if word_prev.text and word_prev.text[-1] in ".?!":
|
||||
return False
|
||||
elif word.text and word.text[0].isupper():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||
"""
|
||||
Remove overlap in diarization results
|
||||
|
||||
When using a diarization algorithm, it's possible to have overlapping segments
|
||||
This function remove the overlap by keeping the longest segment
|
||||
|
||||
Warning: this function mutate the diarization list
|
||||
"""
|
||||
# remove overlap by keeping the longest segment
|
||||
diarization_idx = 0
|
||||
while diarization_idx < len(diarization) - 1:
|
||||
d = diarization[diarization_idx]
|
||||
dnext = diarization[diarization_idx + 1]
|
||||
if d["end"] > dnext["start"]:
|
||||
# remove the shortest segment
|
||||
if d["end"] - d["start"] > dnext["end"] - dnext["start"]:
|
||||
# remove next segment
|
||||
diarization.pop(diarization_idx + 1)
|
||||
else:
|
||||
# remove current segment
|
||||
diarization.pop(diarization_idx)
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_remove_segment_without_words(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Remove diarization segments without words
|
||||
|
||||
Warning: this function mutate the diarization list
|
||||
"""
|
||||
# count the number of words for each diarization segment
|
||||
diarization_count = []
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
count = 0
|
||||
for word in words:
|
||||
if start <= word.start < end:
|
||||
count += 1
|
||||
elif start < word.end <= end:
|
||||
count += 1
|
||||
diarization_count.append(count)
|
||||
|
||||
# remove diarization segments with no words
|
||||
diarization_idx = 0
|
||||
while diarization_idx < len(diarization):
|
||||
if diarization_count[diarization_idx] == 0:
|
||||
diarization.pop(diarization_idx)
|
||||
diarization_count.pop(diarization_idx)
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_merge_same_speaker(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Merge diarization contigous segments with the same speaker
|
||||
|
||||
Warning: this function mutate the diarization list
|
||||
"""
|
||||
# merge segment with same speaker
|
||||
diarization_idx = 0
|
||||
while diarization_idx < len(diarization) - 1:
|
||||
d = diarization[diarization_idx]
|
||||
dnext = diarization[diarization_idx + 1]
|
||||
if d["speaker"] == dnext["speaker"]:
|
||||
diarization[diarization_idx]["end"] = dnext["end"]
|
||||
diarization.pop(diarization_idx + 1)
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
"""
|
||||
Assign speaker to words based on diarization
|
||||
|
||||
Warning: this function mutate the words list
|
||||
"""
|
||||
|
||||
word_idx = 0
|
||||
last_speaker = None
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
speaker = d["speaker"]
|
||||
|
||||
# diarization may start after the first set of words
|
||||
# in this case, we assign the last speaker
|
||||
for word in words[word_idx:]:
|
||||
if word.start < start:
|
||||
# speaker change, but what make sense for assigning the word ?
|
||||
# If it's a new sentence, assign with the new speaker
|
||||
# If it's a continuation, assign with the last speaker
|
||||
is_continuation = False
|
||||
if word_idx > 0 and word_idx < len(words) - 1:
|
||||
is_continuation = self.is_word_continuation(
|
||||
*words[word_idx - 1 : word_idx + 1]
|
||||
)
|
||||
if is_continuation:
|
||||
word.speaker = last_speaker
|
||||
else:
|
||||
word.speaker = speaker
|
||||
last_speaker = speaker
|
||||
word_idx += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# now continue to assign speaker until the word starts after the end
|
||||
for word in words[word_idx:]:
|
||||
if start <= word.start < end:
|
||||
last_speaker = speaker
|
||||
word.speaker = speaker
|
||||
word_idx += 1
|
||||
elif word.start > end:
|
||||
break
|
||||
|
||||
# no more diarization available,
|
||||
# assign last speaker to all words without speaker
|
||||
for word in words[word_idx:]:
|
||||
word.speaker = last_speaker
|
||||
33
server/reflector/processors/audio_diarization_auto.py
Normal file
33
server/reflector/processors/audio_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
37
server/reflector/processors/audio_diarization_modal.py
Normal file
37
server/reflector/processors/audio_diarization_modal.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import httpx
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
# Gather diarization data
|
||||
params = {
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
timeout=None,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["diarization"]
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
|
||||
@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
super().__init__()
|
||||
def __init__(self, path: Path | str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path.suffix not in (".mp3", ".wav"):
|
||||
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
|
||||
self.path = path
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
self.last_packet = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.out_container:
|
||||
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
for packet in self.out_stream.encode(data):
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
await self.emit(data)
|
||||
|
||||
async def _flush(self):
|
||||
if self.out_container:
|
||||
for packet in self.out_stream.encode():
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
try:
|
||||
if self.last_packet is not None:
|
||||
duration = round(
|
||||
float(
|
||||
(self.last_packet.pts * self.last_packet.duration)
|
||||
* self.last_packet.time_base
|
||||
),
|
||||
2,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to get duration")
|
||||
duration = 0
|
||||
|
||||
self.out_container.close()
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
|
||||
if duration > 0:
|
||||
await self.emit(duration, name="duration")
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from profanityfilter import ProfanityFilter
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile, Transcript
|
||||
|
||||
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
self.profanity_filter = ProfanityFilter()
|
||||
self.profanity_filter.set_censor("*")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
raise NotImplementedError
|
||||
|
||||
def filter_profanity(self, text: str) -> str:
|
||||
"""
|
||||
Remove censored words from the transcript
|
||||
"""
|
||||
return self.profanity_filter.censor(text)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.base import Pipeline, Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@@ -13,8 +11,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name):
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
@@ -30,30 +29,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_pipeline(self, pipeline: Pipeline):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
self.processor.connect(processor)
|
||||
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
return await self.processor._push(data)
|
||||
|
||||
async def _flush(self):
|
||||
return await self.processor._flush()
|
||||
return cls._registry[name](**config | kwargs)
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
Implementation using the GPU service from banana.
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://...",
|
||||
"audio_ext": "wav",
|
||||
"timestamp": 123.456
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, banana_api_key: str, banana_model_key: str):
|
||||
super().__init__()
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.storage = Storage.get_instance(
|
||||
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
|
||||
)
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": banana_api_key,
|
||||
"X-Banana-Model-Key": banana_model_key,
|
||||
}
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Uploading audio {data.path.name} to S3")
|
||||
url = await self._upload_file(data.path)
|
||||
|
||||
print(f"Try to transcribe audio {data.path.name}")
|
||||
request_data = {
|
||||
"audio_url": url,
|
||||
"audio_ext": data.path.suffix[1:],
|
||||
"timestamp": float(round(data.timestamp, 2)),
|
||||
}
|
||||
response = await retry(client.post)(
|
||||
self.transcript_url,
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
print(f"Transcript response: {response.status_code} {response.content}")
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
transcript = Transcript(
|
||||
text=result["text"],
|
||||
words=[
|
||||
Word(text=word["text"], start=word["start"], end=word["end"])
|
||||
for word in result["words"]
|
||||
],
|
||||
)
|
||||
|
||||
# remove audio file from S3
|
||||
await self._delete_file(data.path)
|
||||
|
||||
return transcript
|
||||
|
||||
@retry
|
||||
async def _upload_file(self, path: Path) -> str:
|
||||
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
|
||||
return upload_result.url
|
||||
|
||||
@retry
|
||||
async def _delete_file(self, path: Path):
|
||||
await self.storage.delete_file(path.name)
|
||||
return True
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)
|
||||
@@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
timeout=self.timeout,
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
@@ -48,10 +49,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
text = result["text"][source_language]
|
||||
text = self.filter_profanity(text)
|
||||
transcript = Transcript(
|
||||
text=text,
|
||||
words=[
|
||||
Word(
|
||||
text=word["text"],
|
||||
|
||||
@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
ts = data.timestamp
|
||||
|
||||
for segment in segments:
|
||||
transcript.text += segment.text
|
||||
for word in segment.words:
|
||||
transcript.words.append(
|
||||
Word(
|
||||
|
||||
36
server/reflector/processors/audio_waveform_processor.py
Normal file
36
server/reflector/processors/audio_waveform_processor.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
|
||||
class AudioWaveformProcessor(Processor):
|
||||
"""
|
||||
Write the waveform for the final audio
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(audio_path, str):
|
||||
audio_path = Path(audio_path)
|
||||
if audio_path.suffix not in (".mp3", ".wav"):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
self.audio_path = audio_path
|
||||
self.waveform_path = waveform_path
|
||||
|
||||
async def _flush(self):
|
||||
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info("Waveform Processing Started")
|
||||
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
|
||||
|
||||
with open(self.waveform_path, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
self.logger.info("Waveform Processing Finished")
|
||||
await self.emit(waveform, name="waveform")
|
||||
|
||||
async def _push(_self, _data):
|
||||
return
|
||||
@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class Processor:
|
||||
class Emitter:
|
||||
def __init__(self, **kwargs):
|
||||
self._callbacks = {}
|
||||
|
||||
# register callbacks from kwargs (on_*)
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("on_"):
|
||||
self.on(value, name=key[3:])
|
||||
|
||||
def on(self, callback, name="default"):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
if name not in self._callbacks:
|
||||
self._callbacks[name] = []
|
||||
self._callbacks[name].append(callback)
|
||||
|
||||
def off(self, callback, name="default"):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
self._callbacks[name].remove(callback)
|
||||
|
||||
async def emit(self, data, name="default"):
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
for callback in self._callbacks[name]:
|
||||
await callback(data)
|
||||
|
||||
|
||||
class Processor(Emitter):
|
||||
INPUT_TYPE: type = None
|
||||
OUTPUT_TYPE: type = None
|
||||
|
||||
@@ -59,7 +94,8 @@ class Processor:
|
||||
["processor"],
|
||||
)
|
||||
|
||||
def __init__(self, callback=None, custom_logger=None):
|
||||
def __init__(self, callback=None, custom_logger=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = name = self.__class__.__name__
|
||||
self.m_processor = self.m_processor.labels(name)
|
||||
self.m_processor_call = self.m_processor_call.labels(name)
|
||||
@@ -70,9 +106,11 @@ class Processor:
|
||||
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
|
||||
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
|
||||
self._processors = []
|
||||
self._callbacks = []
|
||||
|
||||
# register callbacks
|
||||
if callback:
|
||||
self.on(callback)
|
||||
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
@@ -100,21 +138,6 @@ class Processor:
|
||||
"""
|
||||
self._processors.remove(processor)
|
||||
|
||||
def on(self, callback):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def off(self, callback):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def get_pref(self, key: str, default: Any = None):
|
||||
"""
|
||||
Get a preference from the pipeline prefs
|
||||
@@ -123,15 +146,16 @@ class Processor:
|
||||
return self.pipeline.get_pref(key, default)
|
||||
return default
|
||||
|
||||
async def emit(self, data):
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
for callback in self._callbacks:
|
||||
await callback(data)
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
async def emit(self, data, name="default"):
|
||||
if name == "default":
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
await super().emit(data, name=name)
|
||||
if name == "default":
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
|
||||
async def push(self, data):
|
||||
"""
|
||||
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
def on(self, callback, name="default"):
|
||||
self.processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
def off(self, callback, name="default"):
|
||||
self.processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
@@ -290,12 +314,12 @@ class BroadcastProcessor(Processor):
|
||||
processor.set_pipeline(pipeline)
|
||||
|
||||
async def _push(self, data):
|
||||
for processor in self.processors:
|
||||
await processor.push(data)
|
||||
coros = [processor.push(data) for processor in self.processors]
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def _flush(self):
|
||||
for processor in self.processors:
|
||||
await processor.flush()
|
||||
coros = [processor.flush() for processor in self.processors]
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
for processor in self.processors:
|
||||
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
|
||||
for processor in self.processors:
|
||||
processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
def on(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.on(callback)
|
||||
processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
def off(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.off(callback)
|
||||
processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
@@ -333,6 +357,7 @@ class Pipeline(Processor):
|
||||
self.logger.info("Pipeline created")
|
||||
|
||||
self.processors = processors
|
||||
self.options = None
|
||||
self.prefs = {}
|
||||
|
||||
for processor in processors:
|
||||
|
||||
@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
|
||||
# cut to the next .
|
||||
partial = Transcript(words=[])
|
||||
for word in self.transcript.words[:]:
|
||||
partial.text += word.text
|
||||
partial.words.append(word)
|
||||
if not self.is_sentence_terminated(word.text):
|
||||
continue
|
||||
|
||||
@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
||||
@@ -50,6 +51,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["text"]
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
profanity_filter.set_censor("*")
|
||||
|
||||
|
||||
class AudioFile(BaseModel):
|
||||
@@ -43,13 +51,34 @@ class Word(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
text: str = ""
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
def raw_text(self):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return self._get_censored_text(self.raw_text)
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
minutes = int(self.timestamp / 60)
|
||||
@@ -74,7 +103,6 @@ class Transcript(BaseModel):
|
||||
self.words = other.words
|
||||
else:
|
||||
self.words.extend(other.words)
|
||||
self.text += other.text
|
||||
|
||||
def add_offset(self, offset: float):
|
||||
for word in self.words:
|
||||
@@ -87,6 +115,51 @@ class Transcript(BaseModel):
|
||||
]
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
# from a list of word, create a list of segments
|
||||
# join the word that are less than 2 seconds apart
|
||||
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
|
||||
segments = []
|
||||
current_segment = None
|
||||
MAX_SEGMENT_LENGTH = 120
|
||||
|
||||
for word in self.words:
|
||||
if current_segment is None:
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the word is attach to another speaker, push the current segment
|
||||
# and start a new one
|
||||
if word.speaker != current_segment.speaker:
|
||||
segments.append(current_segment)
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# if the word is the end of a sentence, and we have enough content,
|
||||
# add the word to the current segment and push it
|
||||
current_segment.text += word.text
|
||||
current_segment.end = word.end
|
||||
|
||||
have_punc = PUNC_RE.search(word.text)
|
||||
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
||||
segments.append(current_segment)
|
||||
current_segment = None
|
||||
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class TitleSummary(BaseModel):
|
||||
title: str
|
||||
@@ -103,6 +176,10 @@ class TitleSummary(BaseModel):
|
||||
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
class TitleSummaryWithId(TitleSummary):
|
||||
id: str
|
||||
|
||||
|
||||
class FinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
duration: float
|
||||
@@ -318,3 +395,8 @@ class TranslationLanguages(BaseModel):
|
||||
|
||||
def is_supported(self, lang_id: str) -> bool:
|
||||
return lang_id in self.supported_languages
|
||||
|
||||
|
||||
class AudioDiarizationInput(BaseModel):
|
||||
audio_url: str
|
||||
topics: list[TitleSummaryWithId]
|
||||
|
||||
50
server/reflector/redis_cache.py
Normal file
50
server/reflector/redis_cache.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import json
|
||||
|
||||
import redis
|
||||
from reflector.settings import settings
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
|
||||
def get_redis_client(db=0):
|
||||
"""
|
||||
Get a Redis client for the specified database.
|
||||
"""
|
||||
if db not in redis_clients:
|
||||
redis_clients[db] = redis.StrictRedis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=db,
|
||||
)
|
||||
return redis_clients[db]
|
||||
|
||||
|
||||
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
|
||||
"""
|
||||
Cache the result of a function in Redis.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Check if the first argument is a string
|
||||
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Compute the cache key based on the arguments and prefix
|
||||
cache_key = prefix + ":" + args[argidx]
|
||||
redis_client = get_redis_client(db=db)
|
||||
cached_result = redis_client.get(cache_key)
|
||||
|
||||
if cached_result:
|
||||
return json.loads(cached_result.decode("utf-8"))
|
||||
|
||||
# If the result is not cached, call the original function
|
||||
result = func(*args, **kwargs)
|
||||
redis_client.setex(cache_key, duration, json.dumps(result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -2,7 +2,11 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
OPENMP_KMP_DUPLICATE_LIB_OK: bool = False
|
||||
|
||||
@@ -37,7 +41,7 @@ class Settings(BaseSettings):
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, banana, modal
|
||||
# backends: whisper, modal
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
@@ -46,24 +50,20 @@ class Settings(BaseSettings):
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription banana.dev configuration
|
||||
TRANSCRIPT_BANANA_API_KEY: str | None = None
|
||||
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# Audio transcription modal.com configuration
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Audio transcription storage
|
||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
||||
TRANSCRIPT_STORAGE_BACKEND: str | None = None
|
||||
|
||||
# Storage configuration for AWS
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
||||
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# LLM
|
||||
# available backend: openai, banana, modal, oobabooga
|
||||
# available backend: openai, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
@@ -78,13 +78,14 @@ class Settings(BaseSettings):
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
ZEPHYR_LLM_URL: str | None = None
|
||||
|
||||
# LLM Banana configuration
|
||||
LLM_BANANA_API_KEY: str | None = None
|
||||
LLM_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
@@ -109,5 +110,26 @@ class Settings(BaseSettings):
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
|
||||
# Celery
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
|
||||
# Redis
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_CACHE_DB: int = 2
|
||||
|
||||
# Secret key
|
||||
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||
|
||||
# Current hosting/domain
|
||||
BASE_URL: str = "http://localhost:1250"
|
||||
|
||||
# Profiling
|
||||
PROFILING: bool = False
|
||||
|
||||
# Healthcheck
|
||||
HEALTHCHECK_URL: str | None = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from pydantic import BaseModel
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class FileResult(BaseModel):
|
||||
@@ -17,7 +18,7 @@ class Storage:
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name, settings_prefix=""):
|
||||
def get_instance(cls, name: str, settings_prefix: str = ""):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.storage.storage_{name}"
|
||||
importlib.import_module(module_name)
|
||||
@@ -45,3 +46,9 @@ class Storage:
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_file_url(self, filename: str) -> str:
|
||||
return await self._get_file_url(filename)
|
||||
|
||||
async def _get_file_url(self, filename: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import aioboto3
|
||||
from reflector.storage.base import Storage, FileResult
|
||||
from reflector.logger import logger
|
||||
from reflector.storage.base import FileResult, Storage
|
||||
|
||||
|
||||
class AwsStorage(Storage):
|
||||
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
|
||||
Body=data,
|
||||
)
|
||||
|
||||
async def _get_file_url(self, filename: str) -> FileResult:
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
presigned_url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": s3filename},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
return FileResult(
|
||||
filename=filename,
|
||||
url=presigned_url,
|
||||
)
|
||||
return presigned_url
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
bucket = self.aws_bucket_name
|
||||
|
||||
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
|
||||
from reflector.app import celery_app # noqa
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("transcript_id", type=str)
|
||||
parser.add_argument("--delay", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.delay:
|
||||
task_pipeline_main_post.delay(args.transcript_id)
|
||||
else:
|
||||
task_pipeline_main_post(args.transcript_id)
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
@@ -57,6 +57,9 @@ def range_requests_response(
|
||||
),
|
||||
}
|
||||
|
||||
if request.method == "HEAD":
|
||||
return Response(headers=headers)
|
||||
|
||||
if content_disposition:
|
||||
headers["Content-Disposition"] = content_disposition
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from json import dumps, loads
|
||||
from pathlib import Path
|
||||
from json import loads
|
||||
|
||||
import av
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
@@ -10,25 +8,7 @@ from prometheus_client import Gauge
|
||||
from pydantic import BaseModel
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
Pipeline,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.types import FinalTitle
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
@@ -38,7 +18,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
|
||||
class TranscriptionContext(object):
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.pipeline = None
|
||||
self.pipeline_runner = None
|
||||
self.data_channel = None
|
||||
self.status = "idle"
|
||||
self.topics = []
|
||||
@@ -60,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
|
||||
ctx = self.ctx
|
||||
frame = await self.track.recv()
|
||||
try:
|
||||
await ctx.pipeline.push(frame)
|
||||
ctx.pipeline_runner.push(frame)
|
||||
except Exception as e:
|
||||
ctx.logger.error("Pipeline error", error=e)
|
||||
return frame
|
||||
@@ -71,27 +51,10 @@ class RtcOffer(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
|
||||
FINAL_TITLE = "FINAL_TITLE"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
pipeline_runner: PipelineRunner,
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -101,146 +64,10 @@ async def rtc_offer_base(
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
if changed:
|
||||
ctx.status = status
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.STATUS,
|
||||
args=event_callback_args,
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TRANSCRIPT,
|
||||
args=event_callback_args,
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Topic", topic=topic)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": topic.title,
|
||||
"timestamp": topic.timestamp,
|
||||
"transcript": topic.transcript.text,
|
||||
"desc": topic.summary,
|
||||
}
|
||||
)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
||||
)
|
||||
|
||||
async def on_final_short_summary(summary: FinalShortSummary):
|
||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
||||
"summary": summary.short_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_long_summary(summary: FinalLongSummary):
|
||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
||||
"summary": summary.long_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_title(title: FinalTitle):
|
||||
ctx.logger.info("FinalTitle", final_title=title)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_TITLE,
|
||||
args=event_callback_args,
|
||||
data=title,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
processors = []
|
||||
if audio_filename is not None:
|
||||
processors += [AudioFileWriterProcessor(path=audio_filename)]
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||
ctx.pipeline.set_pref("audio:target_language", target_language)
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
ctx.pipeline_runner = pipeline_runner
|
||||
ctx.pipeline_runner.start()
|
||||
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
# may be called twice
|
||||
@@ -249,12 +76,10 @@ async def rtc_offer_base(
|
||||
# - when we receive the close event, we do nothing.
|
||||
# 2. or the client close the connection
|
||||
# and there is nothing to do because it is already closed
|
||||
await update_status("processing")
|
||||
await ctx.pipeline.flush()
|
||||
ctx.pipeline_runner.flush()
|
||||
if close:
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
await update_status("ended")
|
||||
if pc in sessions:
|
||||
sessions.remove(pc)
|
||||
m_rtc_sessions.dec()
|
||||
@@ -287,7 +112,6 @@ async def rtc_offer_base(
|
||||
def on_track(track):
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||
asyncio.get_event_loop().create_task(update_status("recording"))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
@@ -308,8 +132,3 @@ async def rtc_clean_sessions(_):
|
||||
logger.debug(f"Closing session {pc}")
|
||||
await pc.close()
|
||||
sessions.clear()
|
||||
|
||||
|
||||
@router.post("/offer")
|
||||
async def rtc_offer(params: RtcOffer, request: Request):
|
||||
return await rtc_offer_base(params, request)
|
||||
|
||||
@@ -1,213 +1,33 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.processors.types import Word
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==============================================================
|
||||
# Models to move to a database, but required for the API to work
|
||||
# ==============================================================
|
||||
ALGORITHM = "HS256"
|
||||
DOWNLOAD_EXPIRE_MINUTES = 60
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
transcript: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
if existing_topic:
|
||||
existing_topic.update_from(topic)
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict):
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> None:
|
||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
def create_access_token(data: dict, expires_delta: timedelta):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -217,16 +37,20 @@ transcripts_controller = TranscriptController()
|
||||
|
||||
class GetTranscript(BaseModel):
|
||||
id: str
|
||||
user_id: str | None
|
||||
name: str
|
||||
status: str
|
||||
locked: bool
|
||||
duration: int
|
||||
duration: float
|
||||
title: str | None
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
created_at: datetime
|
||||
source_language: str
|
||||
target_language: str
|
||||
share_mode: str = Field("private")
|
||||
source_language: str | None
|
||||
target_language: str | None
|
||||
participants: list[TranscriptParticipant] | None
|
||||
reviewed: bool
|
||||
|
||||
|
||||
class CreateTranscript(BaseModel):
|
||||
@@ -241,6 +65,9 @@ class UpdateTranscript(BaseModel):
|
||||
title: Optional[str] = Field(None)
|
||||
short_summary: Optional[str] = Field(None)
|
||||
long_summary: Optional[str] = Field(None)
|
||||
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
|
||||
participants: Optional[list[TranscriptParticipant]] = Field(None)
|
||||
reviewed: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
@@ -251,11 +78,20 @@ class DeletionStatus(BaseModel):
|
||||
async def transcripts_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
from reflector.db import database
|
||||
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
return paginate(await transcripts_controller.get_all(user_id=user_id))
|
||||
return await paginate(
|
||||
database,
|
||||
await transcripts_controller.get_all(
|
||||
user_id=user_id,
|
||||
order_by="-created_at",
|
||||
return_query=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
@@ -277,16 +113,117 @@ async def transcripts_create(
|
||||
# ==============================================================
|
||||
|
||||
|
||||
class GetTranscriptSegmentTopic(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
speaker: int
|
||||
|
||||
|
||||
class GetTranscriptTopic(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
duration: float | None
|
||||
transcript: str
|
||||
segments: list[GetTranscriptSegmentTopic] = []
|
||||
|
||||
@classmethod
|
||||
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||
if not topic.words:
|
||||
# In previous version, words were missing
|
||||
# Just output a segment with speaker 0
|
||||
text = topic.transcript
|
||||
duration = None
|
||||
segments = [
|
||||
GetTranscriptSegmentTopic(
|
||||
text=topic.transcript,
|
||||
start=topic.timestamp,
|
||||
speaker=0,
|
||||
)
|
||||
]
|
||||
else:
|
||||
# New versions include words
|
||||
transcript = ProcessorTranscript(words=topic.words)
|
||||
text = transcript.text
|
||||
duration = transcript.duration
|
||||
segments = [
|
||||
GetTranscriptSegmentTopic(
|
||||
text=segment.text,
|
||||
start=segment.start,
|
||||
speaker=segment.speaker,
|
||||
)
|
||||
for segment in transcript.as_segments()
|
||||
]
|
||||
return cls(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
transcript=text,
|
||||
segments=segments,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
|
||||
class GetTranscriptTopicWithWords(GetTranscriptTopic):
|
||||
words: list[Word] = []
|
||||
|
||||
@classmethod
|
||||
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||
instance = super().from_transcript_topic(topic)
|
||||
if topic.words:
|
||||
instance.words = topic.words
|
||||
return instance
|
||||
|
||||
|
||||
class SpeakerWords(BaseModel):
|
||||
speaker: int
|
||||
words: list[Word]
|
||||
|
||||
|
||||
class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
|
||||
words_per_speaker: list[SpeakerWords] = []
|
||||
|
||||
@classmethod
|
||||
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||
instance = super().from_transcript_topic(topic)
|
||||
if topic.words:
|
||||
words_per_speakers = []
|
||||
# group words by speaker
|
||||
words = []
|
||||
for word in topic.words:
|
||||
if words and words[-1].speaker != word.speaker:
|
||||
words_per_speakers.append(
|
||||
SpeakerWords(
|
||||
speaker=words[-1].speaker,
|
||||
words=words,
|
||||
)
|
||||
)
|
||||
words = []
|
||||
words.append(word)
|
||||
if words:
|
||||
words_per_speakers.append(
|
||||
SpeakerWords(
|
||||
speaker=words[-1].speaker,
|
||||
words=words,
|
||||
)
|
||||
)
|
||||
|
||||
instance.words_per_speaker = words_per_speakers
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript
|
||||
return await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
@@ -299,32 +236,7 @@ async def transcript_update(
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {"events": []}
|
||||
if info.name is not None:
|
||||
values["name"] = info.name
|
||||
if info.locked is not None:
|
||||
values["locked"] = info.locked
|
||||
if info.long_summary is not None:
|
||||
values["long_summary"] = info.long_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
transcript_event["long_summary"] = info.long_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.short_summary is not None:
|
||||
values["short_summary"] = info.short_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
transcript_event["short_summary"] = info.short_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.title is not None:
|
||||
values["title"] = info.title
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
||||
transcript_event["title"] = info.title
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
values = info.dict(exclude_unset=True)
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
@@ -342,255 +254,63 @@ async def transcript_delete(
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
truncated_id = str(transcript.id).split("-")[0]
|
||||
filename = f"recording_{truncated_id}.mp3"
|
||||
|
||||
return range_requests_response(
|
||||
request,
|
||||
transcript.audio_mp3_filename,
|
||||
content_type="audio/mpeg",
|
||||
content_disposition=f"attachment; filename={filename}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
@router.get(
|
||||
"/transcripts/{transcript_id}/topics",
|
||||
response_model=list[GetTranscriptTopic],
|
||||
)
|
||||
async def transcript_get_topics(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript.topics
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# convert to GetTranscriptTopic
|
||||
return [
|
||||
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
|
||||
]
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/events")
|
||||
async def transcript_get_websocket_events(transcript_id: str):
|
||||
pass
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Websocket Manager
|
||||
# ==============================================================
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, transcript_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if transcript_id not in self.active_connections:
|
||||
self.active_connections[transcript_id] = []
|
||||
self.active_connections[transcript_id].append(websocket)
|
||||
|
||||
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
self.active_connections[transcript_id].remove(websocket)
|
||||
if not self.active_connections[transcript_id]:
|
||||
del self.active_connections[transcript_id]
|
||||
|
||||
async def send_json(self, transcript_id: str, message):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
for connection in self.active_connections[transcript_id][:]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
self.active_connections[transcript_id].remove(connection)
|
||||
|
||||
|
||||
ws_manager = WebsocketManager()
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(
|
||||
@router.get(
|
||||
"/transcripts/{transcript_id}/topics/with-words",
|
||||
response_model=list[GetTranscriptTopicWithWords],
|
||||
)
|
||||
async def transcript_get_topics_with_words(
|
||||
transcript_id: str,
|
||||
websocket: WebSocket,
|
||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
# user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
await ws_manager.connect(transcript_id, websocket)
|
||||
|
||||
# on first connection, send all events
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (RuntimeError, WebSocketDisconnect):
|
||||
ws_manager.disconnect(transcript_id, websocket)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Web RTC
|
||||
# ==============================================================
|
||||
|
||||
|
||||
async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
# OFC the current implementation is not good,
|
||||
# but it's just a POC before persistence. It won't query the
|
||||
# transcript from the database for each event.
|
||||
# print(f"Event: {event}", args, data)
|
||||
transcript_id = args
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
# event send to websocket clients may not be the same as the event
|
||||
# received from the pipeline. For example, the pipeline will send
|
||||
# a TRANSCRIPT event with all words, but this is not what we want
|
||||
# to send to the websocket client.
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(
|
||||
event=event,
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
transcript=data.transcript.text,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_TITLE:
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
resp = transcript.add_event(event=event, data=final_title)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"status": data.value,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: str,
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# FIXME do not allow multiple recording at the same time
|
||||
return await rtc_offer_base(
|
||||
params,
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_mp3_filename,
|
||||
source_language=transcript.source_language,
|
||||
target_language=transcript.target_language,
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# convert to GetTranscriptTopicWithWords
|
||||
return [
|
||||
GetTranscriptTopicWithWords.from_transcript_topic(topic)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker",
|
||||
response_model=GetTranscriptTopicWithWordsPerSpeaker,
|
||||
)
|
||||
async def transcript_get_topics_with_words_per_speaker(
|
||||
transcript_id: str,
|
||||
topic_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# get the topic from the transcript
|
||||
topic = next((t for t in transcript.topics if t.id == topic_id), None)
|
||||
if not topic:
|
||||
raise HTTPException(status_code=404, detail="Topic not found")
|
||||
|
||||
# convert to GetTranscriptTopicWithWordsPerSpeaker
|
||||
return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic)
|
||||
|
||||
115
server/reflector/views/transcripts_audio.py
Normal file
115
server/reflector/views/transcripts_audio.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Transcripts audio related endpoints
|
||||
===================================
|
||||
|
||||
"""
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from jose import jwt
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/transcripts/{transcript_id}/audio/mp3",
|
||||
operation_id="transcript_get_audio_mp3",
|
||||
)
|
||||
@router.head(
|
||||
"/transcripts/{transcript_id}/audio/mp3",
|
||||
operation_id="transcript_head_audio_mp3",
|
||||
)
|
||||
async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
token: str | None = None,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.JWTError:
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
# proxy S3 file, to prevent issue with CORS
|
||||
url = await transcript.get_audio_url()
|
||||
headers = {}
|
||||
|
||||
copy_headers = ["range", "accept-encoding"]
|
||||
for header in copy_headers:
|
||||
if header in request.headers:
|
||||
headers[header] = request.headers[header]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(request.method, url, headers=headers)
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
# proxy S3 file, to prevent issue with CORS
|
||||
url = await transcript.get_audio_url()
|
||||
headers = {}
|
||||
|
||||
copy_headers = ["range", "accept-encoding"]
|
||||
for header in copy_headers:
|
||||
if header in request.headers:
|
||||
headers[header] = request.headers[header]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(request.method, url, headers=headers)
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
)
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
truncated_id = str(transcript.id).split("-")[0]
|
||||
filename = f"recording_{truncated_id}.mp3"
|
||||
|
||||
return range_requests_response(
|
||||
request,
|
||||
transcript.audio_mp3_filename,
|
||||
content_type="audio/mpeg",
|
||||
content_disposition=f"attachment; filename={filename}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_waveform_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return transcript.audio_waveform
|
||||
143
server/reflector/views/transcripts_participants.py
Normal file
143
server/reflector/views/transcripts_participants.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Transcript participants API endpoints
|
||||
=====================================
|
||||
|
||||
"""
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
||||
from reflector.views.types import DeletionStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Participant(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str
|
||||
speaker: int | None
|
||||
name: str
|
||||
|
||||
|
||||
class CreateParticipant(BaseModel):
|
||||
speaker: Optional[int] = Field(None)
|
||||
name: str
|
||||
|
||||
|
||||
class UpdateParticipant(BaseModel):
|
||||
speaker: Optional[int] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/participants")
|
||||
async def transcript_get_participants(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> list[Participant]:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
return [
|
||||
Participant.model_validate(participant)
|
||||
for participant in transcript.participants
|
||||
]
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/participants")
|
||||
async def transcript_add_participant(
|
||||
transcript_id: str,
|
||||
participant: CreateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
if participant.speaker is not None:
|
||||
for p in transcript.participants:
|
||||
if p.speaker == participant.speaker:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Speaker already assigned",
|
||||
)
|
||||
|
||||
obj = await transcripts_controller.upsert_participant(
|
||||
transcript, TranscriptParticipant(**participant.dict())
|
||||
)
|
||||
return Participant.model_validate(obj)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
|
||||
async def transcript_get_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
for p in transcript.participants:
|
||||
if p.id == participant_id:
|
||||
return Participant.model_validate(p)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Participant not found")
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
|
||||
async def transcript_update_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
participant: UpdateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
for p in transcript.participants:
|
||||
if p.speaker == participant.speaker and p.id != participant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Speaker already assigned",
|
||||
)
|
||||
|
||||
# find the participant
|
||||
obj = None
|
||||
for p in transcript.participants:
|
||||
if p.id == participant_id:
|
||||
obj = p
|
||||
break
|
||||
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="Participant not found")
|
||||
|
||||
# update participant but just the fields that are set
|
||||
fields = participant.dict(exclude_unset=True)
|
||||
obj = obj.copy(update=fields)
|
||||
|
||||
await transcripts_controller.upsert_participant(transcript, obj)
|
||||
return Participant.model_validate(obj)
|
||||
|
||||
|
||||
@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
|
||||
async def transcript_delete_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> DeletionStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||
return DeletionStatus(status="ok")
|
||||
170
server/reflector/views/transcripts_speaker.py
Normal file
170
server/reflector/views/transcripts_speaker.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Reassign speakers in a transcript
|
||||
=================================
|
||||
|
||||
"""
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SpeakerAssignment(BaseModel):
|
||||
speaker: Optional[int] = Field(None, ge=0)
|
||||
participant: Optional[str] = Field(None)
|
||||
timestamp_from: float
|
||||
timestamp_to: float
|
||||
|
||||
|
||||
class SpeakerAssignmentStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class SpeakerMerge(BaseModel):
|
||||
speaker_from: int
|
||||
speaker_to: int
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}/speaker/assign")
|
||||
async def transcript_assign_speaker(
|
||||
transcript_id: str,
|
||||
assignment: SpeakerAssignment,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if assignment.speaker is None and assignment.participant is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either speaker or participant must be provided",
|
||||
)
|
||||
|
||||
if assignment.speaker is not None and assignment.participant is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only one of speaker or participant must be provided",
|
||||
)
|
||||
|
||||
# if it's a participant, search for it
|
||||
if assignment.speaker is not None:
|
||||
speaker = assignment.speaker
|
||||
|
||||
elif assignment.participant is not None:
|
||||
participant = next(
|
||||
(
|
||||
participant
|
||||
for participant in transcript.participants
|
||||
if participant.id == assignment.participant
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not participant:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Participant not found",
|
||||
)
|
||||
|
||||
# if the participant does not have a speaker, create one
|
||||
if participant.speaker is None:
|
||||
participant.speaker = transcript.find_empty_speaker()
|
||||
await transcripts_controller.upsert_participant(transcript, participant)
|
||||
|
||||
speaker = participant.speaker
|
||||
|
||||
# reassign speakers from words in the transcript
|
||||
ts_from = assignment.timestamp_from
|
||||
ts_to = assignment.timestamp_to
|
||||
changed_topics = []
|
||||
for topic in transcript.topics:
|
||||
changed = False
|
||||
for word in topic.words:
|
||||
if ts_from <= word.start <= ts_to:
|
||||
word.speaker = speaker
|
||||
changed = True
|
||||
if changed:
|
||||
changed_topics.append(topic)
|
||||
|
||||
# batch changes
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
return SpeakerAssignmentStatus(status="ok")
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}/speaker/merge")
|
||||
async def transcript_merge_speaker(
|
||||
transcript_id: str,
|
||||
merge: SpeakerMerge,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# ensure both speaker are not assigned to the 2 differents participants
|
||||
participant_from = next(
|
||||
(
|
||||
participant
|
||||
for participant in transcript.participants
|
||||
if participant.speaker == merge.speaker_from
|
||||
),
|
||||
None,
|
||||
)
|
||||
participant_to = next(
|
||||
(
|
||||
participant
|
||||
for participant in transcript.participants
|
||||
if participant.speaker == merge.speaker_to
|
||||
),
|
||||
None,
|
||||
)
|
||||
if participant_from and participant_to:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Both speakers are assigned to participants",
|
||||
)
|
||||
|
||||
# reassign speakers from words in the transcript
|
||||
speaker_from = merge.speaker_from
|
||||
speaker_to = merge.speaker_to
|
||||
changed_topics = []
|
||||
for topic in transcript.topics:
|
||||
changed = False
|
||||
for word in topic.words:
|
||||
if word.speaker == speaker_from:
|
||||
word.speaker = speaker_to
|
||||
changed = True
|
||||
if changed:
|
||||
changed_topics.append(topic)
|
||||
|
||||
# batch changes
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
return SpeakerAssignmentStatus(status="ok")
|
||||
79
server/reflector/views/transcripts_upload.py
Normal file
79
server/reflector/views/transcripts_upload.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UploadStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/upload")
|
||||
async def transcript_record_upload(
|
||||
transcript_id: str,
|
||||
file: UploadFile,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# ensure there is no other upload in the directory (searching data_path/upload.*)
|
||||
if any(transcript.data_path.glob("upload.*")):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="There is already an upload in progress"
|
||||
)
|
||||
|
||||
# save the file to the transcript folder
|
||||
extension = file.filename.split(".")[-1]
|
||||
upload_filename = transcript.data_path / f"upload.{extension}"
|
||||
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ensure the file is back to the beginning
|
||||
await file.seek(0)
|
||||
|
||||
# save the file to the transcript folder
|
||||
try:
|
||||
with open(upload_filename, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(16384)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception:
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
|
||||
# ensure the file have audio part, using av
|
||||
# XXX Trying to do this check on the initial UploadFile object is not
|
||||
# possible, dunno why. UploadFile.file has no name.
|
||||
# Trying to pass UploadFile.file with format=extension does not work
|
||||
# it never detect audio stream...
|
||||
container = av.open(upload_filename.as_posix())
|
||||
try:
|
||||
if not len(container.streams.audio):
|
||||
raise HTTPException(status_code=400, detail="File has no audio stream")
|
||||
except Exception:
|
||||
# delete the uploaded file
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
# set the status to "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_upload.delay(transcript_id=transcript_id)
|
||||
|
||||
return UploadStatus(status="ok")
|
||||
37
server/reflector/views/transcripts_webrtc.py
Normal file
37
server/reflector/views/transcripts_webrtc.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: str,
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# create a pipeline runner
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainLive
|
||||
|
||||
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
|
||||
|
||||
# FIXME do not allow multiple recording at the same time
|
||||
return await rtc_offer_base(
|
||||
params,
|
||||
request,
|
||||
pipeline_runner=pipeline_runner,
|
||||
)
|
||||
53
server/reflector/views/transcripts_websocket.py
Normal file
53
server/reflector/views/transcripts_websocket.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Transcripts websocket API
|
||||
=========================
|
||||
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/events")
|
||||
async def transcript_get_websocket_events(transcript_id: str):
|
||||
pass
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(
|
||||
transcript_id: str,
|
||||
websocket: WebSocket,
|
||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
# user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# connect to websocket manager
|
||||
# use ts:transcript_id as room id
|
||||
room_id = f"ts:{transcript_id}"
|
||||
ws_manager = get_ws_manager()
|
||||
await ws_manager.add_user_to_room(room_id, websocket)
|
||||
|
||||
try:
|
||||
# on first connection, send all events only to the current user
|
||||
for event in transcript.events:
|
||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||
# not necessary to be sent to the client; but keep the rest
|
||||
name = event.event
|
||||
if name in ("TRANSCRIPT", "STATUS"):
|
||||
continue
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
# we do not have command system now,
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (RuntimeError, WebSocketDisconnect):
|
||||
await ws_manager.remove_user_from_room(room_id, websocket)
|
||||
5
server/reflector/views/types.py
Normal file
5
server/reflector/views/types.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
32
server/reflector/worker/app.py
Normal file
32
server/reflector/worker/app.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import celery
|
||||
import structlog
|
||||
from celery import Celery
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
if celery.current_app.main != "default":
|
||||
logger.info(f"Celery already configured ({celery.current_app})")
|
||||
app = celery.current_app
|
||||
else:
|
||||
app = Celery(__name__)
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
app.conf.broker_connection_retry_on_startup = True
|
||||
app.autodiscover_tasks(
|
||||
[
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.worker.healthcheck",
|
||||
]
|
||||
)
|
||||
|
||||
# crontab
|
||||
app.conf.beat_schedule = {}
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
"schedule": 60.0 * 10,
|
||||
}
|
||||
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
||||
else:
|
||||
logger.warning("Healthcheck disabled, no url configured")
|
||||
18
server/reflector/worker/healthcheck.py
Normal file
18
server/reflector/worker/healthcheck.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import httpx
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@shared_task
|
||||
def healthcheck_ping():
|
||||
url = settings.HEALTHCHECK_URL
|
||||
if not url:
|
||||
return
|
||||
try:
|
||||
print("pinging healthcheck url", url)
|
||||
httpx.get(url, timeout=10)
|
||||
except Exception as e:
|
||||
logger.error("healthcheck_ping", error=str(e))
|
||||
126
server/reflector/ws_manager.py
Normal file
126
server/reflector/ws_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Websocket manager
|
||||
=================
|
||||
|
||||
This module contains the WebsocketManager class, which is responsible for
|
||||
managing websockets and handling websocket connections.
|
||||
|
||||
It uses the RedisPubSubManager class to subscribe to Redis channels and
|
||||
broadcast messages to all connected websockets.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import WebSocket
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
def __init__(self, host="localhost", port=6379):
|
||||
self.redis_host = host
|
||||
self.redis_port = port
|
||||
self.redis_connection = None
|
||||
self.pubsub = None
|
||||
|
||||
async def get_redis_connection(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
auto_close_connection_pool=False,
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.redis_connection is not None:
|
||||
return
|
||||
self.redis_connection = await self.get_redis_connection()
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.redis_connection is None:
|
||||
return
|
||||
await self.redis_connection.close()
|
||||
self.redis_connection = None
|
||||
|
||||
async def send_json(self, room_id: str, message: str) -> None:
|
||||
if not self.redis_connection:
|
||||
await self.connect()
|
||||
message = json.dumps(message)
|
||||
await self.redis_connection.publish(room_id, message)
|
||||
|
||||
async def subscribe(self, room_id: str) -> redis.Redis:
|
||||
await self.pubsub.subscribe(room_id)
|
||||
return self.pubsub
|
||||
|
||||
async def unsubscribe(self, room_id: str) -> None:
|
||||
await self.pubsub.unsubscribe(room_id)
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||
self.rooms: dict = {}
|
||||
self.pubsub_client = pubsub_client
|
||||
|
||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
if room_id in self.rooms:
|
||||
self.rooms[room_id].append(websocket)
|
||||
else:
|
||||
self.rooms[room_id] = [websocket]
|
||||
|
||||
await self.pubsub_client.connect()
|
||||
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||
|
||||
async def send_json(self, room_id: str, message: dict) -> None:
|
||||
await self.pubsub_client.send_json(room_id, message)
|
||||
|
||||
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
self.rooms[room_id].remove(websocket)
|
||||
|
||||
if len(self.rooms[room_id]) == 0:
|
||||
del self.rooms[room_id]
|
||||
await self.pubsub_client.unsubscribe(room_id)
|
||||
|
||||
async def _pubsub_data_reader(self, pubsub_subscriber):
|
||||
while True:
|
||||
message = await pubsub_subscriber.get_message(
|
||||
ignore_subscribe_messages=True
|
||||
)
|
||||
if message is not None:
|
||||
room_id = message["channel"].decode("utf-8")
|
||||
all_sockets = self.rooms[room_id]
|
||||
for socket in all_sockets:
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
await socket.send_json(data)
|
||||
|
||||
|
||||
def get_ws_manager() -> WebsocketManager:
|
||||
"""
|
||||
Returns the WebsocketManager instance for managing websockets.
|
||||
|
||||
This function initializes and returns the WebsocketManager instance,
|
||||
which is responsible for managing websockets and handling websocket
|
||||
connections.
|
||||
|
||||
Returns:
|
||||
WebsocketManager: The initialized WebsocketManager instance.
|
||||
|
||||
Raises:
|
||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||
"""
|
||||
local = threading.local()
|
||||
if hasattr(local, "ws_manager"):
|
||||
return local.ws_manager
|
||||
|
||||
pubsub_client = RedisPubSubManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
)
|
||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||
local.ws_manager = ws_manager
|
||||
return ws_manager
|
||||
@@ -4,4 +4,13 @@ if [ -f "/venv/bin/activate" ]; then
|
||||
source /venv/bin/activate
|
||||
fi
|
||||
alembic upgrade head
|
||||
python -m reflector.app
|
||||
|
||||
if [ "${ENTRYPOINT}" = "server" ]; then
|
||||
python -m reflector.app
|
||||
elif [ "${ENTRYPOINT}" = "worker" ]; then
|
||||
celery -A reflector.worker.app worker --loglevel=info
|
||||
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
||||
celery -A reflector.worker.app beat --loglevel=info
|
||||
else
|
||||
echo "Unknown command"
|
||||
fi
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -7,7 +8,6 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def setup_database():
|
||||
from reflector.settings import settings
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as f:
|
||||
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
||||
@@ -36,7 +36,13 @@ def dummy_processors():
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
|
||||
mock_translate.return_value = "Bonjour le monde"
|
||||
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa
|
||||
yield (
|
||||
mock_translate,
|
||||
mock_topic,
|
||||
mock_title,
|
||||
mock_long_summary,
|
||||
mock_short_summary,
|
||||
) # noqa
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -45,28 +51,50 @@ async def dummy_transcript():
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
|
||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||
async def _transcript(self, data: AudioFile):
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
print("transcripting", source_language)
|
||||
print("pipeline", self.pipeline)
|
||||
print("prefs", self.pipeline.prefs)
|
||||
_time_idx = 0
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return Transcript(
|
||||
text="Hello world.",
|
||||
words=[
|
||||
Word(start=0.0, end=1.0, text="Hello"),
|
||||
Word(start=1.0, end=2.0, text=" world."),
|
||||
Word(start=i, end=i + 1, text="Hello", speaker=0),
|
||||
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_transcript_auto"
|
||||
".AudioTranscriptAutoProcessor.get_instance"
|
||||
".AudioTranscriptAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_diarization():
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
|
||||
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
|
||||
_time_idx = 0
|
||||
|
||||
async def _diarize(self, data):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return [
|
||||
{"start": i, "end": i + 1, "speaker": 0},
|
||||
{"start": i + 1, "end": i + 2, "speaker": 1},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_diarization_auto"
|
||||
".AudioDiarizationAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
@@ -81,6 +109,25 @@ async def dummy_llm():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_storage():
|
||||
from reflector.storage.base import Storage
|
||||
|
||||
class DummyStorage(Storage):
|
||||
async def _put_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _delete_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _get_file_url(self, *args, **kwargs):
|
||||
return "http://fake_server/audio.mp3"
|
||||
|
||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
||||
mock_storage.return_value = DummyStorage()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
@@ -98,7 +145,96 @@ def ensure_casing():
|
||||
@pytest.fixture
|
||||
def sentence_tokenize():
|
||||
with patch(
|
||||
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
|
||||
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
|
||||
) as mock_sent_tokenize:
|
||||
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_enable_logging():
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
with NamedTemporaryFile() as f:
|
||||
yield {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": f"db+sqlite:///{f.name}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fake_mp3_upload():
|
||||
with patch(
|
||||
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
|
||||
) as mock_move:
|
||||
mock_move.return_value = True
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir):
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
from reflector.views.transcripts import transcripts_controller
|
||||
from reflector.db.transcripts import TranscriptTopic
|
||||
from reflector.processors.types import Word
|
||||
from pathlib import Path
|
||||
from httpx import AsyncClient
|
||||
import shutil
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# create a transcript
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.post("/transcripts", json={"name": "Test audio download"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
assert transcript is not None
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "finished"})
|
||||
|
||||
# manually copy a file at the expected location
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(path, audio_filename)
|
||||
|
||||
# create some topics
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Topic 1 summary",
|
||||
timestamp=0,
|
||||
transcript="Hello world",
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
Word(text="world", start=1, end=2, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 2",
|
||||
summary="Topic 2 summary",
|
||||
timestamp=2,
|
||||
transcript="Hello world",
|
||||
words=[
|
||||
Word(text="Hello", start=2, end=3, speaker=0),
|
||||
Word(text="world", start=3, end=4, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
yield transcript
|
||||
|
||||
140
server/tests/test_processor_audio_diarization.py
Normal file
140
server/tests/test_processor_audio_diarization.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,diarization,expected",
|
||||
[
|
||||
[
|
||||
"no overlap",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 1.0, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"same speaker",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 1.0, "end": 2.0, "speaker": "A"},
|
||||
],
|
||||
["A", "A", "A", "A"],
|
||||
],
|
||||
[
|
||||
# first segment is removed because it overlap
|
||||
# with the second segment, and it is smaller
|
||||
"overlap at 0.5s",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 0.5, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["B", "B", "B", "B"],
|
||||
],
|
||||
[
|
||||
"junk segment at 0.5s for 0.2s",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 0.5, "end": 0.7, "speaker": "B"},
|
||||
{"start": 1, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"start without diarization",
|
||||
[
|
||||
{"start": 0.5, "end": 1.0, "speaker": "A"},
|
||||
{"start": 1.0, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"end missing diarization",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 1.0, "end": 1.5, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"continuation of next speaker",
|
||||
[
|
||||
{"start": 0.0, "end": 0.9, "speaker": "A"},
|
||||
{"start": 1.5, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"continuation of previous speaker",
|
||||
[
|
||||
{"start": 0.0, "end": 0.5, "speaker": "A"},
|
||||
{"start": 1.0, "end": 2.0, "speaker": "B"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
[
|
||||
"segment without words",
|
||||
[
|
||||
{"start": 0.0, "end": 1.0, "speaker": "A"},
|
||||
{"start": 1.0, "end": 2.0, "speaker": "B"},
|
||||
{"start": 2.0, "end": 3.0, "speaker": "X"},
|
||||
],
|
||||
["A", "A", "B", "B"],
|
||||
],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_processors_audio_diarization(event_loop, name, diarization, expected):
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.types import (
|
||||
TitleSummaryWithId,
|
||||
Transcript,
|
||||
Word,
|
||||
AudioDiarizationInput,
|
||||
)
|
||||
|
||||
# create fake topic
|
||||
topics = [
|
||||
TitleSummaryWithId(
|
||||
id="1",
|
||||
title="Title1",
|
||||
summary="Summary1",
|
||||
timestamp=0.0,
|
||||
duration=1.0,
|
||||
transcript=Transcript(
|
||||
words=[
|
||||
Word(text="Word1", start=0.0, end=0.5),
|
||||
Word(text="word2.", start=0.5, end=1.0),
|
||||
]
|
||||
),
|
||||
),
|
||||
TitleSummaryWithId(
|
||||
id="2",
|
||||
title="Title2",
|
||||
summary="Summary2",
|
||||
timestamp=0.0,
|
||||
duration=1.0,
|
||||
transcript=Transcript(
|
||||
words=[
|
||||
Word(text="Word3", start=1.0, end=1.5),
|
||||
Word(text="word4.", start=1.5, end=2.0),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
diarizer = AudioDiarizationProcessor()
|
||||
with mock.patch.object(diarizer, "_diarize") as mock_diarize:
|
||||
mock_diarize.return_value = diarization
|
||||
|
||||
data = AudioDiarizationInput(
|
||||
audio_url="https://example.com/audio.mp3",
|
||||
topics=topics,
|
||||
)
|
||||
await diarizer._push(data)
|
||||
|
||||
# check that the speaker has been assigned to the words
|
||||
assert topics[0].transcript.words[0].speaker == expected[0]
|
||||
assert topics[0].transcript.words[1].speaker == expected[1]
|
||||
assert topics[1].transcript.words[0].speaker == expected[2]
|
||||
assert topics[1].transcript.words[1].speaker == expected[3]
|
||||
161
server/tests/test_processor_transcript_segment.py
Normal file
161
server/tests/test_processor_transcript_segment.py
Normal file
@@ -0,0 +1,161 @@
|
||||
def test_processor_transcript_segment():
|
||||
from reflector.processors.types import Transcript, Word
|
||||
|
||||
transcript = Transcript(
|
||||
words=[
|
||||
Word(text=" the", start=5.12, end=5.48, speaker=0),
|
||||
Word(text=" different", start=5.48, end=5.8, speaker=0),
|
||||
Word(text=" projects", start=5.8, end=6.3, speaker=0),
|
||||
Word(text=" that", start=6.3, end=6.5, speaker=0),
|
||||
Word(text=" are", start=6.5, end=6.58, speaker=0),
|
||||
Word(text=" going", start=6.58, end=6.82, speaker=0),
|
||||
Word(text=" on", start=6.82, end=7.26, speaker=0),
|
||||
Word(text=" to", start=7.26, end=7.4, speaker=0),
|
||||
Word(text=" give", start=7.4, end=7.54, speaker=0),
|
||||
Word(text=" you", start=7.54, end=7.9, speaker=0),
|
||||
Word(text=" context", start=7.9, end=8.24, speaker=0),
|
||||
Word(text=" and", start=8.24, end=8.66, speaker=0),
|
||||
Word(text=" I", start=8.66, end=8.72, speaker=0),
|
||||
Word(text=" think", start=8.72, end=8.82, speaker=0),
|
||||
Word(text=" that's", start=8.82, end=9.04, speaker=0),
|
||||
Word(text=" what", start=9.04, end=9.12, speaker=0),
|
||||
Word(text=" we'll", start=9.12, end=9.24, speaker=0),
|
||||
Word(text=" do", start=9.24, end=9.32, speaker=0),
|
||||
Word(text=" this", start=9.32, end=9.52, speaker=0),
|
||||
Word(text=" week.", start=9.52, end=9.76, speaker=0),
|
||||
Word(text=" Um,", start=10.24, end=10.62, speaker=0),
|
||||
Word(text=" so,", start=11.36, end=11.94, speaker=0),
|
||||
Word(text=" um,", start=12.46, end=12.92, speaker=0),
|
||||
Word(text=" what", start=13.74, end=13.94, speaker=0),
|
||||
Word(text=" we're", start=13.94, end=14.1, speaker=0),
|
||||
Word(text=" going", start=14.1, end=14.24, speaker=0),
|
||||
Word(text=" to", start=14.24, end=14.34, speaker=0),
|
||||
Word(text=" do", start=14.34, end=14.8, speaker=0),
|
||||
Word(text=" at", start=14.8, end=14.98, speaker=0),
|
||||
Word(text=" H", start=14.98, end=15.04, speaker=0),
|
||||
Word(text=" of", start=15.04, end=15.16, speaker=0),
|
||||
Word(text=" you,", start=15.16, end=15.26, speaker=0),
|
||||
Word(text=" maybe.", start=15.28, end=15.34, speaker=0),
|
||||
Word(text=" you", start=15.36, end=15.52, speaker=0),
|
||||
Word(text=" can", start=15.52, end=15.62, speaker=0),
|
||||
Word(text=" introduce", start=15.62, end=15.98, speaker=0),
|
||||
Word(text=" yourself", start=15.98, end=16.42, speaker=0),
|
||||
Word(text=" to", start=16.42, end=16.68, speaker=0),
|
||||
Word(text=" the", start=16.68, end=16.72, speaker=0),
|
||||
Word(text=" team", start=16.72, end=17.52, speaker=0),
|
||||
Word(text=" quickly", start=17.87, end=18.65, speaker=0),
|
||||
Word(text=" and", start=18.65, end=19.63, speaker=0),
|
||||
Word(text=" Oh,", start=20.91, end=21.55, speaker=0),
|
||||
Word(text=" this", start=21.67, end=21.83, speaker=0),
|
||||
Word(text=" is", start=21.83, end=22.17, speaker=0),
|
||||
Word(text=" a", start=22.17, end=22.35, speaker=0),
|
||||
Word(text=" reflector", start=22.35, end=22.89, speaker=0),
|
||||
Word(text=" translating", start=22.89, end=23.33, speaker=0),
|
||||
Word(text=" into", start=23.33, end=23.73, speaker=0),
|
||||
Word(text=" French", start=23.73, end=23.95, speaker=0),
|
||||
Word(text=" for", start=23.95, end=24.15, speaker=0),
|
||||
Word(text=" me.", start=24.15, end=24.43, speaker=0),
|
||||
Word(text=" This", start=27.87, end=28.19, speaker=0),
|
||||
Word(text=" is", start=28.19, end=28.45, speaker=0),
|
||||
Word(text=" all", start=28.45, end=28.79, speaker=0),
|
||||
Word(text=" the", start=28.79, end=29.15, speaker=0),
|
||||
Word(text=" way,", start=29.15, end=29.15, speaker=0),
|
||||
Word(text=" please,", start=29.53, end=29.59, speaker=0),
|
||||
Word(text=" please,", start=29.73, end=29.77, speaker=0),
|
||||
Word(text=" please,", start=29.77, end=29.83, speaker=0),
|
||||
Word(text=" please.", start=29.83, end=29.97, speaker=0),
|
||||
Word(text=" Yeah,", start=29.97, end=30.17, speaker=0),
|
||||
Word(text=" that's", start=30.25, end=30.33, speaker=0),
|
||||
Word(text=" all", start=30.33, end=30.49, speaker=0),
|
||||
Word(text=" it's", start=30.49, end=30.69, speaker=0),
|
||||
Word(text=" right.", start=30.69, end=30.69, speaker=0),
|
||||
Word(text=" Right.", start=30.72, end=30.98, speaker=1),
|
||||
Word(text=" Yeah,", start=31.56, end=31.72, speaker=2),
|
||||
Word(text=" that's", start=31.86, end=31.98, speaker=2),
|
||||
Word(text=" right.", start=31.98, end=32.2, speaker=2),
|
||||
Word(text=" Because", start=32.38, end=32.46, speaker=0),
|
||||
Word(text=" I", start=32.46, end=32.58, speaker=0),
|
||||
Word(text=" thought", start=32.58, end=32.78, speaker=0),
|
||||
Word(text=" I'd", start=32.78, end=33.0, speaker=0),
|
||||
Word(text=" be", start=33.0, end=33.02, speaker=0),
|
||||
Word(text=" able", start=33.02, end=33.18, speaker=0),
|
||||
Word(text=" to", start=33.18, end=33.34, speaker=0),
|
||||
Word(text=" pull", start=33.34, end=33.52, speaker=0),
|
||||
Word(text=" out.", start=33.52, end=33.68, speaker=0),
|
||||
Word(text=" Yeah,", start=33.7, end=33.9, speaker=0),
|
||||
Word(text=" that", start=33.9, end=34.02, speaker=0),
|
||||
Word(text=" was", start=34.02, end=34.24, speaker=0),
|
||||
Word(text=" the", start=34.24, end=34.34, speaker=0),
|
||||
Word(text=" one", start=34.34, end=34.44, speaker=0),
|
||||
Word(text=" before", start=34.44, end=34.7, speaker=0),
|
||||
Word(text=" that.", start=34.7, end=35.24, speaker=0),
|
||||
Word(text=" Friends,", start=35.84, end=36.46, speaker=0),
|
||||
Word(text=" if", start=36.64, end=36.7, speaker=0),
|
||||
Word(text=" you", start=36.7, end=36.7, speaker=0),
|
||||
Word(text=" have", start=36.7, end=37.24, speaker=0),
|
||||
Word(text=" tell", start=37.24, end=37.44, speaker=0),
|
||||
Word(text=" us", start=37.44, end=37.68, speaker=0),
|
||||
Word(text=" if", start=37.68, end=37.82, speaker=0),
|
||||
Word(text=" it's", start=37.82, end=38.04, speaker=0),
|
||||
Word(text=" good,", start=38.04, end=38.58, speaker=0),
|
||||
Word(text=" exceptionally", start=38.96, end=39.1, speaker=0),
|
||||
Word(text=" good", start=39.1, end=39.6, speaker=0),
|
||||
Word(text=" and", start=39.6, end=39.86, speaker=0),
|
||||
Word(text=" tell", start=39.86, end=40.0, speaker=0),
|
||||
Word(text=" us", start=40.0, end=40.06, speaker=0),
|
||||
Word(text=" when", start=40.06, end=40.2, speaker=0),
|
||||
Word(text=" it's", start=40.2, end=40.34, speaker=0),
|
||||
Word(text=" exceptionally", start=40.34, end=40.6, speaker=0),
|
||||
Word(text=" bad.", start=40.6, end=40.94, speaker=0),
|
||||
Word(text=" We", start=40.96, end=41.26, speaker=0),
|
||||
Word(text=" don't", start=41.26, end=41.44, speaker=0),
|
||||
Word(text=" need", start=41.44, end=41.66, speaker=0),
|
||||
Word(text=" that", start=41.66, end=41.82, speaker=0),
|
||||
Word(text=" at", start=41.82, end=41.94, speaker=0),
|
||||
Word(text=" the", start=41.94, end=41.98, speaker=0),
|
||||
Word(text=" middle", start=41.98, end=42.18, speaker=0),
|
||||
Word(text=" of", start=42.18, end=42.36, speaker=0),
|
||||
Word(text=" age.", start=42.36, end=42.7, speaker=0),
|
||||
Word(text=" Okay,", start=43.26, end=43.44, speaker=0),
|
||||
Word(text=" yeah,", start=43.68, end=43.76, speaker=0),
|
||||
Word(text=" that", start=43.78, end=44.3, speaker=0),
|
||||
Word(text=" sentence", start=44.3, end=44.72, speaker=0),
|
||||
Word(text=" right", start=44.72, end=45.1, speaker=0),
|
||||
Word(text=" before.", start=45.1, end=45.56, speaker=0),
|
||||
Word(text=" it", start=46.08, end=46.36, speaker=0),
|
||||
Word(text=" realizing", start=46.36, end=47.0, speaker=0),
|
||||
Word(text=" that", start=47.0, end=47.28, speaker=0),
|
||||
Word(text=" I", start=47.28, end=47.28, speaker=0),
|
||||
Word(text=" was", start=47.28, end=47.64, speaker=0),
|
||||
Word(text=" saying", start=47.64, end=48.06, speaker=0),
|
||||
Word(text=" that", start=48.06, end=48.44, speaker=0),
|
||||
Word(text=" it's", start=48.44, end=48.54, speaker=0),
|
||||
Word(text=" interesting", start=48.54, end=48.78, speaker=0),
|
||||
Word(text=" that", start=48.78, end=48.96, speaker=0),
|
||||
Word(text=" it's", start=48.96, end=49.08, speaker=0),
|
||||
Word(text=" translating", start=49.08, end=49.32, speaker=0),
|
||||
Word(text=" the", start=49.32, end=49.56, speaker=0),
|
||||
Word(text=" French", start=49.56, end=49.76, speaker=0),
|
||||
Word(text=" was", start=49.76, end=50.16, speaker=0),
|
||||
Word(text=" completely", start=50.16, end=50.4, speaker=0),
|
||||
Word(text=" wrong.", start=50.4, end=50.7, speaker=0),
|
||||
]
|
||||
)
|
||||
|
||||
segments = transcript.as_segments()
|
||||
assert len(segments) == 7
|
||||
|
||||
# check speaker order
|
||||
assert segments[0].speaker == 0
|
||||
assert segments[1].speaker == 0
|
||||
assert segments[2].speaker == 0
|
||||
assert segments[3].speaker == 1
|
||||
assert segments[4].speaker == 2
|
||||
assert segments[5].speaker == 0
|
||||
assert segments[6].speaker == 0
|
||||
|
||||
# check the timing (first entry, and first of others speakers)
|
||||
assert segments[0].start == 5.12
|
||||
assert segments[3].start == 30.72
|
||||
assert segments[4].start == 31.56
|
||||
assert segments[5].start == 32.38
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
from reflector.utils.retry import (
|
||||
@@ -8,6 +9,31 @@ from reflector.utils.retry import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_redirect(httpx_mock):
|
||||
async def custom_response(request: httpx.Request):
|
||||
if request.url.path == "/hello":
|
||||
await asyncio.sleep(1)
|
||||
return httpx.Response(
|
||||
status_code=303, headers={"location": "https://test_url/redirected"}
|
||||
)
|
||||
elif request.url.path == "/redirected":
|
||||
return httpx.Response(status_code=200, json={"hello": "world"})
|
||||
else:
|
||||
raise Exception("Unexpected path")
|
||||
|
||||
httpx_mock.add_callback(custom_response)
|
||||
async with httpx.AsyncClient() as client:
|
||||
# timeout should not triggered, as it will end up ok
|
||||
# even though the first request is a 303 and took more that 0.5
|
||||
resp = await retry(client.get)(
|
||||
"https://test_url/hello",
|
||||
retry_timeout=0.5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
assert resp.json() == {"hello": "world"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_httpx(httpx_mock):
|
||||
# this code should be force a retry
|
||||
|
||||
@@ -196,3 +196,29 @@ async def test_transcript_delete():
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_mark_reviewed():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test"
|
||||
assert response.json()["reviewed"] is False
|
||||
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test"
|
||||
assert response.json()["reviewed"] is False
|
||||
|
||||
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["reviewed"] is True
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["reviewed"] is True
|
||||
|
||||
@@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == content_type
|
||||
|
||||
# test get 404
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"url_suffix,content_type",
|
||||
[
|
||||
["/mp3", "audio/mpeg"],
|
||||
],
|
||||
)
|
||||
async def test_transcript_audio_download_head(
|
||||
fake_transcript, url_suffix, content_type
|
||||
):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == content_type
|
||||
|
||||
# test head 404
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
@@ -90,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
|
||||
assert response.status_code == 206
|
||||
assert response.headers["content-type"] == content_type
|
||||
assert response.headers["content-range"].startswith("bytes 100-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_audio_download_waveform(fake_transcript):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
assert isinstance(response.json()["data"], list)
|
||||
assert len(response.json()["data"]) >= 255
|
||||
|
||||
164
server/tests/test_transcripts_participants.py
Normal file
164
server/tests/test_transcripts_participants.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_participants():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["participants"] == []
|
||||
|
||||
# create a participant
|
||||
transcript_id = response.json()["id"]
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] is not None
|
||||
assert response.json()["speaker"] is None
|
||||
assert response.json()["name"] == "test"
|
||||
|
||||
# create another one with a speaker
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test2", "speaker": 1},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] is not None
|
||||
assert response.json()["speaker"] == 1
|
||||
assert response.json()["name"] == "test2"
|
||||
|
||||
# get all participants via transcript
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["participants"]) == 2
|
||||
|
||||
# get participants via participants endpoint
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_participants_same_speaker():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["participants"] == []
|
||||
transcript_id = response.json()["id"]
|
||||
|
||||
# create a participant
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test", "speaker": 1},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["speaker"] == 1
|
||||
|
||||
# create another one with the same speaker
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test2", "speaker": 1},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_participants_update_name():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["participants"] == []
|
||||
transcript_id = response.json()["id"]
|
||||
|
||||
# create a participant
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test", "speaker": 1},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["speaker"] == 1
|
||||
|
||||
# update the participant
|
||||
participant_id = response.json()["id"]
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/participants/{participant_id}",
|
||||
json={"name": "test2"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test2"
|
||||
|
||||
# verify the participant was updated
|
||||
response = await ac.get(
|
||||
f"/transcripts/{transcript_id}/participants/{participant_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test2"
|
||||
|
||||
# verify the participant was updated in transcript
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["participants"]) == 1
|
||||
assert response.json()["participants"][0]["name"] == "test2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_participants_update_speaker():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["participants"] == []
|
||||
transcript_id = response.json()["id"]
|
||||
|
||||
# create a participant
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test", "speaker": 1},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
participant1_id = response.json()["id"]
|
||||
|
||||
# create another participant
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={"name": "test2", "speaker": 2},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
participant2_id = response.json()["id"]
|
||||
|
||||
# update the participant, refused as speaker is already taken
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
||||
json={"speaker": 1},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
# delete the participant 1
|
||||
response = await ac.delete(
|
||||
f"/transcripts/{transcript_id}/participants/{participant1_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# update the participant 2 again, should be accepted now
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
||||
json={"speaker": 1},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# ensure participant2 name is still there
|
||||
response = await ac.get(
|
||||
f"/transcripts/{transcript_id}/participants/{participant2_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test2"
|
||||
assert response.json()["speaker"] == 1
|
||||
@@ -32,7 +32,7 @@ class ThreadedUvicorn:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def appserver(tmpdir):
|
||||
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
|
||||
@@ -52,12 +52,23 @@ async def appserver(tmpdir):
|
||||
settings.DATA_DIR = DATA_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(
|
||||
tmpdir,
|
||||
dummy_llm,
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
@@ -95,6 +106,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
@@ -121,14 +133,20 @@ async def test_transcript_rtc_and_websocket(
|
||||
# XXX aiortc is long to close the connection
|
||||
# instead of waiting a long time, we just send a STOP
|
||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
timeout = 20
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
# stop websocket task
|
||||
websocket_task.cancel()
|
||||
@@ -167,31 +185,47 @@ async def test_transcript_rtc_and_websocket(
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
|
||||
assert "WAVEFORM" in eventnames
|
||||
ev = events[eventnames.index("WAVEFORM")]
|
||||
assert isinstance(ev["data"]["waveform"], list)
|
||||
assert len(ev["data"]["waveform"]) >= 250
|
||||
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
|
||||
assert waveform_resp.status_code == 200
|
||||
assert waveform_resp.headers["content-type"] == "application/json"
|
||||
assert isinstance(waveform_resp.json()["data"], list)
|
||||
assert len(waveform_resp.json()["data"]) >= 250
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
assert statuses.index("recording") < statuses.index("processing")
|
||||
assert statuses.index("processing") < statuses.index("ended")
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# check that transcript status in model is updated
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
# check on the latest response that the audio duration is > 0
|
||||
assert resp.json()["duration"] > 0
|
||||
assert "DURATION" in eventnames
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mpeg"
|
||||
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert audio_resp.status_code == 200
|
||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(
|
||||
tmpdir,
|
||||
dummy_llm,
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
@@ -232,6 +266,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
@@ -265,6 +300,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
timeout = 20
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] == "ended":
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# stop websocket task
|
||||
@@ -306,7 +353,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
assert statuses.index("recording") < statuses.index("processing")
|
||||
assert statuses.index("processing") < statuses.index("ended")
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
|
||||
401
server/tests/test_transcripts_speaker.py
Normal file
401
server/tests/test_transcripts_speaker.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
|
||||
from reflector.app import app
|
||||
|
||||
transcript_id = fake_transcript_with_topics.id
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
# check the transcript exists
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# check initial topics of the transcript
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 0
|
||||
assert topics[0]["words"][1]["speaker"] == 0
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check through segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 0
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 0
|
||||
|
||||
# reassign speaker
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"speaker": 1,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 1
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 1
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 0
|
||||
|
||||
# reassign speaker, middle of 2 topics
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"speaker": 2,
|
||||
"timestamp_from": 1,
|
||||
"timestamp_to": 2.5,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 2
|
||||
assert topics[1]["words"][0]["speaker"] == 2
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 2
|
||||
assert topics[0]["segments"][0]["speaker"] == 1
|
||||
assert topics[0]["segments"][1]["speaker"] == 2
|
||||
assert len(topics[1]["segments"]) == 2
|
||||
assert topics[1]["segments"][0]["speaker"] == 2
|
||||
assert topics[1]["segments"][1]["speaker"] == 0
|
||||
|
||||
# reassign speaker, everything
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"speaker": 4,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 100,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 4
|
||||
assert topics[0]["words"][1]["speaker"] == 4
|
||||
assert topics[1]["words"][0]["speaker"] == 4
|
||||
assert topics[1]["words"][1]["speaker"] == 4
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 4
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_merge_speaker(fake_transcript_with_topics):
|
||||
from reflector.app import app
|
||||
|
||||
transcript_id = fake_transcript_with_topics.id
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
# check the transcript exists
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# check initial topics of the transcript
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 0
|
||||
assert topics[0]["words"][1]["speaker"] == 0
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
|
||||
# reassign speaker
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"speaker": 1,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 1
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
|
||||
# merge speakers
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/merge",
|
||||
json={
|
||||
"speaker_from": 1,
|
||||
"speaker_to": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 0
|
||||
assert topics[0]["words"][1]["speaker"] == 0
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
|
||||
from reflector.app import app
|
||||
|
||||
transcript_id = fake_transcript_with_topics.id
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
# check the transcript exists
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
transcript = response.json()
|
||||
assert len(transcript["participants"]) == 0
|
||||
|
||||
# create 2 participants
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={
|
||||
"name": "Participant 1",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
participant1_id = response.json()["id"]
|
||||
|
||||
response = await ac.post(
|
||||
f"/transcripts/{transcript_id}/participants",
|
||||
json={
|
||||
"name": "Participant 2",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
participant2_id = response.json()["id"]
|
||||
|
||||
# check participants speakers
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||
assert response.status_code == 200
|
||||
participants = response.json()
|
||||
assert len(participants) == 2
|
||||
assert participants[0]["name"] == "Participant 1"
|
||||
assert participants[0]["speaker"] is None
|
||||
assert participants[1]["name"] == "Participant 2"
|
||||
assert participants[1]["speaker"] is None
|
||||
|
||||
# check initial topics of the transcript
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 0
|
||||
assert topics[0]["words"][1]["speaker"] == 0
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check through segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 0
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 0
|
||||
|
||||
# reassign speaker from a participant
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"participant": participant1_id,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check participants if speaker has been assigned
|
||||
# first participant should have 1, because it's not used yet.
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||
assert response.status_code == 200
|
||||
participants = response.json()
|
||||
assert len(participants) == 2
|
||||
assert participants[0]["name"] == "Participant 1"
|
||||
assert participants[0]["speaker"] == 1
|
||||
assert participants[1]["name"] == "Participant 2"
|
||||
assert participants[1]["speaker"] is None
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 1
|
||||
assert topics[1]["words"][0]["speaker"] == 0
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 1
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 0
|
||||
|
||||
# reassign participant, middle of 2 topics
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"participant": participant2_id,
|
||||
"timestamp_from": 1,
|
||||
"timestamp_to": 2.5,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check participants if speaker has been assigned
|
||||
# first participant should have 1, because it's not used yet.
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||
assert response.status_code == 200
|
||||
participants = response.json()
|
||||
assert len(participants) == 2
|
||||
assert participants[0]["name"] == "Participant 1"
|
||||
assert participants[0]["speaker"] == 1
|
||||
assert participants[1]["name"] == "Participant 2"
|
||||
assert participants[1]["speaker"] == 2
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 2
|
||||
assert topics[1]["words"][0]["speaker"] == 2
|
||||
assert topics[1]["words"][1]["speaker"] == 0
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 2
|
||||
assert topics[0]["segments"][0]["speaker"] == 1
|
||||
assert topics[0]["segments"][1]["speaker"] == 2
|
||||
assert len(topics[1]["segments"]) == 2
|
||||
assert topics[1]["segments"][0]["speaker"] == 2
|
||||
assert topics[1]["segments"][1]["speaker"] == 0
|
||||
|
||||
# reassign speaker, everything
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"participant": participant1_id,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 100,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# check topics again
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||
assert response.status_code == 200
|
||||
topics = response.json()
|
||||
assert len(topics) == 2
|
||||
|
||||
# check through words
|
||||
assert topics[0]["words"][0]["speaker"] == 1
|
||||
assert topics[0]["words"][1]["speaker"] == 1
|
||||
assert topics[1]["words"][0]["speaker"] == 1
|
||||
assert topics[1]["words"][1]["speaker"] == 1
|
||||
# check segments
|
||||
assert len(topics[0]["segments"]) == 1
|
||||
assert topics[0]["segments"][0]["speaker"] == 1
|
||||
assert len(topics[1]["segments"]) == 1
|
||||
assert topics[1]["segments"][0]["speaker"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
|
||||
from reflector.app import app
|
||||
|
||||
transcript_id = fake_transcript_with_topics.id
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
# check the transcript exists
|
||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||
assert response.status_code == 200
|
||||
transcript = response.json()
|
||||
assert len(transcript["participants"]) == 0
|
||||
|
||||
# try reassign without any participant_id or speaker
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
# try reassing with both participant_id and speaker
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"participant": "123",
|
||||
"speaker": 1,
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
# try reassing with non-existing participant_id
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{transcript_id}/speaker/assign",
|
||||
json={
|
||||
"participant": "123",
|
||||
"timestamp_from": 0,
|
||||
"timestamp_to": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
26
server/tests/test_transcripts_topics.py
Normal file
26
server/tests/test_transcripts_topics.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_topics(fake_transcript_with_topics):
|
||||
from reflector.app import app
|
||||
|
||||
transcript_id = fake_transcript_with_topics.id
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
# check the transcript exists
|
||||
response = await ac.get(f"/transcripts/{transcript_id}/topics")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
topic_id = response.json()[0]["id"]
|
||||
|
||||
# get words per speakers
|
||||
response = await ac.get(
|
||||
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["words_per_speaker"]) == 1
|
||||
assert data["words_per_speaker"][0]["speaker"] == 0
|
||||
assert len(data["words_per_speaker"][0]["words"]) == 2
|
||||
61
server/tests/test_transcripts_upload.py
Normal file
61
server/tests/test_transcripts_upload.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_upload_file(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
|
||||
# create a transcript
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "idle"
|
||||
tid = response.json()["id"]
|
||||
|
||||
# upload mp3
|
||||
response = await ac.post(
|
||||
f"/transcripts/{tid}/record/upload",
|
||||
files={
|
||||
"file": (
|
||||
"test_short.wav",
|
||||
open("tests/records/test_short.wav", "rb"),
|
||||
"audio/mpeg",
|
||||
)
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait the processing to finish
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert "want to share" in response.json()[0]["transcript"]
|
||||
2
www/.env_template
Normal file
2
www/.env_template
Normal file
@@ -0,0 +1,2 @@
|
||||
FIEF_CLIENT_SECRET=<omitted, ask in zulip>
|
||||
ZULIP_API_KEY=<omitted, ask in zulip>
|
||||
2
www/.gitignore
vendored
2
www/.gitignore
vendored
@@ -39,3 +39,5 @@ next-env.d.ts
|
||||
|
||||
# Sentry Auth Token
|
||||
.sentryclirc
|
||||
|
||||
config.ts
|
||||
@@ -1,11 +1,18 @@
|
||||
"use client";
|
||||
|
||||
import { FiefAuthProvider } from "@fief/fief/nextjs/react";
|
||||
import { createContext } from "react";
|
||||
|
||||
export default function FiefWrapper({ children }) {
|
||||
export const CookieContext = createContext<{ hasAuthCookie: boolean }>({
|
||||
hasAuthCookie: false,
|
||||
});
|
||||
|
||||
export default function FiefWrapper({ children, hasAuthCookie }) {
|
||||
return (
|
||||
<FiefAuthProvider currentUserPath="/api/current-user">
|
||||
{children}
|
||||
</FiefAuthProvider>
|
||||
<CookieContext.Provider value={{ hasAuthCookie }}>
|
||||
<FiefAuthProvider currentUserPath="/api/current-user">
|
||||
{children}
|
||||
</FiefAuthProvider>
|
||||
</CookieContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
"use client";
|
||||
import {
|
||||
useFiefIsAuthenticated,
|
||||
useFiefUserinfo,
|
||||
} from "@fief/fief/nextjs/react";
|
||||
import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react";
|
||||
import Link from "next/link";
|
||||
|
||||
export default function UserInfo() {
|
||||
const isAuthenticated = useFiefIsAuthenticated();
|
||||
const userinfo = useFiefUserinfo();
|
||||
|
||||
return !isAuthenticated ? (
|
||||
<span className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2">
|
||||
<Link href="/login" className="outline-none">
|
||||
Log in or create account
|
||||
<Link href="/login" className="outline-none" prefetch={false}>
|
||||
Log in
|
||||
</Link>
|
||||
</span>
|
||||
) : (
|
||||
<span className="font-light px-2">
|
||||
{userinfo?.email} (
|
||||
<span className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px]">
|
||||
<Link href="/logout" className="outline-none">
|
||||
<Link href="/logout" className="outline-none" prefetch={false}>
|
||||
Log out
|
||||
</Link>
|
||||
</span>
|
||||
)
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ import React, { createContext, useContext, useState } from "react";
|
||||
|
||||
interface ErrorContextProps {
|
||||
error: Error | null;
|
||||
setError: React.Dispatch<React.SetStateAction<Error | null>>;
|
||||
humanMessage?: string;
|
||||
setError: (error: Error, humanMessage?: string) => void;
|
||||
}
|
||||
|
||||
const ErrorContext = createContext<ErrorContextProps | undefined>(undefined);
|
||||
@@ -22,9 +23,16 @@ interface ErrorProviderProps {
|
||||
|
||||
export const ErrorProvider: React.FC<ErrorProviderProps> = ({ children }) => {
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const [humanMessage, setHumanMessage] = useState<string | undefined>();
|
||||
|
||||
const declareError = (error, humanMessage?) => {
|
||||
setError(error);
|
||||
setHumanMessage(humanMessage);
|
||||
};
|
||||
return (
|
||||
<ErrorContext.Provider value={{ error, setError }}>
|
||||
<ErrorContext.Provider
|
||||
value={{ error, setError: declareError, humanMessage }}
|
||||
>
|
||||
{children}
|
||||
</ErrorContext.Provider>
|
||||
);
|
||||
|
||||
@@ -4,29 +4,51 @@ import { useEffect, useState } from "react";
|
||||
import * as Sentry from "@sentry/react";
|
||||
|
||||
const ErrorMessage: React.FC = () => {
|
||||
const { error, setError } = useError();
|
||||
const { error, setError, humanMessage } = useError();
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
|
||||
// Setup Shortcuts
|
||||
useEffect(() => {
|
||||
const handleKeyPress = (event: KeyboardEvent) => {
|
||||
switch (event.key) {
|
||||
case "^":
|
||||
throw new Error("Unhandled Exception thrown by '^' shortcut");
|
||||
case "$":
|
||||
setError(
|
||||
new Error("Unhandled Exception thrown by '$' shortcut"),
|
||||
"You did this to yourself",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", handleKeyPress);
|
||||
return () => document.removeEventListener("keydown", handleKeyPress);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
setIsVisible(true);
|
||||
Sentry.captureException(error);
|
||||
console.error("Error", error.message, error);
|
||||
if (humanMessage) {
|
||||
setIsVisible(true);
|
||||
Sentry.captureException(Error(humanMessage, { cause: error }));
|
||||
} else {
|
||||
Sentry.captureException(error);
|
||||
}
|
||||
|
||||
console.error("Error", error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
if (!isVisible || !error) return null;
|
||||
if (!isVisible || !humanMessage) return null;
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={() => {
|
||||
setIsVisible(false);
|
||||
setError(null);
|
||||
}}
|
||||
className="max-w-xs z-50 fixed bottom-5 right-5 md:bottom-10 md:right-10 border-solid bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded transition-opacity duration-300 ease-out opacity-100 hover:opacity-80 focus-visible:opacity-80 cursor-pointer transform hover:scale-105 focus-visible:scale-105"
|
||||
role="alert"
|
||||
>
|
||||
<span className="block sm:inline">{error?.message}</span>
|
||||
<span className="block sm:inline">{humanMessage}</span>
|
||||
</button>
|
||||
);
|
||||
};
|
||||
|
||||
94
www/app/[domain]/browse/page.tsx
Normal file
94
www/app/[domain]/browse/page.tsx
Normal file
@@ -0,0 +1,94 @@
|
||||
"use client";
|
||||
import React, { useState } from "react";
|
||||
|
||||
import { GetTranscript } from "../../api";
|
||||
import { Title } from "../../lib/textComponents";
|
||||
import Pagination from "./pagination";
|
||||
import Link from "next/link";
|
||||
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
|
||||
import { faGear } from "@fortawesome/free-solid-svg-icons";
|
||||
import useTranscriptList from "../transcripts/useTranscriptList";
|
||||
|
||||
export default function TranscriptBrowser() {
|
||||
const [page, setPage] = useState<number>(1);
|
||||
const { loading, response } = useTranscriptList(page);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/*
|
||||
<div className="flex flex-row gap-2">
|
||||
<input className="text-sm p-2 w-80 ring-1 ring-slate-900/10 shadow-sm rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 caret-blue-500" placeholder="Search" />
|
||||
</div>
|
||||
*/}
|
||||
|
||||
<div className="flex flex-row gap-2 items-center">
|
||||
<Title className="mb-5 mt-5 flex-1">Past transcripts</Title>
|
||||
<Pagination
|
||||
page={page}
|
||||
setPage={setPage}
|
||||
total={response?.total || 0}
|
||||
size={response?.size || 0}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{loading && (
|
||||
<div className="full-screen flex flex-col items-center justify-center">
|
||||
<FontAwesomeIcon
|
||||
icon={faGear}
|
||||
className="animate-spin-slow h-14 w-14 md:h-20 md:w-20"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{!loading && !response && (
|
||||
<div className="text-gray-500">
|
||||
No transcripts found, but you can
|
||||
<Link href="/transcripts/new" className="underline">
|
||||
record a meeting
|
||||
</Link>
|
||||
to get started.
|
||||
</div>
|
||||
)}
|
||||
<div /** center and max 900px wide */ className="mx-auto max-w-[900px]">
|
||||
<div className="grid grid-cols-1 gap-2 lg:gap-4 h-full">
|
||||
{response?.items.map((item: GetTranscript) => (
|
||||
<div
|
||||
key={item.id}
|
||||
className="flex flex-col bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div className="flex flex-row gap-2 items-start">
|
||||
<Link
|
||||
href={`/transcripts/${item.id}`}
|
||||
className="text-1xl font-semibold flex-1 pl-0 hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
|
||||
>
|
||||
{item.title || item.name}
|
||||
</Link>
|
||||
|
||||
{item.locked ? (
|
||||
<div className="inline-block bg-red-500 text-white px-2 py-1 rounded-full text-xs font-semibold">
|
||||
Locked
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
||||
{item.source_language ? (
|
||||
<div className="inline-block bg-blue-500 text-white px-2 py-1 rounded-full text-xs font-semibold">
|
||||
{item.source_language}
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-700">
|
||||
{new Date(item.created_at).toLocaleDateString("en-US")}
|
||||
</div>
|
||||
<div className="text-sm">{item.short_summary}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
75
www/app/[domain]/browse/pagination.tsx
Normal file
75
www/app/[domain]/browse/pagination.tsx
Normal file
@@ -0,0 +1,75 @@
|
||||
type PaginationProps = {
|
||||
page: number;
|
||||
setPage: (page: number) => void;
|
||||
total: number;
|
||||
size: number;
|
||||
};
|
||||
|
||||
export default function Pagination(props: PaginationProps) {
|
||||
const { page, setPage, total, size } = props;
|
||||
const totalPages = Math.ceil(total / size);
|
||||
|
||||
const pageNumbers = Array.from(
|
||||
{ length: totalPages },
|
||||
(_, i) => i + 1,
|
||||
).filter((pageNumber) => {
|
||||
if (totalPages <= 3) {
|
||||
// If there are 3 or fewer total pages, show all pages.
|
||||
return true;
|
||||
} else if (page <= 2) {
|
||||
// For the first two pages, show the first 3 pages.
|
||||
return pageNumber <= 3;
|
||||
} else if (page >= totalPages - 1) {
|
||||
// For the last two pages, show the last 3 pages.
|
||||
return pageNumber >= totalPages - 2;
|
||||
} else {
|
||||
// For all other cases, show 3 pages centered around the current page.
|
||||
return pageNumber >= page - 1 && pageNumber <= page + 1;
|
||||
}
|
||||
});
|
||||
|
||||
const canGoPrevious = page > 1;
|
||||
const canGoNext = page < totalPages;
|
||||
|
||||
const handlePageChange = (newPage: number) => {
|
||||
if (newPage >= 1 && newPage <= totalPages) {
|
||||
setPage(newPage);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex justify-center space-x-4 my-4">
|
||||
<button
|
||||
className={`w-10 h-10 rounded-full p-2 border border-gray-300 disabled:bg-white ${
|
||||
canGoPrevious ? "text-gray-500" : "text-gray-300"
|
||||
}`}
|
||||
onClick={() => handlePageChange(page - 1)}
|
||||
disabled={!canGoPrevious}
|
||||
>
|
||||
<i className="fa fa-chevron-left"><</i>
|
||||
</button>
|
||||
|
||||
{pageNumbers.map((pageNumber) => (
|
||||
<button
|
||||
key={pageNumber}
|
||||
className={`w-10 h-10 rounded-full p-2 border ${
|
||||
page === pageNumber ? "border-gray-600" : "border-gray-300"
|
||||
} rounded`}
|
||||
onClick={() => handlePageChange(pageNumber)}
|
||||
>
|
||||
{pageNumber}
|
||||
</button>
|
||||
))}
|
||||
|
||||
<button
|
||||
className={`w-10 h-10 rounded-full p-2 border border-gray-300 disabled:bg-white ${
|
||||
canGoNext ? "text-gray-500" : "text-gray-300"
|
||||
}`}
|
||||
onClick={() => handlePageChange(page + 1)}
|
||||
disabled={!canGoNext}
|
||||
>
|
||||
<i className="fa fa-chevron-right">></i>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
50
www/app/[domain]/domainContext.tsx
Normal file
50
www/app/[domain]/domainContext.tsx
Normal file
@@ -0,0 +1,50 @@
|
||||
"use client";
|
||||
import { createContext, useContext, useEffect, useState } from "react";
|
||||
import { DomainConfig } from "../lib/edgeConfig";
|
||||
|
||||
type DomainContextType = Omit<DomainConfig, "auth_callback_url">;
|
||||
|
||||
export const DomainContext = createContext<DomainContextType>({
|
||||
features: {
|
||||
requireLogin: false,
|
||||
privacy: true,
|
||||
browse: false,
|
||||
sendToZulip: false,
|
||||
},
|
||||
api_url: "",
|
||||
websocket_url: "",
|
||||
zulip_streams: "",
|
||||
});
|
||||
|
||||
export const DomainContextProvider = ({
|
||||
config,
|
||||
children,
|
||||
}: {
|
||||
config: DomainConfig;
|
||||
children: any;
|
||||
}) => {
|
||||
const [context, setContext] = useState<DomainContextType>();
|
||||
|
||||
useEffect(() => {
|
||||
if (!config) return;
|
||||
const { auth_callback_url, ...others } = config;
|
||||
setContext(others);
|
||||
}, [config]);
|
||||
|
||||
if (!context) return;
|
||||
|
||||
return (
|
||||
<DomainContext.Provider value={context}>{children}</DomainContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// Get feature config client-side with
|
||||
export const featureEnabled = (
|
||||
featureName: "requireLogin" | "privacy" | "browse" | "sendToZulip",
|
||||
) => {
|
||||
const context = useContext(DomainContext);
|
||||
|
||||
return context.features[featureName] as boolean | undefined;
|
||||
};
|
||||
|
||||
// Get config server-side (out of react) : see lib/edgeConfig.
|
||||
166
www/app/[domain]/layout.tsx
Normal file
166
www/app/[domain]/layout.tsx
Normal file
@@ -0,0 +1,166 @@
|
||||
import "../styles/globals.scss";
|
||||
import { Poppins } from "next/font/google";
|
||||
import { Metadata, Viewport } from "next";
|
||||
import FiefWrapper from "../(auth)/fiefWrapper";
|
||||
import UserInfo from "../(auth)/userInfo";
|
||||
import { ErrorProvider } from "../(errors)/errorContext";
|
||||
import ErrorMessage from "../(errors)/errorMessage";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import About from "../(aboutAndPrivacy)/about";
|
||||
import Privacy from "../(aboutAndPrivacy)/privacy";
|
||||
import { DomainContextProvider } from "./domainContext";
|
||||
import { getConfig } from "../lib/edgeConfig";
|
||||
import { ErrorBoundary } from "@sentry/nextjs";
|
||||
import { cookies } from "next/dist/client/components/headers";
|
||||
import { SESSION_COOKIE_NAME } from "../lib/fief";
|
||||
|
||||
const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] });
|
||||
|
||||
export const viewport: Viewport = {
|
||||
themeColor: "black",
|
||||
width: "device-width",
|
||||
initialScale: 1,
|
||||
maximumScale: 1,
|
||||
};
|
||||
|
||||
export const metadata: Metadata = {
|
||||
metadataBase: new URL(process.env.DEV_URL || "https://reflector.media"),
|
||||
title: {
|
||||
template: "%s – Reflector",
|
||||
default: "Reflector - AI-Powered Meeting Transcriptions by Monadical",
|
||||
},
|
||||
description:
|
||||
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise",
|
||||
applicationName: "Reflector",
|
||||
referrer: "origin-when-cross-origin",
|
||||
keywords: ["Reflector", "Monadical", "AI", "Meetings", "Transcription"],
|
||||
authors: [{ name: "Monadical Team", url: "https://monadical.com/team.html" }],
|
||||
formatDetection: {
|
||||
email: false,
|
||||
address: false,
|
||||
telephone: false,
|
||||
},
|
||||
|
||||
openGraph: {
|
||||
title: "Reflector",
|
||||
description:
|
||||
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise.",
|
||||
type: "website",
|
||||
},
|
||||
|
||||
twitter: {
|
||||
card: "summary_large_image",
|
||||
title: "Reflector",
|
||||
description:
|
||||
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise.",
|
||||
images: ["/r-icon.png"],
|
||||
},
|
||||
|
||||
icons: {
|
||||
icon: "/r-icon.png",
|
||||
shortcut: "/r-icon.png",
|
||||
apple: "/r-icon.png",
|
||||
},
|
||||
robots: { index: false, follow: false, noarchive: true, noimageindex: true },
|
||||
};
|
||||
|
||||
type LayoutProps = {
|
||||
params: {
|
||||
domain: string;
|
||||
};
|
||||
children: any;
|
||||
};
|
||||
|
||||
export default async function RootLayout({ children, params }: LayoutProps) {
|
||||
const config = await getConfig(params.domain);
|
||||
const { requireLogin, privacy, browse } = config.features;
|
||||
const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME);
|
||||
|
||||
return (
|
||||
<html lang="en">
|
||||
<body className={poppins.className + " h-screen relative"}>
|
||||
<FiefWrapper hasAuthCookie={hasAuthCookie}>
|
||||
<DomainContextProvider config={config}>
|
||||
<ErrorBoundary fallback={<p>"something went really wrong"</p>}>
|
||||
<ErrorProvider>
|
||||
<ErrorMessage />
|
||||
<div
|
||||
id="container"
|
||||
className="items-center h-[100svh] w-[100svw] p-2 md:p-4 grid grid-rows-layout gap-2 md:gap-4"
|
||||
>
|
||||
<header className="flex justify-between items-center w-full">
|
||||
{/* Logo on the left */}
|
||||
<Link
|
||||
href="/"
|
||||
className="flex outline-blue-300 md:outline-none focus-visible:underline underline-offset-2 decoration-[.5px] decoration-gray-500"
|
||||
>
|
||||
<Image
|
||||
src="/reach.png"
|
||||
width={16}
|
||||
height={16}
|
||||
className="h-10 w-auto"
|
||||
alt="Reflector"
|
||||
/>
|
||||
<div className="hidden flex-col ml-2 md:block">
|
||||
<h1 className="text-[38px] font-bold tracking-wide leading-tight">
|
||||
Reflector
|
||||
</h1>
|
||||
<p className="text-gray-500 text-xs tracking-tighter">
|
||||
Capture the signal, not the noise
|
||||
</p>
|
||||
</div>
|
||||
</Link>
|
||||
<div>
|
||||
{/* Text link on the right */}
|
||||
<Link
|
||||
href="/transcripts/new"
|
||||
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
|
||||
>
|
||||
Create
|
||||
</Link>
|
||||
{browse ? (
|
||||
<>
|
||||
·
|
||||
<Link
|
||||
href="/browse"
|
||||
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
|
||||
prefetch={false}
|
||||
>
|
||||
Browse
|
||||
</Link>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
·
|
||||
<About buttonText="About" />
|
||||
{privacy ? (
|
||||
<>
|
||||
·
|
||||
<Privacy buttonText="Privacy" />
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{requireLogin ? (
|
||||
<>
|
||||
·
|
||||
<UserInfo />
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{children}
|
||||
</div>
|
||||
</ErrorProvider>
|
||||
</ErrorBoundary>
|
||||
</DomainContextProvider>
|
||||
</FiefWrapper>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
154
www/app/[domain]/transcripts/[transcriptId]/page.tsx
Normal file
154
www/app/[domain]/transcripts/[transcriptId]/page.tsx
Normal file
@@ -0,0 +1,154 @@
|
||||
"use client";
|
||||
import Modal from "../modal";
|
||||
import useTranscript from "../useTranscript";
|
||||
import useTopics from "../useTopics";
|
||||
import useWaveform from "../useWaveform";
|
||||
import useMp3 from "../useMp3";
|
||||
import { TopicList } from "../topicList";
|
||||
import { Topic } from "../webSocketTypes";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import "../../../styles/button.css";
|
||||
import FinalSummary from "../finalSummary";
|
||||
import ShareLink from "../shareLink";
|
||||
import QRCode from "react-qr-code";
|
||||
import TranscriptTitle from "../transcriptTitle";
|
||||
import ShareModal from "./shareModal";
|
||||
import Player from "../player";
|
||||
import WaveformLoading from "../waveformLoading";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { featureEnabled } from "../../domainContext";
|
||||
import { toShareMode } from "../../../lib/shareMode";
|
||||
|
||||
type TranscriptDetails = {
|
||||
params: {
|
||||
transcriptId: string;
|
||||
};
|
||||
};
|
||||
|
||||
export default function TranscriptDetails(details: TranscriptDetails) {
|
||||
const transcriptId = details.params.transcriptId;
|
||||
const router = useRouter();
|
||||
|
||||
const transcript = useTranscript(transcriptId);
|
||||
const topics = useTopics(transcriptId);
|
||||
const waveform = useWaveform(transcriptId);
|
||||
const useActiveTopic = useState<Topic | null>(null);
|
||||
const mp3 = useMp3(transcriptId);
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const statusToRedirect = ["idle", "recording", "processing"];
|
||||
if (statusToRedirect.includes(transcript.response?.status)) {
|
||||
const newUrl = "/transcripts/" + details.params.transcriptId + "/record";
|
||||
// Shallow redirection does not work on NextJS 13
|
||||
// https://github.com/vercel/next.js/discussions/48110
|
||||
// https://github.com/vercel/next.js/discussions/49540
|
||||
router.push(newUrl, undefined);
|
||||
// history.replaceState({}, "", newUrl);
|
||||
}
|
||||
}, [transcript.response?.status]);
|
||||
|
||||
const fullTranscript =
|
||||
topics.topics
|
||||
?.map((topic) => topic.transcript)
|
||||
.join("\n\n")
|
||||
.replace(/ +/g, " ")
|
||||
.trim() || "";
|
||||
|
||||
if (transcript && transcript.response) {
|
||||
if (transcript.error || topics?.error) {
|
||||
return (
|
||||
<Modal
|
||||
title="Transcription Not Found"
|
||||
text="A trascription with this ID does not exist."
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!transcriptId || transcript?.loading || topics?.loading) {
|
||||
return <Modal title="Loading" text={"Loading transcript..."} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{featureEnabled("sendToZulip") && (
|
||||
<ShareModal
|
||||
transcript={transcript.response}
|
||||
topics={topics ? topics.topics : null}
|
||||
show={showModal}
|
||||
setShow={(v) => setShowModal(v)}
|
||||
/>
|
||||
)}
|
||||
<div className="flex flex-col">
|
||||
{transcript?.response?.title && (
|
||||
<TranscriptTitle
|
||||
title={transcript.response.title}
|
||||
transcriptId={transcript.response.id}
|
||||
/>
|
||||
)}
|
||||
{waveform.waveform && mp3.media ? (
|
||||
<Player
|
||||
topics={topics?.topics || []}
|
||||
useActiveTopic={useActiveTopic}
|
||||
waveform={waveform.waveform}
|
||||
media={mp3.media}
|
||||
mediaDuration={transcript.response.duration}
|
||||
/>
|
||||
) : waveform.error ? (
|
||||
<div>"error loading this recording"</div>
|
||||
) : (
|
||||
<WaveformLoading />
|
||||
)}
|
||||
</div>
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-2 lg:grid-rows-1 gap-2 lg:gap-4 h-full">
|
||||
<TopicList
|
||||
topics={topics.topics || []}
|
||||
useActiveTopic={useActiveTopic}
|
||||
autoscroll={false}
|
||||
/>
|
||||
|
||||
<div className="w-full h-full grid grid-rows-layout-one grid-cols-1 gap-2 lg:gap-4">
|
||||
<section className=" bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4 h-full">
|
||||
{transcript.response.long_summary ? (
|
||||
<FinalSummary
|
||||
fullTranscript={fullTranscript}
|
||||
summary={transcript.response.long_summary}
|
||||
transcriptId={transcript.response.id}
|
||||
openZulipModal={() => setShowModal(true)}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex flex-col h-full justify-center content-center">
|
||||
{transcript.response.status == "processing" ? (
|
||||
<p>Loading Transcript</p>
|
||||
) : (
|
||||
<p>
|
||||
There was an error generating the final summary, please
|
||||
come back later
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
|
||||
<section className="flex items-center">
|
||||
<div className="mr-4 hidden md:block h-auto">
|
||||
<QRCode
|
||||
value={`${location.origin}/transcripts/${details.params.transcriptId}`}
|
||||
level="L"
|
||||
size={98}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-grow max-w-full">
|
||||
<ShareLink
|
||||
transcriptId={transcript?.response?.id}
|
||||
userId={transcript?.response?.user_id}
|
||||
shareMode={toShareMode(transcript?.response?.share_mode)}
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -6,14 +6,17 @@ import useWebRTC from "../../useWebRTC";
|
||||
import useTranscript from "../../useTranscript";
|
||||
import { useWebSockets } from "../../useWebSockets";
|
||||
import useAudioDevice from "../../useAudioDevice";
|
||||
import "../../../styles/button.css";
|
||||
import "../../../../styles/button.css";
|
||||
import { Topic } from "../../webSocketTypes";
|
||||
import getApi from "../../../lib/getApi";
|
||||
import LiveTrancription from "../../liveTranscription";
|
||||
import DisconnectedIndicator from "../../disconnectedIndicator";
|
||||
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
|
||||
import { faGear } from "@fortawesome/free-solid-svg-icons";
|
||||
import { lockWakeState, releaseWakeState } from "../../../lib/wakeLock";
|
||||
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Player from "../../player";
|
||||
import useMp3 from "../../useMp3";
|
||||
import WaveformLoading from "../../waveformLoading";
|
||||
|
||||
type TranscriptDetails = {
|
||||
params: {
|
||||
@@ -37,14 +40,17 @@ const TranscriptRecord = (details: TranscriptDetails) => {
|
||||
}, []);
|
||||
|
||||
const transcript = useTranscript(details.params.transcriptId);
|
||||
const api = getApi();
|
||||
const webRTC = useWebRTC(stream, details.params.transcriptId, api);
|
||||
const webRTC = useWebRTC(stream, details.params.transcriptId);
|
||||
const webSockets = useWebSockets(details.params.transcriptId);
|
||||
|
||||
const { audioDevices, getAudioStream } = useAudioDevice();
|
||||
|
||||
const [hasRecorded, setHasRecorded] = useState(false);
|
||||
const [recordedTime, setRecordedTime] = useState(0);
|
||||
const [startTime, setStartTime] = useState(0);
|
||||
const [transcriptStarted, setTranscriptStarted] = useState(false);
|
||||
let mp3 = useMp3(details.params.transcriptId, true);
|
||||
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (!transcriptStarted && webSockets.transcriptText.length !== 0)
|
||||
@@ -52,15 +58,25 @@ const TranscriptRecord = (details: TranscriptDetails) => {
|
||||
}, [webSockets.transcriptText]);
|
||||
|
||||
useEffect(() => {
|
||||
if (transcript?.response?.longSummary) {
|
||||
const newUrl = `/transcripts/${transcript.response.id}`;
|
||||
const statusToRedirect = ["ended", "error"];
|
||||
|
||||
//TODO if has no topic and is error, get back to new
|
||||
if (
|
||||
statusToRedirect.includes(transcript.response?.status) ||
|
||||
statusToRedirect.includes(webSockets.status.value)
|
||||
) {
|
||||
const newUrl = "/transcripts/" + details.params.transcriptId;
|
||||
// Shallow redirection does not work on NextJS 13
|
||||
// https://github.com/vercel/next.js/discussions/48110
|
||||
// https://github.com/vercel/next.js/discussions/49540
|
||||
// router.push(newUrl, undefined, { shallow: true });
|
||||
history.replaceState({}, "", newUrl);
|
||||
}
|
||||
});
|
||||
router.replace(newUrl);
|
||||
// history.replaceState({}, "", newUrl);
|
||||
} // history.replaceState({}, "", newUrl);
|
||||
}, [webSockets.status.value, transcript.response?.status]);
|
||||
|
||||
useEffect(() => {
|
||||
if (transcript.response?.status === "ended") mp3.getNow();
|
||||
}, [transcript.response]);
|
||||
|
||||
useEffect(() => {
|
||||
lockWakeState();
|
||||
@@ -71,19 +87,32 @@ const TranscriptRecord = (details: TranscriptDetails) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Recorder
|
||||
setStream={setStream}
|
||||
onStop={() => {
|
||||
setStream(null);
|
||||
setHasRecorded(true);
|
||||
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
|
||||
}}
|
||||
topics={webSockets.topics}
|
||||
getAudioStream={getAudioStream}
|
||||
useActiveTopic={useActiveTopic}
|
||||
isPastMeeting={false}
|
||||
audioDevices={audioDevices}
|
||||
/>
|
||||
{webSockets.waveform && webSockets.duration && mp3?.media ? (
|
||||
<Player
|
||||
topics={webSockets.topics || []}
|
||||
useActiveTopic={useActiveTopic}
|
||||
waveform={webSockets.waveform}
|
||||
media={mp3.media}
|
||||
mediaDuration={webSockets.duration}
|
||||
/>
|
||||
) : recordedTime ? (
|
||||
<WaveformLoading />
|
||||
) : (
|
||||
<Recorder
|
||||
setStream={setStream}
|
||||
onStop={() => {
|
||||
setStream(null);
|
||||
setRecordedTime(Date.now() - startTime);
|
||||
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
|
||||
}}
|
||||
onRecord={() => {
|
||||
setStartTime(Date.now());
|
||||
}}
|
||||
getAudioStream={getAudioStream}
|
||||
audioDevices={audioDevices}
|
||||
transcriptId={details.params.transcriptId}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-mobile-inner lg:grid-rows-1 gap-2 lg:gap-4 h-full">
|
||||
<TopicList
|
||||
@@ -95,7 +124,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
|
||||
<section
|
||||
className={`w-full h-full bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4`}
|
||||
>
|
||||
{!hasRecorded ? (
|
||||
{!recordedTime ? (
|
||||
<>
|
||||
{transcriptStarted && (
|
||||
<h2 className="md:text-lg font-bold">Transcription</h2>
|
||||
@@ -129,6 +158,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
|
||||
couple of minutes. Please do not navigate away from the page
|
||||
during this time.
|
||||
</p>
|
||||
{/* NTH If login required remove last sentence */}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
159
www/app/[domain]/transcripts/[transcriptId]/shareModal.tsx
Normal file
159
www/app/[domain]/transcripts/[transcriptId]/shareModal.tsx
Normal file
@@ -0,0 +1,159 @@
|
||||
import React, { useContext, useState, useEffect } from "react";
|
||||
import SelectSearch from "react-select-search";
|
||||
import { getZulipMessage, sendZulipMessage } from "../../../lib/zulip";
|
||||
import { GetTranscript, GetTranscriptTopic } from "../../../api";
|
||||
import "react-select-search/style.css";
|
||||
import { DomainContext } from "../../domainContext";
|
||||
|
||||
type ShareModal = {
|
||||
show: boolean;
|
||||
setShow: (show: boolean) => void;
|
||||
transcript: GetTranscript | null;
|
||||
topics: GetTranscriptTopic[] | null;
|
||||
};
|
||||
|
||||
interface Stream {
|
||||
id: number;
|
||||
name: string;
|
||||
topics: string[];
|
||||
}
|
||||
|
||||
interface SelectSearchOption {
|
||||
name: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
const ShareModal = (props: ShareModal) => {
|
||||
const [stream, setStream] = useState<string | undefined>(undefined);
|
||||
const [topic, setTopic] = useState<string | undefined>(undefined);
|
||||
const [includeTopics, setIncludeTopics] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [streams, setStreams] = useState<Stream[]>([]);
|
||||
const { zulip_streams } = useContext(DomainContext);
|
||||
|
||||
useEffect(() => {
|
||||
fetch(zulip_streams + "/streams.json")
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then((data) => {
|
||||
data = data.sort((a: Stream, b: Stream) =>
|
||||
a.name.localeCompare(b.name),
|
||||
);
|
||||
setStreams(data);
|
||||
setIsLoading(false);
|
||||
// data now contains the JavaScript object decoded from JSON
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("There was a problem with your fetch operation:", error);
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleSendToZulip = () => {
|
||||
if (!props.transcript) return;
|
||||
|
||||
const msg = getZulipMessage(props.transcript, props.topics, includeTopics);
|
||||
|
||||
if (stream && topic) sendZulipMessage(stream, topic, msg);
|
||||
};
|
||||
|
||||
if (props.show && isLoading) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
let streamOptions: SelectSearchOption[] = [];
|
||||
if (streams) {
|
||||
streams.forEach((stream) => {
|
||||
const value = stream.name;
|
||||
streamOptions.push({ name: value, value: value });
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="absolute">
|
||||
{props.show && (
|
||||
<div className="fixed inset-0 bg-gray-600 bg-opacity-50 overflow-y-auto h-full w-full z-50">
|
||||
<div className="relative top-20 mx-auto p-5 w-96 shadow-lg rounded-md bg-white">
|
||||
<div className="mt-3 text-center">
|
||||
<h3 className="font-bold text-xl">Send to Zulip</h3>
|
||||
|
||||
{/* Checkbox for 'Include Topics' */}
|
||||
<div className="mt-4 text-left ml-5">
|
||||
<label className="flex items-center">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="form-checkbox rounded border-gray-300 text-indigo-600 shadow-sm focus:border-indigo-300 focus:ring focus:ring-indigo-200 focus:ring-opacity-50"
|
||||
checked={includeTopics}
|
||||
onChange={(e) => setIncludeTopics(e.target.checked)}
|
||||
/>
|
||||
<span className="ml-2">Include topics</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center mt-4">
|
||||
<span className="mr-2">#</span>
|
||||
<SelectSearch
|
||||
search={true}
|
||||
options={streamOptions}
|
||||
value={stream}
|
||||
onChange={(val) => {
|
||||
setTopic(undefined);
|
||||
setStream(val.toString());
|
||||
}}
|
||||
placeholder="Pick a stream"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{stream && (
|
||||
<>
|
||||
<div className="flex items-center mt-4">
|
||||
<span className="mr-2 invisible">#</span>
|
||||
<SelectSearch
|
||||
search={true}
|
||||
options={
|
||||
streams
|
||||
.find((s) => s.name == stream)
|
||||
?.topics.sort((a: string, b: string) =>
|
||||
a.localeCompare(b),
|
||||
)
|
||||
.map((t) => ({ name: t, value: t })) || []
|
||||
}
|
||||
value={topic}
|
||||
onChange={(val) => setTopic(val.toString())}
|
||||
placeholder="Pick a topic"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<button
|
||||
className={`bg-blue-400 hover:bg-blue-500 focus-visible:bg-blue-500 text-white rounded py-2 px-4 mr-3 ${
|
||||
!stream || !topic ? "opacity-50 cursor-not-allowed" : ""
|
||||
}`}
|
||||
disabled={!stream || !topic}
|
||||
onClick={() => {
|
||||
handleSendToZulip();
|
||||
props.setShow(false);
|
||||
}}
|
||||
>
|
||||
Send to Zulip
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="bg-red-500 hover:bg-red-700 focus-visible:bg-red-700 text-white rounded py-2 px-4 mt-4"
|
||||
onClick={() => props.setShow(false)}
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ShareModal;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user