mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge branch 'fix-mp3-download-while-authenticated' into sara/fix-api-auth
This commit is contained in:
5
.github/workflows/test_server.yml
vendored
5
.github/workflows/test_server.yml
vendored
@@ -11,6 +11,11 @@ on:
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:6
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Install poetry
|
||||
|
||||
12
README.md
12
README.md
@@ -6,7 +6,7 @@ The project architecture consists of three primary components:
|
||||
|
||||
* **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
|
||||
* **Back-End**: Python server that offers an API and data persistence, found in `server/`.
|
||||
* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
|
||||
|
||||
It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end.
|
||||
|
||||
@@ -120,6 +120,9 @@ TRANSCRIPT_MODAL_API_KEY=<omitted>
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=<omitted>
|
||||
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
||||
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||
|
||||
AUTH_BACKEND=fief
|
||||
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
|
||||
@@ -138,6 +141,10 @@ Use:
|
||||
poetry run python3 -m reflector.app
|
||||
```
|
||||
|
||||
And start the background worker
|
||||
|
||||
celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
#### Using docker
|
||||
|
||||
Use:
|
||||
@@ -161,4 +168,5 @@ poetry run python -m reflector.tools.process path/to/audio.wav
|
||||
|
||||
## AI Models
|
||||
|
||||
*(Documentation for this section is pending.)*
|
||||
*(Documentation for this section is pending.)*
|
||||
|
||||
|
||||
@@ -5,10 +5,19 @@ services:
|
||||
context: server
|
||||
ports:
|
||||
- 1250:1250
|
||||
environment:
|
||||
LLM_URL: "${LLM_URL}"
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment: ENTRYPOINT=server
|
||||
worker:
|
||||
build:
|
||||
context: server
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment: ENTRYPOINT=worker
|
||||
redis:
|
||||
image: redis:7.2
|
||||
ports:
|
||||
- 6379:6379
|
||||
web:
|
||||
build:
|
||||
context: www
|
||||
@@ -17,4 +26,3 @@ services:
|
||||
|
||||
volumes:
|
||||
model-cache:
|
||||
|
||||
|
||||
@@ -5,11 +5,23 @@ services:
|
||||
context: .
|
||||
ports:
|
||||
- 1250:1250
|
||||
environment:
|
||||
LLM_URL: "${LLM_URL}"
|
||||
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment:
|
||||
ENTRYPOINT: server
|
||||
REDIS_HOST: redis
|
||||
worker:
|
||||
build:
|
||||
context: .
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
environment:
|
||||
ENTRYPOINT: worker
|
||||
REDIS_HOST: redis
|
||||
redis:
|
||||
image: redis:7.2
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
volumes:
|
||||
model-cache:
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""rename back text to transcript
|
||||
|
||||
Revision ID: 38a927dcb099
|
||||
Revises: 9920ecfe2735
|
||||
Create Date: 2023-11-02 19:53:09.116240
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38a927dcb099'
|
||||
down_revision: Union[str, None] = '9920ecfe2735'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "text" in topic:
|
||||
# Rename key 'text' back to 'transcript'
|
||||
topic["transcript"] = topic.pop("text")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "transcript" in topic:
|
||||
# Rename key 'transcript' to 'text'
|
||||
topic["text"] = topic.pop("transcript")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Migration transcript to text field in transcripts table
|
||||
|
||||
Revision ID: 9920ecfe2735
|
||||
Revises: 99365b0cd87b
|
||||
Create Date: 2023-11-02 18:55:17.019498
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9920ecfe2735"
|
||||
down_revision: Union[str, None] = "99365b0cd87b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "transcript" in topic:
|
||||
# Rename key 'transcript' to 'text'
|
||||
topic["text"] = topic.pop("transcript")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# bind the engine
|
||||
bind = op.get_bind()
|
||||
|
||||
# Reflect the table
|
||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||
|
||||
# Select all rows from the transcript table
|
||||
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||
|
||||
for row in results:
|
||||
transcript_id = row["id"]
|
||||
topics_json = row["topics"]
|
||||
|
||||
# Process each topic in the topics JSON array
|
||||
updated_topics = []
|
||||
for topic in topics_json:
|
||||
if "text" in topic:
|
||||
# Rename key 'text' back to 'transcript'
|
||||
topic["transcript"] = topic.pop("text")
|
||||
updated_topics.append(topic)
|
||||
|
||||
# Update the transcript table
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == transcript_id)
|
||||
.values(topics=updated_topics)
|
||||
)
|
||||
307
server/poetry.lock
generated
307
server/poetry.lock
generated
@@ -308,6 +308,20 @@ typing-extensions = ">=4"
|
||||
[package.extras]
|
||||
tz = ["python-dateutil"]
|
||||
|
||||
[[package]]
|
||||
name = "amqp"
|
||||
version = "5.1.1"
|
||||
description = "Low-level AMQP client for Python (fork of amqplib)."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
|
||||
{file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
vine = ">=5.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.6.0"
|
||||
@@ -474,6 +488,17 @@ files = [
|
||||
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "billiard"
|
||||
version = "4.1.0"
|
||||
description = "Python multiprocessing fork with improvements and bugfixes"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"},
|
||||
{file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.9.1"
|
||||
@@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27"
|
||||
[package.extras]
|
||||
crt = ["awscrt (==0.16.26)"]
|
||||
|
||||
[[package]]
|
||||
name = "celery"
|
||||
version = "5.3.4"
|
||||
description = "Distributed Task Queue."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"},
|
||||
{file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
billiard = ">=4.1.0,<5.0"
|
||||
click = ">=8.1.2,<9.0"
|
||||
click-didyoumean = ">=0.3.0"
|
||||
click-plugins = ">=1.1.1"
|
||||
click-repl = ">=0.2.0"
|
||||
kombu = ">=5.3.2,<6.0"
|
||||
python-dateutil = ">=2.8.2"
|
||||
tzdata = ">=2022.7"
|
||||
vine = ">=5.0.0,<6.0"
|
||||
|
||||
[package.extras]
|
||||
arangodb = ["pyArango (>=2.0.2)"]
|
||||
auth = ["cryptography (==41.0.3)"]
|
||||
azureblockblob = ["azure-storage-blob (>=12.15.0)"]
|
||||
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
|
||||
cassandra = ["cassandra-driver (>=3.25.0,<4)"]
|
||||
consul = ["python-consul2 (==0.1.5)"]
|
||||
cosmosdbsql = ["pydocumentdb (==2.3.5)"]
|
||||
couchbase = ["couchbase (>=3.0.0)"]
|
||||
couchdb = ["pycouchdb (==1.14.2)"]
|
||||
django = ["Django (>=2.2.28)"]
|
||||
dynamodb = ["boto3 (>=1.26.143)"]
|
||||
elasticsearch = ["elasticsearch (<8.0)"]
|
||||
eventlet = ["eventlet (>=0.32.0)"]
|
||||
gevent = ["gevent (>=1.5.0)"]
|
||||
librabbitmq = ["librabbitmq (>=2.0.0)"]
|
||||
memcache = ["pylibmc (==1.6.3)"]
|
||||
mongodb = ["pymongo[srv] (>=4.0.2)"]
|
||||
msgpack = ["msgpack (==1.0.5)"]
|
||||
pymemcache = ["python-memcached (==1.59)"]
|
||||
pyro = ["pyro4 (==4.82)"]
|
||||
pytest = ["pytest-celery (==0.0.0)"]
|
||||
redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"]
|
||||
s3 = ["boto3 (>=1.26.143)"]
|
||||
slmq = ["softlayer-messaging (>=1.0.3)"]
|
||||
solar = ["ephem (==4.1.4)"]
|
||||
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
|
||||
sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
|
||||
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
|
||||
yaml = ["PyYAML (>=3.10)"]
|
||||
zookeeper = ["kazoo (>=1.3.1)"]
|
||||
zstd = ["zstandard (==0.21.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.7.22"
|
||||
@@ -744,6 +824,55 @@ files = [
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[[package]]
|
||||
name = "click-didyoumean"
|
||||
version = "0.3.0"
|
||||
description = "Enables git-like *did-you-mean* feature in click"
|
||||
optional = false
|
||||
python-versions = ">=3.6.2,<4.0.0"
|
||||
files = [
|
||||
{file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
|
||||
{file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7"
|
||||
|
||||
[[package]]
|
||||
name = "click-plugins"
|
||||
version = "1.1.1"
|
||||
description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
|
||||
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "click-repl"
|
||||
version = "0.3.0"
|
||||
description = "REPL plugin for Click"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"},
|
||||
{file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7.0"
|
||||
prompt-toolkit = ">=3.0.36"
|
||||
|
||||
[package.extras]
|
||||
testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@@ -981,6 +1110,24 @@ idna = ["idna (>=2.1,<4.0)"]
|
||||
trio = ["trio (>=0.14,<0.23)"]
|
||||
wmi = ["wmi (>=1.5.1,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "ecdsa"
|
||||
version = "0.18.0"
|
||||
description = "ECDSA cryptographic signature library (pure python)"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
files = [
|
||||
{file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"},
|
||||
{file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.9.0"
|
||||
|
||||
[package.extras]
|
||||
gmpy = ["gmpy"]
|
||||
gmpy2 = ["gmpy2"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.100.1"
|
||||
@@ -1624,6 +1771,38 @@ files = [
|
||||
cryptography = ">=3.4"
|
||||
deprecated = "*"
|
||||
|
||||
[[package]]
|
||||
name = "kombu"
|
||||
version = "5.3.2"
|
||||
description = "Messaging library for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"},
|
||||
{file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
amqp = ">=5.1.1,<6.0.0"
|
||||
vine = "*"
|
||||
|
||||
[package.extras]
|
||||
azureservicebus = ["azure-servicebus (>=7.10.0)"]
|
||||
azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"]
|
||||
confluentkafka = ["confluent-kafka (==2.1.1)"]
|
||||
consul = ["python-consul2"]
|
||||
librabbitmq = ["librabbitmq (>=2.0.0)"]
|
||||
mongodb = ["pymongo (>=4.1.1)"]
|
||||
msgpack = ["msgpack"]
|
||||
pyro = ["pyro4"]
|
||||
qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
|
||||
redis = ["redis (>=4.5.2)"]
|
||||
slmq = ["softlayer-messaging (>=1.0.3)"]
|
||||
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
|
||||
sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
|
||||
yaml = ["PyYAML (>=3.10)"]
|
||||
zookeeper = ["kazoo (>=2.8.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "levenshtein"
|
||||
version = "0.21.1"
|
||||
@@ -2151,6 +2330,20 @@ files = [
|
||||
fastapi = ">=0.38.1,<1.0.0"
|
||||
prometheus-client = ">=0.8.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.39"
|
||||
description = "Library for building powerful interactive command lines in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
|
||||
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
wcwidth = "*"
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "4.24.4"
|
||||
@@ -2173,6 +2366,17 @@ files = [
|
||||
{file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.5.0"
|
||||
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||
files = [
|
||||
{file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
|
||||
{file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.21"
|
||||
@@ -2501,6 +2705,20 @@ pytest = ">=7.0.0"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-celery"
|
||||
version = "0.0.0"
|
||||
description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"},
|
||||
{file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
celery = ">=4.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "4.1.0"
|
||||
@@ -2565,6 +2783,28 @@ files = [
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-jose"
|
||||
version = "3.3.0"
|
||||
description = "JOSE implementation in Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"},
|
||||
{file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""}
|
||||
ecdsa = "!=0.15"
|
||||
pyasn1 = "*"
|
||||
rsa = "*"
|
||||
|
||||
[package.extras]
|
||||
cryptography = ["cryptography (>=3.4.0)"]
|
||||
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
||||
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.1"
|
||||
@@ -2744,6 +2984,24 @@ files = [
|
||||
[package.extras]
|
||||
full = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.0.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
|
||||
{file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=1.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2023.10.3"
|
||||
@@ -2862,6 +3120,20 @@ urllib3 = ">=1.21.1,<3"
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "4.9"
|
||||
description = "Pure-Python RSA implementation"
|
||||
optional = false
|
||||
python-versions = ">=3.6,<4"
|
||||
files = [
|
||||
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
|
||||
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.6.2"
|
||||
@@ -3438,6 +3710,17 @@ files = [
|
||||
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tzdata"
|
||||
version = "2023.3"
|
||||
description = "Provider of IANA time zone data"
|
||||
optional = false
|
||||
python-versions = ">=2"
|
||||
files = [
|
||||
{file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
|
||||
{file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "1.26.17"
|
||||
@@ -3523,6 +3806,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak
|
||||
docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
|
||||
test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "vine"
|
||||
version = "5.0.0"
|
||||
description = "Promises, promises, promises."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
|
||||
{file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "0.20.0"
|
||||
@@ -3557,6 +3851,17 @@ files = [
|
||||
[package.dependencies]
|
||||
anyio = ">=3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.8"
|
||||
description = "Measures the displayed width of unicode strings in a terminal"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
|
||||
{file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "11.0.3"
|
||||
@@ -3838,4 +4143,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67"
|
||||
content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe"
|
||||
|
||||
@@ -33,6 +33,9 @@ prometheus-fastapi-instrumentator = "^6.1.0"
|
||||
sentencepiece = "^0.1.99"
|
||||
protobuf = "^4.24.3"
|
||||
profanityfilter = "^2.0.6"
|
||||
celery = "^5.3.4"
|
||||
redis = "^5.0.1"
|
||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -47,6 +50,7 @@ pytest-asyncio = "^0.21.1"
|
||||
pytest = "^7.4.0"
|
||||
httpx-ws = "^0.4.1"
|
||||
pytest-httpx = "^0.23.1"
|
||||
pytest-celery = "^0.0.0"
|
||||
|
||||
|
||||
[tool.poetry.group.aws.dependencies]
|
||||
|
||||
@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
from reflector.worker import app as celery_app # noqa
|
||||
|
||||
|
||||
# simpler openapi id
|
||||
def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
|
||||
@@ -1,32 +1,13 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
database = databases.Database(settings.DATABASE_URL)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
)
|
||||
# import models
|
||||
import reflector.db.transcripts # noqa
|
||||
|
||||
engine = sqlalchemy.create_engine(
|
||||
settings.DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
|
||||
296
server/reflector/db/transcripts.py
Normal file
296
server/reflector/db/transcripts.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String),
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
)
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptSegmentTopic(BaseModel):
|
||||
speaker: int
|
||||
text: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
duration: float | None = 0
|
||||
transcript: str | None = None
|
||||
words: list[ProcessorWord] = []
|
||||
|
||||
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
|
||||
if index is not None:
|
||||
self.topics[index] = topic
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
filter_recording: bool | None = False,
|
||||
) -> list[Transcript]:
|
||||
"""
|
||||
Get all transcripts
|
||||
|
||||
If `user_id` is specified, only return transcripts that belong to the user.
|
||||
Otherwise, return all anonymous transcripts.
|
||||
|
||||
Parameters:
|
||||
- `order_by`: field to order by, e.g. "-created_at"
|
||||
- `filter_empty`: filter out empty transcripts
|
||||
- `filter_recording`: filter out transcripts that are currently recording
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
"""
|
||||
Get a transcript by id
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Add a new transcript
|
||||
"""
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict):
|
||||
"""
|
||||
Update a transcript fields with key/values in values
|
||||
"""
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
transcript_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove a transcript by id
|
||||
"""
|
||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with database.transaction(isolation="serializable"):
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
event: str,
|
||||
data: Any,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(transcript, {"events": transcript.events_dump()})
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
self,
|
||||
transcript: Transcript,
|
||||
topic: TranscriptTopic,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
@@ -47,6 +47,7 @@ class ModalLLM(LLM):
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
|
||||
362
server/reflector/pipelines/main_live_pipeline.py
Normal file
362
server/reflector/pipelines/main_live_pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Main reflector pipeline for live streaming
|
||||
==========================================
|
||||
|
||||
This is the default pipeline used in the API.
|
||||
|
||||
It is decoupled to:
|
||||
- PipelineMainLive: have limited processing during live
|
||||
- PipelineMainPost: do heavy lifting after the live
|
||||
|
||||
It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from celery import shared_task
|
||||
from pydantic import BaseModel
|
||||
from reflector.app import app
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
BroadcastProcessor,
|
||||
Pipeline,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
from reflector.processors.types import (
|
||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||
)
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
"""
|
||||
Decorator to broadcast transcript event to websockets
|
||||
concerning this transcript
|
||||
"""
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
resp = await func(self, *args, **kwargs)
|
||||
if resp is None:
|
||||
return
|
||||
await self.ws_manager.send_json(
|
||||
room_id=self.ws_room_id,
|
||||
message=resp.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
self._lock = asyncio.Lock()
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self._lock:
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
# if it's the first part, update the status of the transcript
|
||||
# but do not set the ended status yet.
|
||||
if isinstance(self, PipelineMainLive):
|
||||
status_mapping = {
|
||||
"started": "recording",
|
||||
"push": "recording",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainDiarization):
|
||||
status_mapping = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
"ended": "ended",
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Runner {self.__class__} is missing status mapping")
|
||||
|
||||
# mutate to model status
|
||||
status = status_mapping.get(status)
|
||||
if not status:
|
||||
return
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if status == transcript.status:
|
||||
return
|
||||
resp = await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
return resp
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_topic(self, data):
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
transcript=data.transcript.text,
|
||||
words=data.transcript.words,
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_short_summary(self, data):
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
audio_filename: Path | None = None
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main live created",
|
||||
transcript_id=self.transcript_id,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
async def on_ended(self):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
"""
|
||||
Diarization is a long time process, so we do it in a separate pipeline
|
||||
When done, adjust the short and final summary
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
topics = [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from reflector.views.transcripts import create_access_token
|
||||
|
||||
path = app.url_path_for(
|
||||
"transcript_get_audio_mp3",
|
||||
transcript_id=transcript.id,
|
||||
)
|
||||
url = f"{settings.BASE_URL}{path}"
|
||||
if transcript.user_id:
|
||||
# we pass token only if the user_id is set
|
||||
# otherwise, the audio is public
|
||||
token = create_access_token(
|
||||
{"sub": transcript.user_id},
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
url += f"?token={token}"
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_url=url,
|
||||
topics=topics,
|
||||
)
|
||||
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main post created", transcript_id=self.transcript_id
|
||||
)
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@shared_task
|
||||
def task_pipeline_main_post(transcript_id: str):
|
||||
logger.info(
|
||||
"Starting main post pipeline",
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||
runner.start_sync()
|
||||
137
server/reflector/pipelines/runner.py
Normal file
137
server/reflector/pipelines/runner.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Pipeline Runner
|
||||
===============
|
||||
|
||||
Pipeline runner designed to be executed in a asyncio task.
|
||||
|
||||
It is meant to be subclassed, and implement a create() method
|
||||
that expose/return a Pipeline instance.
|
||||
|
||||
During its lifecycle, it will emit the following status:
|
||||
- started: the pipeline has been started
|
||||
- push: the pipeline received at least one data
|
||||
- flush: the pipeline is flushing
|
||||
- ended: the pipeline has ended
|
||||
- error: the pipeline has ended with an error
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
|
||||
class PipelineRunner(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._q_cmd = asyncio.Queue()
|
||||
self._ev_done = asyncio.Event()
|
||||
self._is_first_push = True
|
||||
self._logger = logger.bind(
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create(self) -> Pipeline:
|
||||
"""
|
||||
Create the pipeline if not specified earlier.
|
||||
Should be implemented in a subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the pipeline as a coroutine task
|
||||
"""
|
||||
asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
def start_sync(self):
|
||||
"""
|
||||
Start the pipeline synchronously (for non-asyncio apps)
|
||||
"""
|
||||
coro = self.run()
|
||||
asyncio.run(coro)
|
||||
|
||||
def push(self, data):
|
||||
"""
|
||||
Push data to the pipeline
|
||||
"""
|
||||
self._add_cmd("PUSH", data)
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush the pipeline
|
||||
"""
|
||||
self._add_cmd("FLUSH", None)
|
||||
|
||||
async def on_status(self, status):
|
||||
"""
|
||||
Called when the status of the pipeline changes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_ended(self):
|
||||
"""
|
||||
Called when the pipeline ends
|
||||
"""
|
||||
pass
|
||||
|
||||
def _add_cmd(self, cmd: str, data):
|
||||
"""
|
||||
Enqueue a command to be executed in the runner.
|
||||
Currently supported commands: PUSH, FLUSH
|
||||
"""
|
||||
self._q_cmd.put_nowait([cmd, data])
|
||||
|
||||
async def _set_status(self, status):
|
||||
self._logger.debug("Runner status updated", status=status)
|
||||
self.status = status
|
||||
if self.on_status:
|
||||
try:
|
||||
await self.on_status(status)
|
||||
except Exception:
|
||||
self._logger.exception("Runer error while setting status")
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
# create the pipeline if not yet done
|
||||
await self._set_status("init")
|
||||
self._is_first_push = True
|
||||
if not self.pipeline:
|
||||
self.pipeline = await self.create()
|
||||
|
||||
# start the loop
|
||||
await self._set_status("started")
|
||||
while not self._ev_done.is_set():
|
||||
cmd, data = await self._q_cmd.get()
|
||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||
if func:
|
||||
await func(data)
|
||||
else:
|
||||
raise Exception(f"Unknown command {cmd}")
|
||||
except Exception:
|
||||
self._logger.exception("Runner error")
|
||||
await self._set_status("error")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
|
||||
async def cmd_push(self, data):
|
||||
if self._is_first_push:
|
||||
await self._set_status("push")
|
||||
self._is_first_push = False
|
||||
await self.pipeline.push(data)
|
||||
|
||||
async def cmd_flush(self, data):
|
||||
await self._set_status("flush")
|
||||
await self.pipeline.flush()
|
||||
await self._set_status("ended")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
@@ -1,9 +1,16 @@
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
BroadcastProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .transcript_final_long_summary import ( # noqa: F401
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
)
|
||||
|
||||
34
server/reflector/processors/audio_diarization.py
Normal file
34
server/reflector/processors/audio_diarization.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
|
||||
|
||||
class AudioDiarizationProcessor(Processor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
async def _push(self, data: AudioDiarizationInput):
|
||||
try:
|
||||
self.logger.info("Diarization started", audio_file_url=data.audio_url)
|
||||
diarization = await self._diarize(data)
|
||||
self.logger.info("Diarization finished")
|
||||
except Exception:
|
||||
self.logger.exception("Diarization failed after retrying")
|
||||
raise
|
||||
|
||||
# now reapply speaker to topics (if any)
|
||||
# topics is a list[BaseModel] with an attribute words
|
||||
# words is a list[BaseModel] with text, start and speaker attribute
|
||||
|
||||
# mutate in place
|
||||
for topic in data.topics:
|
||||
for word in topic.transcript.words:
|
||||
for d in diarization:
|
||||
if d["start"] <= word.start <= d["end"]:
|
||||
word.speaker = d["speaker"]
|
||||
|
||||
# emit them
|
||||
for topic in data.topics:
|
||||
await self.emit(topic)
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
33
server/reflector/processors/audio_diarization_auto.py
Normal file
33
server/reflector/processors/audio_diarization_auto.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
37
server/reflector/processors/audio_diarization_modal.py
Normal file
37
server/reflector/processors/audio_diarization_modal.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import httpx
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
# Gather diarization data
|
||||
params = {
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
timeout=None,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["text"]
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
|
||||
@@ -1,6 +1,4 @@
|
||||
from profanityfilter import ProfanityFilter
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile, Transcript
|
||||
|
||||
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
self.profanity_filter = ProfanityFilter()
|
||||
self.profanity_filter.set_censor("*")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
raise NotImplementedError
|
||||
|
||||
def filter_profanity(self, text: str) -> str:
|
||||
"""
|
||||
Remove censored words from the transcript
|
||||
"""
|
||||
return self.profanity_filter.censor(text)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.base import Pipeline, Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@@ -13,8 +11,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name):
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
@@ -30,30 +29,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_pipeline(self, pipeline: Pipeline):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
self.processor.connect(processor)
|
||||
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
return await self.processor._push(data)
|
||||
|
||||
async def _flush(self):
|
||||
return await self.processor._flush()
|
||||
return cls._registry[name](**config | kwargs)
|
||||
|
||||
@@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
timeout=self.timeout,
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
@@ -48,10 +49,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
text = result["text"][source_language]
|
||||
text = self.filter_profanity(text)
|
||||
transcript = Transcript(
|
||||
text=text,
|
||||
words=[
|
||||
Word(
|
||||
text=word["text"],
|
||||
|
||||
@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
ts = data.timestamp
|
||||
|
||||
for segment in segments:
|
||||
transcript.text += segment.text
|
||||
for word in segment.words:
|
||||
transcript.words.append(
|
||||
Word(
|
||||
|
||||
@@ -290,12 +290,12 @@ class BroadcastProcessor(Processor):
|
||||
processor.set_pipeline(pipeline)
|
||||
|
||||
async def _push(self, data):
|
||||
for processor in self.processors:
|
||||
await processor.push(data)
|
||||
coros = [processor.push(data) for processor in self.processors]
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def _flush(self):
|
||||
for processor in self.processors:
|
||||
await processor.flush()
|
||||
coros = [processor.flush() for processor in self.processors]
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
for processor in self.processors:
|
||||
@@ -333,6 +333,7 @@ class Pipeline(Processor):
|
||||
self.logger.info("Pipeline created")
|
||||
|
||||
self.processors = processors
|
||||
self.options = None
|
||||
self.prefs = {}
|
||||
|
||||
for processor in processors:
|
||||
|
||||
@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
|
||||
# cut to the next .
|
||||
partial = Transcript(words=[])
|
||||
for word in self.transcript.words[:]:
|
||||
partial.text += word.text
|
||||
partial.words.append(word)
|
||||
if not self.is_sentence_terminated(word.text):
|
||||
continue
|
||||
|
||||
@@ -50,6 +50,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
headers=self.headers,
|
||||
params=json_payload,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["text"]
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
profanity_filter.set_censor("*")
|
||||
|
||||
|
||||
class AudioFile(BaseModel):
|
||||
name: str
|
||||
@@ -43,13 +50,29 @@ class Word(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
text: str = ""
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
def raw_text(self):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return profanity_filter.censor(self.raw_text).strip()
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
minutes = int(self.timestamp / 60)
|
||||
@@ -74,7 +97,6 @@ class Transcript(BaseModel):
|
||||
self.words = other.words
|
||||
else:
|
||||
self.words.extend(other.words)
|
||||
self.text += other.text
|
||||
|
||||
def add_offset(self, offset: float):
|
||||
for word in self.words:
|
||||
@@ -87,6 +109,48 @@ class Transcript(BaseModel):
|
||||
]
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
# from a list of word, create a list of segments
|
||||
# join the word that are less than 2 seconds apart
|
||||
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
|
||||
segments = []
|
||||
current_segment = None
|
||||
MAX_SEGMENT_LENGTH = 120
|
||||
|
||||
for word in self.words:
|
||||
if current_segment is None:
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the word is attach to another speaker, push the current segment
|
||||
# and start a new one
|
||||
if word.speaker != current_segment.speaker:
|
||||
segments.append(current_segment)
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# if the word is the end of a sentence, and we have enough content,
|
||||
# add the word to the current segment and push it
|
||||
current_segment.text += word.text
|
||||
|
||||
have_punc = PUNC_RE.search(word.text)
|
||||
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
||||
segments.append(current_segment)
|
||||
current_segment = None
|
||||
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class TitleSummary(BaseModel):
|
||||
title: str
|
||||
@@ -103,6 +167,10 @@ class TitleSummary(BaseModel):
|
||||
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
class TitleSummaryWithId(TitleSummary):
|
||||
id: str
|
||||
|
||||
|
||||
class FinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
duration: float
|
||||
@@ -318,3 +386,8 @@ class TranslationLanguages(BaseModel):
|
||||
|
||||
def is_supported(self, lang_id: str) -> bool:
|
||||
return lang_id in self.supported_languages
|
||||
|
||||
|
||||
class AudioDiarizationInput(BaseModel):
|
||||
audio_url: str
|
||||
topics: list[TitleSummaryWithId]
|
||||
|
||||
@@ -89,6 +89,10 @@ class Settings(BaseSettings):
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
@@ -113,5 +117,19 @@ class Settings(BaseSettings):
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
|
||||
# Celery
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
|
||||
# Redis
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
|
||||
# Secret key
|
||||
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||
|
||||
# Current hosting/domain
|
||||
BASE_URL: str = "http://localhost:1250"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
|
||||
from reflector.app import celery_app # noqa
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("transcript_id", type=str)
|
||||
parser.add_argument("--delay", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.delay:
|
||||
task_pipeline_main_post.delay(args.transcript_id)
|
||||
else:
|
||||
task_pipeline_main_post(args.transcript_id)
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from json import dumps, loads
|
||||
from pathlib import Path
|
||||
from json import loads
|
||||
|
||||
import av
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
@@ -10,25 +8,7 @@ from prometheus_client import Gauge
|
||||
from pydantic import BaseModel
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
Pipeline,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.types import FinalTitle
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
@@ -38,7 +18,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
|
||||
class TranscriptionContext(object):
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.pipeline = None
|
||||
self.pipeline_runner = None
|
||||
self.data_channel = None
|
||||
self.status = "idle"
|
||||
self.topics = []
|
||||
@@ -60,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
|
||||
ctx = self.ctx
|
||||
frame = await self.track.recv()
|
||||
try:
|
||||
await ctx.pipeline.push(frame)
|
||||
ctx.pipeline_runner.push(frame)
|
||||
except Exception as e:
|
||||
ctx.logger.error("Pipeline error", error=e)
|
||||
return frame
|
||||
@@ -71,27 +51,10 @@ class RtcOffer(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
|
||||
FINAL_TITLE = "FINAL_TITLE"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
pipeline_runner: PipelineRunner,
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -101,146 +64,10 @@ async def rtc_offer_base(
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
if changed:
|
||||
ctx.status = status
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.STATUS,
|
||||
args=event_callback_args,
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TRANSCRIPT,
|
||||
args=event_callback_args,
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Topic", topic=topic)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": topic.title,
|
||||
"timestamp": topic.timestamp,
|
||||
"transcript": topic.transcript.text,
|
||||
"desc": topic.summary,
|
||||
}
|
||||
)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
||||
)
|
||||
|
||||
async def on_final_short_summary(summary: FinalShortSummary):
|
||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
||||
"summary": summary.short_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_long_summary(summary: FinalLongSummary):
|
||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
||||
"summary": summary.long_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_title(title: FinalTitle):
|
||||
ctx.logger.info("FinalTitle", final_title=title)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_TITLE,
|
||||
args=event_callback_args,
|
||||
data=title,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
processors = []
|
||||
if audio_filename is not None:
|
||||
processors += [AudioFileWriterProcessor(path=audio_filename)]
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||
ctx.pipeline.set_pref("audio:target_language", target_language)
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
ctx.pipeline_runner = pipeline_runner
|
||||
ctx.pipeline_runner.start()
|
||||
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
# may be called twice
|
||||
@@ -249,12 +76,10 @@ async def rtc_offer_base(
|
||||
# - when we receive the close event, we do nothing.
|
||||
# 2. or the client close the connection
|
||||
# and there is nothing to do because it is already closed
|
||||
await update_status("processing")
|
||||
await ctx.pipeline.flush()
|
||||
ctx.pipeline_runner.flush()
|
||||
if close:
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
await update_status("ended")
|
||||
if pc in sessions:
|
||||
sessions.remove(pc)
|
||||
m_rtc_sessions.dec()
|
||||
@@ -287,7 +112,6 @@ async def rtc_offer_base(
|
||||
def on_track(track):
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||
asyncio.get_event_loop().create_task(update_status("recording"))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
@@ -308,8 +132,3 @@ async def rtc_clean_sessions(_):
|
||||
logger.debug(f"Closing session {pc}")
|
||||
await pc.close()
|
||||
sessions.clear()
|
||||
|
||||
|
||||
@router.post("/offer")
|
||||
async def rtc_offer(params: RtcOffer, request: Request):
|
||||
return await rtc_offer_base(params, request)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
@@ -12,221 +9,36 @@ from fastapi import (
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.db.transcripts import (
|
||||
AudioWaveform,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==============================================================
|
||||
# Models to move to a database, but required for the API to work
|
||||
# ==============================================================
|
||||
ALGORITHM = "HS256"
|
||||
DOWNLOAD_EXPIRE_MINUTES = 60
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
title: str
|
||||
summary: str
|
||||
transcript: str | None = None
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
if existing_topic:
|
||||
existing_topic.update_from(topic)
|
||||
else:
|
||||
self.topics.append(topic)
|
||||
|
||||
def events_dump(self, mode="json"):
|
||||
return [event.model_dump(mode=mode) for event in self.events]
|
||||
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
filter_recording: bool | None = False,
|
||||
) -> list[Transcript]:
|
||||
query = transcripts.select().where(transcripts.c.user_id == user_id)
|
||||
|
||||
if order_by is not None:
|
||||
field = getattr(transcripts.c, order_by[1:])
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
|
||||
if filter_empty:
|
||||
query = query.filter(transcripts.c.status != "idle")
|
||||
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict):
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> None:
|
||||
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
def create_access_token(data: dict, expires_delta: timedelta):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -298,6 +110,55 @@ async def transcripts_create(
|
||||
# ==============================================================
|
||||
|
||||
|
||||
class GetTranscriptSegmentTopic(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
speaker: int
|
||||
|
||||
|
||||
class GetTranscriptTopic(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
transcript: str
|
||||
segments: list[GetTranscriptSegmentTopic] = []
|
||||
|
||||
@classmethod
|
||||
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||
if not topic.words:
|
||||
# In previous version, words were missing
|
||||
# Just output a segment with speaker 0
|
||||
text = topic.transcript
|
||||
segments = [
|
||||
GetTranscriptSegmentTopic(
|
||||
text=topic.transcript,
|
||||
start=topic.timestamp,
|
||||
speaker=0,
|
||||
)
|
||||
]
|
||||
else:
|
||||
# New versions include words
|
||||
transcript = ProcessorTranscript(words=topic.words)
|
||||
text = transcript.text
|
||||
segments = [
|
||||
GetTranscriptSegmentTopic(
|
||||
text=segment.text,
|
||||
start=segment.start,
|
||||
speaker=segment.speaker,
|
||||
)
|
||||
for segment in transcript.as_segments()
|
||||
]
|
||||
return cls(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
transcript=text,
|
||||
segments=segments,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
@@ -320,32 +181,17 @@ async def transcript_update(
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {"events": []}
|
||||
values = {}
|
||||
if info.name is not None:
|
||||
values["name"] = info.name
|
||||
if info.locked is not None:
|
||||
values["locked"] = info.locked
|
||||
if info.long_summary is not None:
|
||||
values["long_summary"] = info.long_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
transcript_event["long_summary"] = info.long_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.short_summary is not None:
|
||||
values["short_summary"] = info.short_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
transcript_event["short_summary"] = info.short_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.title is not None:
|
||||
values["title"] = info.title
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
||||
transcript_event["title"] = info.title
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
@@ -368,8 +214,21 @@ async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
token: str | None = None,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.JWTError:
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
@@ -406,7 +265,10 @@ async def transcript_get_audio_waveform(
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
@router.get(
|
||||
"/transcripts/{transcript_id}/topics",
|
||||
response_model=list[GetTranscriptTopic],
|
||||
)
|
||||
async def transcript_get_topics(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
@@ -415,7 +277,16 @@ async def transcript_get_topics(
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript.topics
|
||||
|
||||
# convert to GetTranscriptTopic
|
||||
return [
|
||||
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
|
||||
]
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Websocket
|
||||
# ==============================================================
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/events")
|
||||
@@ -423,41 +294,6 @@ async def transcript_get_websocket_events(transcript_id: str):
|
||||
pass
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Websocket Manager
|
||||
# ==============================================================
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, transcript_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if transcript_id not in self.active_connections:
|
||||
self.active_connections[transcript_id] = []
|
||||
self.active_connections[transcript_id].append(websocket)
|
||||
|
||||
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
self.active_connections[transcript_id].remove(websocket)
|
||||
if not self.active_connections[transcript_id]:
|
||||
del self.active_connections[transcript_id]
|
||||
|
||||
async def send_json(self, transcript_id: str, message):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
for connection in self.active_connections[transcript_id][:]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
self.active_connections[transcript_id].remove(connection)
|
||||
|
||||
|
||||
ws_manager = WebsocketManager()
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(
|
||||
transcript_id: str,
|
||||
@@ -469,21 +305,31 @@ async def transcript_events_websocket(
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
await ws_manager.connect(transcript_id, websocket)
|
||||
# connect to websocket manager
|
||||
# use ts:transcript_id as room id
|
||||
room_id = f"ts:{transcript_id}"
|
||||
ws_manager = get_ws_manager()
|
||||
await ws_manager.add_user_to_room(room_id, websocket)
|
||||
|
||||
# on first connection, send all events
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
try:
|
||||
# on first connection, send all events only to the current user
|
||||
for event in transcript.events:
|
||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||
# not necessary to be sent to the client; but keep the rest
|
||||
name = event.event
|
||||
if name in ("TRANSCRIPT", "STATUS"):
|
||||
continue
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
# we do not have command system now,
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (RuntimeError, WebSocketDisconnect):
|
||||
ws_manager.disconnect(transcript_id, websocket)
|
||||
await ws_manager.remove_user_from_room(room_id, websocket)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -491,105 +337,6 @@ async def transcript_events_websocket(
|
||||
# ==============================================================
|
||||
|
||||
|
||||
async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
# OFC the current implementation is not good,
|
||||
# but it's just a POC before persistence. It won't query the
|
||||
# transcript from the database for each event.
|
||||
# print(f"Event: {event}", args, data)
|
||||
transcript_id = args
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
# event send to websocket clients may not be the same as the event
|
||||
# received from the pipeline. For example, the pipeline will send
|
||||
# a TRANSCRIPT event with all words, but this is not what we want
|
||||
# to send to the websocket client.
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(
|
||||
event=event,
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
transcript=data.transcript.text,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"topics": transcript.topics_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_TITLE:
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
resp = transcript.add_event(event=event, data=final_title)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"status": data.value,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
async def transcript_record_webrtc(
|
||||
transcript_id: str,
|
||||
@@ -605,13 +352,14 @@ async def transcript_record_webrtc(
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# create a pipeline runner
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainLive
|
||||
|
||||
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
|
||||
|
||||
# FIXME do not allow multiple recording at the same time
|
||||
return await rtc_offer_base(
|
||||
params,
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_mp3_filename,
|
||||
source_language=transcript.source_language,
|
||||
target_language=transcript.target_language,
|
||||
pipeline_runner=pipeline_runner,
|
||||
)
|
||||
|
||||
12
server/reflector/worker/app.py
Normal file
12
server/reflector/worker/app.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from celery import Celery
|
||||
from reflector.settings import settings
|
||||
|
||||
app = Celery(__name__)
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
app.conf.broker_connection_retry_on_startup = True
|
||||
app.autodiscover_tasks(
|
||||
[
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
]
|
||||
)
|
||||
126
server/reflector/ws_manager.py
Normal file
126
server/reflector/ws_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Websocket manager
|
||||
=================
|
||||
|
||||
This module contains the WebsocketManager class, which is responsible for
|
||||
managing websockets and handling websocket connections.
|
||||
|
||||
It uses the RedisPubSubManager class to subscribe to Redis channels and
|
||||
broadcast messages to all connected websockets.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import WebSocket
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
def __init__(self, host="localhost", port=6379):
|
||||
self.redis_host = host
|
||||
self.redis_port = port
|
||||
self.redis_connection = None
|
||||
self.pubsub = None
|
||||
|
||||
async def get_redis_connection(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
auto_close_connection_pool=False,
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.redis_connection is not None:
|
||||
return
|
||||
self.redis_connection = await self.get_redis_connection()
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.redis_connection is None:
|
||||
return
|
||||
await self.redis_connection.close()
|
||||
self.redis_connection = None
|
||||
|
||||
async def send_json(self, room_id: str, message: str) -> None:
|
||||
if not self.redis_connection:
|
||||
await self.connect()
|
||||
message = json.dumps(message)
|
||||
await self.redis_connection.publish(room_id, message)
|
||||
|
||||
async def subscribe(self, room_id: str) -> redis.Redis:
|
||||
await self.pubsub.subscribe(room_id)
|
||||
return self.pubsub
|
||||
|
||||
async def unsubscribe(self, room_id: str) -> None:
|
||||
await self.pubsub.unsubscribe(room_id)
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||
self.rooms: dict = {}
|
||||
self.pubsub_client = pubsub_client
|
||||
|
||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
if room_id in self.rooms:
|
||||
self.rooms[room_id].append(websocket)
|
||||
else:
|
||||
self.rooms[room_id] = [websocket]
|
||||
|
||||
await self.pubsub_client.connect()
|
||||
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||
|
||||
async def send_json(self, room_id: str, message: dict) -> None:
|
||||
await self.pubsub_client.send_json(room_id, message)
|
||||
|
||||
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
self.rooms[room_id].remove(websocket)
|
||||
|
||||
if len(self.rooms[room_id]) == 0:
|
||||
del self.rooms[room_id]
|
||||
await self.pubsub_client.unsubscribe(room_id)
|
||||
|
||||
async def _pubsub_data_reader(self, pubsub_subscriber):
|
||||
while True:
|
||||
message = await pubsub_subscriber.get_message(
|
||||
ignore_subscribe_messages=True
|
||||
)
|
||||
if message is not None:
|
||||
room_id = message["channel"].decode("utf-8")
|
||||
all_sockets = self.rooms[room_id]
|
||||
for socket in all_sockets:
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
await socket.send_json(data)
|
||||
|
||||
|
||||
def get_ws_manager() -> WebsocketManager:
|
||||
"""
|
||||
Returns the WebsocketManager instance for managing websockets.
|
||||
|
||||
This function initializes and returns the WebsocketManager instance,
|
||||
which is responsible for managing websockets and handling websocket
|
||||
connections.
|
||||
|
||||
Returns:
|
||||
WebsocketManager: The initialized WebsocketManager instance.
|
||||
|
||||
Raises:
|
||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||
"""
|
||||
local = threading.local()
|
||||
if hasattr(local, "ws_manager"):
|
||||
return local.ws_manager
|
||||
|
||||
pubsub_client = RedisPubSubManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
)
|
||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||
local.ws_manager = ws_manager
|
||||
return ws_manager
|
||||
@@ -4,4 +4,11 @@ if [ -f "/venv/bin/activate" ]; then
|
||||
source /venv/bin/activate
|
||||
fi
|
||||
alembic upgrade head
|
||||
python -m reflector.app
|
||||
|
||||
if [ "${ENTRYPOINT}" = "server" ]; then
|
||||
python -m reflector.app
|
||||
elif [ "${ENTRYPOINT}" = "worker" ]; then
|
||||
celery -A reflector.worker.app worker --loglevel=info
|
||||
else
|
||||
echo "Unknown command"
|
||||
fi
|
||||
|
||||
@@ -45,28 +45,50 @@ async def dummy_transcript():
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
|
||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||
async def _transcript(self, data: AudioFile):
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
print("transcripting", source_language)
|
||||
print("pipeline", self.pipeline)
|
||||
print("prefs", self.pipeline.prefs)
|
||||
_time_idx = 0
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return Transcript(
|
||||
text="Hello world.",
|
||||
words=[
|
||||
Word(start=0.0, end=1.0, text="Hello"),
|
||||
Word(start=1.0, end=2.0, text=" world."),
|
||||
Word(start=i, end=i + 1, text="Hello", speaker=0),
|
||||
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_transcript_auto"
|
||||
".AudioTranscriptAutoProcessor.get_instance"
|
||||
".AudioTranscriptAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_diarization():
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
|
||||
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
|
||||
_time_idx = 0
|
||||
|
||||
async def _diarize(self, data):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return [
|
||||
{"start": i, "end": i + 1, "speaker": 0},
|
||||
{"start": i + 1, "end": i + 2, "speaker": 1},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_diarization_auto"
|
||||
".AudioDiarizationAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
@@ -98,7 +120,17 @@ def ensure_casing():
|
||||
@pytest.fixture
|
||||
def sentence_tokenize():
|
||||
with patch(
|
||||
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
|
||||
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
|
||||
) as mock_sent_tokenize:
|
||||
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_enable_logging():
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
return {"broker_url": "memory://", "result_backend": "rpc"}
|
||||
|
||||
161
server/tests/test_processor_transcript_segment.py
Normal file
161
server/tests/test_processor_transcript_segment.py
Normal file
@@ -0,0 +1,161 @@
|
||||
def test_processor_transcript_segment():
|
||||
from reflector.processors.types import Transcript, Word
|
||||
|
||||
transcript = Transcript(
|
||||
words=[
|
||||
Word(text=" the", start=5.12, end=5.48, speaker=0),
|
||||
Word(text=" different", start=5.48, end=5.8, speaker=0),
|
||||
Word(text=" projects", start=5.8, end=6.3, speaker=0),
|
||||
Word(text=" that", start=6.3, end=6.5, speaker=0),
|
||||
Word(text=" are", start=6.5, end=6.58, speaker=0),
|
||||
Word(text=" going", start=6.58, end=6.82, speaker=0),
|
||||
Word(text=" on", start=6.82, end=7.26, speaker=0),
|
||||
Word(text=" to", start=7.26, end=7.4, speaker=0),
|
||||
Word(text=" give", start=7.4, end=7.54, speaker=0),
|
||||
Word(text=" you", start=7.54, end=7.9, speaker=0),
|
||||
Word(text=" context", start=7.9, end=8.24, speaker=0),
|
||||
Word(text=" and", start=8.24, end=8.66, speaker=0),
|
||||
Word(text=" I", start=8.66, end=8.72, speaker=0),
|
||||
Word(text=" think", start=8.72, end=8.82, speaker=0),
|
||||
Word(text=" that's", start=8.82, end=9.04, speaker=0),
|
||||
Word(text=" what", start=9.04, end=9.12, speaker=0),
|
||||
Word(text=" we'll", start=9.12, end=9.24, speaker=0),
|
||||
Word(text=" do", start=9.24, end=9.32, speaker=0),
|
||||
Word(text=" this", start=9.32, end=9.52, speaker=0),
|
||||
Word(text=" week.", start=9.52, end=9.76, speaker=0),
|
||||
Word(text=" Um,", start=10.24, end=10.62, speaker=0),
|
||||
Word(text=" so,", start=11.36, end=11.94, speaker=0),
|
||||
Word(text=" um,", start=12.46, end=12.92, speaker=0),
|
||||
Word(text=" what", start=13.74, end=13.94, speaker=0),
|
||||
Word(text=" we're", start=13.94, end=14.1, speaker=0),
|
||||
Word(text=" going", start=14.1, end=14.24, speaker=0),
|
||||
Word(text=" to", start=14.24, end=14.34, speaker=0),
|
||||
Word(text=" do", start=14.34, end=14.8, speaker=0),
|
||||
Word(text=" at", start=14.8, end=14.98, speaker=0),
|
||||
Word(text=" H", start=14.98, end=15.04, speaker=0),
|
||||
Word(text=" of", start=15.04, end=15.16, speaker=0),
|
||||
Word(text=" you,", start=15.16, end=15.26, speaker=0),
|
||||
Word(text=" maybe.", start=15.28, end=15.34, speaker=0),
|
||||
Word(text=" you", start=15.36, end=15.52, speaker=0),
|
||||
Word(text=" can", start=15.52, end=15.62, speaker=0),
|
||||
Word(text=" introduce", start=15.62, end=15.98, speaker=0),
|
||||
Word(text=" yourself", start=15.98, end=16.42, speaker=0),
|
||||
Word(text=" to", start=16.42, end=16.68, speaker=0),
|
||||
Word(text=" the", start=16.68, end=16.72, speaker=0),
|
||||
Word(text=" team", start=16.72, end=17.52, speaker=0),
|
||||
Word(text=" quickly", start=17.87, end=18.65, speaker=0),
|
||||
Word(text=" and", start=18.65, end=19.63, speaker=0),
|
||||
Word(text=" Oh,", start=20.91, end=21.55, speaker=0),
|
||||
Word(text=" this", start=21.67, end=21.83, speaker=0),
|
||||
Word(text=" is", start=21.83, end=22.17, speaker=0),
|
||||
Word(text=" a", start=22.17, end=22.35, speaker=0),
|
||||
Word(text=" reflector", start=22.35, end=22.89, speaker=0),
|
||||
Word(text=" translating", start=22.89, end=23.33, speaker=0),
|
||||
Word(text=" into", start=23.33, end=23.73, speaker=0),
|
||||
Word(text=" French", start=23.73, end=23.95, speaker=0),
|
||||
Word(text=" for", start=23.95, end=24.15, speaker=0),
|
||||
Word(text=" me.", start=24.15, end=24.43, speaker=0),
|
||||
Word(text=" This", start=27.87, end=28.19, speaker=0),
|
||||
Word(text=" is", start=28.19, end=28.45, speaker=0),
|
||||
Word(text=" all", start=28.45, end=28.79, speaker=0),
|
||||
Word(text=" the", start=28.79, end=29.15, speaker=0),
|
||||
Word(text=" way,", start=29.15, end=29.15, speaker=0),
|
||||
Word(text=" please,", start=29.53, end=29.59, speaker=0),
|
||||
Word(text=" please,", start=29.73, end=29.77, speaker=0),
|
||||
Word(text=" please,", start=29.77, end=29.83, speaker=0),
|
||||
Word(text=" please.", start=29.83, end=29.97, speaker=0),
|
||||
Word(text=" Yeah,", start=29.97, end=30.17, speaker=0),
|
||||
Word(text=" that's", start=30.25, end=30.33, speaker=0),
|
||||
Word(text=" all", start=30.33, end=30.49, speaker=0),
|
||||
Word(text=" it's", start=30.49, end=30.69, speaker=0),
|
||||
Word(text=" right.", start=30.69, end=30.69, speaker=0),
|
||||
Word(text=" Right.", start=30.72, end=30.98, speaker=1),
|
||||
Word(text=" Yeah,", start=31.56, end=31.72, speaker=2),
|
||||
Word(text=" that's", start=31.86, end=31.98, speaker=2),
|
||||
Word(text=" right.", start=31.98, end=32.2, speaker=2),
|
||||
Word(text=" Because", start=32.38, end=32.46, speaker=0),
|
||||
Word(text=" I", start=32.46, end=32.58, speaker=0),
|
||||
Word(text=" thought", start=32.58, end=32.78, speaker=0),
|
||||
Word(text=" I'd", start=32.78, end=33.0, speaker=0),
|
||||
Word(text=" be", start=33.0, end=33.02, speaker=0),
|
||||
Word(text=" able", start=33.02, end=33.18, speaker=0),
|
||||
Word(text=" to", start=33.18, end=33.34, speaker=0),
|
||||
Word(text=" pull", start=33.34, end=33.52, speaker=0),
|
||||
Word(text=" out.", start=33.52, end=33.68, speaker=0),
|
||||
Word(text=" Yeah,", start=33.7, end=33.9, speaker=0),
|
||||
Word(text=" that", start=33.9, end=34.02, speaker=0),
|
||||
Word(text=" was", start=34.02, end=34.24, speaker=0),
|
||||
Word(text=" the", start=34.24, end=34.34, speaker=0),
|
||||
Word(text=" one", start=34.34, end=34.44, speaker=0),
|
||||
Word(text=" before", start=34.44, end=34.7, speaker=0),
|
||||
Word(text=" that.", start=34.7, end=35.24, speaker=0),
|
||||
Word(text=" Friends,", start=35.84, end=36.46, speaker=0),
|
||||
Word(text=" if", start=36.64, end=36.7, speaker=0),
|
||||
Word(text=" you", start=36.7, end=36.7, speaker=0),
|
||||
Word(text=" have", start=36.7, end=37.24, speaker=0),
|
||||
Word(text=" tell", start=37.24, end=37.44, speaker=0),
|
||||
Word(text=" us", start=37.44, end=37.68, speaker=0),
|
||||
Word(text=" if", start=37.68, end=37.82, speaker=0),
|
||||
Word(text=" it's", start=37.82, end=38.04, speaker=0),
|
||||
Word(text=" good,", start=38.04, end=38.58, speaker=0),
|
||||
Word(text=" exceptionally", start=38.96, end=39.1, speaker=0),
|
||||
Word(text=" good", start=39.1, end=39.6, speaker=0),
|
||||
Word(text=" and", start=39.6, end=39.86, speaker=0),
|
||||
Word(text=" tell", start=39.86, end=40.0, speaker=0),
|
||||
Word(text=" us", start=40.0, end=40.06, speaker=0),
|
||||
Word(text=" when", start=40.06, end=40.2, speaker=0),
|
||||
Word(text=" it's", start=40.2, end=40.34, speaker=0),
|
||||
Word(text=" exceptionally", start=40.34, end=40.6, speaker=0),
|
||||
Word(text=" bad.", start=40.6, end=40.94, speaker=0),
|
||||
Word(text=" We", start=40.96, end=41.26, speaker=0),
|
||||
Word(text=" don't", start=41.26, end=41.44, speaker=0),
|
||||
Word(text=" need", start=41.44, end=41.66, speaker=0),
|
||||
Word(text=" that", start=41.66, end=41.82, speaker=0),
|
||||
Word(text=" at", start=41.82, end=41.94, speaker=0),
|
||||
Word(text=" the", start=41.94, end=41.98, speaker=0),
|
||||
Word(text=" middle", start=41.98, end=42.18, speaker=0),
|
||||
Word(text=" of", start=42.18, end=42.36, speaker=0),
|
||||
Word(text=" age.", start=42.36, end=42.7, speaker=0),
|
||||
Word(text=" Okay,", start=43.26, end=43.44, speaker=0),
|
||||
Word(text=" yeah,", start=43.68, end=43.76, speaker=0),
|
||||
Word(text=" that", start=43.78, end=44.3, speaker=0),
|
||||
Word(text=" sentence", start=44.3, end=44.72, speaker=0),
|
||||
Word(text=" right", start=44.72, end=45.1, speaker=0),
|
||||
Word(text=" before.", start=45.1, end=45.56, speaker=0),
|
||||
Word(text=" it", start=46.08, end=46.36, speaker=0),
|
||||
Word(text=" realizing", start=46.36, end=47.0, speaker=0),
|
||||
Word(text=" that", start=47.0, end=47.28, speaker=0),
|
||||
Word(text=" I", start=47.28, end=47.28, speaker=0),
|
||||
Word(text=" was", start=47.28, end=47.64, speaker=0),
|
||||
Word(text=" saying", start=47.64, end=48.06, speaker=0),
|
||||
Word(text=" that", start=48.06, end=48.44, speaker=0),
|
||||
Word(text=" it's", start=48.44, end=48.54, speaker=0),
|
||||
Word(text=" interesting", start=48.54, end=48.78, speaker=0),
|
||||
Word(text=" that", start=48.78, end=48.96, speaker=0),
|
||||
Word(text=" it's", start=48.96, end=49.08, speaker=0),
|
||||
Word(text=" translating", start=49.08, end=49.32, speaker=0),
|
||||
Word(text=" the", start=49.32, end=49.56, speaker=0),
|
||||
Word(text=" French", start=49.56, end=49.76, speaker=0),
|
||||
Word(text=" was", start=49.76, end=50.16, speaker=0),
|
||||
Word(text=" completely", start=50.16, end=50.4, speaker=0),
|
||||
Word(text=" wrong.", start=50.4, end=50.7, speaker=0),
|
||||
]
|
||||
)
|
||||
|
||||
segments = transcript.as_segments()
|
||||
assert len(segments) == 7
|
||||
|
||||
# check speaker order
|
||||
assert segments[0].speaker == 0
|
||||
assert segments[1].speaker == 0
|
||||
assert segments[2].speaker == 0
|
||||
assert segments[3].speaker == 1
|
||||
assert segments[4].speaker == 2
|
||||
assert segments[5].speaker == 0
|
||||
assert segments[6].speaker == 0
|
||||
|
||||
# check the timing (first entry, and first of others speakers)
|
||||
assert segments[0].start == 5.12
|
||||
assert segments[3].start == 30.72
|
||||
assert segments[4].start == 31.56
|
||||
assert segments[5].start == 32.38
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
from reflector.utils.retry import (
|
||||
@@ -8,6 +9,31 @@ from reflector.utils.retry import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_redirect(httpx_mock):
|
||||
async def custom_response(request: httpx.Request):
|
||||
if request.url.path == "/hello":
|
||||
await asyncio.sleep(1)
|
||||
return httpx.Response(
|
||||
status_code=303, headers={"location": "https://test_url/redirected"}
|
||||
)
|
||||
elif request.url.path == "/redirected":
|
||||
return httpx.Response(status_code=200, json={"hello": "world"})
|
||||
else:
|
||||
raise Exception("Unexpected path")
|
||||
|
||||
httpx_mock.add_callback(custom_response)
|
||||
async with httpx.AsyncClient() as client:
|
||||
# timeout should not triggered, as it will end up ok
|
||||
# even though the first request is a 303 and took more that 0.5
|
||||
resp = await retry(client.get)(
|
||||
"https://test_url/hello",
|
||||
retry_timeout=0.5,
|
||||
follow_redirects=True,
|
||||
)
|
||||
assert resp.json() == {"hello": "world"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_httpx(httpx_mock):
|
||||
# this code should be force a retry
|
||||
|
||||
@@ -32,7 +32,7 @@ class ThreadedUvicorn:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def appserver(tmpdir):
|
||||
async def appserver(tmpdir, celery_session_app, celery_session_worker):
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
|
||||
@@ -52,12 +52,20 @@ async def appserver(tmpdir):
|
||||
settings.DATA_DIR = DATA_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(
|
||||
tmpdir,
|
||||
dummy_llm,
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
@@ -95,6 +103,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
@@ -121,14 +130,20 @@ async def test_transcript_rtc_and_websocket(
|
||||
# XXX aiortc is long to close the connection
|
||||
# instead of waiting a long time, we just send a STOP
|
||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
timeout = 20
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
# stop websocket task
|
||||
websocket_task.cancel()
|
||||
@@ -169,29 +184,28 @@ async def test_transcript_rtc_and_websocket(
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
assert statuses.index("recording") < statuses.index("processing")
|
||||
assert statuses.index("processing") < statuses.index("ended")
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# check that transcript status in model is updated
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(
|
||||
tmpdir,
|
||||
dummy_llm,
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
@@ -232,6 +246,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
@@ -265,6 +280,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
timeout = 20
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] == "ended":
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# stop websocket task
|
||||
@@ -306,7 +333,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
assert statuses.index("recording") < statuses.index("processing")
|
||||
assert statuses.index("processing") < statuses.index("ended")
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
|
||||
@@ -3,6 +3,7 @@ import Modal from "../modal";
|
||||
import useTranscript from "../useTranscript";
|
||||
import useTopics from "../useTopics";
|
||||
import useWaveform from "../useWaveform";
|
||||
import useMp3 from "../useMp3";
|
||||
import { TopicList } from "../topicList";
|
||||
import Recorder from "../recorder";
|
||||
import { Topic } from "../webSocketTypes";
|
||||
@@ -28,6 +29,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
||||
const topics = useTopics(protectedPath, transcriptId);
|
||||
const waveform = useWaveform(protectedPath, transcriptId);
|
||||
const useActiveTopic = useState<Topic | null>(null);
|
||||
const mp3 = useMp3(api, transcriptId);
|
||||
|
||||
if (transcript?.error /** || topics?.error || waveform?.error **/) {
|
||||
return (
|
||||
@@ -62,6 +64,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
||||
waveform={waveform?.waveform}
|
||||
isPastMeeting={true}
|
||||
transcriptId={transcript?.response?.id}
|
||||
mp3Blob={mp3.blob}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -30,6 +30,7 @@ type RecorderProps = {
|
||||
waveform?: AudioWaveform | null;
|
||||
isPastMeeting: boolean;
|
||||
transcriptId?: string | null;
|
||||
mp3Blob?: Blob | null;
|
||||
};
|
||||
|
||||
export default function Recorder(props: RecorderProps) {
|
||||
@@ -108,11 +109,7 @@ export default function Recorder(props: RecorderProps) {
|
||||
if (waveformRef.current) {
|
||||
const _wavesurfer = WaveSurfer.create({
|
||||
container: waveformRef.current,
|
||||
url: props.transcriptId
|
||||
? `${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3`
|
||||
: undefined,
|
||||
peaks: props.waveform?.data,
|
||||
|
||||
hideScrollbar: true,
|
||||
autoCenter: true,
|
||||
barWidth: 2,
|
||||
@@ -146,6 +143,10 @@ export default function Recorder(props: RecorderProps) {
|
||||
|
||||
if (props.isPastMeeting) _wavesurfer.toggleInteraction(true);
|
||||
|
||||
if (props.mp3Blob) {
|
||||
_wavesurfer.loadBlob(props.mp3Blob);
|
||||
}
|
||||
|
||||
setWavesurfer(_wavesurfer);
|
||||
|
||||
return () => {
|
||||
@@ -157,6 +158,12 @@ export default function Recorder(props: RecorderProps) {
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!wavesurfer) return;
|
||||
if (!props.mp3Blob) return;
|
||||
wavesurfer.loadBlob(props.mp3Blob);
|
||||
}, [props.mp3Blob]);
|
||||
|
||||
useEffect(() => {
|
||||
topicsRef.current = props.topics;
|
||||
if (!isRecording) renderMarkers();
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
import { formatTime } from "../../lib/time";
|
||||
import ScrollToBottom from "./scrollToBottom";
|
||||
import { Topic } from "./webSocketTypes";
|
||||
import { generateHighContrastColor } from "../../lib/utils";
|
||||
|
||||
type TopicListProps = {
|
||||
topics: Topic[];
|
||||
@@ -103,7 +104,37 @@ export function TopicList({
|
||||
/>
|
||||
</div>
|
||||
{activeTopic?.id == topic.id && (
|
||||
<div className="p-2">{topic.transcript}</div>
|
||||
<div className="p-2">
|
||||
{topic.segments ? (
|
||||
<>
|
||||
{topic.segments.map((segment, index: number) => (
|
||||
<p
|
||||
key={index}
|
||||
className="text-left text-slate-500 text-sm md:text-base"
|
||||
>
|
||||
<span className="font-mono text-slate-500">
|
||||
[{formatTime(segment.start)}]
|
||||
</span>
|
||||
<span
|
||||
className="font-bold text-slate-500"
|
||||
style={{
|
||||
color: generateHighContrastColor(
|
||||
`Speaker ${segment.speaker}`,
|
||||
[96, 165, 250],
|
||||
),
|
||||
}}
|
||||
>
|
||||
{" "}
|
||||
(Speaker {segment.speaker}):
|
||||
</span>{" "}
|
||||
<span>{segment.text}</span>
|
||||
</p>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<>{topic.transcript}</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
))}
|
||||
|
||||
@@ -1,36 +1,64 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import {
|
||||
DefaultApi,
|
||||
V1TranscriptGetAudioMp3Request,
|
||||
// V1TranscriptGetAudioMp3Request,
|
||||
} from "../../api/apis/DefaultApi";
|
||||
import {} from "../../api";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import { DomainContext } from "../domainContext";
|
||||
|
||||
type Mp3Response = {
|
||||
url: string | null;
|
||||
blob: Blob | null;
|
||||
loading: boolean;
|
||||
error: Error | null;
|
||||
};
|
||||
|
||||
const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
|
||||
const [url, setUrl] = useState<string | null>(null);
|
||||
const [blob, setBlob] = useState<Blob | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const [error, setErrorState] = useState<Error | null>(null);
|
||||
const { setError } = useError();
|
||||
const { api_url } = useContext(DomainContext);
|
||||
|
||||
const getMp3 = (id: string) => {
|
||||
if (!id) throw new Error("Transcript ID is required to get transcript Mp3");
|
||||
if (!id) return;
|
||||
|
||||
setLoading(true);
|
||||
const requestParameters: V1TranscriptGetAudioMp3Request = {
|
||||
transcriptId: id,
|
||||
};
|
||||
api
|
||||
.v1TranscriptGetAudioMp3(requestParameters)
|
||||
.then((result) => {
|
||||
setUrl(result);
|
||||
setLoading(false);
|
||||
console.debug("Transcript Mp3 loaded:", result);
|
||||
// XXX Current API interface does not output a blob, we need to to is manually
|
||||
// const requestParameters: V1TranscriptGetAudioMp3Request = {
|
||||
// transcriptId: id,
|
||||
// };
|
||||
// api
|
||||
// .v1TranscriptGetAudioMp3(requestParameters)
|
||||
// .then((result) => {
|
||||
// setUrl(result);
|
||||
// setLoading(false);
|
||||
// console.debug("Transcript Mp3 loaded:", result);
|
||||
// })
|
||||
// .catch((err) => {
|
||||
// setError(err);
|
||||
// setErrorState(err);
|
||||
// });
|
||||
const localUrl = `${api_url}/v1/transcripts/${id}/audio/mp3`;
|
||||
if (localUrl == url) return;
|
||||
const headers = new Headers();
|
||||
|
||||
if (api.configuration.configuration.accessToken) {
|
||||
headers.set("Authorization", api.configuration.configuration.accessToken);
|
||||
}
|
||||
|
||||
fetch(localUrl, {
|
||||
method: "GET",
|
||||
headers,
|
||||
})
|
||||
.then((response) => {
|
||||
setUrl(localUrl);
|
||||
response.blob().then((blob) => {
|
||||
setBlob(blob);
|
||||
setLoading(false);
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err);
|
||||
@@ -42,7 +70,7 @@ const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
|
||||
getMp3(id);
|
||||
}, [id]);
|
||||
|
||||
return { url, loading, error };
|
||||
return { url, blob, loading, error };
|
||||
};
|
||||
|
||||
export default useMp3;
|
||||
|
||||
@@ -58,6 +58,39 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
title: "Topic 1: Introduction to Quantum Mechanics",
|
||||
transcript:
|
||||
"A brief overview of quantum mechanics and its principles.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
,
|
||||
{
|
||||
speaker: 3,
|
||||
start: 90,
|
||||
text: "This is the third speaker",
|
||||
},
|
||||
{
|
||||
speaker: 4,
|
||||
start: 90,
|
||||
text: "This is the fourth speaker",
|
||||
},
|
||||
{
|
||||
speaker: 5,
|
||||
start: 123,
|
||||
text: "This is the fifth speaker",
|
||||
},
|
||||
{
|
||||
speaker: 6,
|
||||
start: 300,
|
||||
text: "This is the sixth speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "2",
|
||||
@@ -66,6 +99,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
title: "Topic 2: Machine Learning Algorithms",
|
||||
transcript:
|
||||
"Understanding the different types of machine learning algorithms.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "3",
|
||||
@@ -73,6 +118,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
summary: "This is test topic 3",
|
||||
title: "Topic 3: Mental Health Awareness",
|
||||
transcript: "Ways to improve mental health and reduce stigma.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "4",
|
||||
@@ -80,6 +137,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
summary: "This is test topic 4",
|
||||
title: "Topic 4: Basics of Productivity",
|
||||
transcript: "Tips and tricks to increase daily productivity.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "5",
|
||||
@@ -88,6 +157,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
title: "Topic 5: Future of Aviation",
|
||||
transcript:
|
||||
"Exploring the advancements and possibilities in aviation.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
@@ -106,6 +187,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
"Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.",
|
||||
transcript:
|
||||
"A brief overview of quantum mechanics and its principles.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "2",
|
||||
@@ -115,6 +208,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
"Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.",
|
||||
transcript:
|
||||
"Understanding the different types of machine learning algorithms.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "3",
|
||||
@@ -123,6 +228,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
title:
|
||||
"Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.",
|
||||
transcript: "Ways to improve mental health and reduce stigma.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "4",
|
||||
@@ -131,6 +248,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
title:
|
||||
"Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.",
|
||||
transcript: "Tips and tricks to increase daily productivity.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "5",
|
||||
@@ -140,6 +269,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
"Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.",
|
||||
transcript:
|
||||
"Exploring the advancements and possibilities in aviation.",
|
||||
segments: [
|
||||
{
|
||||
speaker: 1,
|
||||
start: 0,
|
||||
text: "This is the transcription of an example title",
|
||||
},
|
||||
{
|
||||
speaker: 2,
|
||||
start: 10,
|
||||
text: "This is the second speaker",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
@@ -173,7 +314,17 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
break;
|
||||
|
||||
case "TOPIC":
|
||||
setTopics((prevTopics) => [...prevTopics, message.data]);
|
||||
setTopics((prevTopics) => {
|
||||
const topic = message.data as Topic;
|
||||
const index = prevTopics.findIndex(
|
||||
(prevTopic) => prevTopic.id === topic.id,
|
||||
);
|
||||
if (index >= 0) {
|
||||
prevTopics[index] = topic;
|
||||
return prevTopics;
|
||||
}
|
||||
return [...prevTopics, topic];
|
||||
});
|
||||
console.debug("TOPIC event:", message.data);
|
||||
break;
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
export type Topic = {
|
||||
timestamp: number;
|
||||
title: string;
|
||||
transcript: string;
|
||||
summary: string;
|
||||
id: string;
|
||||
};
|
||||
import { GetTranscriptTopic } from "../../api";
|
||||
|
||||
export type Topic = GetTranscriptTopic;
|
||||
|
||||
export type Transcript = {
|
||||
text: string;
|
||||
|
||||
@@ -5,10 +5,11 @@ models/AudioWaveform.ts
|
||||
models/CreateTranscript.ts
|
||||
models/DeletionStatus.ts
|
||||
models/GetTranscript.ts
|
||||
models/GetTranscriptSegmentTopic.ts
|
||||
models/GetTranscriptTopic.ts
|
||||
models/HTTPValidationError.ts
|
||||
models/PageGetTranscript.ts
|
||||
models/RtcOffer.ts
|
||||
models/TranscriptTopic.ts
|
||||
models/UpdateTranscript.ts
|
||||
models/UserInfo.ts
|
||||
models/ValidationError.ts
|
||||
|
||||
@@ -42,10 +42,6 @@ import {
|
||||
UpdateTranscriptToJSON,
|
||||
} from "../models";
|
||||
|
||||
export interface RtcOfferRequest {
|
||||
rtcOffer: RtcOffer;
|
||||
}
|
||||
|
||||
export interface V1TranscriptDeleteRequest {
|
||||
transcriptId: any;
|
||||
}
|
||||
@@ -56,6 +52,7 @@ export interface V1TranscriptGetRequest {
|
||||
|
||||
export interface V1TranscriptGetAudioMp3Request {
|
||||
transcriptId: any;
|
||||
token?: any;
|
||||
}
|
||||
|
||||
export interface V1TranscriptGetAudioWaveformRequest {
|
||||
@@ -132,58 +129,6 @@ export class DefaultApi extends runtime.BaseAPI {
|
||||
return await response.value();
|
||||
}
|
||||
|
||||
/**
|
||||
* Rtc Offer
|
||||
*/
|
||||
async rtcOfferRaw(
|
||||
requestParameters: RtcOfferRequest,
|
||||
initOverrides?: RequestInit | runtime.InitOverrideFunction,
|
||||
): Promise<runtime.ApiResponse<any>> {
|
||||
if (
|
||||
requestParameters.rtcOffer === null ||
|
||||
requestParameters.rtcOffer === undefined
|
||||
) {
|
||||
throw new runtime.RequiredError(
|
||||
"rtcOffer",
|
||||
"Required parameter requestParameters.rtcOffer was null or undefined when calling rtcOffer.",
|
||||
);
|
||||
}
|
||||
|
||||
const queryParameters: any = {};
|
||||
|
||||
const headerParameters: runtime.HTTPHeaders = {};
|
||||
|
||||
headerParameters["Content-Type"] = "application/json";
|
||||
|
||||
const response = await this.request(
|
||||
{
|
||||
path: `/offer`,
|
||||
method: "POST",
|
||||
headers: headerParameters,
|
||||
query: queryParameters,
|
||||
body: RtcOfferToJSON(requestParameters.rtcOffer),
|
||||
},
|
||||
initOverrides,
|
||||
);
|
||||
|
||||
if (this.isJsonMime(response.headers.get("content-type"))) {
|
||||
return new runtime.JSONApiResponse<any>(response);
|
||||
} else {
|
||||
return new runtime.TextApiResponse(response) as any;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rtc Offer
|
||||
*/
|
||||
async rtcOffer(
|
||||
requestParameters: RtcOfferRequest,
|
||||
initOverrides?: RequestInit | runtime.InitOverrideFunction,
|
||||
): Promise<any> {
|
||||
const response = await this.rtcOfferRaw(requestParameters, initOverrides);
|
||||
return await response.value();
|
||||
}
|
||||
|
||||
/**
|
||||
* Transcript Delete
|
||||
*/
|
||||
@@ -325,6 +270,10 @@ export class DefaultApi extends runtime.BaseAPI {
|
||||
|
||||
const queryParameters: any = {};
|
||||
|
||||
if (requestParameters.token !== undefined) {
|
||||
queryParameters["token"] = requestParameters.token;
|
||||
}
|
||||
|
||||
const headerParameters: runtime.HTTPHeaders = {};
|
||||
|
||||
if (this.configuration && this.configuration.accessToken) {
|
||||
|
||||
88
www/app/api/models/GetTranscriptSegmentTopic.ts
Normal file
88
www/app/api/models/GetTranscriptSegmentTopic.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
/**
|
||||
* FastAPI
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 0.1.0
|
||||
*
|
||||
*
|
||||
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
|
||||
* https://openapi-generator.tech
|
||||
* Do not edit the class manually.
|
||||
*/
|
||||
|
||||
import { exists, mapValues } from "../runtime";
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
* @interface GetTranscriptSegmentTopic
|
||||
*/
|
||||
export interface GetTranscriptSegmentTopic {
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptSegmentTopic
|
||||
*/
|
||||
text: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptSegmentTopic
|
||||
*/
|
||||
start: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptSegmentTopic
|
||||
*/
|
||||
speaker: any | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a given object implements the GetTranscriptSegmentTopic interface.
|
||||
*/
|
||||
export function instanceOfGetTranscriptSegmentTopic(value: object): boolean {
|
||||
let isInstance = true;
|
||||
isInstance = isInstance && "text" in value;
|
||||
isInstance = isInstance && "start" in value;
|
||||
isInstance = isInstance && "speaker" in value;
|
||||
|
||||
return isInstance;
|
||||
}
|
||||
|
||||
export function GetTranscriptSegmentTopicFromJSON(
|
||||
json: any,
|
||||
): GetTranscriptSegmentTopic {
|
||||
return GetTranscriptSegmentTopicFromJSONTyped(json, false);
|
||||
}
|
||||
|
||||
export function GetTranscriptSegmentTopicFromJSONTyped(
|
||||
json: any,
|
||||
ignoreDiscriminator: boolean,
|
||||
): GetTranscriptSegmentTopic {
|
||||
if (json === undefined || json === null) {
|
||||
return json;
|
||||
}
|
||||
return {
|
||||
text: json["text"],
|
||||
start: json["start"],
|
||||
speaker: json["speaker"],
|
||||
};
|
||||
}
|
||||
|
||||
export function GetTranscriptSegmentTopicToJSON(
|
||||
value?: GetTranscriptSegmentTopic | null,
|
||||
): any {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
text: value.text,
|
||||
start: value.start,
|
||||
speaker: value.speaker,
|
||||
};
|
||||
}
|
||||
112
www/app/api/models/GetTranscriptTopic.ts
Normal file
112
www/app/api/models/GetTranscriptTopic.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
/**
|
||||
* FastAPI
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 0.1.0
|
||||
*
|
||||
*
|
||||
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
|
||||
* https://openapi-generator.tech
|
||||
* Do not edit the class manually.
|
||||
*/
|
||||
|
||||
import { exists, mapValues } from "../runtime";
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
* @interface GetTranscriptTopic
|
||||
*/
|
||||
export interface GetTranscriptTopic {
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
id: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
title: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
summary: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
timestamp: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
transcript: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof GetTranscriptTopic
|
||||
*/
|
||||
segments?: any | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a given object implements the GetTranscriptTopic interface.
|
||||
*/
|
||||
export function instanceOfGetTranscriptTopic(value: object): boolean {
|
||||
let isInstance = true;
|
||||
isInstance = isInstance && "id" in value;
|
||||
isInstance = isInstance && "title" in value;
|
||||
isInstance = isInstance && "summary" in value;
|
||||
isInstance = isInstance && "timestamp" in value;
|
||||
isInstance = isInstance && "transcript" in value;
|
||||
|
||||
return isInstance;
|
||||
}
|
||||
|
||||
export function GetTranscriptTopicFromJSON(json: any): GetTranscriptTopic {
|
||||
return GetTranscriptTopicFromJSONTyped(json, false);
|
||||
}
|
||||
|
||||
export function GetTranscriptTopicFromJSONTyped(
|
||||
json: any,
|
||||
ignoreDiscriminator: boolean,
|
||||
): GetTranscriptTopic {
|
||||
if (json === undefined || json === null) {
|
||||
return json;
|
||||
}
|
||||
return {
|
||||
id: json["id"],
|
||||
title: json["title"],
|
||||
summary: json["summary"],
|
||||
timestamp: json["timestamp"],
|
||||
transcript: json["transcript"],
|
||||
segments: !exists(json, "segments") ? undefined : json["segments"],
|
||||
};
|
||||
}
|
||||
|
||||
export function GetTranscriptTopicToJSON(
|
||||
value?: GetTranscriptTopic | null,
|
||||
): any {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
id: value.id,
|
||||
title: value.title,
|
||||
summary: value.summary,
|
||||
timestamp: value.timestamp,
|
||||
transcript: value.transcript,
|
||||
segments: value.segments,
|
||||
};
|
||||
}
|
||||
88
www/app/api/models/TranscriptSegmentTopic.ts
Normal file
88
www/app/api/models/TranscriptSegmentTopic.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
/**
|
||||
* FastAPI
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 0.1.0
|
||||
*
|
||||
*
|
||||
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
|
||||
* https://openapi-generator.tech
|
||||
* Do not edit the class manually.
|
||||
*/
|
||||
|
||||
import { exists, mapValues } from "../runtime";
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
* @interface TranscriptSegmentTopic
|
||||
*/
|
||||
export interface TranscriptSegmentTopic {
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof TranscriptSegmentTopic
|
||||
*/
|
||||
speaker: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof TranscriptSegmentTopic
|
||||
*/
|
||||
text: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof TranscriptSegmentTopic
|
||||
*/
|
||||
timestamp: any | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a given object implements the TranscriptSegmentTopic interface.
|
||||
*/
|
||||
export function instanceOfTranscriptSegmentTopic(value: object): boolean {
|
||||
let isInstance = true;
|
||||
isInstance = isInstance && "speaker" in value;
|
||||
isInstance = isInstance && "text" in value;
|
||||
isInstance = isInstance && "timestamp" in value;
|
||||
|
||||
return isInstance;
|
||||
}
|
||||
|
||||
export function TranscriptSegmentTopicFromJSON(
|
||||
json: any,
|
||||
): TranscriptSegmentTopic {
|
||||
return TranscriptSegmentTopicFromJSONTyped(json, false);
|
||||
}
|
||||
|
||||
export function TranscriptSegmentTopicFromJSONTyped(
|
||||
json: any,
|
||||
ignoreDiscriminator: boolean,
|
||||
): TranscriptSegmentTopic {
|
||||
if (json === undefined || json === null) {
|
||||
return json;
|
||||
}
|
||||
return {
|
||||
speaker: json["speaker"],
|
||||
text: json["text"],
|
||||
timestamp: json["timestamp"],
|
||||
};
|
||||
}
|
||||
|
||||
export function TranscriptSegmentTopicToJSON(
|
||||
value?: TranscriptSegmentTopic | null,
|
||||
): any {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
speaker: value.speaker,
|
||||
text: value.text,
|
||||
timestamp: value.timestamp,
|
||||
};
|
||||
}
|
||||
@@ -42,13 +42,13 @@ export interface TranscriptTopic {
|
||||
* @type {any}
|
||||
* @memberof TranscriptTopic
|
||||
*/
|
||||
transcript?: any | null;
|
||||
timestamp: any | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof TranscriptTopic
|
||||
*/
|
||||
timestamp: any | null;
|
||||
segments?: any | null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -78,8 +78,8 @@ export function TranscriptTopicFromJSONTyped(
|
||||
id: !exists(json, "id") ? undefined : json["id"],
|
||||
title: json["title"],
|
||||
summary: json["summary"],
|
||||
transcript: !exists(json, "transcript") ? undefined : json["transcript"],
|
||||
timestamp: json["timestamp"],
|
||||
segments: !exists(json, "segments") ? undefined : json["segments"],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ export function TranscriptTopicToJSON(value?: TranscriptTopic | null): any {
|
||||
id: value.id,
|
||||
title: value.title,
|
||||
summary: value.summary,
|
||||
transcript: value.transcript,
|
||||
timestamp: value.timestamp,
|
||||
segments: value.segments,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,10 +4,11 @@ export * from "./AudioWaveform";
|
||||
export * from "./CreateTranscript";
|
||||
export * from "./DeletionStatus";
|
||||
export * from "./GetTranscript";
|
||||
export * from "./GetTranscriptSegmentTopic";
|
||||
export * from "./GetTranscriptTopic";
|
||||
export * from "./HTTPValidationError";
|
||||
export * from "./PageGetTranscript";
|
||||
export * from "./RtcOffer";
|
||||
export * from "./TranscriptTopic";
|
||||
export * from "./UpdateTranscript";
|
||||
export * from "./UserInfo";
|
||||
export * from "./ValidationError";
|
||||
|
||||
@@ -1,3 +1,123 @@
|
||||
export function isDevelopment() {
|
||||
return process.env.NEXT_PUBLIC_ENV === "development";
|
||||
}
|
||||
|
||||
// Function to calculate WCAG contrast ratio
|
||||
export const getContrastRatio = (
|
||||
foreground: [number, number, number],
|
||||
background: [number, number, number],
|
||||
) => {
|
||||
const [r1, g1, b1] = foreground;
|
||||
const [r2, g2, b2] = background;
|
||||
|
||||
const lum1 =
|
||||
0.2126 * Math.pow(r1 / 255, 2.2) +
|
||||
0.7152 * Math.pow(g1 / 255, 2.2) +
|
||||
0.0722 * Math.pow(b1 / 255, 2.2);
|
||||
const lum2 =
|
||||
0.2126 * Math.pow(r2 / 255, 2.2) +
|
||||
0.7152 * Math.pow(g2 / 255, 2.2) +
|
||||
0.0722 * Math.pow(b2 / 255, 2.2);
|
||||
|
||||
return (Math.max(lum1, lum2) + 0.05) / (Math.min(lum1, lum2) + 0.05);
|
||||
};
|
||||
|
||||
// Function to hash string into 32-bit integer
|
||||
// 🔴 DO NOT USE FOR CRYPTOGRAPHY PURPOSES 🔴
|
||||
|
||||
export function murmurhash3_32_gc(key: string, seed: number = 0) {
|
||||
let remainder, bytes, h1, h1b, c1, c2, k1, i;
|
||||
|
||||
remainder = key.length & 3; // key.length % 4
|
||||
bytes = key.length - remainder;
|
||||
h1 = seed;
|
||||
c1 = 0xcc9e2d51;
|
||||
c2 = 0x1b873593;
|
||||
i = 0;
|
||||
|
||||
while (i < bytes) {
|
||||
k1 =
|
||||
(key.charCodeAt(i) & 0xff) |
|
||||
((key.charCodeAt(++i) & 0xff) << 8) |
|
||||
((key.charCodeAt(++i) & 0xff) << 16) |
|
||||
((key.charCodeAt(++i) & 0xff) << 24);
|
||||
|
||||
++i;
|
||||
|
||||
k1 =
|
||||
((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & 0xffffffff;
|
||||
k1 = (k1 << 15) | (k1 >>> 17);
|
||||
k1 =
|
||||
((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & 0xffffffff;
|
||||
|
||||
h1 ^= k1;
|
||||
h1 = (h1 << 13) | (h1 >>> 19);
|
||||
h1b =
|
||||
((h1 & 0xffff) * 5 + ((((h1 >>> 16) * 5) & 0xffff) << 16)) & 0xffffffff;
|
||||
h1 = (h1b & 0xffff) + 0x6b64 + ((((h1b >>> 16) + 0xe654) & 0xffff) << 16);
|
||||
}
|
||||
|
||||
k1 = 0;
|
||||
|
||||
switch (remainder) {
|
||||
case 3:
|
||||
k1 ^= (key.charCodeAt(i + 2) & 0xff) << 16;
|
||||
case 2:
|
||||
k1 ^= (key.charCodeAt(i + 1) & 0xff) << 8;
|
||||
case 1:
|
||||
k1 ^= key.charCodeAt(i) & 0xff;
|
||||
|
||||
k1 =
|
||||
((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) &
|
||||
0xffffffff;
|
||||
k1 = (k1 << 15) | (k1 >>> 17);
|
||||
k1 =
|
||||
((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) &
|
||||
0xffffffff;
|
||||
h1 ^= k1;
|
||||
}
|
||||
|
||||
h1 ^= key.length;
|
||||
|
||||
h1 ^= h1 >>> 16;
|
||||
h1 =
|
||||
((h1 & 0xffff) * 0x85ebca6b +
|
||||
((((h1 >>> 16) * 0x85ebca6b) & 0xffff) << 16)) &
|
||||
0xffffffff;
|
||||
h1 ^= h1 >>> 13;
|
||||
h1 =
|
||||
((h1 & 0xffff) * 0xc2b2ae35 +
|
||||
((((h1 >>> 16) * 0xc2b2ae35) & 0xffff) << 16)) &
|
||||
0xffffffff;
|
||||
h1 ^= h1 >>> 16;
|
||||
|
||||
return h1 >>> 0;
|
||||
}
|
||||
|
||||
// Generates a color that is guaranteed to have high contrast with the given background color (optional)
|
||||
|
||||
export const generateHighContrastColor = (
|
||||
name: string,
|
||||
backgroundColor: [number, number, number] | null = null,
|
||||
) => {
|
||||
const hash = murmurhash3_32_gc(name);
|
||||
let red = (hash & 0xff0000) >> 16;
|
||||
let green = (hash & 0x00ff00) >> 8;
|
||||
let blue = hash & 0x0000ff;
|
||||
|
||||
const getCssColor = (red: number, green: number, blue: number) =>
|
||||
`rgb(${red}, ${green}, ${blue})`;
|
||||
|
||||
if (!backgroundColor) return getCssColor(red, green, blue);
|
||||
|
||||
const contrast = getContrastRatio([red, green, blue], backgroundColor);
|
||||
|
||||
// Adjust the color to achieve better contrast if necessary (WCAG recommends at least 4.5:1 for text)
|
||||
if (contrast < 4.5) {
|
||||
red = Math.abs(255 - red);
|
||||
green = Math.abs(255 - green);
|
||||
blue = Math.abs(255 - blue);
|
||||
}
|
||||
|
||||
return getCssColor(red, green, blue);
|
||||
};
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
"supports-color": "^9.4.0",
|
||||
"tailwindcss": "^3.3.2",
|
||||
"typescript": "^5.1.6",
|
||||
"wavesurfer.js": "^7.0.3"
|
||||
"wavesurfer.js": "^7.4.2"
|
||||
},
|
||||
"main": "index.js",
|
||||
"repository": "https://github.com/Monadical-SAS/reflector-ui.git",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { NextPage } from "next";
|
||||
|
||||
const Forbidden: NextPage = () => {
|
||||
return <h2>Sorry, you are not authorized to access this page.</h2>;
|
||||
return <h2>Sorry, you are not authorized to access this page</h2>;
|
||||
};
|
||||
|
||||
export default Forbidden;
|
||||
|
||||
@@ -2638,10 +2638,10 @@ watchpack@2.4.0:
|
||||
glob-to-regexp "^0.4.1"
|
||||
graceful-fs "^4.1.2"
|
||||
|
||||
wavesurfer.js@^7.0.3:
|
||||
version "7.0.3"
|
||||
resolved "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.0.3.tgz"
|
||||
integrity sha512-gJ3P+Bd3Q4E8qETjjg0pneaVqm2J7jegG2Cc6vqEF5YDDKQ3m8sKsvVfgVhJkacKkO9jFAGDu58Hw4zLr7xD0A==
|
||||
wavesurfer.js@^7.4.2:
|
||||
version "7.4.2"
|
||||
resolved "https://registry.yarnpkg.com/wavesurfer.js/-/wavesurfer.js-7.4.2.tgz#59f5c87193d4eeeb199858688ddac1ad7ba86b3a"
|
||||
integrity sha512-4pNQ1porOCUBYBmd2F1TqVuBnB2wBPipaw2qI920zYLuPnada0Rd1CURgh8HRuPGKxijj2iyZDFN2UZwsaEuhA==
|
||||
|
||||
wcwidth@>=1.0.1, wcwidth@^1.0.1:
|
||||
version "1.0.1"
|
||||
|
||||
Reference in New Issue
Block a user