Merge branch 'main' into reenable-non-latin-languages

This commit is contained in:
Koper
2024-01-02 18:09:08 +07:00
196 changed files with 13533 additions and 4368 deletions

View File

@@ -3,7 +3,7 @@ name: Deploy to Amazon ECS
on: [workflow_dispatch]
env:
# 384658522150.dkr.ecr.us-east-1.amazonaws.com/reflector
# 950402358378.dkr.ecr.us-east-1.amazonaws.com/reflector
AWS_REGION: us-east-1
ECR_REPOSITORY: reflector

View File

@@ -2,15 +2,20 @@ name: Unittests
on:
pull_request:
paths-ignore:
- 'www/**'
paths:
- 'server/**'
push:
paths-ignore:
- 'www/**'
paths:
- 'server/**'
jobs:
pytest:
runs-on: ubuntu-latest
services:
redis:
image: redis:6
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- name: Install poetry

3
.gitignore vendored
View File

@@ -2,3 +2,6 @@
server/.env
.env
server/exportdanswer
.vercel
.env*.local
dump.rdb

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11.6

39
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,39 @@
{
"configurations": [
{
"type": "aws-sam",
"request": "direct-invoke",
"name": "lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)",
"invokeTarget": {
"target": "template",
"templatePath": "${workspaceFolder}/aws/lambda-nodejs18.x/template.yaml",
"logicalId": "HelloWorldFunction"
},
"lambda": {
"payload": {},
"environmentVariables": {},
"runtime": "nodejs18.x"
}
},
{
"type": "aws-sam",
"request": "direct-invoke",
"name": "API lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)",
"invokeTarget": {
"target": "api",
"templatePath": "${workspaceFolder}/aws/lambda-nodejs18.x/template.yaml",
"logicalId": "HelloWorldFunction"
},
"api": {
"path": "/hello",
"httpMethod": "get",
"payload": {
"json": {}
}
},
"lambda": {
"runtime": "nodejs18.x"
}
}
]
}

113
README.md
View File

