Merge branch 'fix-mp3-download-while-authenticated' into sara/fix-api-auth

This commit is contained in:
Sara
2023-11-03 12:10:48 +01:00
54 changed files with 2756 additions and 763 deletions

View File

@@ -11,6 +11,11 @@ on:
jobs:
pytest:
runs-on: ubuntu-latest
services:
redis:
image: redis:6
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- name: Install poetry

View File

@@ -6,7 +6,7 @@ The project architecture consists of three primary components:
* **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
* **Back-End**: Python server that offers an API and data persistence, found in `server/`.
* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end.
@@ -120,6 +120,9 @@ TRANSCRIPT_MODAL_API_KEY=<omitted>
LLM_BACKEND=modal
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=<omitted>
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
AUTH_BACKEND=fief
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
@@ -138,6 +141,10 @@ Use:
poetry run python3 -m reflector.app
```
And start the background worker
celery -A reflector.worker.app worker --loglevel=info
#### Using docker
Use:
@@ -161,4 +168,5 @@ poetry run python -m reflector.tools.process path/to/audio.wav
## AI Models
*(Documentation for this section is pending.)*
*(Documentation for this section is pending.)*

View File

@@ -5,10 +5,19 @@ services:
context: server
ports:
- 1250:1250
environment:
LLM_URL: "${LLM_URL}"
volumes:
- model-cache:/root/.cache
environment: ENTRYPOINT=server
worker:
build:
context: server
volumes:
- model-cache:/root/.cache
environment: ENTRYPOINT=worker
redis:
image: redis:7.2
ports:
- 6379:6379
web:
build:
context: www
@@ -17,4 +26,3 @@ services:
volumes:
model-cache:

View File

@@ -5,11 +5,23 @@ services:
context: .
ports:
- 1250:1250
environment:
LLM_URL: "${LLM_URL}"
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
volumes:
- model-cache:/root/.cache
environment:
ENTRYPOINT: server
REDIS_HOST: redis
worker:
build:
context: .
volumes:
- model-cache:/root/.cache
environment:
ENTRYPOINT: worker
REDIS_HOST: redis
redis:
image: redis:7.2
ports:
- 6379:6379
volumes:
model-cache:

View File

@@ -0,0 +1,80 @@
"""rename back text to transcript
Revision ID: 38a927dcb099
Revises: 9920ecfe2735
Create Date: 2023-11-02 19:53:09.116240
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = '38a927dcb099'
down_revision: Union[str, None] = '9920ecfe2735'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "text" in topic:
# Rename key 'text' back to 'transcript'
topic["transcript"] = topic.pop("text")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)
def downgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "transcript" in topic:
# Rename key 'transcript' to 'text'
topic["text"] = topic.pop("transcript")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)

View File

@@ -0,0 +1,80 @@
"""Migration transcript to text field in transcripts table
Revision ID: 9920ecfe2735
Revises: 99365b0cd87b
Create Date: 2023-11-02 18:55:17.019498
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = "9920ecfe2735"
down_revision: Union[str, None] = "99365b0cd87b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "transcript" in topic:
# Rename key 'transcript' to 'text'
topic["text"] = topic.pop("transcript")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)
def downgrade() -> None:
# bind the engine
bind = op.get_bind()
# Reflect the table
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
topics_json = row["topics"]
# Process each topic in the topics JSON array
updated_topics = []
for topic in topics_json:
if "text" in topic:
# Rename key 'text' back to 'transcript'
topic["transcript"] = topic.pop("text")
updated_topics.append(topic)
# Update the transcript table
bind.execute(
transcript.update()
.where(transcript.c.id == transcript_id)
.values(topics=updated_topics)
)

307
server/poetry.lock generated
View File

@@ -308,6 +308,20 @@ typing-extensions = ">=4"
[package.extras]
tz = ["python-dateutil"]
[[package]]
name = "amqp"
version = "5.1.1"
description = "Low-level AMQP client for Python (fork of amqplib)."
optional = false
python-versions = ">=3.6"
files = [
{file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
{file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
]
[package.dependencies]
vine = ">=5.0.0"
[[package]]
name = "annotated-types"
version = "0.6.0"
@@ -474,6 +488,17 @@ files = [
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
]
[[package]]
name = "billiard"
version = "4.1.0"
description = "Python multiprocessing fork with improvements and bugfixes"
optional = false
python-versions = ">=3.7"
files = [
{file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"},
{file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"},
]
[[package]]
name = "black"
version = "23.9.1"
@@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27"
[package.extras]
crt = ["awscrt (==0.16.26)"]
[[package]]
name = "celery"
version = "5.3.4"
description = "Distributed Task Queue."
optional = false
python-versions = ">=3.8"
files = [
{file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"},
{file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"},
]
[package.dependencies]
billiard = ">=4.1.0,<5.0"
click = ">=8.1.2,<9.0"
click-didyoumean = ">=0.3.0"
click-plugins = ">=1.1.1"
click-repl = ">=0.2.0"
kombu = ">=5.3.2,<6.0"
python-dateutil = ">=2.8.2"
tzdata = ">=2022.7"
vine = ">=5.0.0,<6.0"
[package.extras]
arangodb = ["pyArango (>=2.0.2)"]
auth = ["cryptography (==41.0.3)"]
azureblockblob = ["azure-storage-blob (>=12.15.0)"]
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
cassandra = ["cassandra-driver (>=3.25.0,<4)"]
consul = ["python-consul2 (==0.1.5)"]
cosmosdbsql = ["pydocumentdb (==2.3.5)"]
couchbase = ["couchbase (>=3.0.0)"]
couchdb = ["pycouchdb (==1.14.2)"]
django = ["Django (>=2.2.28)"]
dynamodb = ["boto3 (>=1.26.143)"]
elasticsearch = ["elasticsearch (<8.0)"]
eventlet = ["eventlet (>=0.32.0)"]
gevent = ["gevent (>=1.5.0)"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
memcache = ["pylibmc (==1.6.3)"]
mongodb = ["pymongo[srv] (>=4.0.2)"]
msgpack = ["msgpack (==1.0.5)"]
pymemcache = ["python-memcached (==1.59)"]
pyro = ["pyro4 (==4.82)"]
pytest = ["pytest-celery (==0.0.0)"]
redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"]
s3 = ["boto3 (>=1.26.143)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
solar = ["ephem (==4.1.4)"]
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
zstd = ["zstandard (==0.21.0)"]
[[package]]
name = "certifi"
version = "2023.7.22"
@@ -744,6 +824,55 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "click-didyoumean"
version = "0.3.0"
description = "Enables git-like *did-you-mean* feature in click"
optional = false
python-versions = ">=3.6.2,<4.0.0"
files = [
{file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
{file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
]
[package.dependencies]
click = ">=7"
[[package]]
name = "click-plugins"
version = "1.1.1"
description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
optional = false
python-versions = "*"
files = [
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
]
[package.dependencies]
click = ">=4.0"
[package.extras]
dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
[[package]]
name = "click-repl"
version = "0.3.0"
description = "REPL plugin for Click"
optional = false
python-versions = ">=3.6"
files = [
{file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"},
{file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"},
]
[package.dependencies]
click = ">=7.0"
prompt-toolkit = ">=3.0.36"
[package.extras]
testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"]
[[package]]
name = "colorama"
version = "0.4.6"
@@ -981,6 +1110,24 @@ idna = ["idna (>=2.1,<4.0)"]
trio = ["trio (>=0.14,<0.23)"]
wmi = ["wmi (>=1.5.1,<2.0.0)"]
[[package]]
name = "ecdsa"
version = "0.18.0"
description = "ECDSA cryptographic signature library (pure python)"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
files = [
{file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"},
{file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"},
]
[package.dependencies]
six = ">=1.9.0"
[package.extras]
gmpy = ["gmpy"]
gmpy2 = ["gmpy2"]
[[package]]
name = "fastapi"
version = "0.100.1"
@@ -1624,6 +1771,38 @@ files = [
cryptography = ">=3.4"
deprecated = "*"
[[package]]
name = "kombu"
version = "5.3.2"
description = "Messaging library for Python."
optional = false
python-versions = ">=3.8"
files = [
{file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"},
{file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"},
]
[package.dependencies]
amqp = ">=5.1.1,<6.0.0"
vine = "*"
[package.extras]
azureservicebus = ["azure-servicebus (>=7.10.0)"]
azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"]
confluentkafka = ["confluent-kafka (==2.1.1)"]
consul = ["python-consul2"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
mongodb = ["pymongo (>=4.1.1)"]
msgpack = ["msgpack"]
pyro = ["pyro4"]
qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
redis = ["redis (>=4.5.2)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=2.8.0)"]
[[package]]
name = "levenshtein"
version = "0.21.1"
@@ -2151,6 +2330,20 @@ files = [
fastapi = ">=0.38.1,<1.0.0"
prometheus-client = ">=0.8.0,<1.0.0"
[[package]]
name = "prompt-toolkit"
version = "3.0.39"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
]
[package.dependencies]
wcwidth = "*"
[[package]]
name = "protobuf"
version = "4.24.4"
@@ -2173,6 +2366,17 @@ files = [
{file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"},
]
[[package]]
name = "pyasn1"
version = "0.5.0"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
{file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
{file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
]
[[package]]
name = "pycparser"
version = "2.21"
@@ -2501,6 +2705,20 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-celery"
version = "0.0.0"
description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest"
optional = false
python-versions = "*"
files = [
{file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"},
{file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"},
]
[package.dependencies]
celery = ">=4.4.0"
[[package]]
name = "pytest-cov"
version = "4.1.0"
@@ -2565,6 +2783,28 @@ files = [
[package.extras]
cli = ["click (>=5.0)"]
[[package]]
name = "python-jose"
version = "3.3.0"
description = "JOSE implementation in Python"
optional = false
python-versions = "*"
files = [
{file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"},
{file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"},
]
[package.dependencies]
cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""}
ecdsa = "!=0.15"
pyasn1 = "*"
rsa = "*"
[package.extras]
cryptography = ["cryptography (>=3.4.0)"]
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
[[package]]
name = "pyyaml"
version = "6.0.1"
@@ -2744,6 +2984,24 @@ files = [
[package.extras]
full = ["numpy"]
[[package]]
name = "redis"
version = "5.0.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.7"
files = [
{file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
{file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
[package.extras]
hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]]
name = "regex"
version = "2023.10.3"
@@ -2862,6 +3120,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = false
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]]
name = "s3transfer"
version = "0.6.2"
@@ -3438,6 +3710,17 @@ files = [
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
]
[[package]]
name = "tzdata"
version = "2023.3"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
{file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
]
[[package]]
name = "urllib3"
version = "1.26.17"
@@ -3523,6 +3806,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak
docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"]
[[package]]
name = "vine"
version = "5.0.0"
description = "Promises, promises, promises."
optional = false
python-versions = ">=3.6"
files = [
{file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
{file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
]
[[package]]
name = "watchfiles"
version = "0.20.0"
@@ -3557,6 +3851,17 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "wcwidth"
version = "0.2.8"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
{file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
]
[[package]]
name = "websockets"
version = "11.0.3"
@@ -3838,4 +4143,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67"
content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe"

View File

@@ -33,6 +33,9 @@ prometheus-fastapi-instrumentator = "^6.1.0"
sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
redis = "^5.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
[tool.poetry.group.dev.dependencies]
@@ -47,6 +50,7 @@ pytest-asyncio = "^0.21.1"
pytest = "^7.4.0"
httpx-ws = "^0.4.1"
pytest-httpx = "^0.23.1"
pytest-celery = "^0.0.0"
[tool.poetry.group.aws.dependencies]

View File

@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
# prepare celery
from reflector.worker import app as celery_app # noqa
# simpler openapi id
def use_route_names_as_operation_ids(app: FastAPI) -> None:

View File

@@ -1,32 +1,13 @@
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
# import models
import reflector.db.transcripts # noqa
engine = sqlalchemy.create_engine(
settings.DATABASE_URL, connect_args={"check_same_thread": False}

View File

@@ -0,0 +1,296 @@
import json
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import sqlalchemy
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
def generate_uuid4():
return str(uuid4())
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
timestamp: float
duration: float | None = 0
transcript: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
if index is not None:
self.topics[index] = topic
else:
self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(
self,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
) -> list[Transcript]:
"""
Get all transcripts
If `user_id` is specified, only return transcripts that belong to the user.
Otherwise, return all anonymous transcripts.
Parameters:
- `order_by`: field to order by, e.g. "-created_at"
- `filter_empty`: filter out empty transcripts
- `filter_recording`: filter out transcripts that are currently recording
"""
query = transcripts.select().where(transcripts.c.user_id == user_id)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
"""
Get a transcript by id
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
"""
Add a new transcript
"""
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict):
"""
Update a transcript fields with key/values in values
"""
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with database.transaction(isolation="serializable"):
yield
async def append_event(
self,
transcript: Transcript,
event: str,
data: Any,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
self,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
transcripts_controller = TranscriptController()

View File

@@ -47,6 +47,7 @@ class ModalLLM(LLM):
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
follow_redirects=True,
)
response.raise_for_status()
text = response.json()["text"]

View File

@@ -0,0 +1,362 @@
"""
Main reflector pipeline for live streaming
==========================================
This is the default pipeline used in the API.
It is decoupled to:
- PipelineMainLive: have limited processing during live
- PipelineMainPost: do heavy lifting after the live
It is directly linked to our data model.
"""
import asyncio
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task
from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationAutoProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
BroadcastProcessor,
Pipeline,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
)
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
def broadcast_to_sockets(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
"""
async def wrapper(self, *args, **kwargs):
resp = await func(self, *args, **kwargs)
if resp is None:
return
await self.ws_manager.send_json(
room_id=self.ws_room_id,
message=resp.model_dump(mode="json"),
)
return wrapper
class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
def prepare(self):
# prepare websocket
self._lock = asyncio.Lock()
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
async def get_transcript(self) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def transaction(self):
async with self._lock:
async with transcripts_controller.transaction():
yield
@broadcast_to_sockets
async def on_status(self, status):
# if it's the first part, update the status of the transcript
# but do not set the ended status yet.
if isinstance(self, PipelineMainLive):
status_mapping = {
"started": "recording",
"push": "recording",
"flush": "processing",
"error": "error",
}
elif isinstance(self, PipelineMainDiarization):
status_mapping = {
"push": "processing",
"flush": "processing",
"error": "error",
"ended": "ended",
}
else:
raise Exception(f"Runner {self.__class__} is missing status mapping")
# mutate to model status
status = status_mapping.get(status)
if not status:
return
# when the status of the pipeline changes, update the transcript
async with self.transaction():
transcript = await self.get_transcript()
if status == transcript.status:
return
resp = await transcripts_controller.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await transcripts_controller.update(
transcript,
{
"status": status,
},
)
return resp
@broadcast_to_sockets
async def on_transcript(self, data):
async with self.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
)
@broadcast_to_sockets
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
transcript=data.transcript.text,
words=data.transcript.words,
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
await transcripts_controller.update(
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_sockets
async def on_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main live created",
transcript_id=self.transcript_id,
)
return pipeline
async def on_ended(self):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
Diarization is a long time process, so we do it in a separate pipeline
When done, adjust the short and final summary
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
AudioDiarizationAutoProcessor(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=transcript.id,
)
url = f"{settings.BASE_URL}{path}"
if transcript.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input)
self.flush()
return pipeline
@shared_task
def task_pipeline_main_post(transcript_id: str):
logger.info(
"Starting main post pipeline",
transcript_id=transcript_id,
)
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()

View File

@@ -0,0 +1,137 @@
"""
Pipeline Runner
===============
Pipeline runner designed to be executed in a asyncio task.
It is meant to be subclassed, and implement a create() method
that expose/return a Pipeline instance.
During its lifecycle, it will emit the following status:
- started: the pipeline has been started
- push: the pipeline received at least one data
- flush: the pipeline is flushing
- ended: the pipeline has ended
- error: the pipeline has ended with an error
"""
import asyncio
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event()
self._is_first_push = True
self._logger = logger.bind(
runner=id(self),
runner_cls=self.__class__.__name__,
)
def create(self) -> Pipeline:
"""
Create the pipeline if not specified earlier.
Should be implemented in a subclass
"""
raise NotImplementedError()
def start(self):
"""
Start the pipeline as a coroutine task
"""
asyncio.get_event_loop().create_task(self.run())
def start_sync(self):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
coro = self.run()
asyncio.run(coro)
def push(self, data):
"""
Push data to the pipeline
"""
self._add_cmd("PUSH", data)
def flush(self):
"""
Flush the pipeline
"""
self._add_cmd("FLUSH", None)
async def on_status(self, status):
"""
Called when the status of the pipeline changes
"""
pass
async def on_ended(self):
"""
Called when the pipeline ends
"""
pass
def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
self._q_cmd.put_nowait([cmd, data])
async def _set_status(self, status):
self._logger.debug("Runner status updated", status=status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
except Exception:
self._logger.exception("Runer error while setting status")
async def run(self):
try:
# create the pipeline if not yet done
await self._set_status("init")
self._is_first_push = True
if not self.pipeline:
self.pipeline = await self.create()
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception:
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
async def cmd_push(self, data):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()

View File

@@ -1,9 +1,16 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
from .base import ( # noqa: F401
BroadcastProcessor,
Pipeline,
PipelineEvent,
Processor,
ThreadedProcessor,
)
from .transcript_final_long_summary import ( # noqa: F401
TranscriptFinalLongSummaryProcessor,
)

View File

@@ -0,0 +1,34 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
try:
self.logger.info("Diarization started", audio_file_url=data.audio_url)
diarization = await self._diarize(data)
self.logger.info("Diarization finished")
except Exception:
self.logger.exception("Diarization failed after retrying")
raise
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["end"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError

View File

@@ -0,0 +1,33 @@
import importlib
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.settings import settings
class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,37 @@
import httpx
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings
class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
async def _diarize(self, data: AudioDiarizationInput):
# Gather diarization data
params = {
"audio_file_url": data.audio_url,
"timestamp": 0,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.diarization_url,
headers=self.headers,
params=params,
timeout=None,
follow_redirects=True,
)
response.raise_for_status()
return response.json()["text"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)

View File

@@ -1,6 +1,4 @@
from profanityfilter import ProfanityFilter
from prometheus_client import Counter, Histogram
from reflector.processors.base import Processor
from reflector.processors.types import AudioFile, Transcript
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name)
self.profanity_filter = ProfanityFilter()
self.profanity_filter.set_censor("*")
super().__init__(*args, **kwargs)
async def _push(self, data: AudioFile):
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
async def _transcript(self, data: AudioFile):
raise NotImplementedError
def filter_profanity(self, text: str) -> str:
"""
Remove censored words from the transcript
"""
return self.profanity_filter.censor(text)

View File

@@ -1,8 +1,6 @@
import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile
from reflector.settings import settings
@@ -13,8 +11,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name):
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name)
@@ -30,30 +29,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config)
def __init__(self, **kwargs):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor):
self.processor.connect(processor)
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def off(self, callback):
self.processor.off(callback)
async def _push(self, data: AudioFile):
return await self.processor._push(data)
async def _flush(self):
return await self.processor._flush()
return cls._registry[name](**config | kwargs)

View File

@@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
timeout=self.timeout,
headers=self.headers,
params=json_payload,
follow_redirects=True,
)
self.logger.debug(
@@ -48,10 +49,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
)
response.raise_for_status()
result = response.json()
text = result["text"][source_language]
text = self.filter_profanity(text)
transcript = Transcript(
text=text,
words=[
Word(
text=word["text"],

View File

@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
ts = data.timestamp
for segment in segments:
transcript.text += segment.text
for word in segment.words:
transcript.words.append(
Word(

View File

@@ -290,12 +290,12 @@ class BroadcastProcessor(Processor):
processor.set_pipeline(pipeline)
async def _push(self, data):
for processor in self.processors:
await processor.push(data)
coros = [processor.push(data) for processor in self.processors]
await asyncio.gather(*coros)
async def _flush(self):
for processor in self.processors:
await processor.flush()
coros = [processor.flush() for processor in self.processors]
await asyncio.gather(*coros)
def connect(self, processor: Processor):
for processor in self.processors:
@@ -333,6 +333,7 @@ class Pipeline(Processor):
self.logger.info("Pipeline created")
self.processors = processors
self.options = None
self.prefs = {}
for processor in processors:

View File

@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
# cut to the next .
partial = Transcript(words=[])
for word in self.transcript.words[:]:
partial.text += word.text
partial.words.append(word)
if not self.is_sentence_terminated(word.text):
continue

View File

@@ -50,6 +50,7 @@ class TranscriptTranslatorProcessor(Processor):
headers=self.headers,
params=json_payload,
timeout=self.timeout,
follow_redirects=True,
)
response.raise_for_status()
result = response.json()["text"]

View File

@@ -1,9 +1,16 @@
import io
import re
import tempfile
from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
PUNC_RE = re.compile(r"[.;:?!…]")
profanity_filter = ProfanityFilter()
profanity_filter.set_censor("*")
class AudioFile(BaseModel):
name: str
@@ -43,13 +50,29 @@ class Word(BaseModel):
text: str
start: float
end: float
speaker: int = 0
class TranscriptSegment(BaseModel):
text: str
start: float
speaker: int = 0
class Transcript(BaseModel):
text: str = ""
translation: str | None = None
words: list[Word] = None
@property
def raw_text(self):
# Uncensored text
return "".join([word.text for word in self.words])
@property
def text(self):
# Censored text
return profanity_filter.censor(self.raw_text).strip()
@property
def human_timestamp(self):
minutes = int(self.timestamp / 60)
@@ -74,7 +97,6 @@ class Transcript(BaseModel):
self.words = other.words
else:
self.words.extend(other.words)
self.text += other.text
def add_offset(self, offset: float):
for word in self.words:
@@ -87,6 +109,48 @@ class Transcript(BaseModel):
]
return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class TitleSummary(BaseModel):
title: str
@@ -103,6 +167,10 @@ class TitleSummary(BaseModel):
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
class TitleSummaryWithId(TitleSummary):
id: str
class FinalLongSummary(BaseModel):
long_summary: str
duration: float
@@ -318,3 +386,8 @@ class TranslationLanguages(BaseModel):
def is_supported(self, lang_id: str) -> bool:
return lang_id in self.supported_languages
class AudioDiarizationInput(BaseModel):
audio_url: str
topics: list[TitleSummaryWithId]

View File

@@ -89,6 +89,10 @@ class Settings(BaseSettings):
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
# Sentry
SENTRY_DSN: str | None = None
@@ -113,5 +117,19 @@ class Settings(BaseSettings):
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
# Celery
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
# Redis
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
settings = Settings()

View File

@@ -0,0 +1,14 @@
import argparse
from reflector.app import celery_app # noqa
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
parser.add_argument("--delay", action="store_true")
args = parser.parse_args()
if args.delay:
task_pipeline_main_post.delay(args.transcript_id)
else:
task_pipeline_main_post(args.transcript_id)

View File

@@ -1,7 +1,5 @@
import asyncio
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
from json import loads
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -10,25 +8,7 @@ from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
from reflector.pipelines.runner import PipelineRunner
sessions = []
router = APIRouter()
@@ -38,7 +18,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
class TranscriptionContext(object):
def __init__(self, logger):
self.logger = logger
self.pipeline = None
self.pipeline_runner = None
self.data_channel = None
self.status = "idle"
self.topics = []
@@ -60,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx
frame = await self.track.recv()
try:
await ctx.pipeline.push(frame)
ctx.pipeline_runner.push(frame)
except Exception as e:
ctx.logger.error("Pipeline error", error=e)
return frame
@@ -71,27 +51,10 @@ class RtcOffer(BaseModel):
type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
STATUS = "STATUS"
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
FINAL_TITLE = "FINAL_TITLE"
async def rtc_offer_base(
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
pipeline_runner: PipelineRunner,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -101,146 +64,10 @@ async def rtc_offer_base(
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
args=event_callback_args,
data=transcript,
)
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=on_final_short_summary
),
]
),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# handle RTC peer connection
pc = RTCPeerConnection()
ctx.pipeline_runner = pipeline_runner
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True):
# may be called twice
@@ -249,12 +76,10 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing.
# 2. or the client close the connection
# and there is nothing to do because it is already closed
await update_status("processing")
await ctx.pipeline.flush()
ctx.pipeline_runner.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
await update_status("ended")
if pc in sessions:
sessions.remove(pc)
m_rtc_sessions.dec()
@@ -287,7 +112,6 @@ async def rtc_offer_base(
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer)
@@ -308,8 +132,3 @@ async def rtc_clean_sessions(_):
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
return await rtc_offer_base(params, request)

View File

@@ -1,8 +1,5 @@
import json
from datetime import datetime
from pathlib import Path
from datetime import datetime, timedelta
from typing import Annotated, Optional
from uuid import uuid4
import reflector.auth as auth
from fastapi import (
@@ -12,221 +9,36 @@ from fastapi import (
Request,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page, paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.db.transcripts import (
AudioWaveform,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
# ==============================================================
# Models to move to a database, but required for the API to work
# ==============================================================
ALGORITHM = "HS256"
DOWNLOAD_EXPIRE_MINUTES = 60
def generate_uuid4():
return str(uuid4())
def generate_transcript_name():
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
transcript: str | None = None
timestamp: float
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
event: str
data: dict
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
if existing_topic:
existing_topic.update_from(topic)
else:
self.topics.append(topic)
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(
self,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
) -> list[Transcript]:
query = transcripts.select().where(transcripts.c.user_id == user_id)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
results = await database.fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
return Transcript(**result)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict):
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
transcripts_controller = TranscriptController()
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ==============================================================
@@ -298,6 +110,55 @@ async def transcripts_create(
# ==============================================================
class GetTranscriptSegmentTopic(BaseModel):
text: str
start: float
speaker: int
class GetTranscriptTopic(BaseModel):
id: str
title: str
summary: str
timestamp: float
transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
if not topic.words:
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.transcript
segments = [
GetTranscriptSegmentTopic(
text=topic.transcript,
start=topic.timestamp,
speaker=0,
)
]
else:
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
start=segment.start,
speaker=segment.speaker,
)
for segment in transcript.as_segments()
]
return cls(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
transcript=text,
segments=segments,
)
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
@@ -320,32 +181,17 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {"events": []}
values = {}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
transcript_event["long_summary"] = info.long_summary
break
values["events"].extend(transcript.events)
if info.short_summary is not None:
values["short_summary"] = info.short_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
transcript_event["short_summary"] = info.short_summary
break
values["events"].extend(transcript.events)
if info.title is not None:
values["title"] = info.title
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
transcript_event["title"] = info.title
break
values["events"].extend(transcript.events)
await transcripts_controller.update(transcript, values)
return transcript
@@ -368,8 +214,21 @@ async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -406,7 +265,10 @@ async def transcript_get_audio_waveform(
return transcript.audio_waveform
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -415,7 +277,16 @@ async def transcript_get_topics(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics
# convert to GetTranscriptTopic
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
# ==============================================================
# Websocket
# ==============================================================
@router.get("/transcripts/{transcript_id}/events")
@@ -423,41 +294,6 @@ async def transcript_get_websocket_events(transcript_id: str):
pass
# ==============================================================
# Websocket Manager
# ==============================================================
class WebsocketManager:
def __init__(self):
self.active_connections = {}
async def connect(self, transcript_id: str, websocket: WebSocket):
await websocket.accept()
if transcript_id not in self.active_connections:
self.active_connections[transcript_id] = []
self.active_connections[transcript_id].append(websocket)
def disconnect(self, transcript_id: str, websocket: WebSocket):
if transcript_id not in self.active_connections:
return
self.active_connections[transcript_id].remove(websocket)
if not self.active_connections[transcript_id]:
del self.active_connections[transcript_id]
async def send_json(self, transcript_id: str, message):
if transcript_id not in self.active_connections:
return
for connection in self.active_connections[transcript_id][:]:
try:
await connection.send_json(message)
except Exception:
self.active_connections[transcript_id].remove(connection)
ws_manager = WebsocketManager()
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
@@ -469,21 +305,31 @@ async def transcript_events_websocket(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await ws_manager.connect(transcript_id, websocket)
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
ws_manager.disconnect(transcript_id, websocket)
await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
@@ -491,105 +337,6 @@ async def transcript_events_websocket(
# ==============================================================
async def handle_rtc_event(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
# print(f"Event: {event}", args, data)
transcript_id = args
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
# event send to websocket clients may not be the same as the event
# received from the pipeline. For example, the pipeline will send
# a TRANSCRIPT event with all words, but this is not what we want
# to send to the websocket client.
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(
event=event,
data=TranscriptText(text=data.text, translation=data.translation),
)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
},
)
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript.text,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"topics": transcript.topics_dump(),
},
)
elif event == PipelineEvent.FINAL_TITLE:
final_title = TranscriptFinalTitle(title=data.title)
resp = transcript.add_event(event=event, data=final_title)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"title": final_title.title,
},
)
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
resp = transcript.add_event(event=event, data=final_long_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"long_summary": final_long_summary.long_summary,
},
)
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
resp = transcript.add_event(event=event, data=final_short_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"short_summary": final_short_summary.short_summary,
},
)
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"status": data.value,
},
)
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
@@ -605,13 +352,14 @@ async def transcript_record_webrtc(
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_mp3_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,12 @@
from celery import Celery
from reflector.settings import settings
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
]
)

View File

@@ -0,0 +1,126 @@
"""
Websocket manager
=================
This module contains the WebsocketManager class, which is responsible for
managing websockets and handling websocket connections.
It uses the RedisPubSubManager class to subscribe to Redis channels and
broadcast messages to all connected websockets.
"""
import asyncio
import json
import threading
import redis.asyncio as redis
from fastapi import WebSocket
from reflector.settings import settings
class RedisPubSubManager:
def __init__(self, host="localhost", port=6379):
self.redis_host = host
self.redis_port = port
self.redis_connection = None
self.pubsub = None
async def get_redis_connection(self) -> redis.Redis:
return redis.Redis(
host=self.redis_host,
port=self.redis_port,
auto_close_connection_pool=False,
)
async def connect(self) -> None:
if self.redis_connection is not None:
return
self.redis_connection = await self.get_redis_connection()
self.pubsub = self.redis_connection.pubsub()
async def disconnect(self) -> None:
if self.redis_connection is None:
return
await self.redis_connection.close()
self.redis_connection = None
async def send_json(self, room_id: str, message: str) -> None:
if not self.redis_connection:
await self.connect()
message = json.dumps(message)
await self.redis_connection.publish(room_id, message)
async def subscribe(self, room_id: str) -> redis.Redis:
await self.pubsub.subscribe(room_id)
return self.pubsub
async def unsubscribe(self, room_id: str) -> None:
await self.pubsub.unsubscribe(room_id)
class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
await websocket.accept()
if room_id in self.rooms:
self.rooms[room_id].append(websocket)
else:
self.rooms[room_id] = [websocket]
await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket)
if len(self.rooms[room_id]) == 0:
del self.rooms[room_id]
await self.pubsub_client.unsubscribe(room_id)
async def _pubsub_data_reader(self, pubsub_subscriber):
while True:
message = await pubsub_subscriber.get_message(
ignore_subscribe_messages=True
)
if message is not None:
room_id = message["channel"].decode("utf-8")
all_sockets = self.rooms[room_id]
for socket in all_sockets:
data = json.loads(message["data"].decode("utf-8"))
await socket.send_json(data)
def get_ws_manager() -> WebsocketManager:
"""
Returns the WebsocketManager instance for managing websockets.
This function initializes and returns the WebsocketManager instance,
which is responsible for managing websockets and handling websocket
connections.
Returns:
WebsocketManager: The initialized WebsocketManager instance.
Raises:
ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server.
"""
local = threading.local()
if hasattr(local, "ws_manager"):
return local.ws_manager
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
local.ws_manager = ws_manager
return ws_manager

View File

@@ -4,4 +4,11 @@ if [ -f "/venv/bin/activate" ]; then
source /venv/bin/activate
fi
alembic upgrade head
python -m reflector.app
if [ "${ENTRYPOINT}" = "server" ]; then
python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info
else
echo "Unknown command"
fi

View File

@@ -45,28 +45,50 @@ async def dummy_transcript():
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text=" world."),
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
@@ -98,7 +120,17 @@ def ensure_casing():
@pytest.fixture
def sentence_tokenize():
with patch(
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"}

View File

@@ -0,0 +1,161 @@
def test_processor_transcript_segment():
from reflector.processors.types import Transcript, Word
transcript = Transcript(
words=[
Word(text=" the", start=5.12, end=5.48, speaker=0),
Word(text=" different", start=5.48, end=5.8, speaker=0),
Word(text=" projects", start=5.8, end=6.3, speaker=0),
Word(text=" that", start=6.3, end=6.5, speaker=0),
Word(text=" are", start=6.5, end=6.58, speaker=0),
Word(text=" going", start=6.58, end=6.82, speaker=0),
Word(text=" on", start=6.82, end=7.26, speaker=0),
Word(text=" to", start=7.26, end=7.4, speaker=0),
Word(text=" give", start=7.4, end=7.54, speaker=0),
Word(text=" you", start=7.54, end=7.9, speaker=0),
Word(text=" context", start=7.9, end=8.24, speaker=0),
Word(text=" and", start=8.24, end=8.66, speaker=0),
Word(text=" I", start=8.66, end=8.72, speaker=0),
Word(text=" think", start=8.72, end=8.82, speaker=0),
Word(text=" that's", start=8.82, end=9.04, speaker=0),
Word(text=" what", start=9.04, end=9.12, speaker=0),
Word(text=" we'll", start=9.12, end=9.24, speaker=0),
Word(text=" do", start=9.24, end=9.32, speaker=0),
Word(text=" this", start=9.32, end=9.52, speaker=0),
Word(text=" week.", start=9.52, end=9.76, speaker=0),
Word(text=" Um,", start=10.24, end=10.62, speaker=0),
Word(text=" so,", start=11.36, end=11.94, speaker=0),
Word(text=" um,", start=12.46, end=12.92, speaker=0),
Word(text=" what", start=13.74, end=13.94, speaker=0),
Word(text=" we're", start=13.94, end=14.1, speaker=0),
Word(text=" going", start=14.1, end=14.24, speaker=0),
Word(text=" to", start=14.24, end=14.34, speaker=0),
Word(text=" do", start=14.34, end=14.8, speaker=0),
Word(text=" at", start=14.8, end=14.98, speaker=0),
Word(text=" H", start=14.98, end=15.04, speaker=0),
Word(text=" of", start=15.04, end=15.16, speaker=0),
Word(text=" you,", start=15.16, end=15.26, speaker=0),
Word(text=" maybe.", start=15.28, end=15.34, speaker=0),
Word(text=" you", start=15.36, end=15.52, speaker=0),
Word(text=" can", start=15.52, end=15.62, speaker=0),
Word(text=" introduce", start=15.62, end=15.98, speaker=0),
Word(text=" yourself", start=15.98, end=16.42, speaker=0),
Word(text=" to", start=16.42, end=16.68, speaker=0),
Word(text=" the", start=16.68, end=16.72, speaker=0),
Word(text=" team", start=16.72, end=17.52, speaker=0),
Word(text=" quickly", start=17.87, end=18.65, speaker=0),
Word(text=" and", start=18.65, end=19.63, speaker=0),
Word(text=" Oh,", start=20.91, end=21.55, speaker=0),
Word(text=" this", start=21.67, end=21.83, speaker=0),
Word(text=" is", start=21.83, end=22.17, speaker=0),
Word(text=" a", start=22.17, end=22.35, speaker=0),
Word(text=" reflector", start=22.35, end=22.89, speaker=0),
Word(text=" translating", start=22.89, end=23.33, speaker=0),
Word(text=" into", start=23.33, end=23.73, speaker=0),
Word(text=" French", start=23.73, end=23.95, speaker=0),
Word(text=" for", start=23.95, end=24.15, speaker=0),
Word(text=" me.", start=24.15, end=24.43, speaker=0),
Word(text=" This", start=27.87, end=28.19, speaker=0),
Word(text=" is", start=28.19, end=28.45, speaker=0),
Word(text=" all", start=28.45, end=28.79, speaker=0),
Word(text=" the", start=28.79, end=29.15, speaker=0),
Word(text=" way,", start=29.15, end=29.15, speaker=0),
Word(text=" please,", start=29.53, end=29.59, speaker=0),
Word(text=" please,", start=29.73, end=29.77, speaker=0),
Word(text=" please,", start=29.77, end=29.83, speaker=0),
Word(text=" please.", start=29.83, end=29.97, speaker=0),
Word(text=" Yeah,", start=29.97, end=30.17, speaker=0),
Word(text=" that's", start=30.25, end=30.33, speaker=0),
Word(text=" all", start=30.33, end=30.49, speaker=0),
Word(text=" it's", start=30.49, end=30.69, speaker=0),
Word(text=" right.", start=30.69, end=30.69, speaker=0),
Word(text=" Right.", start=30.72, end=30.98, speaker=1),
Word(text=" Yeah,", start=31.56, end=31.72, speaker=2),
Word(text=" that's", start=31.86, end=31.98, speaker=2),
Word(text=" right.", start=31.98, end=32.2, speaker=2),
Word(text=" Because", start=32.38, end=32.46, speaker=0),
Word(text=" I", start=32.46, end=32.58, speaker=0),
Word(text=" thought", start=32.58, end=32.78, speaker=0),
Word(text=" I'd", start=32.78, end=33.0, speaker=0),
Word(text=" be", start=33.0, end=33.02, speaker=0),
Word(text=" able", start=33.02, end=33.18, speaker=0),
Word(text=" to", start=33.18, end=33.34, speaker=0),
Word(text=" pull", start=33.34, end=33.52, speaker=0),
Word(text=" out.", start=33.52, end=33.68, speaker=0),
Word(text=" Yeah,", start=33.7, end=33.9, speaker=0),
Word(text=" that", start=33.9, end=34.02, speaker=0),
Word(text=" was", start=34.02, end=34.24, speaker=0),
Word(text=" the", start=34.24, end=34.34, speaker=0),
Word(text=" one", start=34.34, end=34.44, speaker=0),
Word(text=" before", start=34.44, end=34.7, speaker=0),
Word(text=" that.", start=34.7, end=35.24, speaker=0),
Word(text=" Friends,", start=35.84, end=36.46, speaker=0),
Word(text=" if", start=36.64, end=36.7, speaker=0),
Word(text=" you", start=36.7, end=36.7, speaker=0),
Word(text=" have", start=36.7, end=37.24, speaker=0),
Word(text=" tell", start=37.24, end=37.44, speaker=0),
Word(text=" us", start=37.44, end=37.68, speaker=0),
Word(text=" if", start=37.68, end=37.82, speaker=0),
Word(text=" it's", start=37.82, end=38.04, speaker=0),
Word(text=" good,", start=38.04, end=38.58, speaker=0),
Word(text=" exceptionally", start=38.96, end=39.1, speaker=0),
Word(text=" good", start=39.1, end=39.6, speaker=0),
Word(text=" and", start=39.6, end=39.86, speaker=0),
Word(text=" tell", start=39.86, end=40.0, speaker=0),
Word(text=" us", start=40.0, end=40.06, speaker=0),
Word(text=" when", start=40.06, end=40.2, speaker=0),
Word(text=" it's", start=40.2, end=40.34, speaker=0),
Word(text=" exceptionally", start=40.34, end=40.6, speaker=0),
Word(text=" bad.", start=40.6, end=40.94, speaker=0),
Word(text=" We", start=40.96, end=41.26, speaker=0),
Word(text=" don't", start=41.26, end=41.44, speaker=0),
Word(text=" need", start=41.44, end=41.66, speaker=0),
Word(text=" that", start=41.66, end=41.82, speaker=0),
Word(text=" at", start=41.82, end=41.94, speaker=0),
Word(text=" the", start=41.94, end=41.98, speaker=0),
Word(text=" middle", start=41.98, end=42.18, speaker=0),
Word(text=" of", start=42.18, end=42.36, speaker=0),
Word(text=" age.", start=42.36, end=42.7, speaker=0),
Word(text=" Okay,", start=43.26, end=43.44, speaker=0),
Word(text=" yeah,", start=43.68, end=43.76, speaker=0),
Word(text=" that", start=43.78, end=44.3, speaker=0),
Word(text=" sentence", start=44.3, end=44.72, speaker=0),
Word(text=" right", start=44.72, end=45.1, speaker=0),
Word(text=" before.", start=45.1, end=45.56, speaker=0),
Word(text=" it", start=46.08, end=46.36, speaker=0),
Word(text=" realizing", start=46.36, end=47.0, speaker=0),
Word(text=" that", start=47.0, end=47.28, speaker=0),
Word(text=" I", start=47.28, end=47.28, speaker=0),
Word(text=" was", start=47.28, end=47.64, speaker=0),
Word(text=" saying", start=47.64, end=48.06, speaker=0),
Word(text=" that", start=48.06, end=48.44, speaker=0),
Word(text=" it's", start=48.44, end=48.54, speaker=0),
Word(text=" interesting", start=48.54, end=48.78, speaker=0),
Word(text=" that", start=48.78, end=48.96, speaker=0),
Word(text=" it's", start=48.96, end=49.08, speaker=0),
Word(text=" translating", start=49.08, end=49.32, speaker=0),
Word(text=" the", start=49.32, end=49.56, speaker=0),
Word(text=" French", start=49.56, end=49.76, speaker=0),
Word(text=" was", start=49.76, end=50.16, speaker=0),
Word(text=" completely", start=50.16, end=50.4, speaker=0),
Word(text=" wrong.", start=50.4, end=50.7, speaker=0),
]
)
segments = transcript.as_segments()
assert len(segments) == 7
# check speaker order
assert segments[0].speaker == 0
assert segments[1].speaker == 0
assert segments[2].speaker == 0
assert segments[3].speaker == 1
assert segments[4].speaker == 2
assert segments[5].speaker == 0
assert segments[6].speaker == 0
# check the timing (first entry, and first of others speakers)
assert segments[0].start == 5.12
assert segments[3].start == 30.72
assert segments[4].start == 31.56
assert segments[5].start == 32.38

View File

@@ -1,3 +1,4 @@
import asyncio
import pytest
import httpx
from reflector.utils.retry import (
@@ -8,6 +9,31 @@ from reflector.utils.retry import (
)
@pytest.mark.asyncio
async def test_retry_redirect(httpx_mock):
async def custom_response(request: httpx.Request):
if request.url.path == "/hello":
await asyncio.sleep(1)
return httpx.Response(
status_code=303, headers={"location": "https://test_url/redirected"}
)
elif request.url.path == "/redirected":
return httpx.Response(status_code=200, json={"hello": "world"})
else:
raise Exception("Unexpected path")
httpx_mock.add_callback(custom_response)
async with httpx.AsyncClient() as client:
# timeout should not triggered, as it will end up ok
# even though the first request is a 303 and took more that 0.5
resp = await retry(client.get)(
"https://test_url/hello",
retry_timeout=0.5,
follow_redirects=True,
)
assert resp.json() == {"hello": "world"}
@pytest.mark.asyncio
async def test_retry_httpx(httpx_mock):
# this code should be force a retry

View File

@@ -32,7 +32,7 @@ class ThreadedUvicorn:
@pytest.fixture
async def appserver(tmpdir):
async def appserver(tmpdir, celery_session_app, celery_session_worker):
from reflector.settings import settings
from reflector.app import app
@@ -52,12 +52,20 @@ async def appserver(tmpdir):
settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
ensure_casing,
appserver,
sentence_tokenize,
@@ -95,6 +103,7 @@ async def test_transcript_rtc_and_websocket(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -121,14 +130,20 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task
websocket_task.cancel()
@@ -169,29 +184,28 @@ async def test_transcript_rtc_and_websocket(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
ensure_casing,
appserver,
sentence_tokenize,
@@ -232,6 +246,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -265,6 +280,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
await client.stop()
# wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
await asyncio.sleep(2)
# stop websocket task
@@ -306,7 +333,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"

View File

@@ -3,6 +3,7 @@ import Modal from "../modal";
import useTranscript from "../useTranscript";
import useTopics from "../useTopics";
import useWaveform from "../useWaveform";
import useMp3 from "../useMp3";
import { TopicList } from "../topicList";
import Recorder from "../recorder";
import { Topic } from "../webSocketTypes";
@@ -28,6 +29,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const topics = useTopics(protectedPath, transcriptId);
const waveform = useWaveform(protectedPath, transcriptId);
const useActiveTopic = useState<Topic | null>(null);
const mp3 = useMp3(api, transcriptId);
if (transcript?.error /** || topics?.error || waveform?.error **/) {
return (
@@ -62,6 +64,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
waveform={waveform?.waveform}
isPastMeeting={true}
transcriptId={transcript?.response?.id}
mp3Blob={mp3.blob}
/>
)}
</div>

View File

@@ -30,6 +30,7 @@ type RecorderProps = {
waveform?: AudioWaveform | null;
isPastMeeting: boolean;
transcriptId?: string | null;
mp3Blob?: Blob | null;
};
export default function Recorder(props: RecorderProps) {
@@ -108,11 +109,7 @@ export default function Recorder(props: RecorderProps) {
if (waveformRef.current) {
const _wavesurfer = WaveSurfer.create({
container: waveformRef.current,
url: props.transcriptId
? `${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3`
: undefined,
peaks: props.waveform?.data,
hideScrollbar: true,
autoCenter: true,
barWidth: 2,
@@ -146,6 +143,10 @@ export default function Recorder(props: RecorderProps) {
if (props.isPastMeeting) _wavesurfer.toggleInteraction(true);
if (props.mp3Blob) {
_wavesurfer.loadBlob(props.mp3Blob);
}
setWavesurfer(_wavesurfer);
return () => {
@@ -157,6 +158,12 @@ export default function Recorder(props: RecorderProps) {
}
}, []);
useEffect(() => {
if (!wavesurfer) return;
if (!props.mp3Blob) return;
wavesurfer.loadBlob(props.mp3Blob);
}, [props.mp3Blob]);
useEffect(() => {
topicsRef.current = props.topics;
if (!isRecording) renderMarkers();

View File

@@ -7,6 +7,7 @@ import {
import { formatTime } from "../../lib/time";
import ScrollToBottom from "./scrollToBottom";
import { Topic } from "./webSocketTypes";
import { generateHighContrastColor } from "../../lib/utils";
type TopicListProps = {
topics: Topic[];
@@ -103,7 +104,37 @@ export function TopicList({
/>
</div>
{activeTopic?.id == topic.id && (
<div className="p-2">{topic.transcript}</div>
<div className="p-2">
{topic.segments ? (
<>
{topic.segments.map((segment, index: number) => (
<p
key={index}
className="text-left text-slate-500 text-sm md:text-base"
>
<span className="font-mono text-slate-500">
[{formatTime(segment.start)}]
</span>
<span
className="font-bold text-slate-500"
style={{
color: generateHighContrastColor(
`Speaker ${segment.speaker}`,
[96, 165, 250],
),
}}
>
{" "}
(Speaker {segment.speaker}):
</span>{" "}
<span>{segment.text}</span>
</p>
))}
</>
) : (
<>{topic.transcript}</>
)}
</div>
)}
</button>
))}

View File

@@ -1,36 +1,64 @@
import { useEffect, useState } from "react";
import { useContext, useEffect, useState } from "react";
import {
DefaultApi,
V1TranscriptGetAudioMp3Request,
// V1TranscriptGetAudioMp3Request,
} from "../../api/apis/DefaultApi";
import {} from "../../api";
import { useError } from "../../(errors)/errorContext";
import { DomainContext } from "../domainContext";
type Mp3Response = {
url: string | null;
blob: Blob | null;
loading: boolean;
error: Error | null;
};
const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
const [url, setUrl] = useState<string | null>(null);
const [blob, setBlob] = useState<Blob | null>(null);
const [loading, setLoading] = useState<boolean>(false);
const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError();
const { api_url } = useContext(DomainContext);
const getMp3 = (id: string) => {
if (!id) throw new Error("Transcript ID is required to get transcript Mp3");
if (!id) return;
setLoading(true);
const requestParameters: V1TranscriptGetAudioMp3Request = {
transcriptId: id,
};
api
.v1TranscriptGetAudioMp3(requestParameters)
.then((result) => {
setUrl(result);
setLoading(false);
console.debug("Transcript Mp3 loaded:", result);
// XXX Current API interface does not output a blob, we need to to is manually
// const requestParameters: V1TranscriptGetAudioMp3Request = {
// transcriptId: id,
// };
// api
// .v1TranscriptGetAudioMp3(requestParameters)
// .then((result) => {
// setUrl(result);
// setLoading(false);
// console.debug("Transcript Mp3 loaded:", result);
// })
// .catch((err) => {
// setError(err);
// setErrorState(err);
// });
const localUrl = `${api_url}/v1/transcripts/${id}/audio/mp3`;
if (localUrl == url) return;
const headers = new Headers();
if (api.configuration.configuration.accessToken) {
headers.set("Authorization", api.configuration.configuration.accessToken);
}
fetch(localUrl, {
method: "GET",
headers,
})
.then((response) => {
setUrl(localUrl);
response.blob().then((blob) => {
setBlob(blob);
setLoading(false);
});
})
.catch((err) => {
setError(err);
@@ -42,7 +70,7 @@ const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
getMp3(id);
}, [id]);
return { url, loading, error };
return { url, blob, loading, error };
};
export default useMp3;

View File

@@ -58,6 +58,39 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title: "Topic 1: Introduction to Quantum Mechanics",
transcript:
"A brief overview of quantum mechanics and its principles.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
,
{
speaker: 3,
start: 90,
text: "This is the third speaker",
},
{
speaker: 4,
start: 90,
text: "This is the fourth speaker",
},
{
speaker: 5,
start: 123,
text: "This is the fifth speaker",
},
{
speaker: 6,
start: 300,
text: "This is the sixth speaker",
},
],
},
{
id: "2",
@@ -66,6 +99,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title: "Topic 2: Machine Learning Algorithms",
transcript:
"Understanding the different types of machine learning algorithms.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "3",
@@ -73,6 +118,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 3",
title: "Topic 3: Mental Health Awareness",
transcript: "Ways to improve mental health and reduce stigma.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "4",
@@ -80,6 +137,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 4",
title: "Topic 4: Basics of Productivity",
transcript: "Tips and tricks to increase daily productivity.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "5",
@@ -88,6 +157,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title: "Topic 5: Future of Aviation",
transcript:
"Exploring the advancements and possibilities in aviation.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
]);
@@ -106,6 +187,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
"Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.",
transcript:
"A brief overview of quantum mechanics and its principles.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "2",
@@ -115,6 +208,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
"Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.",
transcript:
"Understanding the different types of machine learning algorithms.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "3",
@@ -123,6 +228,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title:
"Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.",
transcript: "Ways to improve mental health and reduce stigma.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "4",
@@ -131,6 +248,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title:
"Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.",
transcript: "Tips and tricks to increase daily productivity.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
{
id: "5",
@@ -140,6 +269,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
"Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.",
transcript:
"Exploring the advancements and possibilities in aviation.",
segments: [
{
speaker: 1,
start: 0,
text: "This is the transcription of an example title",
},
{
speaker: 2,
start: 10,
text: "This is the second speaker",
},
],
},
]);
@@ -173,7 +314,17 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
break;
case "TOPIC":
setTopics((prevTopics) => [...prevTopics, message.data]);
setTopics((prevTopics) => {
const topic = message.data as Topic;
const index = prevTopics.findIndex(
(prevTopic) => prevTopic.id === topic.id,
);
if (index >= 0) {
prevTopics[index] = topic;
return prevTopics;
}
return [...prevTopics, topic];
});
console.debug("TOPIC event:", message.data);
break;

View File

@@ -1,10 +1,6 @@
export type Topic = {
timestamp: number;
title: string;
transcript: string;
summary: string;
id: string;
};
import { GetTranscriptTopic } from "../../api";
export type Topic = GetTranscriptTopic;
export type Transcript = {
text: string;

View File

@@ -5,10 +5,11 @@ models/AudioWaveform.ts
models/CreateTranscript.ts
models/DeletionStatus.ts
models/GetTranscript.ts
models/GetTranscriptSegmentTopic.ts
models/GetTranscriptTopic.ts
models/HTTPValidationError.ts
models/PageGetTranscript.ts
models/RtcOffer.ts
models/TranscriptTopic.ts
models/UpdateTranscript.ts
models/UserInfo.ts
models/ValidationError.ts

View File

@@ -42,10 +42,6 @@ import {
UpdateTranscriptToJSON,
} from "../models";
export interface RtcOfferRequest {
rtcOffer: RtcOffer;
}
export interface V1TranscriptDeleteRequest {
transcriptId: any;
}
@@ -56,6 +52,7 @@ export interface V1TranscriptGetRequest {
export interface V1TranscriptGetAudioMp3Request {
transcriptId: any;
token?: any;
}
export interface V1TranscriptGetAudioWaveformRequest {
@@ -132,58 +129,6 @@ export class DefaultApi extends runtime.BaseAPI {
return await response.value();
}
/**
* Rtc Offer
*/
async rtcOfferRaw(
requestParameters: RtcOfferRequest,
initOverrides?: RequestInit | runtime.InitOverrideFunction,
): Promise<runtime.ApiResponse<any>> {
if (
requestParameters.rtcOffer === null ||
requestParameters.rtcOffer === undefined
) {
throw new runtime.RequiredError(
"rtcOffer",
"Required parameter requestParameters.rtcOffer was null or undefined when calling rtcOffer.",
);
}
const queryParameters: any = {};
const headerParameters: runtime.HTTPHeaders = {};
headerParameters["Content-Type"] = "application/json";
const response = await this.request(
{
path: `/offer`,
method: "POST",
headers: headerParameters,
query: queryParameters,
body: RtcOfferToJSON(requestParameters.rtcOffer),
},
initOverrides,
);
if (this.isJsonMime(response.headers.get("content-type"))) {
return new runtime.JSONApiResponse<any>(response);
} else {
return new runtime.TextApiResponse(response) as any;
}
}
/**
* Rtc Offer
*/
async rtcOffer(
requestParameters: RtcOfferRequest,
initOverrides?: RequestInit | runtime.InitOverrideFunction,
): Promise<any> {
const response = await this.rtcOfferRaw(requestParameters, initOverrides);
return await response.value();
}
/**
* Transcript Delete
*/
@@ -325,6 +270,10 @@ export class DefaultApi extends runtime.BaseAPI {
const queryParameters: any = {};
if (requestParameters.token !== undefined) {
queryParameters["token"] = requestParameters.token;
}
const headerParameters: runtime.HTTPHeaders = {};
if (this.configuration && this.configuration.accessToken) {

View File

@@ -0,0 +1,88 @@
/* tslint:disable */
/* eslint-disable */
/**
* FastAPI
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 0.1.0
*
*
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
* https://openapi-generator.tech
* Do not edit the class manually.
*/
import { exists, mapValues } from "../runtime";
/**
*
* @export
* @interface GetTranscriptSegmentTopic
*/
export interface GetTranscriptSegmentTopic {
/**
*
* @type {any}
* @memberof GetTranscriptSegmentTopic
*/
text: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptSegmentTopic
*/
start: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptSegmentTopic
*/
speaker: any | null;
}
/**
* Check if a given object implements the GetTranscriptSegmentTopic interface.
*/
export function instanceOfGetTranscriptSegmentTopic(value: object): boolean {
let isInstance = true;
isInstance = isInstance && "text" in value;
isInstance = isInstance && "start" in value;
isInstance = isInstance && "speaker" in value;
return isInstance;
}
export function GetTranscriptSegmentTopicFromJSON(
json: any,
): GetTranscriptSegmentTopic {
return GetTranscriptSegmentTopicFromJSONTyped(json, false);
}
export function GetTranscriptSegmentTopicFromJSONTyped(
json: any,
ignoreDiscriminator: boolean,
): GetTranscriptSegmentTopic {
if (json === undefined || json === null) {
return json;
}
return {
text: json["text"],
start: json["start"],
speaker: json["speaker"],
};
}
export function GetTranscriptSegmentTopicToJSON(
value?: GetTranscriptSegmentTopic | null,
): any {
if (value === undefined) {
return undefined;
}
if (value === null) {
return null;
}
return {
text: value.text,
start: value.start,
speaker: value.speaker,
};
}

View File

@@ -0,0 +1,112 @@
/* tslint:disable */
/* eslint-disable */
/**
* FastAPI
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 0.1.0
*
*
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
* https://openapi-generator.tech
* Do not edit the class manually.
*/
import { exists, mapValues } from "../runtime";
/**
*
* @export
* @interface GetTranscriptTopic
*/
export interface GetTranscriptTopic {
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
id: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
title: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
summary: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
timestamp: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
transcript: any | null;
/**
*
* @type {any}
* @memberof GetTranscriptTopic
*/
segments?: any | null;
}
/**
* Check if a given object implements the GetTranscriptTopic interface.
*/
export function instanceOfGetTranscriptTopic(value: object): boolean {
let isInstance = true;
isInstance = isInstance && "id" in value;
isInstance = isInstance && "title" in value;
isInstance = isInstance && "summary" in value;
isInstance = isInstance && "timestamp" in value;
isInstance = isInstance && "transcript" in value;
return isInstance;
}
export function GetTranscriptTopicFromJSON(json: any): GetTranscriptTopic {
return GetTranscriptTopicFromJSONTyped(json, false);
}
export function GetTranscriptTopicFromJSONTyped(
json: any,
ignoreDiscriminator: boolean,
): GetTranscriptTopic {
if (json === undefined || json === null) {
return json;
}
return {
id: json["id"],
title: json["title"],
summary: json["summary"],
timestamp: json["timestamp"],
transcript: json["transcript"],
segments: !exists(json, "segments") ? undefined : json["segments"],
};
}
export function GetTranscriptTopicToJSON(
value?: GetTranscriptTopic | null,
): any {
if (value === undefined) {
return undefined;
}
if (value === null) {
return null;
}
return {
id: value.id,
title: value.title,
summary: value.summary,
timestamp: value.timestamp,
transcript: value.transcript,
segments: value.segments,
};
}

View File

@@ -0,0 +1,88 @@
/* tslint:disable */
/* eslint-disable */
/**
* FastAPI
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 0.1.0
*
*
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
* https://openapi-generator.tech
* Do not edit the class manually.
*/
import { exists, mapValues } from "../runtime";
/**
*
* @export
* @interface TranscriptSegmentTopic
*/
export interface TranscriptSegmentTopic {
/**
*
* @type {any}
* @memberof TranscriptSegmentTopic
*/
speaker: any | null;
/**
*
* @type {any}
* @memberof TranscriptSegmentTopic
*/
text: any | null;
/**
*
* @type {any}
* @memberof TranscriptSegmentTopic
*/
timestamp: any | null;
}
/**
* Check if a given object implements the TranscriptSegmentTopic interface.
*/
export function instanceOfTranscriptSegmentTopic(value: object): boolean {
let isInstance = true;
isInstance = isInstance && "speaker" in value;
isInstance = isInstance && "text" in value;
isInstance = isInstance && "timestamp" in value;
return isInstance;
}
export function TranscriptSegmentTopicFromJSON(
json: any,
): TranscriptSegmentTopic {
return TranscriptSegmentTopicFromJSONTyped(json, false);
}
export function TranscriptSegmentTopicFromJSONTyped(
json: any,
ignoreDiscriminator: boolean,
): TranscriptSegmentTopic {
if (json === undefined || json === null) {
return json;
}
return {
speaker: json["speaker"],
text: json["text"],
timestamp: json["timestamp"],
};
}
export function TranscriptSegmentTopicToJSON(
value?: TranscriptSegmentTopic | null,
): any {
if (value === undefined) {
return undefined;
}
if (value === null) {
return null;
}
return {
speaker: value.speaker,
text: value.text,
timestamp: value.timestamp,
};
}

View File

@@ -42,13 +42,13 @@ export interface TranscriptTopic {
* @type {any}
* @memberof TranscriptTopic
*/
transcript?: any | null;
timestamp: any | null;
/**
*
* @type {any}
* @memberof TranscriptTopic
*/
timestamp: any | null;
segments?: any | null;
}
/**
@@ -78,8 +78,8 @@ export function TranscriptTopicFromJSONTyped(
id: !exists(json, "id") ? undefined : json["id"],
title: json["title"],
summary: json["summary"],
transcript: !exists(json, "transcript") ? undefined : json["transcript"],
timestamp: json["timestamp"],
segments: !exists(json, "segments") ? undefined : json["segments"],
};
}
@@ -94,7 +94,7 @@ export function TranscriptTopicToJSON(value?: TranscriptTopic | null): any {
id: value.id,
title: value.title,
summary: value.summary,
transcript: value.transcript,
timestamp: value.timestamp,
segments: value.segments,
};
}

View File

@@ -4,10 +4,11 @@ export * from "./AudioWaveform";
export * from "./CreateTranscript";
export * from "./DeletionStatus";
export * from "./GetTranscript";
export * from "./GetTranscriptSegmentTopic";
export * from "./GetTranscriptTopic";
export * from "./HTTPValidationError";
export * from "./PageGetTranscript";
export * from "./RtcOffer";
export * from "./TranscriptTopic";
export * from "./UpdateTranscript";
export * from "./UserInfo";
export * from "./ValidationError";

View File

@@ -1,3 +1,123 @@
export function isDevelopment() {
return process.env.NEXT_PUBLIC_ENV === "development";
}
// Function to calculate WCAG contrast ratio
export const getContrastRatio = (
foreground: [number, number, number],
background: [number, number, number],
) => {
const [r1, g1, b1] = foreground;
const [r2, g2, b2] = background;
const lum1 =
0.2126 * Math.pow(r1 / 255, 2.2) +
0.7152 * Math.pow(g1 / 255, 2.2) +
0.0722 * Math.pow(b1 / 255, 2.2);
const lum2 =
0.2126 * Math.pow(r2 / 255, 2.2) +
0.7152 * Math.pow(g2 / 255, 2.2) +
0.0722 * Math.pow(b2 / 255, 2.2);
return (Math.max(lum1, lum2) + 0.05) / (Math.min(lum1, lum2) + 0.05);
};
// Function to hash string into 32-bit integer
// 🔴 DO NOT USE FOR CRYPTOGRAPHY PURPOSES 🔴
export function murmurhash3_32_gc(key: string, seed: number = 0) {
let remainder, bytes, h1, h1b, c1, c2, k1, i;
remainder = key.length & 3; // key.length % 4
bytes = key.length - remainder;
h1 = seed;
c1 = 0xcc9e2d51;
c2 = 0x1b873593;
i = 0;
while (i < bytes) {
k1 =
(key.charCodeAt(i) & 0xff) |
((key.charCodeAt(++i) & 0xff) << 8) |
((key.charCodeAt(++i) & 0xff) << 16) |
((key.charCodeAt(++i) & 0xff) << 24);
++i;
k1 =
((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & 0xffffffff;
k1 = (k1 << 15) | (k1 >>> 17);
k1 =
((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & 0xffffffff;
h1 ^= k1;
h1 = (h1 << 13) | (h1 >>> 19);
h1b =
((h1 & 0xffff) * 5 + ((((h1 >>> 16) * 5) & 0xffff) << 16)) & 0xffffffff;
h1 = (h1b & 0xffff) + 0x6b64 + ((((h1b >>> 16) + 0xe654) & 0xffff) << 16);
}
k1 = 0;
switch (remainder) {
case 3:
k1 ^= (key.charCodeAt(i + 2) & 0xff) << 16;
case 2:
k1 ^= (key.charCodeAt(i + 1) & 0xff) << 8;
case 1:
k1 ^= key.charCodeAt(i) & 0xff;
k1 =
((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) &
0xffffffff;
k1 = (k1 << 15) | (k1 >>> 17);
k1 =
((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) &
0xffffffff;
h1 ^= k1;
}
h1 ^= key.length;
h1 ^= h1 >>> 16;
h1 =
((h1 & 0xffff) * 0x85ebca6b +
((((h1 >>> 16) * 0x85ebca6b) & 0xffff) << 16)) &
0xffffffff;
h1 ^= h1 >>> 13;
h1 =
((h1 & 0xffff) * 0xc2b2ae35 +
((((h1 >>> 16) * 0xc2b2ae35) & 0xffff) << 16)) &
0xffffffff;
h1 ^= h1 >>> 16;
return h1 >>> 0;
}
// Generates a color that is guaranteed to have high contrast with the given background color (optional)
export const generateHighContrastColor = (
name: string,
backgroundColor: [number, number, number] | null = null,
) => {
const hash = murmurhash3_32_gc(name);
let red = (hash & 0xff0000) >> 16;
let green = (hash & 0x00ff00) >> 8;
let blue = hash & 0x0000ff;
const getCssColor = (red: number, green: number, blue: number) =>
`rgb(${red}, ${green}, ${blue})`;
if (!backgroundColor) return getCssColor(red, green, blue);
const contrast = getContrastRatio([red, green, blue], backgroundColor);
// Adjust the color to achieve better contrast if necessary (WCAG recommends at least 4.5:1 for text)
if (contrast < 4.5) {
red = Math.abs(255 - red);
green = Math.abs(255 - green);
blue = Math.abs(255 - blue);
}
return getCssColor(red, green, blue);
};

View File

@@ -35,7 +35,7 @@
"supports-color": "^9.4.0",
"tailwindcss": "^3.3.2",
"typescript": "^5.1.6",
"wavesurfer.js": "^7.0.3"
"wavesurfer.js": "^7.4.2"
},
"main": "index.js",
"repository": "https://github.com/Monadical-SAS/reflector-ui.git",

View File

@@ -1,7 +1,7 @@
import type { NextPage } from "next";
const Forbidden: NextPage = () => {
return <h2>Sorry, you are not authorized to access this page.</h2>;
return <h2>Sorry, you are not authorized to access this page</h2>;
};
export default Forbidden;

View File

@@ -2638,10 +2638,10 @@ watchpack@2.4.0:
glob-to-regexp "^0.4.1"
graceful-fs "^4.1.2"
wavesurfer.js@^7.0.3:
version "7.0.3"
resolved "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.0.3.tgz"
integrity sha512-gJ3P+Bd3Q4E8qETjjg0pneaVqm2J7jegG2Cc6vqEF5YDDKQ3m8sKsvVfgVhJkacKkO9jFAGDu58Hw4zLr7xD0A==
wavesurfer.js@^7.4.2:
version "7.4.2"
resolved "https://registry.yarnpkg.com/wavesurfer.js/-/wavesurfer.js-7.4.2.tgz#59f5c87193d4eeeb199858688ddac1ad7ba86b3a"
integrity sha512-4pNQ1porOCUBYBmd2F1TqVuBnB2wBPipaw2qI920zYLuPnada0Rd1CURgh8HRuPGKxijj2iyZDFN2UZwsaEuhA==
wcwidth@>=1.0.1, wcwidth@^1.0.1:
version "1.0.1"