diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index 71bed5d0..9f3b9a6a 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -11,6 +11,11 @@ on: jobs: pytest: runs-on: ubuntu-latest + services: + redis: + image: redis:6 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 - name: Install poetry diff --git a/README.md b/README.md index 22651cc6..cb75c76b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ The project architecture consists of three primary components: * **Front-End**: NextJS React project hosted on Vercel, located in `www/`. * **Back-End**: Python server that offers an API and data persistence, found in `server/`. -* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. +* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end. @@ -120,6 +120,9 @@ TRANSCRIPT_MODAL_API_KEY= LLM_BACKEND=modal LLM_URL=https://monadical-sas--reflector-llm-web.modal.run LLM_MODAL_API_KEY= +TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run +ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run +DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run AUTH_BACKEND=fief AUTH_FIEF_URL=https://auth.reflector.media/reflector-local @@ -138,6 +141,10 @@ Use: poetry run python3 -m reflector.app ``` +And start the background worker + +celery -A reflector.worker.app worker --loglevel=info + #### Using docker Use: @@ -161,4 +168,5 @@ poetry run python -m reflector.tools.process path/to/audio.wav ## AI Models -*(Documentation for this section is pending.)* \ No newline at end of file +*(Documentation for this section is pending.)* + diff --git a/docker-compose.yml b/docker-compose.yml index 934baaac..9e6519af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,10 +5,19 @@ services: context: server ports: - 1250:1250 - environment: - LLM_URL: "${LLM_URL}" volumes: - model-cache:/root/.cache + environment: ENTRYPOINT=server + worker: + build: + context: server + volumes: + - model-cache:/root/.cache + environment: ENTRYPOINT=worker + redis: + image: redis:7.2 + ports: + - 6379:6379 web: build: context: www @@ -17,4 +26,3 @@ services: volumes: model-cache: - diff --git a/server/docker-compose.yml b/server/docker-compose.yml index 374130fa..c8432816 100644 --- a/server/docker-compose.yml +++ b/server/docker-compose.yml @@ -5,11 +5,23 @@ services: context: . ports: - 1250:1250 - environment: - LLM_URL: "${LLM_URL}" - MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" volumes: - model-cache:/root/.cache + environment: + ENTRYPOINT: server + REDIS_HOST: redis + worker: + build: + context: . + volumes: + - model-cache:/root/.cache + environment: + ENTRYPOINT: worker + REDIS_HOST: redis + redis: + image: redis:7.2 + ports: + - 6379:6379 volumes: model-cache: diff --git a/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py new file mode 100644 index 00000000..dffe6fa1 --- /dev/null +++ b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py @@ -0,0 +1,80 @@ +"""rename back text to transcript + +Revision ID: 38a927dcb099 +Revises: 9920ecfe2735 +Create Date: 2023-11-02 19:53:09.116240 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = '38a927dcb099' +down_revision: Union[str, None] = '9920ecfe2735' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "text" in topic: + # Rename key 'text' back to 'transcript' + topic["transcript"] = topic.pop("text") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) + + +def downgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "transcript" in topic: + # Rename key 'transcript' to 'text' + topic["text"] = topic.pop("transcript") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) diff --git a/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py new file mode 100644 index 00000000..caecaefd --- /dev/null +++ b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py @@ -0,0 +1,80 @@ +"""Migration transcript to text field in transcripts table + +Revision ID: 9920ecfe2735 +Revises: 99365b0cd87b +Create Date: 2023-11-02 18:55:17.019498 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = "9920ecfe2735" +down_revision: Union[str, None] = "99365b0cd87b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "transcript" in topic: + # Rename key 'transcript' to 'text' + topic["text"] = topic.pop("transcript") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) + + +def downgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "text" in topic: + # Rename key 'text' back to 'transcript' + topic["transcript"] = topic.pop("text") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) diff --git a/server/poetry.lock b/server/poetry.lock index 330c23e3..e72ade57 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -308,6 +308,20 @@ typing-extensions = ">=4" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "amqp" +version = "5.1.1" +description = "Low-level AMQP client for Python (fork of amqplib)." +optional = false +python-versions = ">=3.6" +files = [ + {file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"}, + {file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"}, +] + +[package.dependencies] +vine = ">=5.0.0" + [[package]] name = "annotated-types" version = "0.6.0" @@ -474,6 +488,17 @@ files = [ {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, ] +[[package]] +name = "billiard" +version = "4.1.0" +description = "Python multiprocessing fork with improvements and bugfixes" +optional = false +python-versions = ">=3.7" +files = [ + {file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"}, + {file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"}, +] + [[package]] name = "black" version = "23.9.1" @@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27" [package.extras] crt = ["awscrt (==0.16.26)"] +[[package]] +name = "celery" +version = "5.3.4" +description = "Distributed Task Queue." +optional = false +python-versions = ">=3.8" +files = [ + {file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"}, + {file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"}, +] + +[package.dependencies] +billiard = ">=4.1.0,<5.0" +click = ">=8.1.2,<9.0" +click-didyoumean = ">=0.3.0" +click-plugins = ">=1.1.1" +click-repl = ">=0.2.0" +kombu = ">=5.3.2,<6.0" +python-dateutil = ">=2.8.2" +tzdata = ">=2022.7" +vine = ">=5.0.0,<6.0" + +[package.extras] +arangodb = ["pyArango (>=2.0.2)"] +auth = ["cryptography (==41.0.3)"] +azureblockblob = ["azure-storage-blob (>=12.15.0)"] +brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] +cassandra = ["cassandra-driver (>=3.25.0,<4)"] +consul = ["python-consul2 (==0.1.5)"] +cosmosdbsql = ["pydocumentdb (==2.3.5)"] +couchbase = ["couchbase (>=3.0.0)"] +couchdb = ["pycouchdb (==1.14.2)"] +django = ["Django (>=2.2.28)"] +dynamodb = ["boto3 (>=1.26.143)"] +elasticsearch = ["elasticsearch (<8.0)"] +eventlet = ["eventlet (>=0.32.0)"] +gevent = ["gevent (>=1.5.0)"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +memcache = ["pylibmc (==1.6.3)"] +mongodb = ["pymongo[srv] (>=4.0.2)"] +msgpack = ["msgpack (==1.0.5)"] +pymemcache = ["python-memcached (==1.59)"] +pyro = ["pyro4 (==4.82)"] +pytest = ["pytest-celery (==0.0.0)"] +redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"] +s3 = ["boto3 (>=1.26.143)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +solar = ["ephem (==4.1.4)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=1.3.1)"] +zstd = ["zstandard (==0.21.0)"] + [[package]] name = "certifi" version = "2023.7.22" @@ -744,6 +824,55 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "click-didyoumean" +version = "0.3.0" +description = "Enables git-like *did-you-mean* feature in click" +optional = false +python-versions = ">=3.6.2,<4.0.0" +files = [ + {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"}, + {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"}, +] + +[package.dependencies] +click = ">=7" + +[[package]] +name = "click-plugins" +version = "1.1.1" +description = "An extension module for click to enable registering CLI commands via setuptools entry-points." +optional = false +python-versions = "*" +files = [ + {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"}, + {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"}, +] + +[package.dependencies] +click = ">=4.0" + +[package.extras] +dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"] + +[[package]] +name = "click-repl" +version = "0.3.0" +description = "REPL plugin for Click" +optional = false +python-versions = ">=3.6" +files = [ + {file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"}, + {file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"}, +] + +[package.dependencies] +click = ">=7.0" +prompt-toolkit = ">=3.0.36" + +[package.extras] +testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] + [[package]] name = "colorama" version = "0.4.6" @@ -981,6 +1110,24 @@ idna = ["idna (>=2.1,<4.0)"] trio = ["trio (>=0.14,<0.23)"] wmi = ["wmi (>=1.5.1,<2.0.0)"] +[[package]] +name = "ecdsa" +version = "0.18.0" +description = "ECDSA cryptographic signature library (pure python)" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"}, + {file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"}, +] + +[package.dependencies] +six = ">=1.9.0" + +[package.extras] +gmpy = ["gmpy"] +gmpy2 = ["gmpy2"] + [[package]] name = "fastapi" version = "0.100.1" @@ -1624,6 +1771,38 @@ files = [ cryptography = ">=3.4" deprecated = "*" +[[package]] +name = "kombu" +version = "5.3.2" +description = "Messaging library for Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"}, + {file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"}, +] + +[package.dependencies] +amqp = ">=5.1.1,<6.0.0" +vine = "*" + +[package.extras] +azureservicebus = ["azure-servicebus (>=7.10.0)"] +azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"] +confluentkafka = ["confluent-kafka (==2.1.1)"] +consul = ["python-consul2"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +mongodb = ["pymongo (>=4.1.1)"] +msgpack = ["msgpack"] +pyro = ["pyro4"] +qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] +redis = ["redis (>=4.5.2)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=2.8.0)"] + [[package]] name = "levenshtein" version = "0.21.1" @@ -2151,6 +2330,20 @@ files = [ fastapi = ">=0.38.1,<1.0.0" prometheus-client = ">=0.8.0,<1.0.0" +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "protobuf" version = "4.24.4" @@ -2173,6 +2366,17 @@ files = [ {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"}, ] +[[package]] +name = "pyasn1" +version = "0.5.0" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, + {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, +] + [[package]] name = "pycparser" version = "2.21" @@ -2501,6 +2705,20 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-celery" +version = "0.0.0" +description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest" +optional = false +python-versions = "*" +files = [ + {file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"}, + {file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"}, +] + +[package.dependencies] +celery = ">=4.4.0" + [[package]] name = "pytest-cov" version = "4.1.0" @@ -2565,6 +2783,28 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-jose" +version = "3.3.0" +description = "JOSE implementation in Python" +optional = false +python-versions = "*" +files = [ + {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"}, + {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"}, +] + +[package.dependencies] +cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""} +ecdsa = "!=0.15" +pyasn1 = "*" +rsa = "*" + +[package.extras] +cryptography = ["cryptography (>=3.4.0)"] +pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] +pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -2744,6 +2984,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2023.10.3" @@ -2862,6 +3120,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "s3transfer" version = "0.6.2" @@ -3438,6 +3710,17 @@ files = [ {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, ] +[[package]] +name = "tzdata" +version = "2023.3" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, + {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, +] + [[package]] name = "urllib3" version = "1.26.17" @@ -3523,6 +3806,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"] +[[package]] +name = "vine" +version = "5.0.0" +description = "Promises, promises, promises." +optional = false +python-versions = ">=3.6" +files = [ + {file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"}, + {file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"}, +] + [[package]] name = "watchfiles" version = "0.20.0" @@ -3557,6 +3851,17 @@ files = [ [package.dependencies] anyio = ">=3.0.0" +[[package]] +name = "wcwidth" +version = "0.2.8" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"}, + {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"}, +] + [[package]] name = "websockets" version = "11.0.3" @@ -3838,4 +4143,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67" +content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe" diff --git a/server/pyproject.toml b/server/pyproject.toml index e3b44774..7681af39 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -33,6 +33,9 @@ prometheus-fastapi-instrumentator = "^6.1.0" sentencepiece = "^0.1.99" protobuf = "^4.24.3" profanityfilter = "^2.0.6" +celery = "^5.3.4" +redis = "^5.0.1" +python-jose = {extras = ["cryptography"], version = "^3.3.0"} [tool.poetry.group.dev.dependencies] @@ -47,6 +50,7 @@ pytest-asyncio = "^0.21.1" pytest = "^7.4.0" httpx-ws = "^0.4.1" pytest-httpx = "^0.23.1" +pytest-celery = "^0.0.0" [tool.poetry.group.aws.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 758faf69..c2e3bf7e 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) +# prepare celery +from reflector.worker import app as celery_app # noqa + # simpler openapi id def use_route_names_as_operation_ids(app: FastAPI) -> None: diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index b68dfe20..9871c633 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,32 +1,13 @@ import databases import sqlalchemy - from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings database = databases.Database(settings.DATABASE_URL) metadata = sqlalchemy.MetaData() - -transcripts = sqlalchemy.Table( - "transcript", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String), - sqlalchemy.Column("status", sqlalchemy.String), - sqlalchemy.Column("locked", sqlalchemy.Boolean), - sqlalchemy.Column("duration", sqlalchemy.Integer), - sqlalchemy.Column("created_at", sqlalchemy.DateTime), - sqlalchemy.Column("title", sqlalchemy.String, nullable=True), - sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("topics", sqlalchemy.JSON), - sqlalchemy.Column("events", sqlalchemy.JSON), - sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), - sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), - # with user attached, optional - sqlalchemy.Column("user_id", sqlalchemy.String), -) +# import models +import reflector.db.transcripts # noqa engine = sqlalchemy.create_engine( settings.DATABASE_URL, connect_args={"check_same_thread": False} diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py new file mode 100644 index 00000000..6ac2e32a --- /dev/null +++ b/server/reflector/db/transcripts.py @@ -0,0 +1,296 @@ +import json +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import sqlalchemy +from pydantic import BaseModel, Field +from reflector.db import database, metadata +from reflector.processors.types import Word as ProcessorWord +from reflector.settings import settings +from reflector.utils.audio_waveform import get_audio_waveform + +transcripts = sqlalchemy.Table( + "transcript", + metadata, + sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("name", sqlalchemy.String), + sqlalchemy.Column("status", sqlalchemy.String), + sqlalchemy.Column("locked", sqlalchemy.Boolean), + sqlalchemy.Column("duration", sqlalchemy.Integer), + sqlalchemy.Column("created_at", sqlalchemy.DateTime), + sqlalchemy.Column("title", sqlalchemy.String, nullable=True), + sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("topics", sqlalchemy.JSON), + sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + # with user attached, optional + sqlalchemy.Column("user_id", sqlalchemy.String), +) + + +def generate_uuid4(): + return str(uuid4()) + + +def generate_transcript_name(): + now = datetime.utcnow() + return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" + + +class AudioWaveform(BaseModel): + data: list[float] + + +class TranscriptText(BaseModel): + text: str + translation: str | None + + +class TranscriptSegmentTopic(BaseModel): + speaker: int + text: str + timestamp: float + + +class TranscriptTopic(BaseModel): + id: str = Field(default_factory=generate_uuid4) + title: str + summary: str + timestamp: float + duration: float | None = 0 + transcript: str | None = None + words: list[ProcessorWord] = [] + + +class TranscriptFinalShortSummary(BaseModel): + short_summary: str + + +class TranscriptFinalLongSummary(BaseModel): + long_summary: str + + +class TranscriptFinalTitle(BaseModel): + title: str + + +class TranscriptEvent(BaseModel): + event: str + data: dict + + +class Transcript(BaseModel): + id: str = Field(default_factory=generate_uuid4) + user_id: str | None = None + name: str = Field(default_factory=generate_transcript_name) + status: str = "idle" + locked: bool = False + duration: float = 0 + created_at: datetime = Field(default_factory=datetime.utcnow) + title: str | None = None + short_summary: str | None = None + long_summary: str | None = None + topics: list[TranscriptTopic] = [] + events: list[TranscriptEvent] = [] + source_language: str = "en" + target_language: str = "en" + + def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: + ev = TranscriptEvent(event=event, data=data.model_dump()) + self.events.append(ev) + return ev + + def upsert_topic(self, topic: TranscriptTopic): + index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None) + if index is not None: + self.topics[index] = topic + else: + self.topics.append(topic) + + def events_dump(self, mode="json"): + return [event.model_dump(mode=mode) for event in self.events] + + def topics_dump(self, mode="json"): + return [topic.model_dump(mode=mode) for topic in self.topics] + + def convert_audio_to_waveform(self, segments_count=256): + fn = self.audio_waveform_filename + if fn.exists(): + return + waveform = get_audio_waveform( + path=self.audio_mp3_filename, segments_count=segments_count + ) + try: + with open(fn, "w") as fd: + json.dump(waveform, fd) + except Exception: + # remove file if anything happen during the write + fn.unlink(missing_ok=True) + raise + return waveform + + def unlink(self): + self.data_path.unlink(missing_ok=True) + + @property + def data_path(self): + return Path(settings.DATA_DIR) / self.id + + @property + def audio_mp3_filename(self): + return self.data_path / "audio.mp3" + + @property + def audio_waveform_filename(self): + return self.data_path / "audio.json" + + @property + def audio_waveform(self): + try: + with open(self.audio_waveform_filename) as fd: + data = json.load(fd) + except json.JSONDecodeError: + # unlink file if it's corrupted + self.audio_waveform_filename.unlink(missing_ok=True) + return None + + return AudioWaveform(data=data) + + +class TranscriptController: + async def get_all( + self, + user_id: str | None = None, + order_by: str | None = None, + filter_empty: bool | None = False, + filter_recording: bool | None = False, + ) -> list[Transcript]: + """ + Get all transcripts + + If `user_id` is specified, only return transcripts that belong to the user. + Otherwise, return all anonymous transcripts. + + Parameters: + - `order_by`: field to order by, e.g. "-created_at" + - `filter_empty`: filter out empty transcripts + - `filter_recording`: filter out transcripts that are currently recording + """ + query = transcripts.select().where(transcripts.c.user_id == user_id) + + if order_by is not None: + field = getattr(transcripts.c, order_by[1:]) + if order_by.startswith("-"): + field = field.desc() + query = query.order_by(field) + + if filter_empty: + query = query.filter(transcripts.c.status != "idle") + + if filter_recording: + query = query.filter(transcripts.c.status != "recording") + + results = await database.fetch_all(query) + return results + + async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: + """ + Get a transcript by id + """ + query = transcripts.select().where(transcripts.c.id == transcript_id) + if "user_id" in kwargs: + query = query.where(transcripts.c.user_id == kwargs["user_id"]) + result = await database.fetch_one(query) + if not result: + return None + return Transcript(**result) + + async def add( + self, + name: str, + source_language: str = "en", + target_language: str = "en", + user_id: str | None = None, + ): + """ + Add a new transcript + """ + transcript = Transcript( + name=name, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) + query = transcripts.insert().values(**transcript.model_dump()) + await database.execute(query) + return transcript + + async def update(self, transcript: Transcript, values: dict): + """ + Update a transcript fields with key/values in values + """ + query = ( + transcripts.update() + .where(transcripts.c.id == transcript.id) + .values(**values) + ) + await database.execute(query) + for key, value in values.items(): + setattr(transcript, key, value) + + async def remove_by_id( + self, + transcript_id: str, + user_id: str | None = None, + ) -> None: + """ + Remove a transcript by id + """ + transcript = await self.get_by_id(transcript_id, user_id=user_id) + if not transcript: + return + if user_id is not None and transcript.user_id != user_id: + return + transcript.unlink() + query = transcripts.delete().where(transcripts.c.id == transcript_id) + await database.execute(query) + + @asynccontextmanager + async def transaction(self): + """ + A context manager for database transaction + """ + async with database.transaction(isolation="serializable"): + yield + + async def append_event( + self, + transcript: Transcript, + event: str, + data: Any, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + resp = transcript.add_event(event=event, data=data) + await self.update(transcript, {"events": transcript.events_dump()}) + return resp + + async def upsert_topic( + self, + transcript: Transcript, + topic: TranscriptTopic, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + transcript.upsert_topic(topic) + await self.update(transcript, {"topics": transcript.topics_dump()}) + + +transcripts_controller = TranscriptController() diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 220730e5..4b81c5a0 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -47,6 +47,7 @@ class ModalLLM(LLM): json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, + follow_redirects=True, ) response.raise_for_status() text = response.json()["text"] diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py new file mode 100644 index 00000000..b2bc51ea --- /dev/null +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -0,0 +1,362 @@ +""" +Main reflector pipeline for live streaming +========================================== + +This is the default pipeline used in the API. + +It is decoupled to: +- PipelineMainLive: have limited processing during live +- PipelineMainPost: do heavy lifting after the live + +It is directly linked to our data model. +""" + +import asyncio +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path + +from celery import shared_task +from pydantic import BaseModel +from reflector.app import app +from reflector.db.transcripts import ( + Transcript, + TranscriptFinalLongSummary, + TranscriptFinalShortSummary, + TranscriptFinalTitle, + TranscriptText, + TranscriptTopic, + transcripts_controller, +) +from reflector.logger import logger +from reflector.pipelines.runner import PipelineRunner +from reflector.processors import ( + AudioChunkerProcessor, + AudioDiarizationAutoProcessor, + AudioFileWriterProcessor, + AudioMergeProcessor, + AudioTranscriptAutoProcessor, + BroadcastProcessor, + Pipeline, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, + TranscriptLinerProcessor, + TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, +) +from reflector.processors.types import AudioDiarizationInput +from reflector.processors.types import ( + TitleSummaryWithId as TitleSummaryWithIdProcessorType, +) +from reflector.processors.types import Transcript as TranscriptProcessorType +from reflector.settings import settings +from reflector.ws_manager import WebsocketManager, get_ws_manager + + +def broadcast_to_sockets(func): + """ + Decorator to broadcast transcript event to websockets + concerning this transcript + """ + + async def wrapper(self, *args, **kwargs): + resp = await func(self, *args, **kwargs) + if resp is None: + return + await self.ws_manager.send_json( + room_id=self.ws_room_id, + message=resp.model_dump(mode="json"), + ) + + return wrapper + + +class StrValue(BaseModel): + value: str + + +class PipelineMainBase(PipelineRunner): + transcript_id: str + ws_room_id: str | None = None + ws_manager: WebsocketManager | None = None + + def prepare(self): + # prepare websocket + self._lock = asyncio.Lock() + self.ws_room_id = f"ts:{self.transcript_id}" + self.ws_manager = get_ws_manager() + + async def get_transcript(self) -> Transcript: + # fetch the transcript + result = await transcripts_controller.get_by_id( + transcript_id=self.transcript_id + ) + if not result: + raise Exception("Transcript not found") + return result + + @asynccontextmanager + async def transaction(self): + async with self._lock: + async with transcripts_controller.transaction(): + yield + + @broadcast_to_sockets + async def on_status(self, status): + # if it's the first part, update the status of the transcript + # but do not set the ended status yet. + if isinstance(self, PipelineMainLive): + status_mapping = { + "started": "recording", + "push": "recording", + "flush": "processing", + "error": "error", + } + elif isinstance(self, PipelineMainDiarization): + status_mapping = { + "push": "processing", + "flush": "processing", + "error": "error", + "ended": "ended", + } + else: + raise Exception(f"Runner {self.__class__} is missing status mapping") + + # mutate to model status + status = status_mapping.get(status) + if not status: + return + + # when the status of the pipeline changes, update the transcript + async with self.transaction(): + transcript = await self.get_transcript() + if status == transcript.status: + return + resp = await transcripts_controller.append_event( + transcript=transcript, + event="STATUS", + data=StrValue(value=status), + ) + await transcripts_controller.update( + transcript, + { + "status": status, + }, + ) + return resp + + @broadcast_to_sockets + async def on_transcript(self, data): + async with self.transaction(): + transcript = await self.get_transcript() + return await transcripts_controller.append_event( + transcript=transcript, + event="TRANSCRIPT", + data=TranscriptText(text=data.text, translation=data.translation), + ) + + @broadcast_to_sockets + async def on_topic(self, data): + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + transcript=data.transcript.text, + words=data.transcript.words, + ) + if isinstance(data, TitleSummaryWithIdProcessorType): + topic.id = data.id + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.upsert_topic(transcript, topic) + return await transcripts_controller.append_event( + transcript=transcript, + event="TOPIC", + data=topic, + ) + + @broadcast_to_sockets + async def on_title(self, data): + final_title = TranscriptFinalTitle(title=data.title) + async with self.transaction(): + transcript = await self.get_transcript() + if not transcript.title: + await transcripts_controller.update( + transcript, + { + "title": final_title.title, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_TITLE", + data=final_title, + ) + + @broadcast_to_sockets + async def on_long_summary(self, data): + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "long_summary": final_long_summary.long_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + @broadcast_to_sockets + async def on_short_summary(self, data): + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "short_summary": final_short_summary.short_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + +class PipelineMainLive(PipelineMainBase): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + transcript = await self.get_transcript() + + processors = [ + AudioFileWriterProcessor(path=transcript.audio_mp3_filename), + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + pipeline.set_pref("audio:source_language", transcript.source_language) + pipeline.set_pref("audio:target_language", transcript.target_language) + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main live created", + transcript_id=self.transcript_id, + ) + + return pipeline + + async def on_ended(self): + # when the pipeline ends, connect to the post pipeline + logger.info("Pipeline main live ended", transcript_id=self.transcript_id) + logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) + task_pipeline_main_post.delay(transcript_id=self.transcript_id) + + +class PipelineMainDiarization(PipelineMainBase): + """ + Diarization is a long time process, so we do it in a separate pipeline + When done, adjust the short and final summary + """ + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + processors = [ + AudioDiarizationAutoProcessor(callback=self.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + + # now let's start the pipeline by pushing information to the + # first processor diarization processor + # XXX translation is lost when converting our data model to the processor model + transcript = await self.get_transcript() + topics = [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=transcript.id, + ) + url = f"{settings.BASE_URL}{path}" + if transcript.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": transcript.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + audio_diarization_input = AudioDiarizationInput( + audio_url=url, + topics=topics, + ) + + # as tempting to use pipeline.push, prefer to use the runner + # to let the start just do one job. + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main post created", transcript_id=self.transcript_id + ) + self.push(audio_diarization_input) + self.flush() + + return pipeline + + +@shared_task +def task_pipeline_main_post(transcript_id: str): + logger.info( + "Starting main post pipeline", + transcript_id=transcript_id, + ) + runner = PipelineMainDiarization(transcript_id=transcript_id) + runner.start_sync() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py new file mode 100644 index 00000000..a1e137a7 --- /dev/null +++ b/server/reflector/pipelines/runner.py @@ -0,0 +1,137 @@ +""" +Pipeline Runner +=============== + +Pipeline runner designed to be executed in a asyncio task. + +It is meant to be subclassed, and implement a create() method +that expose/return a Pipeline instance. + +During its lifecycle, it will emit the following status: +- started: the pipeline has been started +- push: the pipeline received at least one data +- flush: the pipeline is flushing +- ended: the pipeline has ended +- error: the pipeline has ended with an error +""" + +import asyncio + +from pydantic import BaseModel, ConfigDict +from reflector.logger import logger +from reflector.processors import Pipeline + + +class PipelineRunner(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + status: str = "idle" + pipeline: Pipeline | None = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._q_cmd = asyncio.Queue() + self._ev_done = asyncio.Event() + self._is_first_push = True + self._logger = logger.bind( + runner=id(self), + runner_cls=self.__class__.__name__, + ) + + def create(self) -> Pipeline: + """ + Create the pipeline if not specified earlier. + Should be implemented in a subclass + """ + raise NotImplementedError() + + def start(self): + """ + Start the pipeline as a coroutine task + """ + asyncio.get_event_loop().create_task(self.run()) + + def start_sync(self): + """ + Start the pipeline synchronously (for non-asyncio apps) + """ + coro = self.run() + asyncio.run(coro) + + def push(self, data): + """ + Push data to the pipeline + """ + self._add_cmd("PUSH", data) + + def flush(self): + """ + Flush the pipeline + """ + self._add_cmd("FLUSH", None) + + async def on_status(self, status): + """ + Called when the status of the pipeline changes + """ + pass + + async def on_ended(self): + """ + Called when the pipeline ends + """ + pass + + def _add_cmd(self, cmd: str, data): + """ + Enqueue a command to be executed in the runner. + Currently supported commands: PUSH, FLUSH + """ + self._q_cmd.put_nowait([cmd, data]) + + async def _set_status(self, status): + self._logger.debug("Runner status updated", status=status) + self.status = status + if self.on_status: + try: + await self.on_status(status) + except Exception: + self._logger.exception("Runer error while setting status") + + async def run(self): + try: + # create the pipeline if not yet done + await self._set_status("init") + self._is_first_push = True + if not self.pipeline: + self.pipeline = await self.create() + + # start the loop + await self._set_status("started") + while not self._ev_done.is_set(): + cmd, data = await self._q_cmd.get() + func = getattr(self, f"cmd_{cmd.lower()}") + if func: + await func(data) + else: + raise Exception(f"Unknown command {cmd}") + except Exception: + self._logger.exception("Runner error") + await self._set_status("error") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + + async def cmd_push(self, data): + if self._is_first_push: + await self._set_status("push") + self._is_first_push = False + await self.pipeline.push(data) + + async def cmd_flush(self, data): + await self._set_status("flush") + await self.pipeline.flush() + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 96a3941d..1c88d6c5 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,9 +1,16 @@ from .audio_chunker import AudioChunkerProcessor # noqa: F401 +from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 -from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401 +from .base import ( # noqa: F401 + BroadcastProcessor, + Pipeline, + PipelineEvent, + Processor, + ThreadedProcessor, +) from .transcript_final_long_summary import ( # noqa: F401 TranscriptFinalLongSummaryProcessor, ) diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py new file mode 100644 index 00000000..82c6a553 --- /dev/null +++ b/server/reflector/processors/audio_diarization.py @@ -0,0 +1,34 @@ +from reflector.processors.base import Processor +from reflector.processors.types import AudioDiarizationInput, TitleSummary + + +class AudioDiarizationProcessor(Processor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: AudioDiarizationInput): + try: + self.logger.info("Diarization started", audio_file_url=data.audio_url) + diarization = await self._diarize(data) + self.logger.info("Diarization finished") + except Exception: + self.logger.exception("Diarization failed after retrying") + raise + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + # mutate in place + for topic in data.topics: + for word in topic.transcript.words: + for d in diarization: + if d["start"] <= word.start <= d["end"]: + word.speaker = d["speaker"] + + # emit them + for topic in data.topics: + await self.emit(topic) + + async def _diarize(self, data: AudioDiarizationInput): + raise NotImplementedError diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py new file mode 100644 index 00000000..0e7bfc5c --- /dev/null +++ b/server/reflector/processors/audio_diarization_auto.py @@ -0,0 +1,33 @@ +import importlib + +from reflector.processors.audio_diarization import AudioDiarizationProcessor +from reflector.settings import settings + + +class AudioDiarizationAutoProcessor(AudioDiarizationProcessor): + _registry = {} + + @classmethod + def register(cls, name, kclass): + cls._registry[name] = kclass + + def __new__(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.DIARIZATION_BACKEND + + if name not in cls._registry: + module_name = f"reflector.processors.audio_diarization_{name}" + importlib.import_module(module_name) + + # gather specific configuration for the processor + # search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` + config = {} + name_upper = name.upper() + settings_prefix = "DIARIZATION_" + config_prefix = f"{settings_prefix}{name_upper}_" + for key, value in settings: + if key.startswith(config_prefix): + config_name = key[len(settings_prefix) :].lower() + config[config_name] = value + + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py new file mode 100644 index 00000000..53de2501 --- /dev/null +++ b/server/reflector/processors/audio_diarization_modal.py @@ -0,0 +1,37 @@ +import httpx +from reflector.processors.audio_diarization import AudioDiarizationProcessor +from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor +from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.settings import settings + + +class AudioDiarizationModalProcessor(AudioDiarizationProcessor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.diarization_url = settings.DIARIZATION_URL + "/diarize" + self.headers = { + "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", + } + + async def _diarize(self, data: AudioDiarizationInput): + # Gather diarization data + params = { + "audio_file_url": data.audio_url, + "timestamp": 0, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.diarization_url, + headers=self.headers, + params=params, + timeout=None, + follow_redirects=True, + ) + response.raise_for_status() + return response.json()["text"] + + +AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) diff --git a/server/reflector/processors/audio_transcript.py b/server/reflector/processors/audio_transcript.py index f029b587..3f9dc85b 100644 --- a/server/reflector/processors/audio_transcript.py +++ b/server/reflector/processors/audio_transcript.py @@ -1,6 +1,4 @@ -from profanityfilter import ProfanityFilter from prometheus_client import Counter, Histogram - from reflector.processors.base import Processor from reflector.processors.types import AudioFile, Transcript @@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor): self.m_transcript_call = self.m_transcript_call.labels(name) self.m_transcript_success = self.m_transcript_success.labels(name) self.m_transcript_failure = self.m_transcript_failure.labels(name) - self.profanity_filter = ProfanityFilter() - self.profanity_filter.set_censor("*") super().__init__(*args, **kwargs) async def _push(self, data: AudioFile): @@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor): async def _transcript(self, data: AudioFile): raise NotImplementedError - - def filter_profanity(self, text: str) -> str: - """ - Remove censored words from the transcript - """ - return self.profanity_filter.censor(text) diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index f223a52d..ac79ced0 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -1,8 +1,6 @@ import importlib from reflector.processors.audio_transcript import AudioTranscriptProcessor -from reflector.processors.base import Pipeline, Processor -from reflector.processors.types import AudioFile from reflector.settings import settings @@ -13,8 +11,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): def register(cls, name, kclass): cls._registry[name] = kclass - @classmethod - def get_instance(cls, name): + def __new__(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.TRANSCRIPT_BACKEND if name not in cls._registry: module_name = f"reflector.processors.audio_transcript_{name}" importlib.import_module(module_name) @@ -30,30 +29,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): config_name = key[len(settings_prefix) :].lower() config[config_name] = value - return cls._registry[name](**config) - - def __init__(self, **kwargs): - self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) - super().__init__(**kwargs) - - def set_pipeline(self, pipeline: Pipeline): - super().set_pipeline(pipeline) - self.processor.set_pipeline(pipeline) - - def connect(self, processor: Processor): - self.processor.connect(processor) - - def disconnect(self, processor: Processor): - self.processor.disconnect(processor) - - def on(self, callback): - self.processor.on(callback) - - def off(self, callback): - self.processor.off(callback) - - async def _push(self, data: AudioFile): - return await self.processor._push(data) - - async def _flush(self): - return await self.processor._flush() + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 201ed9d4..0ca4710f 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): timeout=self.timeout, headers=self.headers, params=json_payload, + follow_redirects=True, ) self.logger.debug( @@ -48,10 +49,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): ) response.raise_for_status() result = response.json() - text = result["text"][source_language] - text = self.filter_profanity(text) transcript = Transcript( - text=text, words=[ Word( text=word["text"], diff --git a/server/reflector/processors/audio_transcript_whisper.py b/server/reflector/processors/audio_transcript_whisper.py index e3bd595b..cd96e01a 100644 --- a/server/reflector/processors/audio_transcript_whisper.py +++ b/server/reflector/processors/audio_transcript_whisper.py @@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor): ts = data.timestamp for segment in segments: - transcript.text += segment.text for word in segment.words: transcript.words.append( Word( diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 6771e11e..46bfb4a5 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -290,12 +290,12 @@ class BroadcastProcessor(Processor): processor.set_pipeline(pipeline) async def _push(self, data): - for processor in self.processors: - await processor.push(data) + coros = [processor.push(data) for processor in self.processors] + await asyncio.gather(*coros) async def _flush(self): - for processor in self.processors: - await processor.flush() + coros = [processor.flush() for processor in self.processors] + await asyncio.gather(*coros) def connect(self, processor: Processor): for processor in self.processors: @@ -333,6 +333,7 @@ class Pipeline(Processor): self.logger.info("Pipeline created") self.processors = processors + self.options = None self.prefs = {} for processor in processors: diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py index c1aa14a0..b4e7b5e3 100644 --- a/server/reflector/processors/transcript_liner.py +++ b/server/reflector/processors/transcript_liner.py @@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor): # cut to the next . partial = Transcript(words=[]) for word in self.transcript.words[:]: - partial.text += word.text partial.words.append(word) if not self.is_sentence_terminated(word.text): continue diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py index 77b8f5be..905ea423 100644 --- a/server/reflector/processors/transcript_translator.py +++ b/server/reflector/processors/transcript_translator.py @@ -50,6 +50,7 @@ class TranscriptTranslatorProcessor(Processor): headers=self.headers, params=json_payload, timeout=self.timeout, + follow_redirects=True, ) response.raise_for_status() result = response.json()["text"] diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index e867becf..312f5433 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -1,9 +1,16 @@ import io +import re import tempfile from pathlib import Path +from profanityfilter import ProfanityFilter from pydantic import BaseModel, PrivateAttr +PUNC_RE = re.compile(r"[.;:?!…]") + +profanity_filter = ProfanityFilter() +profanity_filter.set_censor("*") + class AudioFile(BaseModel): name: str @@ -43,13 +50,29 @@ class Word(BaseModel): text: str start: float end: float + speaker: int = 0 + + +class TranscriptSegment(BaseModel): + text: str + start: float + speaker: int = 0 class Transcript(BaseModel): - text: str = "" translation: str | None = None words: list[Word] = None + @property + def raw_text(self): + # Uncensored text + return "".join([word.text for word in self.words]) + + @property + def text(self): + # Censored text + return profanity_filter.censor(self.raw_text).strip() + @property def human_timestamp(self): minutes = int(self.timestamp / 60) @@ -74,7 +97,6 @@ class Transcript(BaseModel): self.words = other.words else: self.words.extend(other.words) - self.text += other.text def add_offset(self, offset: float): for word in self.words: @@ -87,6 +109,48 @@ class Transcript(BaseModel): ] return Transcript(text=self.text, translation=self.translation, words=words) + def as_segments(self) -> list[TranscriptSegment]: + # from a list of word, create a list of segments + # join the word that are less than 2 seconds apart + # but separate if the speaker changes, or if the punctuation is a . , ; : ? ! + segments = [] + current_segment = None + MAX_SEGMENT_LENGTH = 120 + + for word in self.words: + if current_segment is None: + current_segment = TranscriptSegment( + text=word.text, + start=word.start, + speaker=word.speaker, + ) + continue + + # If the word is attach to another speaker, push the current segment + # and start a new one + if word.speaker != current_segment.speaker: + segments.append(current_segment) + current_segment = TranscriptSegment( + text=word.text, + start=word.start, + speaker=word.speaker, + ) + continue + + # if the word is the end of a sentence, and we have enough content, + # add the word to the current segment and push it + current_segment.text += word.text + + have_punc = PUNC_RE.search(word.text) + if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH): + segments.append(current_segment) + current_segment = None + + if current_segment: + segments.append(current_segment) + + return segments + class TitleSummary(BaseModel): title: str @@ -103,6 +167,10 @@ class TitleSummary(BaseModel): return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" +class TitleSummaryWithId(TitleSummary): + id: str + + class FinalLongSummary(BaseModel): long_summary: str duration: float @@ -318,3 +386,8 @@ class TranslationLanguages(BaseModel): def is_supported(self, lang_id: str) -> bool: return lang_id in self.supported_languages + + +class AudioDiarizationInput(BaseModel): + audio_url: str + topics: list[TitleSummaryWithId] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index e0ffd826..021d509f 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -89,6 +89,10 @@ class Settings(BaseSettings): # LLM Modal configuration LLM_MODAL_API_KEY: str | None = None + # Diarization + DIARIZATION_BACKEND: str = "modal" + DIARIZATION_URL: str | None = None + # Sentry SENTRY_DSN: str | None = None @@ -113,5 +117,19 @@ class Settings(BaseSettings): # Min transcript length to generate topic + summary MIN_TRANSCRIPT_LENGTH: int = 750 + # Celery + CELERY_BROKER_URL: str = "redis://localhost:6379/1" + CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" + + # Redis + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + + # Secret key + SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21" + + # Current hosting/domain + BASE_URL: str = "http://localhost:1250" + settings = Settings() diff --git a/server/reflector/tools/start_post_main_live_pipeline.py b/server/reflector/tools/start_post_main_live_pipeline.py new file mode 100644 index 00000000..859f03a4 --- /dev/null +++ b/server/reflector/tools/start_post_main_live_pipeline.py @@ -0,0 +1,14 @@ +import argparse + +from reflector.app import celery_app # noqa +from reflector.pipelines.main_live_pipeline import task_pipeline_main_post + +parser = argparse.ArgumentParser() +parser.add_argument("transcript_id", type=str) +parser.add_argument("--delay", action="store_true") +args = parser.parse_args() + +if args.delay: + task_pipeline_main_post.delay(args.transcript_id) +else: + task_pipeline_main_post(args.transcript_id) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 5662d989..386ada9c 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,7 +1,5 @@ import asyncio -from enum import StrEnum -from json import dumps, loads -from pathlib import Path +from json import loads import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription @@ -10,25 +8,7 @@ from prometheus_client import Gauge from pydantic import BaseModel from reflector.events import subscribers_shutdown from reflector.logger import logger -from reflector.processors import ( - AudioChunkerProcessor, - AudioFileWriterProcessor, - AudioMergeProcessor, - AudioTranscriptAutoProcessor, - FinalLongSummary, - FinalShortSummary, - Pipeline, - TitleSummary, - Transcript, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, - TranscriptFinalTitleProcessor, - TranscriptLinerProcessor, - TranscriptTopicDetectorProcessor, - TranscriptTranslatorProcessor, -) -from reflector.processors.base import BroadcastProcessor -from reflector.processors.types import FinalTitle +from reflector.pipelines.runner import PipelineRunner sessions = [] router = APIRouter() @@ -38,7 +18,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions") class TranscriptionContext(object): def __init__(self, logger): self.logger = logger - self.pipeline = None + self.pipeline_runner = None self.data_channel = None self.status = "idle" self.topics = [] @@ -60,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack): ctx = self.ctx frame = await self.track.recv() try: - await ctx.pipeline.push(frame) + ctx.pipeline_runner.push(frame) except Exception as e: ctx.logger.error("Pipeline error", error=e) return frame @@ -71,27 +51,10 @@ class RtcOffer(BaseModel): type: str -class StrValue(BaseModel): - value: str - - -class PipelineEvent(StrEnum): - TRANSCRIPT = "TRANSCRIPT" - TOPIC = "TOPIC" - FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY" - STATUS = "STATUS" - FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY" - FINAL_TITLE = "FINAL_TITLE" - - async def rtc_offer_base( params: RtcOffer, request: Request, - event_callback=None, - event_callback_args=None, - audio_filename: Path | None = None, - source_language: str = "en", - target_language: str = "en", + pipeline_runner: PipelineRunner, ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -101,146 +64,10 @@ async def rtc_offer_base( clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - async def update_status(status: str): - changed = ctx.status != status - if changed: - ctx.status = status - if event_callback: - await event_callback( - event=PipelineEvent.STATUS, - args=event_callback_args, - data=StrValue(value=status), - ) - - # build pipeline callback - async def on_transcript(transcript: Transcript): - ctx.logger.info("Transcript", transcript=transcript) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": transcript.text, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TRANSCRIPT, - args=event_callback_args, - data=transcript, - ) - - async def on_topic(topic: TitleSummary): - # FIXME: make it incremental with the frontend, not send everything - ctx.logger.info("Topic", topic=topic) - ctx.topics.append( - { - "title": topic.title, - "timestamp": topic.timestamp, - "transcript": topic.transcript.text, - "desc": topic.summary, - } - ) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=topic - ) - - async def on_final_short_summary(summary: FinalShortSummary): - ctx.logger.info("FinalShortSummary", final_short_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_SHORT_SUMMARY", - "summary": summary.short_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_SHORT_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_long_summary(summary: FinalLongSummary): - ctx.logger.info("FinalLongSummary", final_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_LONG_SUMMARY", - "summary": summary.long_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_LONG_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_title(title: FinalTitle): - ctx.logger.info("FinalTitle", final_title=title) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_TITLE, - args=event_callback_args, - data=title, - ) - - # create a context for the whole rtc transaction - # add a customised logger to the context - processors = [] - if audio_filename is not None: - processors += [AudioFileWriterProcessor(path=audio_filename)] - processors += [ - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(), - TranscriptTranslatorProcessor.as_threaded(callback=on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=on_final_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=on_final_short_summary - ), - ] - ), - ] - ctx.pipeline = Pipeline(*processors) - ctx.pipeline.set_pref("audio:source_language", source_language) - ctx.pipeline.set_pref("audio:target_language", target_language) - # handle RTC peer connection pc = RTCPeerConnection() + ctx.pipeline_runner = pipeline_runner + ctx.pipeline_runner.start() async def flush_pipeline_and_quit(close=True): # may be called twice @@ -249,12 +76,10 @@ async def rtc_offer_base( # - when we receive the close event, we do nothing. # 2. or the client close the connection # and there is nothing to do because it is already closed - await update_status("processing") - await ctx.pipeline.flush() + ctx.pipeline_runner.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() - await update_status("ended") if pc in sessions: sessions.remove(pc) m_rtc_sessions.dec() @@ -287,7 +112,6 @@ async def rtc_offer_base( def on_track(track): ctx.logger.info(f"Track {track.kind} received") pc.addTrack(AudioStreamTrack(ctx, track)) - asyncio.get_event_loop().create_task(update_status("recording")) await pc.setRemoteDescription(offer) @@ -308,8 +132,3 @@ async def rtc_clean_sessions(_): logger.debug(f"Closing session {pc}") await pc.close() sessions.clear() - - -@router.post("/offer") -async def rtc_offer(params: RtcOffer, request: Request): - return await rtc_offer_base(params, request) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index a7e01b8c..e3668ecb 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,8 +1,5 @@ -import json -from datetime import datetime -from pathlib import Path +from datetime import datetime, timedelta from typing import Annotated, Optional -from uuid import uuid4 import reflector.auth as auth from fastapi import ( @@ -12,221 +9,36 @@ from fastapi import ( Request, WebSocket, WebSocketDisconnect, + status, ) from fastapi_pagination import Page, paginate +from jose import jwt from pydantic import BaseModel, Field -from reflector.db import database, transcripts -from reflector.logger import logger +from reflector.db.transcripts import ( + AudioWaveform, + TranscriptTopic, + transcripts_controller, +) +from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform +from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response -from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base +from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() -# ============================================================== -# Models to move to a database, but required for the API to work -# ============================================================== +ALGORITHM = "HS256" +DOWNLOAD_EXPIRE_MINUTES = 60 -def generate_uuid4(): - return str(uuid4()) - - -def generate_transcript_name(): - now = datetime.utcnow() - return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" - - -class AudioWaveform(BaseModel): - data: list[float] - - -class TranscriptText(BaseModel): - text: str - translation: str | None - - -class TranscriptTopic(BaseModel): - id: str = Field(default_factory=generate_uuid4) - title: str - summary: str - transcript: str | None = None - timestamp: float - - -class TranscriptFinalShortSummary(BaseModel): - short_summary: str - - -class TranscriptFinalLongSummary(BaseModel): - long_summary: str - - -class TranscriptFinalTitle(BaseModel): - title: str - - -class TranscriptEvent(BaseModel): - event: str - data: dict - - -class Transcript(BaseModel): - id: str = Field(default_factory=generate_uuid4) - user_id: str | None = None - name: str = Field(default_factory=generate_transcript_name) - status: str = "idle" - locked: bool = False - duration: float = 0 - created_at: datetime = Field(default_factory=datetime.utcnow) - title: str | None = None - short_summary: str | None = None - long_summary: str | None = None - topics: list[TranscriptTopic] = [] - events: list[TranscriptEvent] = [] - source_language: str = "en" - target_language: str = "en" - - def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: - ev = TranscriptEvent(event=event, data=data.model_dump()) - self.events.append(ev) - return ev - - def upsert_topic(self, topic: TranscriptTopic): - existing_topic = next((t for t in self.topics if t.id == topic.id), None) - if existing_topic: - existing_topic.update_from(topic) - else: - self.topics.append(topic) - - def events_dump(self, mode="json"): - return [event.model_dump(mode=mode) for event in self.events] - - def topics_dump(self, mode="json"): - return [topic.model_dump(mode=mode) for topic in self.topics] - - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - - def unlink(self): - self.data_path.unlink(missing_ok=True) - - @property - def data_path(self): - return Path(settings.DATA_DIR) / self.id - - @property - def audio_mp3_filename(self): - return self.data_path / "audio.mp3" - - @property - def audio_waveform_filename(self): - return self.data_path / "audio.json" - - @property - def audio_waveform(self): - try: - with open(self.audio_waveform_filename) as fd: - data = json.load(fd) - except json.JSONDecodeError: - # unlink file if it's corrupted - self.audio_waveform_filename.unlink(missing_ok=True) - return None - - return AudioWaveform(data=data) - - -class TranscriptController: - async def get_all( - self, - user_id: str | None = None, - order_by: str | None = None, - filter_empty: bool | None = False, - filter_recording: bool | None = False, - ) -> list[Transcript]: - query = transcripts.select().where(transcripts.c.user_id == user_id) - - if order_by is not None: - field = getattr(transcripts.c, order_by[1:]) - if order_by.startswith("-"): - field = field.desc() - query = query.order_by(field) - - if filter_empty: - query = query.filter(transcripts.c.status != "idle") - - if filter_recording: - query = query.filter(transcripts.c.status != "recording") - - results = await database.fetch_all(query) - return results - - async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: - query = transcripts.select().where(transcripts.c.id == transcript_id) - if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) - result = await database.fetch_one(query) - if not result: - return None - return Transcript(**result) - - async def add( - self, - name: str, - source_language: str = "en", - target_language: str = "en", - user_id: str | None = None, - ): - transcript = Transcript( - name=name, - source_language=source_language, - target_language=target_language, - user_id=user_id, - ) - query = transcripts.insert().values(**transcript.model_dump()) - await database.execute(query) - return transcript - - async def update(self, transcript: Transcript, values: dict): - query = ( - transcripts.update() - .where(transcripts.c.id == transcript.id) - .values(**values) - ) - await database.execute(query) - for key, value in values.items(): - setattr(transcript, key, value) - - async def remove_by_id( - self, transcript_id: str, user_id: str | None = None - ) -> None: - transcript = await self.get_by_id(transcript_id, user_id=user_id) - if not transcript: - return - if user_id is not None and transcript.user_id != user_id: - return - transcript.unlink() - query = transcripts.delete().where(transcripts.c.id == transcript_id) - await database.execute(query) - - -transcripts_controller = TranscriptController() +def create_access_token(data: dict, expires_delta: timedelta): + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt # ============================================================== @@ -298,6 +110,55 @@ async def transcripts_create( # ============================================================== +class GetTranscriptSegmentTopic(BaseModel): + text: str + start: float + speaker: int + + +class GetTranscriptTopic(BaseModel): + id: str + title: str + summary: str + timestamp: float + transcript: str + segments: list[GetTranscriptSegmentTopic] = [] + + @classmethod + def from_transcript_topic(cls, topic: TranscriptTopic): + if not topic.words: + # In previous version, words were missing + # Just output a segment with speaker 0 + text = topic.transcript + segments = [ + GetTranscriptSegmentTopic( + text=topic.transcript, + start=topic.timestamp, + speaker=0, + ) + ] + else: + # New versions include words + transcript = ProcessorTranscript(words=topic.words) + text = transcript.text + segments = [ + GetTranscriptSegmentTopic( + text=segment.text, + start=segment.start, + speaker=segment.speaker, + ) + for segment in transcript.as_segments() + ] + return cls( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + transcript=text, + segments=segments, + ) + + @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) async def transcript_get( transcript_id: str, @@ -320,32 +181,17 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {"events": []} + values = {} if info.name is not None: values["name"] = info.name if info.locked is not None: values["locked"] = info.locked if info.long_summary is not None: values["long_summary"] = info.long_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY: - transcript_event["long_summary"] = info.long_summary - break - values["events"].extend(transcript.events) if info.short_summary is not None: values["short_summary"] = info.short_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY: - transcript_event["short_summary"] = info.short_summary - break - values["events"].extend(transcript.events) if info.title is not None: values["title"] = info.title - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_TITLE: - transcript_event["title"] = info.title - break - values["events"].extend(transcript.events) await transcripts_controller.update(transcript, values) return transcript @@ -368,8 +214,21 @@ async def transcript_get_audio_mp3( request: Request, transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, ): user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -406,7 +265,10 @@ async def transcript_get_audio_waveform( return transcript.audio_waveform -@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) +@router.get( + "/transcripts/{transcript_id}/topics", + response_model=list[GetTranscriptTopic], +) async def transcript_get_topics( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], @@ -415,7 +277,16 @@ async def transcript_get_topics( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - return transcript.topics + + # convert to GetTranscriptTopic + return [ + GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics + ] + + +# ============================================================== +# Websocket +# ============================================================== @router.get("/transcripts/{transcript_id}/events") @@ -423,41 +294,6 @@ async def transcript_get_websocket_events(transcript_id: str): pass -# ============================================================== -# Websocket Manager -# ============================================================== - - -class WebsocketManager: - def __init__(self): - self.active_connections = {} - - async def connect(self, transcript_id: str, websocket: WebSocket): - await websocket.accept() - if transcript_id not in self.active_connections: - self.active_connections[transcript_id] = [] - self.active_connections[transcript_id].append(websocket) - - def disconnect(self, transcript_id: str, websocket: WebSocket): - if transcript_id not in self.active_connections: - return - self.active_connections[transcript_id].remove(websocket) - if not self.active_connections[transcript_id]: - del self.active_connections[transcript_id] - - async def send_json(self, transcript_id: str, message): - if transcript_id not in self.active_connections: - return - for connection in self.active_connections[transcript_id][:]: - try: - await connection.send_json(message) - except Exception: - self.active_connections[transcript_id].remove(connection) - - -ws_manager = WebsocketManager() - - @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket( transcript_id: str, @@ -469,21 +305,31 @@ async def transcript_events_websocket( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - await ws_manager.connect(transcript_id, websocket) + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() + await ws_manager.add_user_to_room(room_id, websocket) - # on first connection, send all events - for event in transcript.events: - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events try: + # on first connection, send all events only to the current user + for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name in ("TRANSCRIPT", "STATUS"): + continue + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, while True: await websocket.receive() except (RuntimeError, WebSocketDisconnect): - ws_manager.disconnect(transcript_id, websocket) + await ws_manager.remove_user_from_room(room_id, websocket) # ============================================================== @@ -491,105 +337,6 @@ async def transcript_events_websocket( # ============================================================== -async def handle_rtc_event(event: PipelineEvent, args, data): - # OFC the current implementation is not good, - # but it's just a POC before persistence. It won't query the - # transcript from the database for each event. - # print(f"Event: {event}", args, data) - transcript_id = args - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - return - - # event send to websocket clients may not be the same as the event - # received from the pipeline. For example, the pipeline will send - # a TRANSCRIPT event with all words, but this is not what we want - # to send to the websocket client. - - # FIXME don't do copy - if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event( - event=event, - data=TranscriptText(text=data.text, translation=data.translation), - ) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - }, - ) - - elif event == PipelineEvent.TOPIC: - topic = TranscriptTopic( - title=data.title, - summary=data.summary, - transcript=data.transcript.text, - timestamp=data.timestamp, - ) - resp = transcript.add_event(event=event, data=topic) - transcript.upsert_topic(topic) - - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "topics": transcript.topics_dump(), - }, - ) - - elif event == PipelineEvent.FINAL_TITLE: - final_title = TranscriptFinalTitle(title=data.title) - resp = transcript.add_event(event=event, data=final_title) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "title": final_title.title, - }, - ) - - elif event == PipelineEvent.FINAL_LONG_SUMMARY: - final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - resp = transcript.add_event(event=event, data=final_long_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "long_summary": final_long_summary.long_summary, - }, - ) - - elif event == PipelineEvent.FINAL_SHORT_SUMMARY: - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - resp = transcript.add_event(event=event, data=final_short_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "short_summary": final_short_summary.short_summary, - }, - ) - - elif event == PipelineEvent.STATUS: - resp = transcript.add_event(event=event, data=data) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "status": data.value, - }, - ) - - else: - logger.warning(f"Unknown event: {event}") - return - - # transmit to websocket clients - await ws_manager.send_json(transcript_id, resp.model_dump(mode="json")) - - @router.post("/transcripts/{transcript_id}/record/webrtc") async def transcript_record_webrtc( transcript_id: str, @@ -605,13 +352,14 @@ async def transcript_record_webrtc( if transcript.locked: raise HTTPException(status_code=400, detail="Transcript is locked") + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + # FIXME do not allow multiple recording at the same time return await rtc_offer_base( params, request, - event_callback=handle_rtc_event, - event_callback_args=transcript_id, - audio_filename=transcript.audio_mp3_filename, - source_language=transcript.source_language, - target_language=transcript.target_language, + pipeline_runner=pipeline_runner, ) diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py new file mode 100644 index 00000000..e1000364 --- /dev/null +++ b/server/reflector/worker/app.py @@ -0,0 +1,12 @@ +from celery import Celery +from reflector.settings import settings + +app = Celery(__name__) +app.conf.broker_url = settings.CELERY_BROKER_URL +app.conf.result_backend = settings.CELERY_RESULT_BACKEND +app.conf.broker_connection_retry_on_startup = True +app.autodiscover_tasks( + [ + "reflector.pipelines.main_live_pipeline", + ] +) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py new file mode 100644 index 00000000..a84e3361 --- /dev/null +++ b/server/reflector/ws_manager.py @@ -0,0 +1,126 @@ +""" +Websocket manager +================= + +This module contains the WebsocketManager class, which is responsible for +managing websockets and handling websocket connections. + +It uses the RedisPubSubManager class to subscribe to Redis channels and +broadcast messages to all connected websockets. +""" + +import asyncio +import json +import threading + +import redis.asyncio as redis +from fastapi import WebSocket +from reflector.settings import settings + + +class RedisPubSubManager: + def __init__(self, host="localhost", port=6379): + self.redis_host = host + self.redis_port = port + self.redis_connection = None + self.pubsub = None + + async def get_redis_connection(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + auto_close_connection_pool=False, + ) + + async def connect(self) -> None: + if self.redis_connection is not None: + return + self.redis_connection = await self.get_redis_connection() + self.pubsub = self.redis_connection.pubsub() + + async def disconnect(self) -> None: + if self.redis_connection is None: + return + await self.redis_connection.close() + self.redis_connection = None + + async def send_json(self, room_id: str, message: str) -> None: + if not self.redis_connection: + await self.connect() + message = json.dumps(message) + await self.redis_connection.publish(room_id, message) + + async def subscribe(self, room_id: str) -> redis.Redis: + await self.pubsub.subscribe(room_id) + return self.pubsub + + async def unsubscribe(self, room_id: str) -> None: + await self.pubsub.unsubscribe(room_id) + + +class WebsocketManager: + def __init__(self, pubsub_client: RedisPubSubManager = None): + self.rooms: dict = {} + self.pubsub_client = pubsub_client + + async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: + await websocket.accept() + + if room_id in self.rooms: + self.rooms[room_id].append(websocket) + else: + self.rooms[room_id] = [websocket] + + await self.pubsub_client.connect() + pubsub_subscriber = await self.pubsub_client.subscribe(room_id) + asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) + + async def send_json(self, room_id: str, message: dict) -> None: + await self.pubsub_client.send_json(room_id, message) + + async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: + self.rooms[room_id].remove(websocket) + + if len(self.rooms[room_id]) == 0: + del self.rooms[room_id] + await self.pubsub_client.unsubscribe(room_id) + + async def _pubsub_data_reader(self, pubsub_subscriber): + while True: + message = await pubsub_subscriber.get_message( + ignore_subscribe_messages=True + ) + if message is not None: + room_id = message["channel"].decode("utf-8") + all_sockets = self.rooms[room_id] + for socket in all_sockets: + data = json.loads(message["data"].decode("utf-8")) + await socket.send_json(data) + + +def get_ws_manager() -> WebsocketManager: + """ + Returns the WebsocketManager instance for managing websockets. + + This function initializes and returns the WebsocketManager instance, + which is responsible for managing websockets and handling websocket + connections. + + Returns: + WebsocketManager: The initialized WebsocketManager instance. + + Raises: + ImportError: If the 'reflector.settings' module cannot be imported. + RedisConnectionError: If there is an error connecting to the Redis server. + """ + local = threading.local() + if hasattr(local, "ws_manager"): + return local.ws_manager + + pubsub_client = RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) + ws_manager = WebsocketManager(pubsub_client=pubsub_client) + local.ws_manager = ws_manager + return ws_manager diff --git a/server/runserver.sh b/server/runserver.sh index 38eafe09..b0c3f138 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -4,4 +4,11 @@ if [ -f "/venv/bin/activate" ]; then source /venv/bin/activate fi alembic upgrade head -python -m reflector.app + +if [ "${ENTRYPOINT}" = "server" ]; then + python -m reflector.app +elif [ "${ENTRYPOINT}" = "worker" ]; then + celery -A reflector.worker.app worker --loglevel=info +else + echo "Unknown command" +fi diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 76b56abf..aafca9fd 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -45,28 +45,50 @@ async def dummy_transcript(): from reflector.processors.types import AudioFile, Transcript, Word class TestAudioTranscriptProcessor(AudioTranscriptProcessor): - async def _transcript(self, data: AudioFile): - source_language = self.get_pref("audio:source_language", "en") - print("transcripting", source_language) - print("pipeline", self.pipeline) - print("prefs", self.pipeline.prefs) + _time_idx = 0 + async def _transcript(self, data: AudioFile): + i = self._time_idx + self._time_idx += 2 return Transcript( text="Hello world.", words=[ - Word(start=0.0, end=1.0, text="Hello"), - Word(start=1.0, end=2.0, text=" world."), + Word(start=i, end=i + 1, text="Hello", speaker=0), + Word(start=i + 1, end=i + 2, text=" world.", speaker=0), ], ) with patch( "reflector.processors.audio_transcript_auto" - ".AudioTranscriptAutoProcessor.get_instance" + ".AudioTranscriptAutoProcessor.__new__" ) as mock_audio: mock_audio.return_value = TestAudioTranscriptProcessor() yield +@pytest.fixture +async def dummy_diarization(): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + + class TestAudioDiarizationProcessor(AudioDiarizationProcessor): + _time_idx = 0 + + async def _diarize(self, data): + i = self._time_idx + self._time_idx += 2 + return [ + {"start": i, "end": i + 1, "speaker": 0}, + {"start": i + 1, "end": i + 2, "speaker": 1}, + ] + + with patch( + "reflector.processors.audio_diarization_auto" + ".AudioDiarizationAutoProcessor.__new__" + ) as mock_audio: + mock_audio.return_value = TestAudioDiarizationProcessor() + yield + + @pytest.fixture async def dummy_llm(): from reflector.llm.base import LLM @@ -98,7 +120,17 @@ def ensure_casing(): @pytest.fixture def sentence_tokenize(): with patch( - "reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize" + "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize" ) as mock_sent_tokenize: mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"] yield + + +@pytest.fixture(scope="session") +def celery_enable_logging(): + return True + + +@pytest.fixture(scope="session") +def celery_config(): + return {"broker_url": "memory://", "result_backend": "rpc"} diff --git a/server/tests/test_processor_transcript_segment.py b/server/tests/test_processor_transcript_segment.py new file mode 100644 index 00000000..6fde0dd1 --- /dev/null +++ b/server/tests/test_processor_transcript_segment.py @@ -0,0 +1,161 @@ +def test_processor_transcript_segment(): + from reflector.processors.types import Transcript, Word + + transcript = Transcript( + words=[ + Word(text=" the", start=5.12, end=5.48, speaker=0), + Word(text=" different", start=5.48, end=5.8, speaker=0), + Word(text=" projects", start=5.8, end=6.3, speaker=0), + Word(text=" that", start=6.3, end=6.5, speaker=0), + Word(text=" are", start=6.5, end=6.58, speaker=0), + Word(text=" going", start=6.58, end=6.82, speaker=0), + Word(text=" on", start=6.82, end=7.26, speaker=0), + Word(text=" to", start=7.26, end=7.4, speaker=0), + Word(text=" give", start=7.4, end=7.54, speaker=0), + Word(text=" you", start=7.54, end=7.9, speaker=0), + Word(text=" context", start=7.9, end=8.24, speaker=0), + Word(text=" and", start=8.24, end=8.66, speaker=0), + Word(text=" I", start=8.66, end=8.72, speaker=0), + Word(text=" think", start=8.72, end=8.82, speaker=0), + Word(text=" that's", start=8.82, end=9.04, speaker=0), + Word(text=" what", start=9.04, end=9.12, speaker=0), + Word(text=" we'll", start=9.12, end=9.24, speaker=0), + Word(text=" do", start=9.24, end=9.32, speaker=0), + Word(text=" this", start=9.32, end=9.52, speaker=0), + Word(text=" week.", start=9.52, end=9.76, speaker=0), + Word(text=" Um,", start=10.24, end=10.62, speaker=0), + Word(text=" so,", start=11.36, end=11.94, speaker=0), + Word(text=" um,", start=12.46, end=12.92, speaker=0), + Word(text=" what", start=13.74, end=13.94, speaker=0), + Word(text=" we're", start=13.94, end=14.1, speaker=0), + Word(text=" going", start=14.1, end=14.24, speaker=0), + Word(text=" to", start=14.24, end=14.34, speaker=0), + Word(text=" do", start=14.34, end=14.8, speaker=0), + Word(text=" at", start=14.8, end=14.98, speaker=0), + Word(text=" H", start=14.98, end=15.04, speaker=0), + Word(text=" of", start=15.04, end=15.16, speaker=0), + Word(text=" you,", start=15.16, end=15.26, speaker=0), + Word(text=" maybe.", start=15.28, end=15.34, speaker=0), + Word(text=" you", start=15.36, end=15.52, speaker=0), + Word(text=" can", start=15.52, end=15.62, speaker=0), + Word(text=" introduce", start=15.62, end=15.98, speaker=0), + Word(text=" yourself", start=15.98, end=16.42, speaker=0), + Word(text=" to", start=16.42, end=16.68, speaker=0), + Word(text=" the", start=16.68, end=16.72, speaker=0), + Word(text=" team", start=16.72, end=17.52, speaker=0), + Word(text=" quickly", start=17.87, end=18.65, speaker=0), + Word(text=" and", start=18.65, end=19.63, speaker=0), + Word(text=" Oh,", start=20.91, end=21.55, speaker=0), + Word(text=" this", start=21.67, end=21.83, speaker=0), + Word(text=" is", start=21.83, end=22.17, speaker=0), + Word(text=" a", start=22.17, end=22.35, speaker=0), + Word(text=" reflector", start=22.35, end=22.89, speaker=0), + Word(text=" translating", start=22.89, end=23.33, speaker=0), + Word(text=" into", start=23.33, end=23.73, speaker=0), + Word(text=" French", start=23.73, end=23.95, speaker=0), + Word(text=" for", start=23.95, end=24.15, speaker=0), + Word(text=" me.", start=24.15, end=24.43, speaker=0), + Word(text=" This", start=27.87, end=28.19, speaker=0), + Word(text=" is", start=28.19, end=28.45, speaker=0), + Word(text=" all", start=28.45, end=28.79, speaker=0), + Word(text=" the", start=28.79, end=29.15, speaker=0), + Word(text=" way,", start=29.15, end=29.15, speaker=0), + Word(text=" please,", start=29.53, end=29.59, speaker=0), + Word(text=" please,", start=29.73, end=29.77, speaker=0), + Word(text=" please,", start=29.77, end=29.83, speaker=0), + Word(text=" please.", start=29.83, end=29.97, speaker=0), + Word(text=" Yeah,", start=29.97, end=30.17, speaker=0), + Word(text=" that's", start=30.25, end=30.33, speaker=0), + Word(text=" all", start=30.33, end=30.49, speaker=0), + Word(text=" it's", start=30.49, end=30.69, speaker=0), + Word(text=" right.", start=30.69, end=30.69, speaker=0), + Word(text=" Right.", start=30.72, end=30.98, speaker=1), + Word(text=" Yeah,", start=31.56, end=31.72, speaker=2), + Word(text=" that's", start=31.86, end=31.98, speaker=2), + Word(text=" right.", start=31.98, end=32.2, speaker=2), + Word(text=" Because", start=32.38, end=32.46, speaker=0), + Word(text=" I", start=32.46, end=32.58, speaker=0), + Word(text=" thought", start=32.58, end=32.78, speaker=0), + Word(text=" I'd", start=32.78, end=33.0, speaker=0), + Word(text=" be", start=33.0, end=33.02, speaker=0), + Word(text=" able", start=33.02, end=33.18, speaker=0), + Word(text=" to", start=33.18, end=33.34, speaker=0), + Word(text=" pull", start=33.34, end=33.52, speaker=0), + Word(text=" out.", start=33.52, end=33.68, speaker=0), + Word(text=" Yeah,", start=33.7, end=33.9, speaker=0), + Word(text=" that", start=33.9, end=34.02, speaker=0), + Word(text=" was", start=34.02, end=34.24, speaker=0), + Word(text=" the", start=34.24, end=34.34, speaker=0), + Word(text=" one", start=34.34, end=34.44, speaker=0), + Word(text=" before", start=34.44, end=34.7, speaker=0), + Word(text=" that.", start=34.7, end=35.24, speaker=0), + Word(text=" Friends,", start=35.84, end=36.46, speaker=0), + Word(text=" if", start=36.64, end=36.7, speaker=0), + Word(text=" you", start=36.7, end=36.7, speaker=0), + Word(text=" have", start=36.7, end=37.24, speaker=0), + Word(text=" tell", start=37.24, end=37.44, speaker=0), + Word(text=" us", start=37.44, end=37.68, speaker=0), + Word(text=" if", start=37.68, end=37.82, speaker=0), + Word(text=" it's", start=37.82, end=38.04, speaker=0), + Word(text=" good,", start=38.04, end=38.58, speaker=0), + Word(text=" exceptionally", start=38.96, end=39.1, speaker=0), + Word(text=" good", start=39.1, end=39.6, speaker=0), + Word(text=" and", start=39.6, end=39.86, speaker=0), + Word(text=" tell", start=39.86, end=40.0, speaker=0), + Word(text=" us", start=40.0, end=40.06, speaker=0), + Word(text=" when", start=40.06, end=40.2, speaker=0), + Word(text=" it's", start=40.2, end=40.34, speaker=0), + Word(text=" exceptionally", start=40.34, end=40.6, speaker=0), + Word(text=" bad.", start=40.6, end=40.94, speaker=0), + Word(text=" We", start=40.96, end=41.26, speaker=0), + Word(text=" don't", start=41.26, end=41.44, speaker=0), + Word(text=" need", start=41.44, end=41.66, speaker=0), + Word(text=" that", start=41.66, end=41.82, speaker=0), + Word(text=" at", start=41.82, end=41.94, speaker=0), + Word(text=" the", start=41.94, end=41.98, speaker=0), + Word(text=" middle", start=41.98, end=42.18, speaker=0), + Word(text=" of", start=42.18, end=42.36, speaker=0), + Word(text=" age.", start=42.36, end=42.7, speaker=0), + Word(text=" Okay,", start=43.26, end=43.44, speaker=0), + Word(text=" yeah,", start=43.68, end=43.76, speaker=0), + Word(text=" that", start=43.78, end=44.3, speaker=0), + Word(text=" sentence", start=44.3, end=44.72, speaker=0), + Word(text=" right", start=44.72, end=45.1, speaker=0), + Word(text=" before.", start=45.1, end=45.56, speaker=0), + Word(text=" it", start=46.08, end=46.36, speaker=0), + Word(text=" realizing", start=46.36, end=47.0, speaker=0), + Word(text=" that", start=47.0, end=47.28, speaker=0), + Word(text=" I", start=47.28, end=47.28, speaker=0), + Word(text=" was", start=47.28, end=47.64, speaker=0), + Word(text=" saying", start=47.64, end=48.06, speaker=0), + Word(text=" that", start=48.06, end=48.44, speaker=0), + Word(text=" it's", start=48.44, end=48.54, speaker=0), + Word(text=" interesting", start=48.54, end=48.78, speaker=0), + Word(text=" that", start=48.78, end=48.96, speaker=0), + Word(text=" it's", start=48.96, end=49.08, speaker=0), + Word(text=" translating", start=49.08, end=49.32, speaker=0), + Word(text=" the", start=49.32, end=49.56, speaker=0), + Word(text=" French", start=49.56, end=49.76, speaker=0), + Word(text=" was", start=49.76, end=50.16, speaker=0), + Word(text=" completely", start=50.16, end=50.4, speaker=0), + Word(text=" wrong.", start=50.4, end=50.7, speaker=0), + ] + ) + + segments = transcript.as_segments() + assert len(segments) == 7 + + # check speaker order + assert segments[0].speaker == 0 + assert segments[1].speaker == 0 + assert segments[2].speaker == 0 + assert segments[3].speaker == 1 + assert segments[4].speaker == 2 + assert segments[5].speaker == 0 + assert segments[6].speaker == 0 + + # check the timing (first entry, and first of others speakers) + assert segments[0].start == 5.12 + assert segments[3].start == 30.72 + assert segments[4].start == 31.56 + assert segments[5].start == 32.38 diff --git a/server/tests/test_retry_decorator.py b/server/tests/test_retry_decorator.py index 22729eac..c60a490f 100644 --- a/server/tests/test_retry_decorator.py +++ b/server/tests/test_retry_decorator.py @@ -1,3 +1,4 @@ +import asyncio import pytest import httpx from reflector.utils.retry import ( @@ -8,6 +9,31 @@ from reflector.utils.retry import ( ) +@pytest.mark.asyncio +async def test_retry_redirect(httpx_mock): + async def custom_response(request: httpx.Request): + if request.url.path == "/hello": + await asyncio.sleep(1) + return httpx.Response( + status_code=303, headers={"location": "https://test_url/redirected"} + ) + elif request.url.path == "/redirected": + return httpx.Response(status_code=200, json={"hello": "world"}) + else: + raise Exception("Unexpected path") + + httpx_mock.add_callback(custom_response) + async with httpx.AsyncClient() as client: + # timeout should not triggered, as it will end up ok + # even though the first request is a 303 and took more that 0.5 + resp = await retry(client.get)( + "https://test_url/hello", + retry_timeout=0.5, + follow_redirects=True, + ) + assert resp.json() == {"hello": "world"} + + @pytest.mark.asyncio async def test_retry_httpx(httpx_mock): # this code should be force a retry diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 50e74231..413c8b24 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -32,7 +32,7 @@ class ThreadedUvicorn: @pytest.fixture -async def appserver(tmpdir): +async def appserver(tmpdir, celery_session_app, celery_session_worker): from reflector.settings import settings from reflector.app import app @@ -52,12 +52,20 @@ async def appserver(tmpdir): settings.DATA_DIR = DATA_DIR +@pytest.fixture(scope="session") +def celery_includes(): + return ["reflector.pipelines.main_live_pipeline"] + + +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket( tmpdir, dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize, @@ -95,6 +103,7 @@ async def test_transcript_rtc_and_websocket( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse @@ -121,14 +130,20 @@ async def test_transcript_rtc_and_websocket( # XXX aiortc is long to close the connection # instead of waiting a long time, we just send a STOP client.channel.send(json.dumps({"cmd": "STOP"})) - - # wait the processing to finish - await asyncio.sleep(2) - await client.stop() # wait the processing to finish - await asyncio.sleep(2) + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") # stop websocket task websocket_task.cancel() @@ -169,29 +184,28 @@ async def test_transcript_rtc_and_websocket( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" - # check that transcript status in model is updated - resp = await ac.get(f"/transcripts/{tid}") - assert resp.status_code == 200 - assert resp.json()["status"] == "ended" - # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200 assert resp.headers["Content-Type"] == "audio/mpeg" +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket_and_fr( tmpdir, dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize, @@ -232,6 +246,7 @@ async def test_transcript_rtc_and_websocket_and_fr( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse @@ -265,6 +280,18 @@ async def test_transcript_rtc_and_websocket_and_fr( await client.stop() # wait the processing to finish + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] == "ended": + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") + await asyncio.sleep(2) # stop websocket task @@ -306,7 +333,8 @@ async def test_transcript_rtc_and_websocket_and_fr( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS" diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 97f1a846..d58d5247 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -3,6 +3,7 @@ import Modal from "../modal"; import useTranscript from "../useTranscript"; import useTopics from "../useTopics"; import useWaveform from "../useWaveform"; +import useMp3 from "../useMp3"; import { TopicList } from "../topicList"; import Recorder from "../recorder"; import { Topic } from "../webSocketTypes"; @@ -28,6 +29,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { const topics = useTopics(protectedPath, transcriptId); const waveform = useWaveform(protectedPath, transcriptId); const useActiveTopic = useState(null); + const mp3 = useMp3(api, transcriptId); if (transcript?.error /** || topics?.error || waveform?.error **/) { return ( @@ -62,6 +64,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { waveform={waveform?.waveform} isPastMeeting={true} transcriptId={transcript?.response?.id} + mp3Blob={mp3.blob} /> )} diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index dc6bc0a7..765d8f09 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -30,6 +30,7 @@ type RecorderProps = { waveform?: AudioWaveform | null; isPastMeeting: boolean; transcriptId?: string | null; + mp3Blob?: Blob | null; }; export default function Recorder(props: RecorderProps) { @@ -108,11 +109,7 @@ export default function Recorder(props: RecorderProps) { if (waveformRef.current) { const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - url: props.transcriptId - ? `${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3` - : undefined, peaks: props.waveform?.data, - hideScrollbar: true, autoCenter: true, barWidth: 2, @@ -146,6 +143,10 @@ export default function Recorder(props: RecorderProps) { if (props.isPastMeeting) _wavesurfer.toggleInteraction(true); + if (props.mp3Blob) { + _wavesurfer.loadBlob(props.mp3Blob); + } + setWavesurfer(_wavesurfer); return () => { @@ -157,6 +158,12 @@ export default function Recorder(props: RecorderProps) { } }, []); + useEffect(() => { + if (!wavesurfer) return; + if (!props.mp3Blob) return; + wavesurfer.loadBlob(props.mp3Blob); + }, [props.mp3Blob]); + useEffect(() => { topicsRef.current = props.topics; if (!isRecording) renderMarkers(); diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index e5de09c8..e7454f79 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -7,6 +7,7 @@ import { import { formatTime } from "../../lib/time"; import ScrollToBottom from "./scrollToBottom"; import { Topic } from "./webSocketTypes"; +import { generateHighContrastColor } from "../../lib/utils"; type TopicListProps = { topics: Topic[]; @@ -103,7 +104,37 @@ export function TopicList({ /> {activeTopic?.id == topic.id && ( -
{topic.transcript}
+
+ {topic.segments ? ( + <> + {topic.segments.map((segment, index: number) => ( +

+ + [{formatTime(segment.start)}] + + + {" "} + (Speaker {segment.speaker}): + {" "} + {segment.text} +

+ ))} + + ) : ( + <>{topic.transcript} + )} +
)} ))} diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 8bccf903..b7677180 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -1,36 +1,64 @@ -import { useEffect, useState } from "react"; +import { useContext, useEffect, useState } from "react"; import { DefaultApi, - V1TranscriptGetAudioMp3Request, + // V1TranscriptGetAudioMp3Request, } from "../../api/apis/DefaultApi"; import {} from "../../api"; import { useError } from "../../(errors)/errorContext"; +import { DomainContext } from "../domainContext"; type Mp3Response = { url: string | null; + blob: Blob | null; loading: boolean; error: Error | null; }; const useMp3 = (api: DefaultApi, id: string): Mp3Response => { const [url, setUrl] = useState(null); + const [blob, setBlob] = useState(null); const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); + const { api_url } = useContext(DomainContext); const getMp3 = (id: string) => { - if (!id) throw new Error("Transcript ID is required to get transcript Mp3"); + if (!id) return; setLoading(true); - const requestParameters: V1TranscriptGetAudioMp3Request = { - transcriptId: id, - }; - api - .v1TranscriptGetAudioMp3(requestParameters) - .then((result) => { - setUrl(result); - setLoading(false); - console.debug("Transcript Mp3 loaded:", result); + // XXX Current API interface does not output a blob, we need to to is manually + // const requestParameters: V1TranscriptGetAudioMp3Request = { + // transcriptId: id, + // }; + // api + // .v1TranscriptGetAudioMp3(requestParameters) + // .then((result) => { + // setUrl(result); + // setLoading(false); + // console.debug("Transcript Mp3 loaded:", result); + // }) + // .catch((err) => { + // setError(err); + // setErrorState(err); + // }); + const localUrl = `${api_url}/v1/transcripts/${id}/audio/mp3`; + if (localUrl == url) return; + const headers = new Headers(); + + if (api.configuration.configuration.accessToken) { + headers.set("Authorization", api.configuration.configuration.accessToken); + } + + fetch(localUrl, { + method: "GET", + headers, + }) + .then((response) => { + setUrl(localUrl); + response.blob().then((blob) => { + setBlob(blob); + setLoading(false); + }); }) .catch((err) => { setError(err); @@ -42,7 +70,7 @@ const useMp3 = (api: DefaultApi, id: string): Mp3Response => { getMp3(id); }, [id]); - return { url, loading, error }; + return { url, blob, loading, error }; }; export default useMp3; diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index 6bd7bf48..5610c2a4 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -58,6 +58,39 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title: "Topic 1: Introduction to Quantum Mechanics", transcript: "A brief overview of quantum mechanics and its principles.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + , + { + speaker: 3, + start: 90, + text: "This is the third speaker", + }, + { + speaker: 4, + start: 90, + text: "This is the fourth speaker", + }, + { + speaker: 5, + start: 123, + text: "This is the fifth speaker", + }, + { + speaker: 6, + start: 300, + text: "This is the sixth speaker", + }, + ], }, { id: "2", @@ -66,6 +99,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title: "Topic 2: Machine Learning Algorithms", transcript: "Understanding the different types of machine learning algorithms.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "3", @@ -73,6 +118,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 3", title: "Topic 3: Mental Health Awareness", transcript: "Ways to improve mental health and reduce stigma.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "4", @@ -80,6 +137,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 4", title: "Topic 4: Basics of Productivity", transcript: "Tips and tricks to increase daily productivity.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "5", @@ -88,6 +157,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title: "Topic 5: Future of Aviation", transcript: "Exploring the advancements and possibilities in aviation.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, ]); @@ -106,6 +187,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { "Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.", transcript: "A brief overview of quantum mechanics and its principles.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "2", @@ -115,6 +208,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { "Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.", transcript: "Understanding the different types of machine learning algorithms.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "3", @@ -123,6 +228,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title: "Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.", transcript: "Ways to improve mental health and reduce stigma.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "4", @@ -131,6 +248,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title: "Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.", transcript: "Tips and tricks to increase daily productivity.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "5", @@ -140,6 +269,18 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { "Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.", transcript: "Exploring the advancements and possibilities in aviation.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, ]); @@ -173,7 +314,17 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { break; case "TOPIC": - setTopics((prevTopics) => [...prevTopics, message.data]); + setTopics((prevTopics) => { + const topic = message.data as Topic; + const index = prevTopics.findIndex( + (prevTopic) => prevTopic.id === topic.id, + ); + if (index >= 0) { + prevTopics[index] = topic; + return prevTopics; + } + return [...prevTopics, topic]; + }); console.debug("TOPIC event:", message.data); break; diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts index 450b3b1c..edd35eb6 100644 --- a/www/app/[domain]/transcripts/webSocketTypes.ts +++ b/www/app/[domain]/transcripts/webSocketTypes.ts @@ -1,10 +1,6 @@ -export type Topic = { - timestamp: number; - title: string; - transcript: string; - summary: string; - id: string; -}; +import { GetTranscriptTopic } from "../../api"; + +export type Topic = GetTranscriptTopic; export type Transcript = { text: string; diff --git a/www/app/api/.openapi-generator/FILES b/www/app/api/.openapi-generator/FILES index 16763a8d..532a6a16 100644 --- a/www/app/api/.openapi-generator/FILES +++ b/www/app/api/.openapi-generator/FILES @@ -5,10 +5,11 @@ models/AudioWaveform.ts models/CreateTranscript.ts models/DeletionStatus.ts models/GetTranscript.ts +models/GetTranscriptSegmentTopic.ts +models/GetTranscriptTopic.ts models/HTTPValidationError.ts models/PageGetTranscript.ts models/RtcOffer.ts -models/TranscriptTopic.ts models/UpdateTranscript.ts models/UserInfo.ts models/ValidationError.ts diff --git a/www/app/api/apis/DefaultApi.ts b/www/app/api/apis/DefaultApi.ts index d51d42ca..5bb2e7e9 100644 --- a/www/app/api/apis/DefaultApi.ts +++ b/www/app/api/apis/DefaultApi.ts @@ -42,10 +42,6 @@ import { UpdateTranscriptToJSON, } from "../models"; -export interface RtcOfferRequest { - rtcOffer: RtcOffer; -} - export interface V1TranscriptDeleteRequest { transcriptId: any; } @@ -56,6 +52,7 @@ export interface V1TranscriptGetRequest { export interface V1TranscriptGetAudioMp3Request { transcriptId: any; + token?: any; } export interface V1TranscriptGetAudioWaveformRequest { @@ -132,58 +129,6 @@ export class DefaultApi extends runtime.BaseAPI { return await response.value(); } - /** - * Rtc Offer - */ - async rtcOfferRaw( - requestParameters: RtcOfferRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise> { - if ( - requestParameters.rtcOffer === null || - requestParameters.rtcOffer === undefined - ) { - throw new runtime.RequiredError( - "rtcOffer", - "Required parameter requestParameters.rtcOffer was null or undefined when calling rtcOffer.", - ); - } - - const queryParameters: any = {}; - - const headerParameters: runtime.HTTPHeaders = {}; - - headerParameters["Content-Type"] = "application/json"; - - const response = await this.request( - { - path: `/offer`, - method: "POST", - headers: headerParameters, - query: queryParameters, - body: RtcOfferToJSON(requestParameters.rtcOffer), - }, - initOverrides, - ); - - if (this.isJsonMime(response.headers.get("content-type"))) { - return new runtime.JSONApiResponse(response); - } else { - return new runtime.TextApiResponse(response) as any; - } - } - - /** - * Rtc Offer - */ - async rtcOffer( - requestParameters: RtcOfferRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise { - const response = await this.rtcOfferRaw(requestParameters, initOverrides); - return await response.value(); - } - /** * Transcript Delete */ @@ -325,6 +270,10 @@ export class DefaultApi extends runtime.BaseAPI { const queryParameters: any = {}; + if (requestParameters.token !== undefined) { + queryParameters["token"] = requestParameters.token; + } + const headerParameters: runtime.HTTPHeaders = {}; if (this.configuration && this.configuration.accessToken) { diff --git a/www/app/api/models/GetTranscriptSegmentTopic.ts b/www/app/api/models/GetTranscriptSegmentTopic.ts new file mode 100644 index 00000000..cc2049b1 --- /dev/null +++ b/www/app/api/models/GetTranscriptSegmentTopic.ts @@ -0,0 +1,88 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface GetTranscriptSegmentTopic + */ +export interface GetTranscriptSegmentTopic { + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + text: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + start: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + speaker: any | null; +} + +/** + * Check if a given object implements the GetTranscriptSegmentTopic interface. + */ +export function instanceOfGetTranscriptSegmentTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "text" in value; + isInstance = isInstance && "start" in value; + isInstance = isInstance && "speaker" in value; + + return isInstance; +} + +export function GetTranscriptSegmentTopicFromJSON( + json: any, +): GetTranscriptSegmentTopic { + return GetTranscriptSegmentTopicFromJSONTyped(json, false); +} + +export function GetTranscriptSegmentTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): GetTranscriptSegmentTopic { + if (json === undefined || json === null) { + return json; + } + return { + text: json["text"], + start: json["start"], + speaker: json["speaker"], + }; +} + +export function GetTranscriptSegmentTopicToJSON( + value?: GetTranscriptSegmentTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + text: value.text, + start: value.start, + speaker: value.speaker, + }; +} diff --git a/www/app/api/models/GetTranscriptTopic.ts b/www/app/api/models/GetTranscriptTopic.ts new file mode 100644 index 00000000..460b8b39 --- /dev/null +++ b/www/app/api/models/GetTranscriptTopic.ts @@ -0,0 +1,112 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface GetTranscriptTopic + */ +export interface GetTranscriptTopic { + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + id: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + title: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + summary: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + timestamp: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + transcript: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + segments?: any | null; +} + +/** + * Check if a given object implements the GetTranscriptTopic interface. + */ +export function instanceOfGetTranscriptTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "id" in value; + isInstance = isInstance && "title" in value; + isInstance = isInstance && "summary" in value; + isInstance = isInstance && "timestamp" in value; + isInstance = isInstance && "transcript" in value; + + return isInstance; +} + +export function GetTranscriptTopicFromJSON(json: any): GetTranscriptTopic { + return GetTranscriptTopicFromJSONTyped(json, false); +} + +export function GetTranscriptTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): GetTranscriptTopic { + if (json === undefined || json === null) { + return json; + } + return { + id: json["id"], + title: json["title"], + summary: json["summary"], + timestamp: json["timestamp"], + transcript: json["transcript"], + segments: !exists(json, "segments") ? undefined : json["segments"], + }; +} + +export function GetTranscriptTopicToJSON( + value?: GetTranscriptTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + id: value.id, + title: value.title, + summary: value.summary, + timestamp: value.timestamp, + transcript: value.transcript, + segments: value.segments, + }; +} diff --git a/www/app/api/models/TranscriptSegmentTopic.ts b/www/app/api/models/TranscriptSegmentTopic.ts new file mode 100644 index 00000000..73496a67 --- /dev/null +++ b/www/app/api/models/TranscriptSegmentTopic.ts @@ -0,0 +1,88 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface TranscriptSegmentTopic + */ +export interface TranscriptSegmentTopic { + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + speaker: any | null; + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + text: any | null; + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + timestamp: any | null; +} + +/** + * Check if a given object implements the TranscriptSegmentTopic interface. + */ +export function instanceOfTranscriptSegmentTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "speaker" in value; + isInstance = isInstance && "text" in value; + isInstance = isInstance && "timestamp" in value; + + return isInstance; +} + +export function TranscriptSegmentTopicFromJSON( + json: any, +): TranscriptSegmentTopic { + return TranscriptSegmentTopicFromJSONTyped(json, false); +} + +export function TranscriptSegmentTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): TranscriptSegmentTopic { + if (json === undefined || json === null) { + return json; + } + return { + speaker: json["speaker"], + text: json["text"], + timestamp: json["timestamp"], + }; +} + +export function TranscriptSegmentTopicToJSON( + value?: TranscriptSegmentTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + speaker: value.speaker, + text: value.text, + timestamp: value.timestamp, + }; +} diff --git a/www/app/api/models/TranscriptTopic.ts b/www/app/api/models/TranscriptTopic.ts index 8b395374..f22496b2 100644 --- a/www/app/api/models/TranscriptTopic.ts +++ b/www/app/api/models/TranscriptTopic.ts @@ -42,13 +42,13 @@ export interface TranscriptTopic { * @type {any} * @memberof TranscriptTopic */ - transcript?: any | null; + timestamp: any | null; /** * * @type {any} * @memberof TranscriptTopic */ - timestamp: any | null; + segments?: any | null; } /** @@ -78,8 +78,8 @@ export function TranscriptTopicFromJSONTyped( id: !exists(json, "id") ? undefined : json["id"], title: json["title"], summary: json["summary"], - transcript: !exists(json, "transcript") ? undefined : json["transcript"], timestamp: json["timestamp"], + segments: !exists(json, "segments") ? undefined : json["segments"], }; } @@ -94,7 +94,7 @@ export function TranscriptTopicToJSON(value?: TranscriptTopic | null): any { id: value.id, title: value.title, summary: value.summary, - transcript: value.transcript, timestamp: value.timestamp, + segments: value.segments, }; } diff --git a/www/app/api/models/index.ts b/www/app/api/models/index.ts index 99874641..9e691b09 100644 --- a/www/app/api/models/index.ts +++ b/www/app/api/models/index.ts @@ -4,10 +4,11 @@ export * from "./AudioWaveform"; export * from "./CreateTranscript"; export * from "./DeletionStatus"; export * from "./GetTranscript"; +export * from "./GetTranscriptSegmentTopic"; +export * from "./GetTranscriptTopic"; export * from "./HTTPValidationError"; export * from "./PageGetTranscript"; export * from "./RtcOffer"; -export * from "./TranscriptTopic"; export * from "./UpdateTranscript"; export * from "./UserInfo"; export * from "./ValidationError"; diff --git a/www/app/lib/utils.ts b/www/app/lib/utils.ts index db775f07..6b72ddea 100644 --- a/www/app/lib/utils.ts +++ b/www/app/lib/utils.ts @@ -1,3 +1,123 @@ export function isDevelopment() { return process.env.NEXT_PUBLIC_ENV === "development"; } + +// Function to calculate WCAG contrast ratio +export const getContrastRatio = ( + foreground: [number, number, number], + background: [number, number, number], +) => { + const [r1, g1, b1] = foreground; + const [r2, g2, b2] = background; + + const lum1 = + 0.2126 * Math.pow(r1 / 255, 2.2) + + 0.7152 * Math.pow(g1 / 255, 2.2) + + 0.0722 * Math.pow(b1 / 255, 2.2); + const lum2 = + 0.2126 * Math.pow(r2 / 255, 2.2) + + 0.7152 * Math.pow(g2 / 255, 2.2) + + 0.0722 * Math.pow(b2 / 255, 2.2); + + return (Math.max(lum1, lum2) + 0.05) / (Math.min(lum1, lum2) + 0.05); +}; + +// Function to hash string into 32-bit integer +// 🔴 DO NOT USE FOR CRYPTOGRAPHY PURPOSES 🔴 + +export function murmurhash3_32_gc(key: string, seed: number = 0) { + let remainder, bytes, h1, h1b, c1, c2, k1, i; + + remainder = key.length & 3; // key.length % 4 + bytes = key.length - remainder; + h1 = seed; + c1 = 0xcc9e2d51; + c2 = 0x1b873593; + i = 0; + + while (i < bytes) { + k1 = + (key.charCodeAt(i) & 0xff) | + ((key.charCodeAt(++i) & 0xff) << 8) | + ((key.charCodeAt(++i) & 0xff) << 16) | + ((key.charCodeAt(++i) & 0xff) << 24); + + ++i; + + k1 = + ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & 0xffffffff; + k1 = (k1 << 15) | (k1 >>> 17); + k1 = + ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & 0xffffffff; + + h1 ^= k1; + h1 = (h1 << 13) | (h1 >>> 19); + h1b = + ((h1 & 0xffff) * 5 + ((((h1 >>> 16) * 5) & 0xffff) << 16)) & 0xffffffff; + h1 = (h1b & 0xffff) + 0x6b64 + ((((h1b >>> 16) + 0xe654) & 0xffff) << 16); + } + + k1 = 0; + + switch (remainder) { + case 3: + k1 ^= (key.charCodeAt(i + 2) & 0xff) << 16; + case 2: + k1 ^= (key.charCodeAt(i + 1) & 0xff) << 8; + case 1: + k1 ^= key.charCodeAt(i) & 0xff; + + k1 = + ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & + 0xffffffff; + k1 = (k1 << 15) | (k1 >>> 17); + k1 = + ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= k1; + } + + h1 ^= key.length; + + h1 ^= h1 >>> 16; + h1 = + ((h1 & 0xffff) * 0x85ebca6b + + ((((h1 >>> 16) * 0x85ebca6b) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= h1 >>> 13; + h1 = + ((h1 & 0xffff) * 0xc2b2ae35 + + ((((h1 >>> 16) * 0xc2b2ae35) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= h1 >>> 16; + + return h1 >>> 0; +} + +// Generates a color that is guaranteed to have high contrast with the given background color (optional) + +export const generateHighContrastColor = ( + name: string, + backgroundColor: [number, number, number] | null = null, +) => { + const hash = murmurhash3_32_gc(name); + let red = (hash & 0xff0000) >> 16; + let green = (hash & 0x00ff00) >> 8; + let blue = hash & 0x0000ff; + + const getCssColor = (red: number, green: number, blue: number) => + `rgb(${red}, ${green}, ${blue})`; + + if (!backgroundColor) return getCssColor(red, green, blue); + + const contrast = getContrastRatio([red, green, blue], backgroundColor); + + // Adjust the color to achieve better contrast if necessary (WCAG recommends at least 4.5:1 for text) + if (contrast < 4.5) { + red = Math.abs(255 - red); + green = Math.abs(255 - green); + blue = Math.abs(255 - blue); + } + + return getCssColor(red, green, blue); +}; diff --git a/www/package.json b/www/package.json index edbc0790..55c7df73 100644 --- a/www/package.json +++ b/www/package.json @@ -35,7 +35,7 @@ "supports-color": "^9.4.0", "tailwindcss": "^3.3.2", "typescript": "^5.1.6", - "wavesurfer.js": "^7.0.3" + "wavesurfer.js": "^7.4.2" }, "main": "index.js", "repository": "https://github.com/Monadical-SAS/reflector-ui.git", diff --git a/www/pages/forbidden.tsx b/www/pages/forbidden.tsx index 31a746fc..ada3d424 100644 --- a/www/pages/forbidden.tsx +++ b/www/pages/forbidden.tsx @@ -1,7 +1,7 @@ import type { NextPage } from "next"; const Forbidden: NextPage = () => { - return

Sorry, you are not authorized to access this page.

; + return

Sorry, you are not authorized to access this page

; }; export default Forbidden; diff --git a/www/yarn.lock b/www/yarn.lock index a67822be..8ec03382 100644 --- a/www/yarn.lock +++ b/www/yarn.lock @@ -2638,10 +2638,10 @@ watchpack@2.4.0: glob-to-regexp "^0.4.1" graceful-fs "^4.1.2" -wavesurfer.js@^7.0.3: - version "7.0.3" - resolved "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.0.3.tgz" - integrity sha512-gJ3P+Bd3Q4E8qETjjg0pneaVqm2J7jegG2Cc6vqEF5YDDKQ3m8sKsvVfgVhJkacKkO9jFAGDu58Hw4zLr7xD0A== +wavesurfer.js@^7.4.2: + version "7.4.2" + resolved "https://registry.yarnpkg.com/wavesurfer.js/-/wavesurfer.js-7.4.2.tgz#59f5c87193d4eeeb199858688ddac1ad7ba86b3a" + integrity sha512-4pNQ1porOCUBYBmd2F1TqVuBnB2wBPipaw2qI920zYLuPnada0Rd1CURgh8HRuPGKxijj2iyZDFN2UZwsaEuhA== wcwidth@>=1.0.1, wcwidth@^1.0.1: version "1.0.1"