@@ -1,12 +1,14 @@
# Reflector
Reflector is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
Reflector Audio Management and Analysis is a cutting-edge web application under development by Monadical. It utilizes AI to record meetings, providing a permanent record with transcripts, translations, and automated summaries.
The project architecture consists of three primary components:
* **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
* **Back-End**: Python server that offers an API and data persistence, found in `server/`.
* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end.
## Table of Contents
@@ -21,7 +23,12 @@ The project architecture consists of three primary components:
- [OpenAPI Code Generation](#openapi-code-generation)
- [Back-End](#back-end)
- [Installation](#installation-1)
- [Start the project](#start-the-project)
- [Start the API/Backend](#start-the-apibackend)
- [Redis (Mac)](#redis-mac)
- [Redis (Windows)](#redis-windows)
- [Update the database schema (run on first install, and after each pull containing a migration)](#update-the-database-schema-run-on-first-install-and-after-each-pull-containing-a-migration)
- [Main Server](#main-server)
- [Crontab (optional)](#crontab-optional)
- [Using docker](#using-docker)
- [Using local GPT4All](#using-local-gpt4all)
- [Using local files](#using-local-files)
@@ -31,12 +38,18 @@ The project architecture consists of three primary components:
### Contribution Guidelines
All new contributions should be made in a separate branch. Before any code is merged into `master`, it requires a code review.
All new contributions should be made in a separate branch. Before any code is merged into `main`, it requires a code review.
### How to Install Blackhole (Mac Only)
To record both your voice and the meeting you're taking part in, you need :
- For an in-person meeting, make sure your microphone is in range of all participants.
- If using several miscrophones, make sure to merge the audio feeds into one with an external tool.
- For an online meeting, if you do not use headphones, your microphone should be able to pick up both your voice and the audio feed of the meeting.
- If you want to use headphones, you need to merge the audio feeds with an external tool.
This is an external tool for merging the audio feeds as explained in the previous section of this document.
Note: We currently do not have instructions for Windows users.
* Install [Blackhole](https://github.com/ExistentialAudio/BlackHole)-2ch (2 ch is enough) by 1 of 2 options listed.
* Setup ["Aggregate device"](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and local microphone input.
* Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
@@ -59,8 +72,12 @@ To install the application, run:
```bash
yarn install
cp .env_template .env
cp config-template.ts config.ts
```
Then, fill in the environment variables in `.env` and the configuration in `config.ts` as needed. If you are unsure on how to proceed, ask in Zulip.
### Run the Application
To run the application in development mode, run:
@@ -69,7 +86,7 @@ To run the application in development mode, run:
yarn dev
```
Then open [http://localhost:3000](http://localhost:3000) to view it in the browser.
Then (after completing server setup and starting it) open [http://localhost:3000](http://localhost:3000) to view it in the browser.
### OpenAPI Code Generation
@@ -87,35 +104,76 @@ Start with `cd server`.
### Installation
Download [Python 3.11 from the official website](https://www.python.org/downloads/) and ensure you have version 3.11 by running `python --version`.
Run:
```bash
poetry install
python --version # It should say 3.11
pip install poetry
poetry install --no-root
cp .env_template .env
```
Then create an `.env` with:
Then fill `.env` with the omitted values (ask in Zulip). At the moment of this writing, the only value omitted is `AUTH_FIEF_CLIENT_SECRET`.
```
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=<omitted>
### Start the API/Backend
LLM_BACKEND=modal
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=<omitted>
AUTH_BACKEND=fief
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
AUTH_FIEF_CLIENT_ID=KQzRsNgoY<omitted>
AUTH_FIEF_CLIENT_SECRET=<omitted>
```
### Start the project
Use:
Start the background worker:
```bash
poetry run python3 -m reflector.app
poetry run celery -A reflector.worker.app worker --loglevel=info
```
### Redis (Mac)
```bash
yarn add redis
redis-server
```
### Redis (Windows)
**Option 1**
```bash
docker compose up -d redis
```
**Option 2**
Install:
- [Git for Windows](https://gitforwindows.org/)
- [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install)
- Install your preferred Linux distribution via the Microsoft Store (e.g., Ubuntu).
Open your Linux distribution and update the package list:
```bash
sudo apt update
sudo apt install redis-server
redis-server
```
## Update the database schema (run on first install, and after each pull containing a migration)
```bash
poetry run alembic heads
```
## Main Server
Start the server:
```bash
poetry run python -m reflector.app
```
### Crontab (optional)
For crontab (only healthcheck for now), start the celery beat (you don't need it on your local dev environment):
```bash
poetry run celery -A reflector.worker.app beat
```
#### Using docker
@@ -141,4 +199,5 @@ poetry run python -m reflector.tools.process path/to/audio.wav
## AI Models
*(Documentation for this section is pending.)*
*(Documentation for this section is pending.)*

211
aws/lambda-nodejs18.x/.gitignore vendored Normal file
View File

@@ -0,0 +1,211 @@
# Created by https://www.toptal.com/developers/gitignore/api/osx,node,linux,windows,sam
# Edit at https://www.toptal.com/developers/gitignore?templates=osx,node,linux,windows,sam
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Node ###
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# TypeScript v1 declaration files
typings/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
.env*.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
# Nuxt.js build / generate output
.nuxt
dist
# Storybook build outputs
.out
.storybook-out
storybook-static
# rollup.js default build output
dist/
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# Temporary folders
tmp/
temp/
### OSX ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### SAM ###
# Ignore build directories for the AWS Serverless Application Model (SAM)
# Info: https://aws.amazon.com/serverless/sam/
# Docs: https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-reference.html
**/.aws-sam
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
.aws-sam
UpdateZulipStreams/node_modules
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/osx,node,linux,windows,sam

View File

@@ -0,0 +1,38 @@
# Developing AWS SAM Applications with the AWS Toolkit For Visual Studio Code
This project contains source code and supporting files for a serverless application that you can locally run, debug, and deploy to AWS with the AWS Toolkit For Visual Studio Code.
A "SAM" (serverless application model) project is a project that contains a template.yaml file which is understood by AWS tooling (such as SAM CLI, and the AWS Toolkit For Visual Studio Code).
## Writing and Debugging Serverless Applications
The code for this application will differ based on the runtime, but the path to a handler can be found in the [`template.yaml`](./template.yaml) file through a resource's `CodeUri` and `Handler` fields.
AWS Toolkit For Visual Studio Code supports local debugging for serverless applications through VS Code's debugger. Since this application was created by the AWS Toolkit, launch configurations for all included handlers have been generated and can be found in the menu next to the Run button:
* lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)
* API lambda-nodejs18.x:HelloWorldFunction (nodejs18.x)
You can debug the Lambda handlers locally by adding a breakpoint to the source file, then running the launch configuration. This works by using Docker on your local machine.
Invocation parameters, including payloads and request parameters, can be edited either by the `Edit SAM Debug Configuration` command (through the Command Palette or CodeLens) or by editing the `launch.json` file.
AWS Lambda functions not defined in the [`template.yaml`](./template.yaml) file can be invoked and debugged by creating a launch configuration through the CodeLens over the function declaration, or with the `Add SAM Debug Configuration` command.
## Deploying Serverless Applications
You can deploy a serverless application by invoking the `AWS: Deploy SAM application` command through the Command Palette or by right-clicking the Lambda node in the AWS Explorer and entering the deployment region, a valid S3 bucket from the region, and the name of a CloudFormation stack to deploy to. You can monitor your deployment's progress through the `AWS Toolkit` Output Channel.
## Interacting With Deployed Serverless Applications
A successfully-deployed serverless application can be found in the AWS Explorer under region and CloudFormation node that the serverless application was deployed to.
In the AWS Explorer, you can invoke _remote_ AWS Lambda Functions by right-clicking the Lambda node and selecting "Invoke on AWS".
Similarly, if the Function declaration contained an API Gateway event, the API Gateway API can be found in the API Gateway node under the region node the serverless application was deployed to, and can be invoked via right-clicking the API node and selecting "Invoke on AWS".
## Resources
General information about this SAM project can be found in the [`README.md`](./README.md) file in this folder.
More information about using the AWS Toolkit For Visual Studio Code with serverless applications can be found [in the AWS documentation](https://docs.aws.amazon.com/toolkit-for-vscode/latest/userguide/serverless-apps.html) .

View File

@@ -0,0 +1,127 @@
# lambda-nodejs18.x
This project contains source code and supporting files for a serverless application that you can deploy with the SAM CLI. It includes the following files and folders.
- hello-world - Code for the application's Lambda function.
- events - Invocation events that you can use to invoke the function.
- hello-world/tests - Unit tests for the application code.
- template.yaml - A template that defines the application's AWS resources.
The application uses several AWS resources, including Lambda functions and an API Gateway API. These resources are defined in the `template.yaml` file in this project. You can update the template to add AWS resources through the same deployment process that updates your application code.
If you prefer to use an integrated development environment (IDE) to build and test your application, you can use the AWS Toolkit.
The AWS Toolkit is an open source plug-in for popular IDEs that uses the SAM CLI to build and deploy serverless applications on AWS. The AWS Toolkit also adds a simplified step-through debugging experience for Lambda function code. See the following links to get started.
* [CLion](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [GoLand](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [IntelliJ](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [WebStorm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [Rider](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [PhpStorm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [PyCharm](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [RubyMine](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [DataGrip](https://docs.aws.amazon.com/toolkit-for-jetbrains/latest/userguide/welcome.html)
* [VS Code](https://docs.aws.amazon.com/toolkit-for-vscode/latest/userguide/welcome.html)
* [Visual Studio](https://docs.aws.amazon.com/toolkit-for-visual-studio/latest/user-guide/welcome.html)
## Deploy the sample application
The Serverless Application Model Command Line Interface (SAM CLI) is an extension of the AWS CLI that adds functionality for building and testing Lambda applications. It uses Docker to run your functions in an Amazon Linux environment that matches Lambda. It can also emulate your application's build environment and API.
To use the SAM CLI, you need the following tools.
* SAM CLI - [Install the SAM CLI](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-install.html)
* Node.js - [Install Node.js 18](https://nodejs.org/en/), including the NPM package management tool.
* Docker - [Install Docker community edition](https://hub.docker.com/search/?type=edition&offering=community)
To build and deploy your application for the first time, run the following in your shell:
```bash
sam build
sam deploy --guided
```
The first command will build the source of your application. The second command will package and deploy your application to AWS, with a series of prompts:
* **Stack Name**: The name of the stack to deploy to CloudFormation. This should be unique to your account and region, and a good starting point would be something matching your project name.
* **AWS Region**: The AWS region you want to deploy your app to.
* **Confirm changes before deploy**: If set to yes, any change sets will be shown to you before execution for manual review. If set to no, the AWS SAM CLI will automatically deploy application changes.
* **Allow SAM CLI IAM role creation**: Many AWS SAM templates, including this example, create AWS IAM roles required for the AWS Lambda function(s) included to access AWS services. By default, these are scoped down to minimum required permissions. To deploy an AWS CloudFormation stack which creates or modifies IAM roles, the `CAPABILITY_IAM` value for `capabilities` must be provided. If permission isn't provided through this prompt, to deploy this example you must explicitly pass `--capabilities CAPABILITY_IAM` to the `sam deploy` command.
* **Save arguments to samconfig.toml**: If set to yes, your choices will be saved to a configuration file inside the project, so that in the future you can just re-run `sam deploy` without parameters to deploy changes to your application.
You can find your API Gateway Endpoint URL in the output values displayed after deployment.
## Use the SAM CLI to build and test locally
Build your application with the `sam build` command.
```bash
lambda-nodejs18.x$ sam build
```
The SAM CLI installs dependencies defined in `hello-world/package.json`, creates a deployment package, and saves it in the `.aws-sam/build` folder.
Test a single function by invoking it directly with a test event. An event is a JSON document that represents the input that the function receives from the event source. Test events are included in the `events` folder in this project.
Run functions locally and invoke them with the `sam local invoke` command.
```bash
lambda-nodejs18.x$ sam local invoke HelloWorldFunction --event events/event.json
```
The SAM CLI can also emulate your application's API. Use the `sam local start-api` to run the API locally on port 3000.
```bash
lambda-nodejs18.x$ sam local start-api
lambda-nodejs18.x$ curl http://localhost:3000/
```
The SAM CLI reads the application template to determine the API's routes and the functions that they invoke. The `Events` property on each function's definition includes the route and method for each path.
```yaml
Events:
HelloWorld:
Type: Api
Properties:
Path: /hello
Method: get
```
## Add a resource to your application
The application template uses AWS Serverless Application Model (AWS SAM) to define application resources. AWS SAM is an extension of AWS CloudFormation with a simpler syntax for configuring common serverless application resources such as functions, triggers, and APIs. For resources not included in [the SAM specification](https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md), you can use standard [AWS CloudFormation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-template-resource-type-ref.html) resource types.
## Fetch, tail, and filter Lambda function logs
To simplify troubleshooting, SAM CLI has a command called `sam logs`. `sam logs` lets you fetch logs generated by your deployed Lambda function from the command line. In addition to printing the logs on the terminal, this command has several nifty features to help you quickly find the bug.
`NOTE`: This command works for all AWS Lambda functions; not just the ones you deploy using SAM.
```bash
lambda-nodejs18.x$ sam logs -n HelloWorldFunction --stack-name lambda-nodejs18.x --tail
```
You can find more information and examples about filtering Lambda function logs in the [SAM CLI Documentation](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-logging.html).
## Unit tests
Tests are defined in the `hello-world/tests` folder in this project. Use NPM to install the [Mocha test framework](https://mochajs.org/) and run unit tests.
```bash
lambda-nodejs18.x$ cd hello-world
hello-world$ npm install
hello-world$ npm run test
```
## Cleanup
To delete the sample application that you created, use the AWS CLI. Assuming you used your project name for the stack name, you can run the following:
```bash
sam delete --stack-name lambda-nodejs18.x
```
## Resources
See the [AWS SAM developer guide](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/what-is-sam.html) for an introduction to SAM specification, the SAM CLI, and serverless application concepts.
Next, you can use AWS Serverless Application Repository to deploy ready to use Apps that go beyond hello world samples and learn how authors developed their applications: [AWS Serverless Application Repository main page](https://aws.amazon.com/serverless/serverlessrepo/)

View File

@@ -0,0 +1,71 @@
const AWS = require('aws-sdk');
const s3 = new AWS.S3();
const axios = require('axios');
async function getTopics(stream_id) {
const response = await axios.get(`https://${process.env.ZULIP_REALM}/api/v1/users/me/${stream_id}/topics`, {
auth: {
username: process.env.ZULIP_BOT_EMAIL || "?",
password: process.env.ZULIP_API_KEY || "?"
}
});
return response.data.topics.map(topic => topic.name);
}
async function getStreams() {
const response = await axios.get(`https://${process.env.ZULIP_REALM}/api/v1/streams`, {
auth: {
username: process.env.ZULIP_BOT_EMAIL || "?",
password: process.env.ZULIP_API_KEY || "?"
}
});
const streams = [];
for (const stream of response.data.streams) {
console.log("Loading topics for " + stream.name);
const topics = await getTopics(stream.stream_id);
streams.push({ id: stream.stream_id, name: stream.name, topics });
}
return streams;
}
const handler = async (event) => {
const streams = await getStreams();
// Convert the streams to JSON
const json_data = JSON.stringify(streams);
const bucketName = process.env.S3BUCKET_NAME;
const fileName = process.env.S3BUCKET_FILE_NAME;
// Parameters for S3 upload
const params = {
Bucket: bucketName,
Key: fileName,
Body: json_data,
ContentType: 'application/json'
};
try {
// Write the JSON data to S3
await s3.putObject(params).promise();
return {
statusCode: 200,
body: JSON.stringify('File written to S3 successfully')
};
} catch (error) {
console.error('Error writing to S3:', error);
return {
statusCode: 500,
body: JSON.stringify('Error writing file to S3')
};
}
};
module.exports = { handler };

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
{
"name": "updatezulipstreams",
"version": "1.0.0",
"description": "Updates the JSON with the zulip streams and topics on S3",
"main": "index.js",
"author": "Andreas Bonini",
"license": "All rights reserved",
"dependencies": {
"aws-sdk": "^2.1498.0",
"axios": "^1.6.2"
},
"devDependencies": {
"chai": "^4.3.6",
"mocha": "^10.1.0"
}
}

View File

@@ -0,0 +1,62 @@
{
"body": "{\"message\": \"hello world\"}",
"resource": "/{proxy+}",
"path": "/path/to/resource",
"httpMethod": "POST",
"isBase64Encoded": false,
"queryStringParameters": {
"foo": "bar"
},
"pathParameters": {
"proxy": "/path/to/resource"
},
"stageVariables": {
"baz": "qux"
},
"headers": {
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
"Accept-Encoding": "gzip, deflate, sdch",
"Accept-Language": "en-US,en;q=0.8",
"Cache-Control": "max-age=0",
"CloudFront-Forwarded-Proto": "https",
"CloudFront-Is-Desktop-Viewer": "true",
"CloudFront-Is-Mobile-Viewer": "false",
"CloudFront-Is-SmartTV-Viewer": "false",
"CloudFront-Is-Tablet-Viewer": "false",
"CloudFront-Viewer-Country": "US",
"Host": "1234567890.execute-api.us-east-1.amazonaws.com",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Custom User Agent String",
"Via": "1.1 08f323deadbeefa7af34d5feb414ce27.cloudfront.net (CloudFront)",
"X-Amz-Cf-Id": "cDehVQoZnx43VYQb9j2-nvCh-9z396Uhbp027Y2JvkCPNLmGJHqlaA==",
"X-Forwarded-For": "127.0.0.1, 127.0.0.2",
"X-Forwarded-Port": "443",
"X-Forwarded-Proto": "https"
},
"requestContext": {
"accountId": "123456789012",
"resourceId": "123456",
"stage": "prod",
"requestId": "c6af9ac6-7b61-11e6-9a41-93e8deadbeef",
"requestTime": "09/Apr/2015:12:34:56 +0000",
"requestTimeEpoch": 1428582896000,
"identity": {
"cognitoIdentityPoolId": null,
"accountId": null,
"cognitoIdentityId": null,
"caller": null,
"accessKey": null,
"sourceIp": "127.0.0.1",
"cognitoAuthenticationType": null,
"cognitoAuthenticationProvider": null,
"userArn": null,
"userAgent": "Custom User Agent String",
"user": null
},
"path": "/prod/path/to/resource",
"resourcePath": "/{proxy+}",
"httpMethod": "POST",
"apiId": "1234567890",
"protocol": "HTTP/1.1"
}
}

View File

@@ -0,0 +1,31 @@
# More information about the configuration file can be found here:
# https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-config.html
version = 0.1
[default]
[default.global.parameters]
stack_name = "lambda-nodejs18.x"
[default.build.parameters]
cached = true
parallel = true
[default.validate.parameters]
lint = true
[default.deploy.parameters]
capabilities = "CAPABILITY_IAM"
confirm_changeset = true
resolve_s3 = true
[default.package.parameters]
resolve_s3 = true
[default.sync.parameters]
watch = true
[default.local_start_api.parameters]
warm_containers = "EAGER"
[default.local_start_lambda.parameters]
warm_containers = "EAGER"

View File

@@ -0,0 +1,41 @@
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Description: >
lambda-nodejs18.x
Sample SAM Template for lambda-nodejs18.x
# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst
Globals:
Function:
Timeout: 3
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function # More info about Function Resource: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#awsserverlessfunction
Properties:
CodeUri: hello-world/
Handler: app.lambdaHandler
Runtime: nodejs18.x
Architectures:
- arm64
Events:
HelloWorld:
Type: Api # More info about API Event Source: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#api
Properties:
Path: /hello
Method: get
Outputs:
# ServerlessRestApi is an implicit API created out of Events key under Serverless::Function
# Find out more about other implicit resources you can reference within SAM
# https://github.com/awslabs/serverless-application-model/blob/master/docs/internals/generated_resources.rst#api
HelloWorldApi:
Description: "API Gateway endpoint URL for Prod stage for Hello World function"
Value: !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/hello/"
HelloWorldFunction:
Description: "Hello World Lambda Function ARN"
Value: !GetAtt HelloWorldFunction.Arn
HelloWorldFunctionIamRole:
Description: "Implicit IAM Role created for Hello World function"
Value: !GetAtt HelloWorldFunctionRole.Arn

View File

@@ -5,10 +5,19 @@ services:
context: server
ports:
- 1250:1250
environment:
LLM_URL: "${LLM_URL}"
volumes:
- model-cache:/root/.cache
environment: ENTRYPOINT=server
worker:
build:
context: server
volumes:
- model-cache:/root/.cache
environment: ENTRYPOINT=worker
redis:
image: redis:7.2
ports:
- 6379:6379
web:
build:
context: www
@@ -17,4 +26,3 @@ services:
volumes:
model-cache:

20
server/.env_template Normal file
View File

@@ -0,0 +1,20 @@
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=***REMOVED***
LLM_BACKEND=modal
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=<ask in zulip>
AUTH_BACKEND=fief
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
AUTH_FIEF_CLIENT_ID=***REMOVED***
AUTH_FIEF_CLIENT_SECRET=<ask in zulip> <-----------------------------------------------------------------------------------------
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
BASE_URL=https://xxxxx.ngrok.app
DIARIZATION_ENABLED=false

2
server/.gitignore vendored
View File

@@ -178,3 +178,5 @@ audio_*.wav
# ignore local database
reflector.sqlite3
data/
dump.rdb

View File

@@ -1 +1 @@
3.11
3.11.6

View File

@@ -5,11 +5,23 @@ services:
context: .
ports:
- 1250:1250
environment:
LLM_URL: "${LLM_URL}"
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
volumes:
- model-cache:/root/.cache
environment:
ENTRYPOINT: server
REDIS_HOST: redis
worker:
build:
context: .
volumes:
- model-cache:/root/.cache
environment:
ENTRYPOINT: worker
REDIS_HOST: redis
redis:
image: redis:7.2
ports:
- 6379:6379
volumes:
model-cache:

View File

@@ -51,17 +51,6 @@
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
#TRANSCRIPT_MODAL_API_KEY=xxxxx
## Using serverless banana.dev (require reflector-gpu-banana deployed)
## XXX this service is buggy do not use at the moment
## XXX it also require the audio to be saved to S3
#TRANSCRIPT_BACKEND=banana
#TRANSCRIPT_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
#TRANSCRIPT_BANANA_API_KEY=xxx
#TRANSCRIPT_BANANA_MODEL_KEY=xxx
#TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID=xxx
#TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY=xxx
#TRANSCRIPT_STORAGE_AWS_BUCKET_NAME="reflector-bucket/chunks"
## =======================================================
## LLM backend
##
@@ -78,13 +67,6 @@
#LLM_URL=https://xxxxxx--reflector-llm-web.modal.run
#LLM_MODAL_API_KEY=xxx
## Using serverless banana.dev (require reflector-gpu-banana deployed)
## XXX this service is buggy do not use at the moment
#LLM_BACKEND=banana
#LLM_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
#LLM_BANANA_API_KEY=xxxxx
#LLM_BANANA_MODEL_KEY=xxxxx
## Using OpenAI
#LLM_BACKEND=openai
#LLM_OPENAI_KEY=xxx

View File

@@ -0,0 +1,188 @@
"""
Reflector GPU backend - diarizer
===================================
"""
import os
import modal.gpu
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
MODEL_DIR = "/root/diarization_models"
stub = Stub(name="reflector-diarizer")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
print("LLM cache moved")
def download_pyannote_audio():
from pyannote.audio import Pipeline
Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR,
use_auth_token="***REMOVED***"
)
diarizer_image = (
Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio",
"requests",
"onnx",
"torchaudio",
"onnxruntime-gpu",
"torch==2.0.0",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"numpy",
"huggingface_hub",
"hf-transfer"
)
.run_function(migrate_cache_llm)
.run_function(download_pyannote_audio)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@stub.cls(
gpu=modal.gpu.A100(memory=40),
timeout=60 * 30,
container_idle_timeout=60,
allow_concurrent_inputs=1,
image=diarizer_image,
)
class Diarizer:
def __enter__(self):
import torch
from pyannote.audio import Pipeline
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR
)
self.diarization_pipeline.to(torch.device(self.device))
@method()
def diarize(
self,
audio_data: str,
audio_suffix: str,
timestamp: float
):
import tempfile
import torchaudio
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:])
}
)
print("Diarization complete")
return {
"diarization": words
}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
timeout=60 * 10,
container_idle_timeout=60 * 3,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
image=diarizer_image
)
@asgi_app()
def web():
import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
diarizerstub = Diarizer()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
def validate_audio_file(audio_file_url: str):
# Check if the audio file exists
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(
status_code=response.status_code,
detail="The audio file does not exist."
)
class DiarizationResponse(BaseModel):
result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
def diarize(
audio_file_url: str,
timestamp: float = 0.0
) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format
audio_suffix = "mp3"
print("Downloading audio file")
response = requests.get(audio_file_url, allow_redirects=True)
print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn(
audio_data=response.content,
audio_suffix=audio_suffix,
timestamp=timestamp
)
result = func.get()
return result
return app

View File

@@ -81,7 +81,8 @@ class LLM:
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
@@ -96,7 +97,8 @@ class LLM:
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# move model to gpu

View File

@@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models"
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
stub = Stub(name="reflector-llm-zephyr")
@@ -81,7 +81,8 @@ class LLM:
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
@@ -96,7 +97,8 @@ class LLM:
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
gen_cfg.pad_token_id = tokenizer.eos_token_id
gen_cfg.eos_token_id = tokenizer.eos_token_id

View File

@@ -95,7 +95,8 @@ class Transcriber:
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
download_root=WHISPER_MODEL_DIR
download_root=WHISPER_MODEL_DIR,
local_files_only=True
)
@method()

View File

@@ -0,0 +1,33 @@
"""add share_mode
Revision ID: 0fea6d96b096
Revises: f819277e5169
Create Date: 2023-11-07 11:12:21.614198
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0fea6d96b096"
down_revision: Union[str, None] = "f819277e5169"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column("share_mode", sa.String(), server_default="private", nullable=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "share_mode")
# ### end Alembic commands ###

View File

@@ -0,0 +1,30 @@
"""participants
Revision ID: 125031f7cb78
Revises: 0fea6d96b096
Create Date: 2023-11-30 15:56:03.341466
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '125031f7cb78'
down_revision: Union[str, None] = '0fea6d96b096'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'participants')
# ### end Alembic commands ###

View File

@@ -0,0 +1,80 @@
"""rename back text to transcript
Revision ID: 38a927dcb099
Revises: 9920ecfe2735
Create Date: 2023-11-02 19:53:09.116240
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = '38a927dcb099'
down_revision: Union[str, None] = '9920ecfe2735'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "text" in topic:
# Rename key 'text' back to 'transcript'
topic["transcript"] = topic.pop("text")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)
def downgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "transcript" in topic:
# Rename key 'transcript' to 'text'
topic["text"] = topic.pop("transcript")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)

View File

@@ -0,0 +1,64 @@
"""fix duration
Revision ID: 4814901632bc
Revises: 38a927dcb099
Create Date: 2023-11-10 18:12:17.886522
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = "4814901632bc"
down_revision: Union[str, None] = "38a927dcb099"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# for all the transcripts, calculate the duration from the mp3
# and update the duration column
from pathlib import Path
from reflector.settings import settings
import av
bind = op.get_bind()
transcript = table(
"transcript", column("id", sa.String), column("duration", sa.Float)
)
# select only the one with duration = 0
results = bind.execute(
select([transcript.c.id, transcript.c.duration]).where(
transcript.c.duration == 0
)
)
data_dir = Path(settings.DATA_DIR)
for row in results:
audio_path = data_dir / row["id"] / "audio.mp3"
if not audio_path.exists():
continue
try:
print(f"Processing {audio_path}")
container = av.open(audio_path.as_posix())
print(container.duration)
duration = round(float(container.duration / av.time_base), 2)
print(f"Duration: {duration}")
bind.execute(
transcript.update()
.where(transcript.c.id == row["id"])
.values(duration=duration)
)
except Exception as e:
print(f"Failed to process {audio_path}: {e}")
def downgrade() -> None:
pass

View File

@@ -0,0 +1,80 @@
"""Migration transcript to text field in transcripts table
Revision ID: 9920ecfe2735
Revises: 99365b0cd87b
Create Date: 2023-11-02 18:55:17.019498
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = "9920ecfe2735"
down_revision: Union[str, None] = "99365b0cd87b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "transcript" in topic:
# Rename key 'transcript' to 'text'
topic["text"] = topic.pop("transcript")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)
def downgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "text" in topic:
# Rename key 'text' back to 'transcript'
topic["transcript"] = topic.pop("text")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)

View File

@@ -0,0 +1,29 @@
"""reviewed
Revision ID: b9348748bbbc
Revises: 125031f7cb78
Create Date: 2023-12-13 15:37:51.303970
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b9348748bbbc'
down_revision: Union[str, None] = '125031f7cb78'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('reviewed', sa.Boolean(), server_default=sa.text('0'), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'reviewed')
# ### end Alembic commands ###

View File

@@ -0,0 +1,35 @@
"""audio_location
Revision ID: f819277e5169
Revises: 4814901632bc
Create Date: 2023-11-16 10:29:09.351664
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f819277e5169"
down_revision: Union[str, None] = "4814901632bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"audio_location", sa.String(), server_default="local", nullable=False
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "audio_location")
# ### end Alembic commands ###

1208
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,12 +8,11 @@ packages = []
[tool.poetry.dependencies]
python = "^3.11"
aiohttp = "^3.8.5"
aiohttp = "^3.9.0"
aiohttp-cors = "^0.7.0"
av = "^10.0.0"
requests = "^2.31.0"
aiortc = "^1.5.0"
faster-whisper = "^0.7.1"
sortedcontainers = "^2.4.0"
loguru = "^0.7.0"
pydantic-settings = "^2.0.2"
@@ -28,16 +27,22 @@ sqlalchemy = "<1.5"
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
alembic = "^1.11.3"
nltk = "^3.8.1"
transformers = "^4.32.1"
prometheus-fastapi-instrumentator = "^6.1.0"
sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
redis = "^5.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
python-multipart = "^0.0.6"
faster-whisper = "^0.10.0"
transformers = "^4.36.2"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
stamina = "^23.1.0"
pyinstrument = "^4.6.1"
[tool.poetry.group.tests.dependencies]
@@ -47,6 +52,7 @@ pytest-asyncio = "^0.21.1"
pytest = "^7.4.0"
httpx-ws = "^0.4.1"
pytest-httpx = "^0.23.1"
pytest-celery = "^0.0.0"
[tool.poetry.group.aws.dependencies]

View File

@@ -13,6 +13,14 @@ from reflector.metrics import metrics_init
from reflector.settings import settings
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
from reflector.views.transcripts_upload import router as transcripts_upload_router
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router
try:
@@ -41,7 +49,6 @@ if settings.SENTRY_DSN:
else:
logger.info("Sentry disabled")
# build app
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -61,9 +68,18 @@ metrics_init(app, instrumentator)
# register views
app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1")
app.include_router(transcripts_audio_router, prefix="/v1")
app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
# prepare celery
from reflector.worker import app as celery_app # noqa
# simpler openapi id
def use_route_names_as_operation_ids(app: FastAPI) -> None:
@@ -84,7 +100,10 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
version = None
if route.path.startswith("/v"):
version = route.path.split("/")[1]
opid = f"{version}_{route.name}"
if route.operation_id is not None:
opid = f"{version}_{route.operation_id}"
else:
opid = f"{version}_{route.name}"
else:
opid = route.name
@@ -94,11 +113,28 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
"Please rename the route or the view function."
)
route.operation_id = opid
ensure_uniq_operation_ids.add(route.name)
ensure_uniq_operation_ids.add(opid)
use_route_names_as_operation_ids(app)
if settings.PROFILING:
from fastapi import Request
from fastapi.responses import HTMLResponse
from pyinstrument import Profiler
@app.middleware("http")
async def profile_request(request: Request, call_next):
profiling = request.query_params.get("profile", False)
if profiling:
profiler = Profiler(async_mode="enabled")
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
if __name__ == "__main__":
import uvicorn

View File

@@ -1,32 +1,13 @@
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
# import models
import reflector.db.transcripts # noqa
engine = sqlalchemy.create_engine(
settings.DATABASE_URL, connect_args={"check_same_thread": False}

View File

@@ -0,0 +1,507 @@
import json
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Literal
from uuid import uuid4
import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import Storage
from sqlalchemy.sql import false
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
"share_mode",
sqlalchemy.String,
nullable=False,
server_default="private",
),
)
def generate_uuid4() -> str:
return str(uuid4())
def generate_transcript_name() -> str:
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
def get_storage() -> Storage:
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
timestamp: float
duration: float | None = 0
transcript: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel):
event: str
data: dict
class TranscriptParticipant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
speaker: int | None
name: str
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
participants: list[TranscriptParticipant] | None = []
source_language: str = "en"
target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"
audio_location: str = "local"
reviewed: bool = False
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
if index is not None:
self.topics[index] = topic
else:
self.topics.append(topic)
def upsert_participant(self, participant: TranscriptParticipant):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant.id),
None,
)
if index is not None:
self.participants[index] = participant
else:
self.participants.append(participant)
return participant
def delete_participant(self, participant_id: str):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant_id),
None,
)
if index is not None:
del self.participants[index]
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def participants_dump(self, mode="json"):
return [participant.model_dump(mode=mode) for participant in self.participants]
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
def find_empty_speaker(self) -> int:
"""
Find an empty speaker seat
"""
speakers = set(
word.speaker
for topic in self.topics
for word in topic.words
if word.speaker is not None
)
i = 0
while True:
if i not in speakers:
return i
i += 1
raise Exception("No empty speaker found")
class TranscriptController:
async def get_all(
self,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
return_query: bool = False,
) -> list[Transcript]:
"""
Get all transcripts
If `user_id` is specified, only return transcripts that belong to the user.
Otherwise, return all anonymous transcripts.
Parameters:
- `order_by`: field to order by, e.g. "-created_at"
- `filter_empty`: filter out empty transcripts
- `filter_recording`: filter out transcripts that are currently recording
"""
query = transcripts.select().where(transcripts.c.user_id == user_id)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
if return_query:
return query
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
"""
Get a transcript by id
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def get_by_id_for_http(
self,
transcript_id: str,
user_id: str | None,
) -> Transcript:
"""
Get a transcript by ID for HTTP request.
If not found, it will raise a 404 error.
If the user is not allowed to access the transcript, it will raise a 403 error.
This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript.
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked
transcript = Transcript(**result)
if transcript.user_id is None:
return transcript
if transcript.share_mode == "private":
# in private mode, only the owner can access the transcript
if transcript.user_id == user_id:
return transcript
elif transcript.share_mode == "semi-private":
# in semi-private mode, only the owner and the users with the link
# can access the transcript
if user_id is not None:
return transcript
elif transcript.share_mode == "public":
# in public mode, everyone can access the transcript
return transcript
raise HTTPException(status_code=403, detail="Transcript access denied")
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
"""
Add a new transcript
"""
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict, mutate=True):
"""
Update a transcript fields with key/values in values
"""
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with database.transaction(isolation="serializable"):
yield
async def append_event(
self,
transcript: Transcript,
event: str,
data: Any,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
return resp
async def upsert_topic(
self,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
transcript.upsert_topic(topic)
await self.update(
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
# store the audio on external storage
await get_storage().put_file(
transcript.storage_audio_path,
transcript.audio_mp3_filename.read_bytes(),
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
async def upsert_participant(
self,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
"""
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
return result
async def delete_participant(
self,
transcript: Transcript,
participant_id: str,
):
"""
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
transcripts_controller = TranscriptController()

View File

@@ -1,54 +0,0 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
class BananaLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.headers = {
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = BananaLLM()
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -47,6 +47,7 @@ class ModalLLM(LLM):
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
follow_redirects=True,
)
response.raise_for_status()
text = response.json()["text"]

View File

@@ -0,0 +1,680 @@
"""
Main reflector pipeline for live streaming
==========================================
This is the default pipeline used in the API.
It is decoupled to:
- PipelineMainLive: have limited processing during live
- PipelineMainPost: do heavy lifting after the live
It is directly linked to our data model.
"""
import asyncio
import functools
from contextlib import asynccontextmanager
from celery import chord, group, shared_task
from pydantic import BaseModel
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationAutoProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
BroadcastProcessor,
Pipeline,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
)
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import BoundLogger as Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
"""
async def wrapper(self, *args, **kwargs):
resp = await func(self, *args, **kwargs)
if resp is None:
return
await self.ws_manager.send_json(
room_id=self.ws_room_id,
message=resp.model_dump(mode="json"),
)
return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
return await func(transcript=transcript, logger=tlogger, **kwargs)
except Exception as exc:
tlogger.error("Pipeline error", exc_info=exc)
raise
return wrapper
class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
def prepare(self):
# prepare websocket
self._lock = asyncio.Lock()
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
async def get_transcript(self) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager
async def transaction(self):
async with self._lock:
async with transcripts_controller.transaction():
yield
@broadcast_to_sockets
async def on_status(self, status):
# if it's the first part, update the status of the transcript
# but do not set the ended status yet.
if isinstance(self, PipelineMainLive):
status_mapping = {
"started": "recording",
"push": "recording",
"flush": "processing",
"error": "error",
}
elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = {
"push": "processing",
"flush": "processing",
"error": "error",
"ended": "ended",
}
else:
# intermediate pipeline don't update status
return
# mutate to model status
status = status_mapping.get(status)
if not status:
return
# when the status of the pipeline changes, update the transcript
async with self.transaction():
transcript = await self.get_transcript()
if status == transcript.status:
return
resp = await transcripts_controller.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await transcripts_controller.update(
transcript,
{
"status": status,
},
)
return resp
@broadcast_to_sockets
async def on_transcript(self, data):
async with self.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
)
@broadcast_to_sockets
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
transcript=data.transcript.text,
words=data.transcript.words,
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
await transcripts_controller.update(
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_sockets
async def on_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase):
"""
Main pipeline for live streaming, attach to RTC connection
Any long post process should be done in the post pipeline
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(
path=transcript.audio_wav_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Pipeline main live created")
return pipeline
async def on_ended(self):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
Diarize the audio and update topics
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
pipeline = Pipeline(
AudioDiarizationAutoProcessor(callback=self.on_topic),
)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
"""
Pseudo class for generating a pipeline from topics
"""
def get_processors(self) -> list:
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
# get transcript
self._transcript = transcript = await self.get_transcript()
# create pipeline
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
"""
Generate title from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
)
]
class PipelineMainFinalSummaries(PipelineMainFromTopics):
"""
Generate summaries from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
class PipelineMainWaveform(PipelineMainFromTopics):
"""
Generate waveform
"""
def get_processors(self) -> list:
return [
AudioWaveformProcessor.as_threaded(
audio_path=self._transcript.audio_wav_filename,
waveform_path=self._transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
@get_transcript
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
logger.info("Starting remove upload")
uploads = transcript.data_path.glob("upload.*")
for upload in uploads:
upload.unlink(missing_ok=True)
logger.info("Remove upload done")
@get_transcript
async def pipeline_waveform(transcript: Transcript, logger: Logger):
logger.info("Starting waveform")
runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run()
logger.info("Waveform done")
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
out_stream = out_container.add_stream("mp3")
for frame in in_container.decode(in_stream):
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Delete the wav file
transcript.audio_wav_filename.unlink(missing_ok=True)
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
@get_transcript
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
logger.info("Starting title and short summary")
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
await runner.run()
logger.info("Title and short summary done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task
@asynctask
async def task_pipeline_remove_upload(*, transcript_id: str):
await pipeline_remove_upload(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_waveform(*, transcript_id: str):
await pipeline_waveform(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
await pipeline_title_and_short_summary(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
chain_title_preview = task_pipeline_title_and_short_summary.si(
transcript_id=transcript_id
)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
)
chain.delay()
@get_transcript
async def pipeline_upload(transcript: Transcript, logger: Logger):
import av
try:
# open audio
upload_filename = next(transcript.data_path.glob("upload.*"))
container = av.open(upload_filename.as_posix())
# create pipeline
pipeline = PipelineMainLive(transcript_id=transcript.id)
pipeline.start()
# push audio to pipeline
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
pipeline.flush()
logger.info("Waiting for the pipeline to end")
await pipeline.join()
except Exception as exc:
logger.error("Pipeline error", exc_info=exc)
await transcripts_controller.update(
transcript,
{
"status": "error",
},
)
raise
logger.info("Pipeline ended")
@shared_task
@asynctask
async def task_pipeline_upload(*, transcript_id: str):
return await pipeline_upload(transcript_id=transcript_id)

View File

@@ -0,0 +1,152 @@
"""
Pipeline Runner
===============
Pipeline runner designed to be executed in a asyncio task.
It is meant to be subclassed, and implement a create() method
that expose/return a Pipeline instance.
During its lifecycle, it will emit the following status:
- started: the pipeline has been started
- push: the pipeline received at least one data
- flush: the pipeline is flushing
- ended: the pipeline has ended
- error: the pipeline has ended with an error
"""
import asyncio
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._task = None
self._q_cmd = asyncio.Queue(maxsize=4096)
self._ev_done = asyncio.Event()
self._is_first_push = True
self._logger = logger.bind(
runner=id(self),
runner_cls=self.__class__.__name__,
)
def create(self) -> Pipeline:
"""
Create the pipeline if not specified earlier.
Should be implemented in a subclass
"""
raise NotImplementedError()
def start(self):
"""
Start the pipeline as a coroutine task
"""
self._task = asyncio.get_event_loop().create_task(self.run())
async def join(self):
"""
Wait for the pipeline to finish
"""
if self._task:
await self._task
def start_sync(self):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
coro = self.run()
asyncio.run(coro)
def push(self, data):
"""
Push data to the pipeline
"""
self._add_cmd("PUSH", data)
def flush(self):
"""
Flush the pipeline
"""
self._add_cmd("FLUSH", None)
async def on_status(self, status):
"""
Called when the status of the pipeline changes
"""
pass
async def on_ended(self):
"""
Called when the pipeline ends
"""
pass
def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
self._q_cmd.put_nowait([cmd, data])
async def _set_status(self, status):
self._logger.debug("Runner status updated", status=status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
except Exception:
self._logger.exception("Runer error while setting status")
async def run(self):
try:
# create the pipeline if not yet done
await self._set_status("init")
self._is_first_push = True
if not self.pipeline:
self.pipeline = await self.create()
if not self.pipeline:
# no pipeline created in create, just finish it then.
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
return
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception:
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
raise
async def cmd_push(self, data):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()

View File

@@ -1,9 +1,16 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
from .base import ( # noqa: F401
BroadcastProcessor,
Pipeline,
PipelineEvent,
Processor,
ThreadedProcessor,
)
from .transcript_final_long_summary import ( # noqa: F401
TranscriptFinalLongSummaryProcessor,
)

View File

@@ -0,0 +1,181 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
try:
self.logger.info("Diarization started", audio_file_url=data.audio_url)
diarization = await self._diarize(data)
self.logger.info("Diarization finished")
except Exception:
self.logger.exception("Diarization failed after retrying")
raise
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
# create a view of words based on topics
# the current algorithm is using words index, we cannot use a generator
words = list(self.iter_words_from_topics(data.topics))
# assign speaker to words (mutate the words list)
self.assign_speaker(words, diarization)
# emit them
for topic in data.topics:
await self.emit(topic)
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError
def assign_speaker(self, words: list[Word], diarization: list[dict]):
self._diarization_remove_overlap(diarization)
self._diarization_remove_segment_without_words(words, diarization)
self._diarization_merge_same_speaker(words, diarization)
self._diarization_assign_speaker(words, diarization)
def iter_words_from_topics(self, topics: TitleSummary):
for topic in topics:
for word in topic.transcript.words:
yield word
def is_word_continuation(self, word_prev, word):
"""
Return True if the word is a continuation of the previous word
by checking if the previous word is ending with a punctuation
or if the current word is starting with a capital letter
"""
# is word_prev ending with a punctuation ?
if word_prev.text and word_prev.text[-1] in ".?!":
return False
elif word.text and word.text[0].isupper():
return False
return True
def _diarization_remove_overlap(self, diarization: list[dict]):
"""
Remove overlap in diarization results
When using a diarization algorithm, it's possible to have overlapping segments
This function remove the overlap by keeping the longest segment
Warning: this function mutate the diarization list
"""
# remove overlap by keeping the longest segment
diarization_idx = 0
while diarization_idx < len(diarization) - 1:
d = diarization[diarization_idx]
dnext = diarization[diarization_idx + 1]
if d["end"] > dnext["start"]:
# remove the shortest segment
if d["end"] - d["start"] > dnext["end"] - dnext["start"]:
# remove next segment
diarization.pop(diarization_idx + 1)
else:
# remove current segment
diarization.pop(diarization_idx)
else:
diarization_idx += 1
def _diarization_remove_segment_without_words(
self, words: list[Word], diarization: list[dict]
):
"""
Remove diarization segments without words
Warning: this function mutate the diarization list
"""
# count the number of words for each diarization segment
diarization_count = []
for d in diarization:
start = d["start"]
end = d["end"]
count = 0
for word in words:
if start <= word.start < end:
count += 1
elif start < word.end <= end:
count += 1
diarization_count.append(count)
# remove diarization segments with no words
diarization_idx = 0
while diarization_idx < len(diarization):
if diarization_count[diarization_idx] == 0:
diarization.pop(diarization_idx)
diarization_count.pop(diarization_idx)
else:
diarization_idx += 1
def _diarization_merge_same_speaker(
self, words: list[Word], diarization: list[dict]
):
"""
Merge diarization contigous segments with the same speaker
Warning: this function mutate the diarization list
"""
# merge segment with same speaker
diarization_idx = 0
while diarization_idx < len(diarization) - 1:
d = diarization[diarization_idx]
dnext = diarization[diarization_idx + 1]
if d["speaker"] == dnext["speaker"]:
diarization[diarization_idx]["end"] = dnext["end"]
diarization.pop(diarization_idx + 1)
else:
diarization_idx += 1
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
"""
Assign speaker to words based on diarization
Warning: this function mutate the words list
"""
word_idx = 0
last_speaker = None
for d in diarization:
start = d["start"]
end = d["end"]
speaker = d["speaker"]
# diarization may start after the first set of words
# in this case, we assign the last speaker
for word in words[word_idx:]:
if word.start < start:
# speaker change, but what make sense for assigning the word ?
# If it's a new sentence, assign with the new speaker
# If it's a continuation, assign with the last speaker
is_continuation = False
if word_idx > 0 and word_idx < len(words) - 1:
is_continuation = self.is_word_continuation(
*words[word_idx - 1 : word_idx + 1]
)
if is_continuation:
word.speaker = last_speaker
else:
word.speaker = speaker
last_speaker = speaker
word_idx += 1
else:
break
# now continue to assign speaker until the word starts after the end
for word in words[word_idx:]:
if start <= word.start < end:
last_speaker = speaker
word.speaker = speaker
word_idx += 1
elif word.start > end:
break
# no more diarization available,
# assign last speaker to all words without speaker
for word in words[word_idx:]:
word.speaker = last_speaker

View File

@@ -0,0 +1,33 @@
import importlib
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.settings import settings
class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,37 @@
import httpx
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings
class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
async def _diarize(self, data: AudioDiarizationInput):
# Gather diarization data
params = {
"audio_file_url": data.audio_url,
"timestamp": 0,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.diarization_url,
headers=self.headers,
params=params,
timeout=None,
follow_redirects=True,
)
response.raise_for_status()
return response.json()["diarization"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)

View File

@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
def __init__(self, path: Path | str, **kwargs):
super().__init__(**kwargs)
if isinstance(path, str):
path = Path(path)
if path.suffix not in (".mp3", ".wav"):
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
self.path = path
self.out_container = None
self.out_stream = None
self.last_packet = None
async def _push(self, data: av.AudioFrame):
if not self.out_container:
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
raise ValueError("Only mp3 and wav files are supported")
for packet in self.out_stream.encode(data):
self.out_container.mux(packet)
self.last_packet = packet
await self.emit(data)
async def _flush(self):
if self.out_container:
for packet in self.out_stream.encode():
self.out_container.mux(packet)
self.last_packet = packet
try:
if self.last_packet is not None:
duration = round(
float(
(self.last_packet.pts * self.last_packet.duration)
* self.last_packet.time_base
),
2,
)
except Exception:
self.logger.exception("Failed to get duration")
duration = 0
self.out_container.close()
self.out_container = None
self.out_stream = None
if duration > 0:
await self.emit(duration, name="duration")

View File

@@ -1,6 +1,4 @@
from profanityfilter import ProfanityFilter
from prometheus_client import Counter, Histogram
from reflector.processors.base import Processor
from reflector.processors.types import AudioFile, Transcript
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name)
self.profanity_filter = ProfanityFilter()
self.profanity_filter.set_censor("*")
super().__init__(*args, **kwargs)
async def _push(self, data: AudioFile):
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
async def _transcript(self, data: AudioFile):
raise NotImplementedError
def filter_profanity(self, text: str) -> str:
"""
Remove censored words from the transcript
"""
return self.profanity_filter.censor(text)

View File

@@ -1,8 +1,6 @@
import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile
from reflector.settings import settings
@@ -13,8 +11,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name):
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name)
@@ -30,30 +29,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config)
def __init__(self, **kwargs):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor):
self.processor.connect(processor)
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def off(self, callback):
self.processor.off(callback)
async def _push(self, data: AudioFile):
return await self.processor._push(data)
async def _flush(self):
return await self.processor._flush()
return cls._registry[name](**config | kwargs)

View File

@@ -1,86 +0,0 @@
"""
Implementation using the GPU service from banana.
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_url": "https://...",
"audio_ext": "wav",
"timestamp": 123.456
"language": "en"
}
```
"""
from pathlib import Path
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings
from reflector.storage import Storage
from reflector.utils.retry import retry
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
def __init__(self, banana_api_key: str, banana_model_key: str):
super().__init__()
self.transcript_url = settings.TRANSCRIPT_URL
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.storage = Storage.get_instance(
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
)
self.headers = {
"X-Banana-API-Key": banana_api_key,
"X-Banana-Model-Key": banana_model_key,
}
async def _transcript(self, data: AudioFile):
async with httpx.AsyncClient() as client:
print(f"Uploading audio {data.path.name} to S3")
url = await self._upload_file(data.path)
print(f"Try to transcribe audio {data.path.name}")
request_data = {
"audio_url": url,
"audio_ext": data.path.suffix[1:],
"timestamp": float(round(data.timestamp, 2)),
}
response = await retry(client.post)(
self.transcript_url,
json=request_data,
headers=self.headers,
timeout=self.timeout,
)
print(f"Transcript response: {response.status_code} {response.content}")
response.raise_for_status()
result = response.json()
transcript = Transcript(
text=result["text"],
words=[
Word(text=word["text"], start=word["start"], end=word["end"])
for word in result["words"]
],
)
# remove audio file from S3
await self._delete_file(data.path)
return transcript
@retry
async def _upload_file(self, path: Path) -> str:
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
return upload_result.url
@retry
async def _delete_file(self, path: Path):
await self.storage.delete_file(path.name)
return True
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)

View File

@@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
timeout=self.timeout,
headers=self.headers,
params=json_payload,
follow_redirects=True,
)
self.logger.debug(
@@ -48,10 +49,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
)
response.raise_for_status()
result = response.json()
text = result["text"][source_language]
text = self.filter_profanity(text)
transcript = Transcript(
text=text,
words=[
Word(
text=word["text"],

View File

@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
ts = data.timestamp
for segment in segments:
transcript.text += segment.text
for word in segment.words:
transcript.words.append(
Word(

View File

@@ -0,0 +1,36 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
async def _push(_self, _data):
return

View File

@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
data: Any
class Processor:
class Emitter:
def __init__(self, **kwargs):
self._callbacks = {}
# register callbacks from kwargs (on_*)
for key, value in kwargs.items():
if key.startswith("on_"):
self.on(value, name=key[3:])
def on(self, callback, name="default"):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
if name not in self._callbacks:
self._callbacks[name] = []
self._callbacks[name].append(callback)
def off(self, callback, name="default"):
"""
Unregister a callback to be called when data is emitted
"""
if name not in self._callbacks:
return
self._callbacks[name].remove(callback)
async def emit(self, data, name="default"):
if name not in self._callbacks:
return
for callback in self._callbacks[name]:
await callback(data)
class Processor(Emitter):
INPUT_TYPE: type = None
OUTPUT_TYPE: type = None
@@ -59,7 +94,8 @@ class Processor:
["processor"],
)
def __init__(self, callback=None, custom_logger=None):
def __init__(self, callback=None, custom_logger=None, **kwargs):
super().__init__(**kwargs)
self.name = name = self.__class__.__name__
self.m_processor = self.m_processor.labels(name)
self.m_processor_call = self.m_processor_call.labels(name)
@@ -70,9 +106,11 @@ class Processor:
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
self._processors = []
self._callbacks = []
# register callbacks
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
@@ -100,21 +138,6 @@ class Processor:
"""
self._processors.remove(processor)
def on(self, callback):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
self._callbacks.append(callback)
def off(self, callback):
"""
Unregister a callback to be called when data is emitted
"""
self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
@@ -123,15 +146,16 @@ class Processor:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data):
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
for callback in self._callbacks:
await callback(data)
for processor in self._processors:
await processor.push(data)
async def emit(self, data, name="default"):
if name == "default":
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
await super().emit(data, name=name)
if name == "default":
for processor in self._processors:
await processor.push(data)
async def push(self, data):
"""
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def on(self, callback, name="default"):
self.processor.on(callback, name=name)
def off(self, callback):
self.processor.off(callback)
def off(self, callback, name="default"):
self.processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)
@@ -290,12 +314,12 @@ class BroadcastProcessor(Processor):
processor.set_pipeline(pipeline)
async def _push(self, data):
for processor in self.processors:
await processor.push(data)
coros = [processor.push(data) for processor in self.processors]
await asyncio.gather(*coros)
async def _flush(self):
for processor in self.processors:
await processor.flush()
coros = [processor.flush() for processor in self.processors]
await asyncio.gather(*coros)
def connect(self, processor: Processor):
for processor in self.processors:
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
for processor in self.processors:
processor.disconnect(processor)
def on(self, callback):
def on(self, callback, name="default"):
for processor in self.processors:
processor.on(callback)
processor.on(callback, name=name)
def off(self, callback):
def off(self, callback, name="default"):
for processor in self.processors:
processor.off(callback)
processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)
@@ -333,6 +357,7 @@ class Pipeline(Processor):
self.logger.info("Pipeline created")
self.processors = processors
self.options = None
self.prefs = {}
for processor in processors:

View File

@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
# cut to the next .
partial = Transcript(words=[])
for word in self.transcript.words[:]:
partial.text += word.text
partial.words.append(word)
if not self.is_sentence_terminated(word.text):
continue

View File

@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
@@ -50,6 +51,7 @@ class TranscriptTranslatorProcessor(Processor):
headers=self.headers,
params=json_payload,
timeout=self.timeout,
follow_redirects=True,
)
response.raise_for_status()
result = response.json()["text"]

View File

@@ -1,8 +1,16 @@
import io
import re
import tempfile
from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
from reflector.redis_cache import redis_cache
PUNC_RE = re.compile(r"[.;:?!…]")
profanity_filter = ProfanityFilter()
profanity_filter.set_censor("*")
class AudioFile(BaseModel):
@@ -43,13 +51,34 @@ class Word(BaseModel):
text: str
start: float
end: float
speaker: int = 0
class TranscriptSegment(BaseModel):
text: str
start: float
end: float
speaker: int = 0
class Transcript(BaseModel):
text: str = ""
translation: str | None = None
words: list[Word] = None
@property
def raw_text(self):
# Uncensored text
return "".join([word.text for word in self.words])
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
def _get_censored_text(self, text: str):
return profanity_filter.censor(text).strip()
@property
def text(self):
# Censored text
return self._get_censored_text(self.raw_text)
@property
def human_timestamp(self):
minutes = int(self.timestamp / 60)
@@ -74,7 +103,6 @@ class Transcript(BaseModel):
self.words = other.words
else:
self.words.extend(other.words)
self.text += other.text
def add_offset(self, offset: float):
for word in self.words:
@@ -87,6 +115,51 @@ class Transcript(BaseModel):
]
return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class TitleSummary(BaseModel):
title: str
@@ -103,6 +176,10 @@ class TitleSummary(BaseModel):
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
class TitleSummaryWithId(TitleSummary):
id: str
class FinalLongSummary(BaseModel):
long_summary: str
duration: float
@@ -318,3 +395,8 @@ class TranslationLanguages(BaseModel):
def is_supported(self, lang_id: str) -> bool:
return lang_id in self.supported_languages
class AudioDiarizationInput(BaseModel):
audio_url: str
topics: list[TitleSummaryWithId]

View File

@@ -0,0 +1,50 @@
import functools
import json
import redis
from reflector.settings import settings
redis_clients = {}
def get_redis_client(db=0):
"""
Get a Redis client for the specified database.
"""
if db not in redis_clients:
redis_clients[db] = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=db,
)
return redis_clients[db]
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
"""
Cache the result of a function in Redis.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Check if the first argument is a string
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
return func(*args, **kwargs)
# Compute the cache key based on the arguments and prefix
cache_key = prefix + ":" + args[argidx]
redis_client = get_redis_client(db=db)
cached_result = redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result.decode("utf-8"))
# If the result is not cached, call the original function
result = func(*args, **kwargs)
redis_client.setex(cache_key, duration, json.dumps(result))
return result
return wrapper
return decorator

View File

@@ -2,7 +2,11 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
OPENMP_KMP_DUPLICATE_LIB_OK: bool = False
@@ -37,7 +41,7 @@ class Settings(BaseSettings):
AUDIO_BUFFER_SIZE: int = 256 * 960
# Audio Transcription
# backends: whisper, banana, modal
# backends: whisper, modal
TRANSCRIPT_BACKEND: str = "whisper"
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
@@ -46,24 +50,20 @@ class Settings(BaseSettings):
TRANSLATE_URL: str | None = None
TRANSLATE_TIMEOUT: int = 90
# Audio transcription banana.dev configuration
TRANSCRIPT_BANANA_API_KEY: str | None = None
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
# Audio transcription modal.com configuration
TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM
# available backend: openai, banana, modal, oobabooga
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
# LLM common configuration
@@ -78,13 +78,14 @@ class Settings(BaseSettings):
LLM_TEMPERATURE: float = 0.7
ZEPHYR_LLM_URL: str | None = None
# LLM Banana configuration
LLM_BANANA_API_KEY: str | None = None
LLM_BANANA_MODEL_KEY: str | None = None
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
# Sentry
SENTRY_DSN: str | None = None
@@ -109,5 +110,26 @@ class Settings(BaseSettings):
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
# Celery
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
# Redis
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
REDIS_CACHE_DB: int = 2
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
# Profiling
PROFILING: bool = False
# Healthcheck
HEALTHCHECK_URL: str | None = None
settings = Settings()

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel
from reflector.settings import settings
import importlib
class FileResult(BaseModel):
@@ -17,7 +18,7 @@ class Storage:
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name, settings_prefix=""):
def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
@@ -45,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str):
raise NotImplementedError
async def get_file_url(self, filename: str) -> str:
return await self._get_file_url(filename)
async def _get_file_url(self, filename: str) -> str:
raise NotImplementedError

View File

@@ -1,6 +1,6 @@
import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage):
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
Body=data,
)
async def _get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600,
)
return FileResult(
filename=filename,
url=presigned_url,
)
return presigned_url
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name

View File

@@ -0,0 +1,14 @@
import argparse
from reflector.app import celery_app # noqa
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
parser.add_argument("--delay", action="store_true")
args = parser.parse_args()
if args.delay:
task_pipeline_main_post.delay(args.transcript_id)
else:
task_pipeline_main_post(args.transcript_id)

View File

@@ -1,7 +1,7 @@
import os
from typing import BinaryIO
from fastapi import HTTPException, Request, status
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
@@ -57,6 +57,9 @@ def range_requests_response(
),
}
if request.method == "HEAD":
return Response(headers=headers)
if content_disposition:
headers["Content-Disposition"] = content_disposition

View File

@@ -1,7 +1,5 @@
import asyncio
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
from json import loads
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -10,25 +8,7 @@ from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
from reflector.pipelines.runner import PipelineRunner
sessions = []
router = APIRouter()
@@ -38,7 +18,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
class TranscriptionContext(object):
def __init__(self, logger):
self.logger = logger
self.pipeline = None
self.pipeline_runner = None
self.data_channel = None
self.status = "idle"
self.topics = []
@@ -60,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx
frame = await self.track.recv()
try:
await ctx.pipeline.push(frame)
ctx.pipeline_runner.push(frame)
except Exception as e:
ctx.logger.error("Pipeline error", error=e)
return frame
@@ -71,27 +51,10 @@ class RtcOffer(BaseModel):
type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
STATUS = "STATUS"
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
FINAL_TITLE = "FINAL_TITLE"
async def rtc_offer_base(
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
pipeline_runner: PipelineRunner,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -101,146 +64,10 @@ async def rtc_offer_base(
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
args=event_callback_args,
data=transcript,
)
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=on_final_short_summary
),
]
),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# handle RTC peer connection
pc = RTCPeerConnection()
ctx.pipeline_runner = pipeline_runner
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True):
# may be called twice
@@ -249,12 +76,10 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing.
# 2. or the client close the connection
# and there is nothing to do because it is already closed
await update_status("processing")
await ctx.pipeline.flush()
ctx.pipeline_runner.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
await update_status("ended")
if pc in sessions:
sessions.remove(pc)
m_rtc_sessions.dec()
@@ -287,7 +112,6 @@ async def rtc_offer_base(
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer)
@@ -308,8 +132,3 @@ async def rtc_clean_sessions(_):
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
return await rtc_offer_base(params, request)

View File

@@ -1,213 +1,33 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional
from uuid import uuid4
from datetime import datetime, timedelta
from typing import Annotated, Literal, Optional
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi_pagination import Page, paginate
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.db.transcripts import (
TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
# ==============================================================
# Models to move to a database, but required for the API to work
# ==============================================================
ALGORITHM = "HS256"
DOWNLOAD_EXPIRE_MINUTES = 60
def generate_uuid4():
return str(uuid4())
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
transcript: str
timestamp: float
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
if existing_topic:
existing_topic.update_from(topic)
else:
self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
query = transcripts.select().where(transcripts.c.user_id == user_id)
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict):
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
transcripts_controller = TranscriptController()
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ==============================================================
@@ -217,16 +37,20 @@ transcripts_controller = TranscriptController()
class GetTranscript(BaseModel):
id: str
user_id: str | None
name: str
status: str
locked: bool
duration: int
duration: float
title: str | None
short_summary: str | None
long_summary: str | None
created_at: datetime
source_language: str
target_language: str
share_mode: str = Field("private")
source_language: str | None
target_language: str | None
participants: list[TranscriptParticipant] | None
reviewed: bool
class CreateTranscript(BaseModel):
@@ -241,6 +65,9 @@ class UpdateTranscript(BaseModel):
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
participants: Optional[list[TranscriptParticipant]] = Field(None)
reviewed: Optional[bool] = Field(None)
class DeletionStatus(BaseModel):
@@ -251,11 +78,20 @@ class DeletionStatus(BaseModel):
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
from reflector.db import database
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
return paginate(await transcripts_controller.get_all(user_id=user_id))
return await paginate(
database,
await transcripts_controller.get_all(
user_id=user_id,
order_by="-created_at",
return_query=True,
),
)
@router.post("/transcripts", response_model=GetTranscript)
@@ -277,16 +113,117 @@ async def transcripts_create(
# ==============================================================
class GetTranscriptSegmentTopic(BaseModel):
text: str
start: float
speaker: int
class GetTranscriptTopic(BaseModel):
id: str
title: str
summary: str
timestamp: float
duration: float | None
transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
if not topic.words:
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.transcript
duration = None
segments = [
GetTranscriptSegmentTopic(
text=topic.transcript,
start=topic.timestamp,
speaker=0,
)
]
else:
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
duration = transcript.duration
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
start=segment.start,
speaker=segment.speaker,
)
for segment in transcript.as_segments()
]
return cls(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
transcript=text,
segments=segments,
duration=duration,
)
class GetTranscriptTopicWithWords(GetTranscriptTopic):
words: list[Word] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
instance = super().from_transcript_topic(topic)
if topic.words:
instance.words = topic.words
return instance
class SpeakerWords(BaseModel):
speaker: int
words: list[Word]
class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
words_per_speaker: list[SpeakerWords] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
instance = super().from_transcript_topic(topic)
if topic.words:
words_per_speakers = []
# group words by speaker
words = []
for word in topic.words:
if words and words[-1].speaker != word.speaker:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
words = []
words.append(word)
if words:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
instance.words_per_speaker = words_per_speakers
return instance
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
return await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
@@ -299,32 +236,7 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {"events": []}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
transcript_event["long_summary"] = info.long_summary
break
values["events"].extend(transcript.events)
if info.short_summary is not None:
values["short_summary"] = info.short_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
transcript_event["short_summary"] = info.short_summary
break
values["events"].extend(transcript.events)
if info.title is not None:
values["title"] = info.title
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
transcript_event["title"] = info.title
break
values["events"].extend(transcript.events)
values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values)
return transcript
@@ -342,255 +254,63 @@ async def transcript_delete(
return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopic
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
# ==============================================================
# Websocket Manager
# ==============================================================
class WebsocketManager:
def __init__(self):
self.active_connections = {}
async def connect(self, transcript_id: str, websocket: WebSocket):
await websocket.accept()
if transcript_id not in self.active_connections:
self.active_connections[transcript_id] = []
self.active_connections[transcript_id].append(websocket)
def disconnect(self, transcript_id: str, websocket: WebSocket):
if transcript_id not in self.active_connections:
return
self.active_connections[transcript_id].remove(websocket)
if not self.active_connections[transcript_id]:
del self.active_connections[transcript_id]
async def send_json(self, transcript_id: str, message):
if transcript_id not in self.active_connections:
return
for connection in self.active_connections[transcript_id][:]:
try:
await connection.send_json(message)
except Exception:
self.active_connections[transcript_id].remove(connection)
ws_manager = WebsocketManager()
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
@router.get(
"/transcripts/{transcript_id}/topics/with-words",
response_model=list[GetTranscriptTopicWithWords],
)
async def transcript_get_topics_with_words(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await ws_manager.connect(transcript_id, websocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
try:
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
ws_manager.disconnect(transcript_id, websocket)
# ==============================================================
# Web RTC
# ==============================================================
async def handle_rtc_event(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
# print(f"Event: {event}", args, data)
transcript_id = args
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
# event send to websocket clients may not be the same as the event
# received from the pipeline. For example, the pipeline will send
# a TRANSCRIPT event with all words, but this is not what we want
# to send to the websocket client.
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(
event=event,
data=TranscriptText(text=data.text, translation=data.translation),
)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
},
)
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript.text,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"topics": transcript.topics_dump(),
},
)
elif event == PipelineEvent.FINAL_TITLE:
final_title = TranscriptFinalTitle(title=data.title)
resp = transcript.add_event(event=event, data=final_title)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"title": final_title.title,
},
)
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
resp = transcript.add_event(event=event, data=final_long_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"long_summary": final_long_summary.long_summary,
},
)
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
resp = transcript.add_event(event=event, data=final_short_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"short_summary": final_short_summary.short_summary,
},
)
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"status": data.value,
},
)
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_mp3_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopicWithWords
return [
GetTranscriptTopicWithWords.from_transcript_topic(topic)
for topic in transcript.topics
]
@router.get(
"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker",
response_model=GetTranscriptTopicWithWordsPerSpeaker,
)
async def transcript_get_topics_with_words_per_speaker(
transcript_id: str,
topic_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# get the topic from the transcript
topic = next((t for t in transcript.topics if t.id == topic_id), None)
if not topic:
raise HTTPException(status_code=404, detail="Topic not found")
# convert to GetTranscriptTopicWithWordsPerSpeaker
return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic)

View File

@@ -0,0 +1,115 @@
"""
Transcripts audio related endpoints
===================================
"""
from typing import Annotated, Optional
import httpx
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
from ._range_requests_response import range_requests_response
router = APIRouter()
@router.get(
"/transcripts/{transcript_id}/audio/mp3",
operation_id="transcript_get_audio_mp3",
)
@router.head(
"/transcripts/{transcript_id}/audio/mp3",
operation_id="transcript_head_audio_mp3",
)
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform

View File

@@ -0,0 +1,143 @@
"""
Transcript participants API endpoints
=====================================
"""
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus
router = APIRouter()
class Participant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
speaker: int | None
name: str
class CreateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: str
class UpdateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: Optional[str] = Field(None)
@router.get("/transcripts/{transcript_id}/participants")
async def transcript_get_participants(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> list[Participant]:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
return [
Participant.model_validate(participant)
for participant in transcript.participants
]
@router.post("/transcripts/{transcript_id}/participants")
async def transcript_add_participant(
transcript_id: str,
participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
if participant.speaker is not None:
for p in transcript.participants:
if p.speaker == participant.speaker:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
obj = await transcripts_controller.upsert_participant(
transcript, TranscriptParticipant(**participant.dict())
)
return Participant.model_validate(obj)
@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_get_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
for p in transcript.participants:
if p.id == participant_id:
return Participant.model_validate(p)
raise HTTPException(status_code=404, detail="Participant not found")
@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_update_participant(
transcript_id: str,
participant_id: str,
participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
for p in transcript.participants:
if p.speaker == participant.speaker and p.id != participant_id:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
# find the participant
obj = None
for p in transcript.participants:
if p.id == participant_id:
obj = p
break
if not obj:
raise HTTPException(status_code=404, detail="Participant not found")
# update participant but just the fields that are set
fields = participant.dict(exclude_unset=True)
obj = obj.copy(update=fields)
await transcripts_controller.upsert_participant(transcript, obj)
return Participant.model_validate(obj)
@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_delete_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> DeletionStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok")

View File

@@ -0,0 +1,170 @@
"""
Reassign speakers in a transcript
=================================
"""
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from reflector.db.transcripts import transcripts_controller
router = APIRouter()
class SpeakerAssignment(BaseModel):
speaker: Optional[int] = Field(None, ge=0)
participant: Optional[str] = Field(None)
timestamp_from: float
timestamp_to: float
class SpeakerAssignmentStatus(BaseModel):
status: str
class SpeakerMerge(BaseModel):
speaker_from: int
speaker_to: int
@router.patch("/transcripts/{transcript_id}/speaker/assign")
async def transcript_assign_speaker(
transcript_id: str,
assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if assignment.speaker is None and assignment.participant is None:
raise HTTPException(
status_code=400,
detail="Either speaker or participant must be provided",
)
if assignment.speaker is not None and assignment.participant is not None:
raise HTTPException(
status_code=400,
detail="Only one of speaker or participant must be provided",
)
# if it's a participant, search for it
if assignment.speaker is not None:
speaker = assignment.speaker
elif assignment.participant is not None:
participant = next(
(
participant
for participant in transcript.participants
if participant.id == assignment.participant
),
None,
)
if not participant:
raise HTTPException(
status_code=404,
detail="Participant not found",
)
# if the participant does not have a speaker, create one
if participant.speaker is None:
participant.speaker = transcript.find_empty_speaker()
await transcripts_controller.upsert_participant(transcript, participant)
speaker = participant.speaker
# reassign speakers from words in the transcript
ts_from = assignment.timestamp_from
ts_to = assignment.timestamp_to
changed_topics = []
for topic in transcript.topics:
changed = False
for word in topic.words:
if ts_from <= word.start <= ts_to:
word.speaker = speaker
changed = True
if changed:
changed_topics.append(topic)
# batch changes
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"topics": transcript.topics_dump(),
},
)
return SpeakerAssignmentStatus(status="ok")
@router.patch("/transcripts/{transcript_id}/speaker/merge")
async def transcript_merge_speaker(
transcript_id: str,
merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# ensure both speaker are not assigned to the 2 differents participants
participant_from = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_from
),
None,
)
participant_to = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_to
),
None,
)
if participant_from and participant_to:
raise HTTPException(
status_code=400,
detail="Both speakers are assigned to participants",
)
# reassign speakers from words in the transcript
speaker_from = merge.speaker_from
speaker_to = merge.speaker_to
changed_topics = []
for topic in transcript.topics:
changed = False
for word in topic.words:
if word.speaker == speaker_from:
word.speaker = speaker_to
changed = True
if changed:
changed_topics.append(topic)
# batch changes
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"topics": transcript.topics_dump(),
},
)
return SpeakerAssignmentStatus(status="ok")

View File

@@ -0,0 +1,79 @@
from typing import Annotated, Optional
import av
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
router = APIRouter()
class UploadStatus(BaseModel):
status: str
@router.post("/transcripts/{transcript_id}/record/upload")
async def transcript_record_upload(
transcript_id: str,
file: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# ensure there is no other upload in the directory (searching data_path/upload.*)
if any(transcript.data_path.glob("upload.*")):
raise HTTPException(
status_code=400, detail="There is already an upload in progress"
)
# save the file to the transcript folder
extension = file.filename.split(".")[-1]
upload_filename = transcript.data_path / f"upload.{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
# ensure the file is back to the beginning
await file.seek(0)
# save the file to the transcript folder
try:
with open(upload_filename, "wb") as f:
while True:
chunk = await file.read(16384)
if not chunk:
break
f.write(chunk)
except Exception:
upload_filename.unlink()
raise
# ensure the file have audio part, using av
# XXX Trying to do this check on the initial UploadFile object is not
# possible, dunno why. UploadFile.file has no name.
# Trying to pass UploadFile.file with format=extension does not work
# it never detect audio stream...
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise HTTPException(status_code=400, detail="File has no audio stream")
except Exception:
# delete the uploaded file
upload_filename.unlink()
raise
finally:
container.close()
# set the status to "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"})
# launch a background task to process the file
task_pipeline_upload.delay(transcript_id=transcript_id)
return UploadStatus(status="ok")

View File

@@ -0,0 +1,37 @@
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request
from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,53 @@
"""
Transcripts websocket API
=========================
"""
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from reflector.db.transcripts import transcripts_controller
from reflector.ws_manager import get_ws_manager
router = APIRouter()
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class DeletionStatus(BaseModel):
status: str

View File

@@ -0,0 +1,32 @@
import celery
import structlog
from celery import Celery
from reflector.settings import settings
logger = structlog.get_logger(__name__)
if celery.current_app.main != "default":
logger.info(f"Celery already configured ({celery.current_app})")
app = celery.current_app
else:
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
"reflector.worker.healthcheck",
]
)
# crontab
app.conf.beat_schedule = {}
if settings.HEALTHCHECK_URL:
app.conf.beat_schedule["healthcheck_ping"] = {
"task": "reflector.worker.healthcheck.healthcheck_ping",
"schedule": 60.0 * 10,
}
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
else:
logger.warning("Healthcheck disabled, no url configured")

View File

@@ -0,0 +1,18 @@
import httpx
import structlog
from celery import shared_task
from reflector.settings import settings
logger = structlog.get_logger(__name__)
@shared_task
def healthcheck_ping():
url = settings.HEALTHCHECK_URL
if not url:
return
try:
print("pinging healthcheck url", url)
httpx.get(url, timeout=10)
except Exception as e:
logger.error("healthcheck_ping", error=str(e))

View File

@@ -0,0 +1,126 @@
"""
Websocket manager
=================
This module contains the WebsocketManager class, which is responsible for
managing websockets and handling websocket connections.
It uses the RedisPubSubManager class to subscribe to Redis channels and
broadcast messages to all connected websockets.
"""
import asyncio
import json
import threading
import redis.asyncio as redis
from fastapi import WebSocket
from reflector.settings import settings
class RedisPubSubManager:
def __init__(self, host="localhost", port=6379):
self.redis_host = host
self.redis_port = port
self.redis_connection = None
self.pubsub = None
async def get_redis_connection(self) -> redis.Redis:
return redis.Redis(
host=self.redis_host,
port=self.redis_port,
auto_close_connection_pool=False,
)
async def connect(self) -> None:
if self.redis_connection is not None:
return
self.redis_connection = await self.get_redis_connection()
self.pubsub = self.redis_connection.pubsub()
async def disconnect(self) -> None:
if self.redis_connection is None:
return
await self.redis_connection.close()
self.redis_connection = None
async def send_json(self, room_id: str, message: str) -> None:
if not self.redis_connection:
await self.connect()
message = json.dumps(message)
await self.redis_connection.publish(room_id, message)
async def subscribe(self, room_id: str) -> redis.Redis:
await self.pubsub.subscribe(room_id)
return self.pubsub
async def unsubscribe(self, room_id: str) -> None:
await self.pubsub.unsubscribe(room_id)
class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
await websocket.accept()
if room_id in self.rooms:
self.rooms[room_id].append(websocket)
else:
self.rooms[room_id] = [websocket]
await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket)
if len(self.rooms[room_id]) == 0:
del self.rooms[room_id]
await self.pubsub_client.unsubscribe(room_id)
async def _pubsub_data_reader(self, pubsub_subscriber):
while True:
message = await pubsub_subscriber.get_message(
ignore_subscribe_messages=True
)
if message is not None:
room_id = message["channel"].decode("utf-8")
all_sockets = self.rooms[room_id]
for socket in all_sockets:
data = json.loads(message["data"].decode("utf-8"))
await socket.send_json(data)
def get_ws_manager() -> WebsocketManager:
"""
Returns the WebsocketManager instance for managing websockets.
This function initializes and returns the WebsocketManager instance,
which is responsible for managing websockets and handling websocket
connections.
Returns:
WebsocketManager: The initialized WebsocketManager instance.
Raises:
ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server.
"""
local = threading.local()
if hasattr(local, "ws_manager"):
return local.ws_manager
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
local.ws_manager = ws_manager
return ws_manager

View File

@@ -4,4 +4,13 @@ if [ -f "/venv/bin/activate" ]; then
source /venv/bin/activate
fi
alembic upgrade head
python -m reflector.app
if [ "${ENTRYPOINT}" = "server" ]; then
python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info
elif [ "${ENTRYPOINT}" = "beat" ]; then
celery -A reflector.worker.app beat --loglevel=info
else
echo "Unknown command"
fi

View File

@@ -1,4 +1,5 @@
from unittest.mock import patch
from tempfile import NamedTemporaryFile
import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -36,7 +36,13 @@ def dummy_processors():
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
mock_translate.return_value = "Bonjour le monde"
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa
yield (
mock_translate,
mock_topic,
mock_title,
mock_long_summary,
mock_short_summary,
) # noqa
@pytest.fixture
@@ -45,28 +51,50 @@ async def dummy_transcript():
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text=" world."),
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
@@ -81,6 +109,25 @@ async def dummy_llm():
yield
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -98,7 +145,96 @@ def ensure_casing():
@pytest.fixture
def sentence_tokenize():
with patch(
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield
@pytest.fixture
async def fake_transcript_with_topics(tmpdir):
from reflector.settings import settings
from reflector.app import app
from reflector.views.transcripts import transcripts_controller
from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
from pathlib import Path
from httpx import AsyncClient
import shutil
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
# create some topics
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 1",
summary="Topic 1 summary",
timestamp=0,
transcript="Hello world",
words=[
Word(text="Hello", start=0, end=1, speaker=0),
Word(text="world", start=1, end=2, speaker=0),
],
),
)
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 2",
summary="Topic 2 summary",
timestamp=2,
transcript="Hello world",
words=[
Word(text="Hello", start=2, end=3, speaker=0),
Word(text="world", start=3, end=4, speaker=0),
],
),
)
yield transcript

View File

@@ -0,0 +1,140 @@
import pytest
from unittest import mock
@pytest.mark.parametrize(
"name,diarization,expected",
[
[
"no overlap",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"same speaker",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "A"},
],
["A", "A", "A", "A"],
],
[
# first segment is removed because it overlap
# with the second segment, and it is smaller
"overlap at 0.5s",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 0.5, "end": 2.0, "speaker": "B"},
],
["B", "B", "B", "B"],
],
[
"junk segment at 0.5s for 0.2s",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 0.5, "end": 0.7, "speaker": "B"},
{"start": 1, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"start without diarization",
[
{"start": 0.5, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"end missing diarization",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 1.5, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"continuation of next speaker",
[
{"start": 0.0, "end": 0.9, "speaker": "A"},
{"start": 1.5, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"continuation of previous speaker",
[
{"start": 0.0, "end": 0.5, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"segment without words",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
{"start": 2.0, "end": 3.0, "speaker": "X"},
],
["A", "A", "B", "B"],
],
],
)
@pytest.mark.asyncio
async def test_processors_audio_diarization(event_loop, name, diarization, expected):
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.types import (
TitleSummaryWithId,
Transcript,
Word,
AudioDiarizationInput,
)
# create fake topic
topics = [
TitleSummaryWithId(
id="1",
title="Title1",
summary="Summary1",
timestamp=0.0,
duration=1.0,
transcript=Transcript(
words=[
Word(text="Word1", start=0.0, end=0.5),
Word(text="word2.", start=0.5, end=1.0),
]
),
),
TitleSummaryWithId(
id="2",
title="Title2",
summary="Summary2",
timestamp=0.0,
duration=1.0,
transcript=Transcript(
words=[
Word(text="Word3", start=1.0, end=1.5),
Word(text="word4.", start=1.5, end=2.0),
]
),
),
]
diarizer = AudioDiarizationProcessor()
with mock.patch.object(diarizer, "_diarize") as mock_diarize:
mock_diarize.return_value = diarization
data = AudioDiarizationInput(
audio_url="https://example.com/audio.mp3",
topics=topics,
)
await diarizer._push(data)
# check that the speaker has been assigned to the words
assert topics[0].transcript.words[0].speaker == expected[0]
assert topics[0].transcript.words[1].speaker == expected[1]
assert topics[1].transcript.words[0].speaker == expected[2]
assert topics[1].transcript.words[1].speaker == expected[3]

View File

@@ -0,0 +1,161 @@
def test_processor_transcript_segment():
from reflector.processors.types import Transcript, Word
transcript = Transcript(
words=[
Word(text=" the", start=5.12, end=5.48, speaker=0),
Word(text=" different", start=5.48, end=5.8, speaker=0),
Word(text=" projects", start=5.8, end=6.3, speaker=0),
Word(text=" that", start=6.3, end=6.5, speaker=0),
Word(text=" are", start=6.5, end=6.58, speaker=0),
Word(text=" going", start=6.58, end=6.82, speaker=0),
Word(text=" on", start=6.82, end=7.26, speaker=0),
Word(text=" to", start=7.26, end=7.4, speaker=0),
Word(text=" give", start=7.4, end=7.54, speaker=0),
Word(text=" you", start=7.54, end=7.9, speaker=0),
Word(text=" context", start=7.9, end=8.24, speaker=0),
Word(text=" and", start=8.24, end=8.66, speaker=0),
Word(text=" I", start=8.66, end=8.72, speaker=0),
Word(text=" think", start=8.72, end=8.82, speaker=0),
Word(text=" that's", start=8.82, end=9.04, speaker=0),
Word(text=" what", start=9.04, end=9.12, speaker=0),
Word(text=" we'll", start=9.12, end=9.24, speaker=0),
Word(text=" do", start=9.24, end=9.32, speaker=0),
Word(text=" this", start=9.32, end=9.52, speaker=0),
Word(text=" week.", start=9.52, end=9.76, speaker=0),
Word(text=" Um,", start=10.24, end=10.62, speaker=0),
Word(text=" so,", start=11.36, end=11.94, speaker=0),
Word(text=" um,", start=12.46, end=12.92, speaker=0),
Word(text=" what", start=13.74, end=13.94, speaker=0),
Word(text=" we're", start=13.94, end=14.1, speaker=0),
Word(text=" going", start=14.1, end=14.24, speaker=0),
Word(text=" to", start=14.24, end=14.34, speaker=0),
Word(text=" do", start=14.34, end=14.8, speaker=0),
Word(text=" at", start=14.8, end=14.98, speaker=0),
Word(text=" H", start=14.98, end=15.04, speaker=0),
Word(text=" of", start=15.04, end=15.16, speaker=0),
Word(text=" you,", start=15.16, end=15.26, speaker=0),
Word(text=" maybe.", start=15.28, end=15.34, speaker=0),
Word(text=" you", start=15.36, end=15.52, speaker=0),
Word(text=" can", start=15.52, end=15.62, speaker=0),
Word(text=" introduce", start=15.62, end=15.98, speaker=0),
Word(text=" yourself", start=15.98, end=16.42, speaker=0),
Word(text=" to", start=16.42, end=16.68, speaker=0),
Word(text=" the", start=16.68, end=16.72, speaker=0),
Word(text=" team", start=16.72, end=17.52, speaker=0),
Word(text=" quickly", start=17.87, end=18.65, speaker=0),
Word(text=" and", start=18.65, end=19.63, speaker=0),
Word(text=" Oh,", start=20.91, end=21.55, speaker=0),
Word(text=" this", start=21.67, end=21.83, speaker=0),
Word(text=" is", start=21.83, end=22.17, speaker=0),
Word(text=" a", start=22.17, end=22.35, speaker=0),
Word(text=" reflector", start=22.35, end=22.89, speaker=0),
Word(text=" translating", start=22.89, end=23.33, speaker=0),
Word(text=" into", start=23.33, end=23.73, speaker=0),
Word(text=" French", start=23.73, end=23.95, speaker=0),
Word(text=" for", start=23.95, end=24.15, speaker=0),
Word(text=" me.", start=24.15, end=24.43, speaker=0),
Word(text=" This", start=27.87, end=28.19, speaker=0),
Word(text=" is", start=28.19, end=28.45, speaker=0),
Word(text=" all", start=28.45, end=28.79, speaker=0),
Word(text=" the", start=28.79, end=29.15, speaker=0),
Word(text=" way,", start=29.15, end=29.15, speaker=0),
Word(text=" please,", start=29.53, end=29.59, speaker=0),
Word(text=" please,", start=29.73, end=29.77, speaker=0),
Word(text=" please,", start=29.77, end=29.83, speaker=0),
Word(text=" please.", start=29.83, end=29.97, speaker=0),
Word(text=" Yeah,", start=29.97, end=30.17, speaker=0),
Word(text=" that's", start=30.25, end=30.33, speaker=0),
Word(text=" all", start=30.33, end=30.49, speaker=0),
Word(text=" it's", start=30.49, end=30.69, speaker=0),
Word(text=" right.", start=30.69, end=30.69, speaker=0),
Word(text=" Right.", start=30.72, end=30.98, speaker=1),
Word(text=" Yeah,", start=31.56, end=31.72, speaker=2),
Word(text=" that's", start=31.86, end=31.98, speaker=2),
Word(text=" right.", start=31.98, end=32.2, speaker=2),
Word(text=" Because", start=32.38, end=32.46, speaker=0),
Word(text=" I", start=32.46, end=32.58, speaker=0),
Word(text=" thought", start=32.58, end=32.78, speaker=0),
Word(text=" I'd", start=32.78, end=33.0, speaker=0),
Word(text=" be", start=33.0, end=33.02, speaker=0),
Word(text=" able", start=33.02, end=33.18, speaker=0),
Word(text=" to", start=33.18, end=33.34, speaker=0),
Word(text=" pull", start=33.34, end=33.52, speaker=0),
Word(text=" out.", start=33.52, end=33.68, speaker=0),
Word(text=" Yeah,", start=33.7, end=33.9, speaker=0),
Word(text=" that", start=33.9, end=34.02, speaker=0),
Word(text=" was", start=34.02, end=34.24, speaker=0),
Word(text=" the", start=34.24, end=34.34, speaker=0),
Word(text=" one", start=34.34, end=34.44, speaker=0),
Word(text=" before", start=34.44, end=34.7, speaker=0),
Word(text=" that.", start=34.7, end=35.24, speaker=0),
Word(text=" Friends,", start=35.84, end=36.46, speaker=0),
Word(text=" if", start=36.64, end=36.7, speaker=0),
Word(text=" you", start=36.7, end=36.7, speaker=0),
Word(text=" have", start=36.7, end=37.24, speaker=0),
Word(text=" tell", start=37.24, end=37.44, speaker=0),
Word(text=" us", start=37.44, end=37.68, speaker=0),
Word(text=" if", start=37.68, end=37.82, speaker=0),
Word(text=" it's", start=37.82, end=38.04, speaker=0),
Word(text=" good,", start=38.04, end=38.58, speaker=0),
Word(text=" exceptionally", start=38.96, end=39.1, speaker=0),
Word(text=" good", start=39.1, end=39.6, speaker=0),
Word(text=" and", start=39.6, end=39.86, speaker=0),
Word(text=" tell", start=39.86, end=40.0, speaker=0),
Word(text=" us", start=40.0, end=40.06, speaker=0),
Word(text=" when", start=40.06, end=40.2, speaker=0),
Word(text=" it's", start=40.2, end=40.34, speaker=0),
Word(text=" exceptionally", start=40.34, end=40.6, speaker=0),
Word(text=" bad.", start=40.6, end=40.94, speaker=0),
Word(text=" We", start=40.96, end=41.26, speaker=0),
Word(text=" don't", start=41.26, end=41.44, speaker=0),
Word(text=" need", start=41.44, end=41.66, speaker=0),
Word(text=" that", start=41.66, end=41.82, speaker=0),
Word(text=" at", start=41.82, end=41.94, speaker=0),
Word(text=" the", start=41.94, end=41.98, speaker=0),
Word(text=" middle", start=41.98, end=42.18, speaker=0),
Word(text=" of", start=42.18, end=42.36, speaker=0),
Word(text=" age.", start=42.36, end=42.7, speaker=0),
Word(text=" Okay,", start=43.26, end=43.44, speaker=0),
Word(text=" yeah,", start=43.68, end=43.76, speaker=0),
Word(text=" that", start=43.78, end=44.3, speaker=0),
Word(text=" sentence", start=44.3, end=44.72, speaker=0),
Word(text=" right", start=44.72, end=45.1, speaker=0),
Word(text=" before.", start=45.1, end=45.56, speaker=0),
Word(text=" it", start=46.08, end=46.36, speaker=0),
Word(text=" realizing", start=46.36, end=47.0, speaker=0),
Word(text=" that", start=47.0, end=47.28, speaker=0),
Word(text=" I", start=47.28, end=47.28, speaker=0),
Word(text=" was", start=47.28, end=47.64, speaker=0),
Word(text=" saying", start=47.64, end=48.06, speaker=0),
Word(text=" that", start=48.06, end=48.44, speaker=0),
Word(text=" it's", start=48.44, end=48.54, speaker=0),
Word(text=" interesting", start=48.54, end=48.78, speaker=0),
Word(text=" that", start=48.78, end=48.96, speaker=0),
Word(text=" it's", start=48.96, end=49.08, speaker=0),
Word(text=" translating", start=49.08, end=49.32, speaker=0),
Word(text=" the", start=49.32, end=49.56, speaker=0),
Word(text=" French", start=49.56, end=49.76, speaker=0),
Word(text=" was", start=49.76, end=50.16, speaker=0),
Word(text=" completely", start=50.16, end=50.4, speaker=0),
Word(text=" wrong.", start=50.4, end=50.7, speaker=0),
]
)
segments = transcript.as_segments()
assert len(segments) == 7
# check speaker order
assert segments[0].speaker == 0
assert segments[1].speaker == 0
assert segments[2].speaker == 0
assert segments[3].speaker == 1
assert segments[4].speaker == 2
assert segments[5].speaker == 0
assert segments[6].speaker == 0
# check the timing (first entry, and first of others speakers)
assert segments[0].start == 5.12
assert segments[3].start == 30.72
assert segments[4].start == 31.56
assert segments[5].start == 32.38

View File

@@ -1,3 +1,4 @@
import asyncio
import pytest
import httpx
from reflector.utils.retry import (
@@ -8,6 +9,31 @@ from reflector.utils.retry import (
)
@pytest.mark.asyncio
async def test_retry_redirect(httpx_mock):
async def custom_response(request: httpx.Request):
if request.url.path == "/hello":
await asyncio.sleep(1)
return httpx.Response(
status_code=303, headers={"location": "https://test_url/redirected"}
)
elif request.url.path == "/redirected":
return httpx.Response(status_code=200, json={"hello": "world"})
else:
raise Exception("Unexpected path")
httpx_mock.add_callback(custom_response)
async with httpx.AsyncClient() as client:
# timeout should not triggered, as it will end up ok
# even though the first request is a 303 and took more that 0.5
resp = await retry(client.get)(
"https://test_url/hello",
retry_timeout=0.5,
follow_redirects=True,
)
assert resp.json() == {"hello": "world"}
@pytest.mark.asyncio
async def test_retry_httpx(httpx_mock):
# this code should be force a retry

View File

@@ -196,3 +196,29 @@ async def test_transcript_delete():
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_transcript_mark_reviewed():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True

View File

@@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test get 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
assert response.status_code == 404
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["/mp3", "audio/mpeg"],
],
)
async def test_transcript_audio_download_head(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test head 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
assert response.status_code == 404
@pytest.mark.asyncio
@pytest.mark.parametrize(
@@ -90,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) >= 255

View File

@@ -0,0 +1,164 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_participants():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
# create a participant
transcript_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create another one with a speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create another one with the same speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_transcript_participants_update_name():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# update the participant
participant_id = response.json()["id"]
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create another participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# delete the participant 1
response = await ac.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# ensure participant2 name is still there
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1

View File

@@ -32,7 +32,7 @@ class ThreadedUvicorn:
@pytest.fixture
async def appserver(tmpdir):
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
from reflector.settings import settings
from reflector.app import app
@@ -52,12 +52,23 @@ async def appserver(tmpdir):
settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
@@ -95,6 +106,7 @@ async def test_transcript_rtc_and_websocket(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -121,14 +133,20 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task
websocket_task.cancel()
@@ -167,31 +185,47 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
@@ -232,6 +266,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -265,6 +300,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
await client.stop()
# wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
await asyncio.sleep(2)
# stop websocket task
@@ -306,7 +353,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"

View File

@@ -0,0 +1,401 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# merge speakers
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# create 2 participants
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with non-existing participant_id
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404

View File

@@ -0,0 +1,26 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# get words per speakers
response = await ac.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2

View File

@@ -0,0 +1,61 @@
import pytest
import asyncio
from httpx import AsyncClient
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,
ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
dummy_storage,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
f"/transcripts/{tid}/record/upload",
files={
"file": (
"test_short.wav",
open("tests/records/test_short.wav", "rb"),
"audio/mpeg",
)
},
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait the processing to finish
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]

2
www/.env_template Normal file
View File

@@ -0,0 +1,2 @@
FIEF_CLIENT_SECRET=<omitted, ask in zulip>
ZULIP_API_KEY=<omitted, ask in zulip>

2
www/.gitignore vendored
View File

@@ -39,3 +39,5 @@ next-env.d.ts
# Sentry Auth Token
.sentryclirc
config.ts

View File

@@ -1,11 +1,18 @@
"use client";
import { FiefAuthProvider } from "@fief/fief/nextjs/react";
import { createContext } from "react";
export default function FiefWrapper({ children }) {
export const CookieContext = createContext<{ hasAuthCookie: boolean }>({
hasAuthCookie: false,
});
export default function FiefWrapper({ children, hasAuthCookie }) {
return (
<FiefAuthProvider currentUserPath="/api/current-user">
{children}
</FiefAuthProvider>
<CookieContext.Provider value={{ hasAuthCookie }}>
<FiefAuthProvider currentUserPath="/api/current-user">
{children}
</FiefAuthProvider>
</CookieContext.Provider>
);
}

View File

@@ -1,29 +1,23 @@
"use client";
import {
useFiefIsAuthenticated,
useFiefUserinfo,
} from "@fief/fief/nextjs/react";
import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react";
import Link from "next/link";
export default function UserInfo() {
const isAuthenticated = useFiefIsAuthenticated();
const userinfo = useFiefUserinfo();
return !isAuthenticated ? (
<span className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2">
<Link href="/login" className="outline-none">
Log in or create account
<Link href="/login" className="outline-none" prefetch={false}>
Log in
</Link>
</span>
) : (
<span className="font-light px-2">
{userinfo?.email} (
<span className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px]">
<Link href="/logout" className="outline-none">
<Link href="/logout" className="outline-none" prefetch={false}>
Log out
</Link>
</span>
)
</span>
);
}

View File

@@ -3,7 +3,8 @@ import React, { createContext, useContext, useState } from "react";
interface ErrorContextProps {
error: Error | null;
setError: React.Dispatch<React.SetStateAction<Error | null>>;
humanMessage?: string;
setError: (error: Error, humanMessage?: string) => void;
}
const ErrorContext = createContext<ErrorContextProps | undefined>(undefined);
@@ -22,9 +23,16 @@ interface ErrorProviderProps {
export const ErrorProvider: React.FC<ErrorProviderProps> = ({ children }) => {
const [error, setError] = useState<Error | null>(null);
const [humanMessage, setHumanMessage] = useState<string | undefined>();
const declareError = (error, humanMessage?) => {
setError(error);
setHumanMessage(humanMessage);
};
return (
<ErrorContext.Provider value={{ error, setError }}>
<ErrorContext.Provider
value={{ error, setError: declareError, humanMessage }}
>
{children}
</ErrorContext.Provider>
);

View File

@@ -4,29 +4,51 @@ import { useEffect, useState } from "react";
import * as Sentry from "@sentry/react";
const ErrorMessage: React.FC = () => {
const { error, setError } = useError();
const { error, setError, humanMessage } = useError();
const [isVisible, setIsVisible] = useState<boolean>(false);
// Setup Shortcuts
useEffect(() => {
const handleKeyPress = (event: KeyboardEvent) => {
switch (event.key) {
case "^":
throw new Error("Unhandled Exception thrown by '^' shortcut");
case "$":
setError(
new Error("Unhandled Exception thrown by '$' shortcut"),
"You did this to yourself",
);
}
};
document.addEventListener("keydown", handleKeyPress);
return () => document.removeEventListener("keydown", handleKeyPress);
}, []);
useEffect(() => {
if (error) {
setIsVisible(true);
Sentry.captureException(error);
console.error("Error", error.message, error);
if (humanMessage) {
setIsVisible(true);
Sentry.captureException(Error(humanMessage, { cause: error }));
} else {
Sentry.captureException(error);
}
console.error("Error", error);
}
}, [error]);
if (!isVisible || !error) return null;
if (!isVisible || !humanMessage) return null;
return (
<button
onClick={() => {
setIsVisible(false);
setError(null);
}}
className="max-w-xs z-50 fixed bottom-5 right-5 md:bottom-10 md:right-10 border-solid bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded transition-opacity duration-300 ease-out opacity-100 hover:opacity-80 focus-visible:opacity-80 cursor-pointer transform hover:scale-105 focus-visible:scale-105"
role="alert"
>
<span className="block sm:inline">{error?.message}</span>
<span className="block sm:inline">{humanMessage}</span>
</button>
);
};

View File

@@ -0,0 +1,94 @@
"use client";
import React, { useState } from "react";
import { GetTranscript } from "../../api";
import { Title } from "../../lib/textComponents";
import Pagination from "./pagination";
import Link from "next/link";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons";
import useTranscriptList from "../transcripts/useTranscriptList";
export default function TranscriptBrowser() {
const [page, setPage] = useState<number>(1);
const { loading, response } = useTranscriptList(page);
return (
<div>
{/*
<div className="flex flex-row gap-2">
<input className="text-sm p-2 w-80 ring-1 ring-slate-900/10 shadow-sm rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 caret-blue-500" placeholder="Search" />
</div>
*/}
<div className="flex flex-row gap-2 items-center">
<Title className="mb-5 mt-5 flex-1">Past transcripts</Title>
<Pagination
page={page}
setPage={setPage}
total={response?.total || 0}
size={response?.size || 0}
/>
</div>
{loading && (
<div className="full-screen flex flex-col items-center justify-center">
<FontAwesomeIcon
icon={faGear}
className="animate-spin-slow h-14 w-14 md:h-20 md:w-20"
/>
</div>
)}
{!loading && !response && (
<div className="text-gray-500">
No transcripts found, but you can&nbsp;
<Link href="/transcripts/new" className="underline">
record a meeting
</Link>
&nbsp;to get started.
</div>
)}
<div /** center and max 900px wide */ className="mx-auto max-w-[900px]">
<div className="grid grid-cols-1 gap-2 lg:gap-4 h-full">
{response?.items.map((item: GetTranscript) => (
<div
key={item.id}
className="flex flex-col bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4"
>
<div className="flex flex-col">
<div className="flex flex-row gap-2 items-start">
<Link
href={`/transcripts/${item.id}`}
className="text-1xl font-semibold flex-1 pl-0 hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
>
{item.title || item.name}
</Link>
{item.locked ? (
<div className="inline-block bg-red-500 text-white px-2 py-1 rounded-full text-xs font-semibold">
Locked
</div>
) : (
<></>
)}
{item.source_language ? (
<div className="inline-block bg-blue-500 text-white px-2 py-1 rounded-full text-xs font-semibold">
{item.source_language}
</div>
) : (
<></>
)}
</div>
<div className="text-xs text-gray-700">
{new Date(item.created_at).toLocaleDateString("en-US")}
</div>
<div className="text-sm">{item.short_summary}</div>
</div>
</div>
))}
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,75 @@
type PaginationProps = {
page: number;
setPage: (page: number) => void;
total: number;
size: number;
};
export default function Pagination(props: PaginationProps) {
const { page, setPage, total, size } = props;
const totalPages = Math.ceil(total / size);
const pageNumbers = Array.from(
{ length: totalPages },
(_, i) => i + 1,
).filter((pageNumber) => {
if (totalPages <= 3) {
// If there are 3 or fewer total pages, show all pages.
return true;
} else if (page <= 2) {
// For the first two pages, show the first 3 pages.
return pageNumber <= 3;
} else if (page >= totalPages - 1) {
// For the last two pages, show the last 3 pages.
return pageNumber >= totalPages - 2;
} else {
// For all other cases, show 3 pages centered around the current page.
return pageNumber >= page - 1 && pageNumber <= page + 1;
}
});
const canGoPrevious = page > 1;
const canGoNext = page < totalPages;
const handlePageChange = (newPage: number) => {
if (newPage >= 1 && newPage <= totalPages) {
setPage(newPage);
}
};
return (
<div className="flex justify-center space-x-4 my-4">
<button
className={`w-10 h-10 rounded-full p-2 border border-gray-300 disabled:bg-white ${
canGoPrevious ? "text-gray-500" : "text-gray-300"
}`}
onClick={() => handlePageChange(page - 1)}
disabled={!canGoPrevious}
>
<i className="fa fa-chevron-left">&lt;</i>
</button>
{pageNumbers.map((pageNumber) => (
<button
key={pageNumber}
className={`w-10 h-10 rounded-full p-2 border ${
page === pageNumber ? "border-gray-600" : "border-gray-300"
} rounded`}
onClick={() => handlePageChange(pageNumber)}
>
{pageNumber}
</button>
))}
<button
className={`w-10 h-10 rounded-full p-2 border border-gray-300 disabled:bg-white ${
canGoNext ? "text-gray-500" : "text-gray-300"
}`}
onClick={() => handlePageChange(page + 1)}
disabled={!canGoNext}
>
<i className="fa fa-chevron-right">&gt;</i>
</button>
</div>
);
}

View File

@@ -0,0 +1,50 @@
"use client";
import { createContext, useContext, useEffect, useState } from "react";
import { DomainConfig } from "../lib/edgeConfig";
type DomainContextType = Omit<DomainConfig, "auth_callback_url">;
export const DomainContext = createContext<DomainContextType>({
features: {
requireLogin: false,
privacy: true,
browse: false,
sendToZulip: false,
},
api_url: "",
websocket_url: "",
zulip_streams: "",
});
export const DomainContextProvider = ({
config,
children,
}: {
config: DomainConfig;
children: any;
}) => {
const [context, setContext] = useState<DomainContextType>();
useEffect(() => {
if (!config) return;
const { auth_callback_url, ...others } = config;
setContext(others);
}, [config]);
if (!context) return;
return (
<DomainContext.Provider value={context}>{children}</DomainContext.Provider>
);
};
// Get feature config client-side with
export const featureEnabled = (
featureName: "requireLogin" | "privacy" | "browse" | "sendToZulip",
) => {
const context = useContext(DomainContext);
return context.features[featureName] as boolean | undefined;
};
// Get config server-side (out of react) : see lib/edgeConfig.

166
www/app/[domain]/layout.tsx Normal file
View File

@@ -0,0 +1,166 @@
import "../styles/globals.scss";
import { Poppins } from "next/font/google";
import { Metadata, Viewport } from "next";
import FiefWrapper from "../(auth)/fiefWrapper";
import UserInfo from "../(auth)/userInfo";
import { ErrorProvider } from "../(errors)/errorContext";
import ErrorMessage from "../(errors)/errorMessage";
import Image from "next/image";
import Link from "next/link";
import About from "../(aboutAndPrivacy)/about";
import Privacy from "../(aboutAndPrivacy)/privacy";
import { DomainContextProvider } from "./domainContext";
import { getConfig } from "../lib/edgeConfig";
import { ErrorBoundary } from "@sentry/nextjs";
import { cookies } from "next/dist/client/components/headers";
import { SESSION_COOKIE_NAME } from "../lib/fief";
const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] });
export const viewport: Viewport = {
themeColor: "black",
width: "device-width",
initialScale: 1,
maximumScale: 1,
};
export const metadata: Metadata = {
metadataBase: new URL(process.env.DEV_URL || "https://reflector.media"),
title: {
template: "%s Reflector",
default: "Reflector - AI-Powered Meeting Transcriptions by Monadical",
},
description:
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise",
applicationName: "Reflector",
referrer: "origin-when-cross-origin",
keywords: ["Reflector", "Monadical", "AI", "Meetings", "Transcription"],
authors: [{ name: "Monadical Team", url: "https://monadical.com/team.html" }],
formatDetection: {
email: false,
address: false,
telephone: false,
},
openGraph: {
title: "Reflector",
description:
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise.",
type: "website",
},
twitter: {
card: "summary_large_image",
title: "Reflector",
description:
"Reflector is an AI-powered tool that transcribes your meetings with unparalleled accuracy, divides content by topics, and provides insightful summaries. Maximize your productivity with Reflector, brought to you by Monadical. Capture the signal, not the noise.",
images: ["/r-icon.png"],
},
icons: {
icon: "/r-icon.png",
shortcut: "/r-icon.png",
apple: "/r-icon.png",
},
robots: { index: false, follow: false, noarchive: true, noimageindex: true },
};
type LayoutProps = {
params: {
domain: string;
};
children: any;
};
export default async function RootLayout({ children, params }: LayoutProps) {
const config = await getConfig(params.domain);
const { requireLogin, privacy, browse } = config.features;
const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME);
return (
<html lang="en">
<body className={poppins.className + " h-screen relative"}>
<FiefWrapper hasAuthCookie={hasAuthCookie}>
<DomainContextProvider config={config}>
<ErrorBoundary fallback={<p>"something went really wrong"</p>}>
<ErrorProvider>
<ErrorMessage />
<div
id="container"
className="items-center h-[100svh] w-[100svw] p-2 md:p-4 grid grid-rows-layout gap-2 md:gap-4"
>
<header className="flex justify-between items-center w-full">
{/* Logo on the left */}
<Link
href="/"
className="flex outline-blue-300 md:outline-none focus-visible:underline underline-offset-2 decoration-[.5px] decoration-gray-500"
>
<Image
src="/reach.png"
width={16}
height={16}
className="h-10 w-auto"
alt="Reflector"
/>
<div className="hidden flex-col ml-2 md:block">
<h1 className="text-[38px] font-bold tracking-wide leading-tight">
Reflector
</h1>
<p className="text-gray-500 text-xs tracking-tighter">
Capture the signal, not the noise
</p>
</div>
</Link>
<div>
{/* Text link on the right */}
<Link
href="/transcripts/new"
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
>
Create
</Link>
{browse ? (
<>
&nbsp;·&nbsp;
<Link
href="/browse"
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
prefetch={false}
>
Browse
</Link>
</>
) : (
<></>
)}
&nbsp;·&nbsp;
<About buttonText="About" />
{privacy ? (
<>
&nbsp;·&nbsp;
<Privacy buttonText="Privacy" />
</>
) : (
<></>
)}
{requireLogin ? (
<>
&nbsp;·&nbsp;
<UserInfo />
</>
) : (
<></>
)}
</div>
</header>
{children}
</div>
</ErrorProvider>
</ErrorBoundary>
</DomainContextProvider>
</FiefWrapper>
</body>
</html>
);
}

View File

@@ -0,0 +1,154 @@
"use client";
import Modal from "../modal";
import useTranscript from "../useTranscript";
import useTopics from "../useTopics";
import useWaveform from "../useWaveform";
import useMp3 from "../useMp3";
import { TopicList } from "../topicList";
import { Topic } from "../webSocketTypes";
import React, { useEffect, useState } from "react";
import "../../../styles/button.css";
import FinalSummary from "../finalSummary";
import ShareLink from "../shareLink";
import QRCode from "react-qr-code";
import TranscriptTitle from "../transcriptTitle";
import ShareModal from "./shareModal";
import Player from "../player";
import WaveformLoading from "../waveformLoading";
import { useRouter } from "next/navigation";
import { featureEnabled } from "../../domainContext";
import { toShareMode } from "../../../lib/shareMode";
type TranscriptDetails = {
params: {
transcriptId: string;
};
};
export default function TranscriptDetails(details: TranscriptDetails) {
const transcriptId = details.params.transcriptId;
const router = useRouter();
const transcript = useTranscript(transcriptId);
const topics = useTopics(transcriptId);
const waveform = useWaveform(transcriptId);
const useActiveTopic = useState<Topic | null>(null);
const mp3 = useMp3(transcriptId);
const [showModal, setShowModal] = useState(false);
useEffect(() => {
const statusToRedirect = ["idle", "recording", "processing"];
if (statusToRedirect.includes(transcript.response?.status)) {
const newUrl = "/transcripts/" + details.params.transcriptId + "/record";
// Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540
router.push(newUrl, undefined);
// history.replaceState({}, "", newUrl);
}
}, [transcript.response?.status]);
const fullTranscript =
topics.topics
?.map((topic) => topic.transcript)
.join("\n\n")
.replace(/ +/g, " ")
.trim() || "";
if (transcript && transcript.response) {
if (transcript.error || topics?.error) {
return (
<Modal
title="Transcription Not Found"
text="A trascription with this ID does not exist."
/>
);
}
if (!transcriptId || transcript?.loading || topics?.loading) {
return <Modal title="Loading" text={"Loading transcript..."} />;
}
return (
<>
{featureEnabled("sendToZulip") && (
<ShareModal
transcript={transcript.response}
topics={topics ? topics.topics : null}
show={showModal}
setShow={(v) => setShowModal(v)}
/>
)}
<div className="flex flex-col">
{transcript?.response?.title && (
<TranscriptTitle
title={transcript.response.title}
transcriptId={transcript.response.id}
/>
)}
{waveform.waveform && mp3.media ? (
<Player
topics={topics?.topics || []}
useActiveTopic={useActiveTopic}
waveform={waveform.waveform}
media={mp3.media}
mediaDuration={transcript.response.duration}
/>
) : waveform.error ? (
<div>"error loading this recording"</div>
) : (
<WaveformLoading />
)}
</div>
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-2 lg:grid-rows-1 gap-2 lg:gap-4 h-full">
<TopicList
topics={topics.topics || []}
useActiveTopic={useActiveTopic}
autoscroll={false}
/>
<div className="w-full h-full grid grid-rows-layout-one grid-cols-1 gap-2 lg:gap-4">
<section className=" bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4 h-full">
{transcript.response.long_summary ? (
<FinalSummary
fullTranscript={fullTranscript}
summary={transcript.response.long_summary}
transcriptId={transcript.response.id}
openZulipModal={() => setShowModal(true)}
/>
) : (
<div className="flex flex-col h-full justify-center content-center">
{transcript.response.status == "processing" ? (
<p>Loading Transcript</p>
) : (
<p>
There was an error generating the final summary, please
come back later
</p>
)}
</div>
)}
</section>
<section className="flex items-center">
<div className="mr-4 hidden md:block h-auto">
<QRCode
value={`${location.origin}/transcripts/${details.params.transcriptId}`}
level="L"
size={98}
/>
</div>
<div className="flex-grow max-w-full">
<ShareLink
transcriptId={transcript?.response?.id}
userId={transcript?.response?.user_id}
shareMode={toShareMode(transcript?.response?.share_mode)}
/>
</div>
</section>
</div>
</div>
</>
);
}
}

View File

@@ -6,14 +6,17 @@ import useWebRTC from "../../useWebRTC";
import useTranscript from "../../useTranscript";
import { useWebSockets } from "../../useWebSockets";
import useAudioDevice from "../../useAudioDevice";
import "../../../styles/button.css";
import "../../../../styles/button.css";
import { Topic } from "../../webSocketTypes";
import getApi from "../../../lib/getApi";
import LiveTrancription from "../../liveTranscription";
import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../lib/wakeLock";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
import { useRouter } from "next/navigation";
import Player from "../../player";
import useMp3 from "../../useMp3";
import WaveformLoading from "../../waveformLoading";
type TranscriptDetails = {
params: {
@@ -37,14 +40,17 @@ const TranscriptRecord = (details: TranscriptDetails) => {
}, []);
const transcript = useTranscript(details.params.transcriptId);
const api = getApi();
const webRTC = useWebRTC(stream, details.params.transcriptId, api);
const webRTC = useWebRTC(stream, details.params.transcriptId);
const webSockets = useWebSockets(details.params.transcriptId);
const { audioDevices, getAudioStream } = useAudioDevice();
const [hasRecorded, setHasRecorded] = useState(false);
const [recordedTime, setRecordedTime] = useState(0);
const [startTime, setStartTime] = useState(0);
const [transcriptStarted, setTranscriptStarted] = useState(false);
let mp3 = useMp3(details.params.transcriptId, true);
const router = useRouter();
useEffect(() => {
if (!transcriptStarted && webSockets.transcriptText.length !== 0)
@@ -52,15 +58,25 @@ const TranscriptRecord = (details: TranscriptDetails) => {
}, [webSockets.transcriptText]);
useEffect(() => {
if (transcript?.response?.longSummary) {
const newUrl = `/transcripts/${transcript.response.id}`;
const statusToRedirect = ["ended", "error"];
//TODO if has no topic and is error, get back to new
if (
statusToRedirect.includes(transcript.response?.status) ||
statusToRedirect.includes(webSockets.status.value)
) {
const newUrl = "/transcripts/" + details.params.transcriptId;
// Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540
// router.push(newUrl, undefined, { shallow: true });
history.replaceState({}, "", newUrl);
}
});
router.replace(newUrl);
// history.replaceState({}, "", newUrl);
} // history.replaceState({}, "", newUrl);
}, [webSockets.status.value, transcript.response?.status]);
useEffect(() => {
if (transcript.response?.status === "ended") mp3.getNow();
}, [transcript.response]);
useEffect(() => {
lockWakeState();
@@ -71,19 +87,32 @@ const TranscriptRecord = (details: TranscriptDetails) => {
return (
<>
<Recorder
setStream={setStream}
onStop={() => {
setStream(null);
setHasRecorded(true);
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
}}
topics={webSockets.topics}
getAudioStream={getAudioStream}
useActiveTopic={useActiveTopic}
isPastMeeting={false}
audioDevices={audioDevices}
/>
{webSockets.waveform && webSockets.duration && mp3?.media ? (
<Player
topics={webSockets.topics || []}
useActiveTopic={useActiveTopic}
waveform={webSockets.waveform}
media={mp3.media}
mediaDuration={webSockets.duration}
/>
) : recordedTime ? (
<WaveformLoading />
) : (
<Recorder
setStream={setStream}
onStop={() => {
setStream(null);
setRecordedTime(Date.now() - startTime);
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
}}
onRecord={() => {
setStartTime(Date.now());
}}
getAudioStream={getAudioStream}
audioDevices={audioDevices}
transcriptId={details.params.transcriptId}
/>
)}
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-mobile-inner lg:grid-rows-1 gap-2 lg:gap-4 h-full">
<TopicList
@@ -95,7 +124,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
<section
className={`w-full h-full bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4`}
>
{!hasRecorded ? (
{!recordedTime ? (
<>
{transcriptStarted && (
<h2 className="md:text-lg font-bold">Transcription</h2>
@@ -129,6 +158,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
couple of minutes. Please do not navigate away from the page
during this time.
</p>
{/* NTH If login required remove last sentence */}
</div>
)}
</section>

View File

@@ -0,0 +1,159 @@
import React, { useContext, useState, useEffect } from "react";
import SelectSearch from "react-select-search";
import { getZulipMessage, sendZulipMessage } from "../../../lib/zulip";
import { GetTranscript, GetTranscriptTopic } from "../../../api";
import "react-select-search/style.css";
import { DomainContext } from "../../domainContext";
type ShareModal = {
show: boolean;
setShow: (show: boolean) => void;
transcript: GetTranscript | null;
topics: GetTranscriptTopic[] | null;
};
interface Stream {
id: number;
name: string;
topics: string[];
}
interface SelectSearchOption {
name: string;
value: string;
}
const ShareModal = (props: ShareModal) => {
const [stream, setStream] = useState<string | undefined>(undefined);
const [topic, setTopic] = useState<string | undefined>(undefined);
const [includeTopics, setIncludeTopics] = useState(false);
const [isLoading, setIsLoading] = useState(true);
const [streams, setStreams] = useState<Stream[]>([]);
const { zulip_streams } = useContext(DomainContext);
useEffect(() => {
fetch(zulip_streams + "/streams.json")
.then((response) => {
if (!response.ok) {
throw new Error("Network response was not ok");
}
return response.json();
})
.then((data) => {
data = data.sort((a: Stream, b: Stream) =>
a.name.localeCompare(b.name),
);
setStreams(data);
setIsLoading(false);
// data now contains the JavaScript object decoded from JSON
})
.catch((error) => {
console.error("There was a problem with your fetch operation:", error);
});
}, []);
const handleSendToZulip = () => {
if (!props.transcript) return;
const msg = getZulipMessage(props.transcript, props.topics, includeTopics);
if (stream && topic) sendZulipMessage(stream, topic, msg);
};
if (props.show && isLoading) {
return <div>Loading...</div>;
}
let streamOptions: SelectSearchOption[] = [];
if (streams) {
streams.forEach((stream) => {
const value = stream.name;
streamOptions.push({ name: value, value: value });
});
}
return (
<div className="absolute">
{props.show && (
<div className="fixed inset-0 bg-gray-600 bg-opacity-50 overflow-y-auto h-full w-full z-50">
<div className="relative top-20 mx-auto p-5 w-96 shadow-lg rounded-md bg-white">
<div className="mt-3 text-center">
<h3 className="font-bold text-xl">Send to Zulip</h3>
{/* Checkbox for 'Include Topics' */}
<div className="mt-4 text-left ml-5">
<label className="flex items-center">
<input
type="checkbox"
className="form-checkbox rounded border-gray-300 text-indigo-600 shadow-sm focus:border-indigo-300 focus:ring focus:ring-indigo-200 focus:ring-opacity-50"
checked={includeTopics}
onChange={(e) => setIncludeTopics(e.target.checked)}
/>
<span className="ml-2">Include topics</span>
</label>
</div>
<div className="flex items-center mt-4">
<span className="mr-2">#</span>
<SelectSearch
search={true}
options={streamOptions}
value={stream}
onChange={(val) => {
setTopic(undefined);
setStream(val.toString());
}}
placeholder="Pick a stream"
/>
</div>
{stream && (
<>
<div className="flex items-center mt-4">
<span className="mr-2 invisible">#</span>
<SelectSearch
search={true}
options={
streams
.find((s) => s.name == stream)
?.topics.sort((a: string, b: string) =>
a.localeCompare(b),
)
.map((t) => ({ name: t, value: t })) || []
}
value={topic}
onChange={(val) => setTopic(val.toString())}
placeholder="Pick a topic"
/>
</div>
</>
)}
<button
className={`bg-blue-400 hover:bg-blue-500 focus-visible:bg-blue-500 text-white rounded py-2 px-4 mr-3 ${
!stream || !topic ? "opacity-50 cursor-not-allowed" : ""
}`}
disabled={!stream || !topic}
onClick={() => {
handleSendToZulip();
props.setShow(false);
}}
>
Send to Zulip
</button>
<button
className="bg-red-500 hover:bg-red-700 focus-visible:bg-red-700 text-white rounded py-2 px-4 mt-4"
onClick={() => props.setShow(false)}
>
Close
</button>
</div>
</div>
</div>
)}
</div>
);
};
export default ShareModal;

Some files were not shown because too many files have changed in this diff Show More