diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index ecac11b4..1ab6a031 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -3,7 +3,7 @@ name: Deploy to Amazon ECS on: [workflow_dispatch] env: - # 384658522150.dkr.ecr.us-east-1.amazonaws.com/reflector + # 950402358378.dkr.ecr.us-east-1.amazonaws.com/reflector AWS_REGION: us-east-1 ECR_REPOSITORY: reflector diff --git a/.gitignore b/.gitignore index a43e88f7..c3b01d5a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ server/.env .env server/exportdanswer +.vercel +.env*.local diff --git a/server/env.example b/server/env.example index 8c4dcdab..c5a38bf5 100644 --- a/server/env.example +++ b/server/env.example @@ -51,17 +51,6 @@ #TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run #TRANSCRIPT_MODAL_API_KEY=xxxxx -## Using serverless banana.dev (require reflector-gpu-banana deployed) -## XXX this service is buggy do not use at the moment -## XXX it also require the audio to be saved to S3 -#TRANSCRIPT_BACKEND=banana -#TRANSCRIPT_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev -#TRANSCRIPT_BANANA_API_KEY=xxx -#TRANSCRIPT_BANANA_MODEL_KEY=xxx -#TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID=xxx -#TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY=xxx -#TRANSCRIPT_STORAGE_AWS_BUCKET_NAME="reflector-bucket/chunks" - ## ======================================================= ## LLM backend ## @@ -78,13 +67,6 @@ #LLM_URL=https://xxxxxx--reflector-llm-web.modal.run #LLM_MODAL_API_KEY=xxx -## Using serverless banana.dev (require reflector-gpu-banana deployed) -## XXX this service is buggy do not use at the moment -#LLM_BACKEND=banana -#LLM_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev -#LLM_BANANA_API_KEY=xxxxx -#LLM_BANANA_MODEL_KEY=xxxxx - ## Using OpenAI #LLM_BACKEND=openai #LLM_OPENAI_KEY=xxx diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 02feedb7..f1e9d166 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -81,7 +81,8 @@ class LLM: LLM_MODEL, torch_dtype=getattr(torch, LLM_TORCH_DTYPE), low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE, - cache_dir=IMAGE_MODEL_DIR + cache_dir=IMAGE_MODEL_DIR, + local_files_only=True ) # JSONFormer doesn't yet support generation configs @@ -96,7 +97,8 @@ class LLM: print("Instance llm tokenizer") tokenizer = AutoTokenizer.from_pretrained( LLM_MODEL, - cache_dir=IMAGE_MODEL_DIR + cache_dir=IMAGE_MODEL_DIR, + local_files_only=True ) # move model to gpu diff --git a/server/gpu/modal/reflector_llm_zephyr.py b/server/gpu/modal/reflector_llm_zephyr.py index cbb436b0..b101f5f2 100644 --- a/server/gpu/modal/reflector_llm_zephyr.py +++ b/server/gpu/modal/reflector_llm_zephyr.py @@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 -IMAGE_MODEL_DIR = "/root/llm_models" +IMAGE_MODEL_DIR = "/root/llm_models/zephyr" stub = Stub(name="reflector-llm-zephyr") @@ -81,7 +81,8 @@ class LLM: LLM_MODEL, torch_dtype=getattr(torch, LLM_TORCH_DTYPE), low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE, - cache_dir=IMAGE_MODEL_DIR + cache_dir=IMAGE_MODEL_DIR, + local_files_only=True ) # JSONFormer doesn't yet support generation configs @@ -96,7 +97,8 @@ class LLM: print("Instance llm tokenizer") tokenizer = AutoTokenizer.from_pretrained( LLM_MODEL, - cache_dir=IMAGE_MODEL_DIR + cache_dir=IMAGE_MODEL_DIR, + local_files_only=True ) gen_cfg.pad_token_id = tokenizer.eos_token_id gen_cfg.eos_token_id = tokenizer.eos_token_id diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index bee9ccd1..4f746ded 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -95,7 +95,8 @@ class Transcriber: device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, - download_root=WHISPER_MODEL_DIR + download_root=WHISPER_MODEL_DIR, + local_files_only=True ) @method() diff --git a/server/migrations/versions/4814901632bc_fix_duration.py b/server/migrations/versions/4814901632bc_fix_duration.py new file mode 100644 index 00000000..66628bb5 --- /dev/null +++ b/server/migrations/versions/4814901632bc_fix_duration.py @@ -0,0 +1,64 @@ +"""fix duration + +Revision ID: 4814901632bc +Revises: 38a927dcb099 +Create Date: 2023-11-10 18:12:17.886522 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = "4814901632bc" +down_revision: Union[str, None] = "38a927dcb099" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # for all the transcripts, calculate the duration from the mp3 + # and update the duration column + from pathlib import Path + from reflector.settings import settings + import av + + bind = op.get_bind() + transcript = table( + "transcript", column("id", sa.String), column("duration", sa.Float) + ) + + # select only the one with duration = 0 + results = bind.execute( + select([transcript.c.id, transcript.c.duration]).where( + transcript.c.duration == 0 + ) + ) + + data_dir = Path(settings.DATA_DIR) + for row in results: + audio_path = data_dir / row["id"] / "audio.mp3" + if not audio_path.exists(): + continue + + try: + print(f"Processing {audio_path}") + container = av.open(audio_path.as_posix()) + print(container.duration) + duration = round(float(container.duration / av.time_base), 2) + print(f"Duration: {duration}") + bind.execute( + transcript.update() + .where(transcript.c.id == row["id"]) + .values(duration=duration) + ) + except Exception as e: + print(f"Failed to process {audio_path}: {e}") + + +def downgrade() -> None: + pass diff --git a/server/poetry.lock b/server/poetry.lock index e72ade57..b89cf400 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2557,6 +2557,82 @@ typing-extensions = "*" [package.extras] dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"] +[[package]] +name = "pyinstrument" +version = "4.6.1" +description = "Call stack profiler for Python. Shows you why your code is slow!" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:73476e4bc6e467ac1b2c3c0dd1f0b71c9061d4de14626676adfdfbb14aa342b4"}, + {file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4d1da8efd974cf9df52ee03edaee2d3875105ddd00de35aa542760f7c612bdf7"}, + {file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:507be1ee2f2b0c9fba74d622a272640dd6d1b0c9ec3388b2cdeb97ad1e77125f"}, + {file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95cee6de08eb45754ef4f602ce52b640d1c535d934a6a8733a974daa095def37"}, + {file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7873e8cec92321251fdf894a72b3c78f4c5c20afdd1fef0baf9042ec843bb04"}, + {file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a242f6cac40bc83e1f3002b6b53681846dfba007f366971db0bf21e02dbb1903"}, + {file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:97c9660cdb4bd2a43cf4f3ab52cffd22f3ac9a748d913b750178fb34e5e39e64"}, + {file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e304cd0723e2b18ada5e63c187abf6d777949454c734f5974d64a0865859f0f4"}, + {file = "pyinstrument-4.6.1-cp310-cp310-win32.whl", hash = "sha256:cee21a2d78187dd8a80f72f5d0f1ddb767b2d9800f8bb4d94b6d11f217c22cdb"}, + {file = "pyinstrument-4.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:2000712f71d693fed2f8a1c1638d37b7919124f367b37976d07128d49f1445eb"}, + {file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a366c6f3dfb11f1739bdc1dee75a01c1563ad0bf4047071e5e77598087df457f"}, + {file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6be327be65d934796558aa9cb0f75ce62ebd207d49ad1854610c97b0579ad47"}, + {file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e160d9c5d20d3e4ef82269e4e8b246ff09bdf37af5fb8cb8ccca97936d95ad6"}, + {file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ffbf56605ef21c2fcb60de2fa74ff81f417d8be0c5002a407e414d6ef6dee43"}, + {file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c92cc4924596d6e8f30a16182bbe90893b1572d847ae12652f72b34a9a17c24a"}, + {file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f4b48a94d938cae981f6948d9ec603bab2087b178d2095d042d5a48aabaecaab"}, + {file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e7a386392275bdef4a1849712dc5b74f0023483fca14ef93d0ca27d453548982"}, + {file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:871b131b83e9b1122f2325061c68ed1e861eebcb568c934d2fb193652f077f77"}, + {file = "pyinstrument-4.6.1-cp311-cp311-win32.whl", hash = "sha256:8d8515156dd91f5652d13b5fcc87e634f8fe1c07b68d1d0840348cdd50bf5ace"}, + {file = "pyinstrument-4.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb868fbe089036e9f32525a249f4c78b8dc46967612393f204b8234f439c9cc4"}, + {file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a18cd234cce4f230f1733807f17a134e64a1f1acabf74a14d27f583cf2b183df"}, + {file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:574cfca69150be4ce4461fb224712fbc0722a49b0dc02fa204d02807adf6b5a0"}, + {file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e02cf505e932eb8ccf561b7527550a67ec14fcae1fe0e25319b09c9c166e914"}, + {file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832fb2acef9d53701c1ab546564c45fb70a8770c816374f8dd11420d399103c9"}, + {file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13cb57e9607545623ebe462345b3d0c4caee0125d2d02267043ece8aca8f4ea0"}, + {file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9be89e7419bcfe8dd6abb0d959d6d9c439c613a4a873514c43d16b48dae697c9"}, + {file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:476785cfbc44e8e1b1ad447398aa3deae81a8df4d37eb2d8bbb0c404eff979cd"}, + {file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e9cebd90128a3d2fee36d3ccb665c1b9dce75261061b2046203e45c4a8012d54"}, + {file = "pyinstrument-4.6.1-cp312-cp312-win32.whl", hash = "sha256:1d0b76683df2ad5c40eff73607dc5c13828c92fbca36aff1ddf869a3c5a55fa6"}, + {file = "pyinstrument-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:c4b7af1d9d6a523cfbfedebcb69202242d5bd0cb89c4e094cc73d5d6e38279bd"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:79ae152f8c6a680a188fb3be5e0f360ac05db5bbf410169a6c40851dfaebcce9"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cad2745964c174c65aa75f1bf68a4394d1b4d28f33894837cfd315d1e836f0"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb81f66f7f94045d723069cf317453d42375de9ff3c69089cf6466b078ac1db4"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab30ae75969da99e9a529e21ff497c18fdf958e822753db4ae7ed1e67094040"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f36cb5b644762fb3c86289324bbef17e95f91cd710603ac19444a47f638e8e96"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8b45075d9dbbc977dbc7007fb22bb0054c6990fbe91bf48dd80c0b96c6307ba7"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:475ac31477f6302e092463896d6a2055f3e6abcd293bad16ff94fc9185308a88"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-win32.whl", hash = "sha256:29172ab3d8609fdf821c3f2562dc61e14f1a8ff5306607c32ca743582d3a760e"}, + {file = "pyinstrument-4.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:bd176f297c99035127b264369d2bb97a65255f65f8d4e843836baf55ebb3cee4"}, + {file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:23e9b4526978432e9999021da9a545992cf2ac3df5ee82db7beb6908fc4c978c"}, + {file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2dbcaccc9f456ef95557ec501caeb292119c24446d768cb4fb43578b0f3d572c"}, + {file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2097f63c66c2bc9678c826b9ff0c25acde3ed455590d9dcac21220673fe74fbf"}, + {file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205ac2e76bd65d61b9611a9ce03d5f6393e34ec5b41dd38808f25d54e6b3e067"}, + {file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f414ddf1161976a40fc0a333000e6a4ad612719eac0b8c9bb73f47153187148"}, + {file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:65e62ebfa2cd8fb57eda90006f4505ac4c70da00fc2f05b6d8337d776ea76d41"}, + {file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d96309df4df10be7b4885797c5f69bb3a89414680ebaec0722d8156fde5268c3"}, + {file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f3d1ad3bc8ebb4db925afa706aa865c4bfb40d52509f143491ac0df2440ee5d2"}, + {file = "pyinstrument-4.6.1-cp38-cp38-win32.whl", hash = "sha256:dc37cb988c8854eb42bda2e438aaf553536566657d157c4473cc8aad5692a779"}, + {file = "pyinstrument-4.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:2cd4ce750c34a0318fc2d6c727cc255e9658d12a5cf3f2d0473f1c27157bdaeb"}, + {file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6ca95b21f022e995e062b371d1f42d901452bcbedd2c02f036de677119503355"}, + {file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ac1e1d7e1f1b64054c4eb04eb4869a7a5eef2261440e73943cc1b1bc3c828c18"}, + {file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0711845e953fce6ab781221aacffa2a66dbc3289f8343e5babd7b2ea34da6c90"}, + {file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b7d28582017de35cb64eb4e4fa603e753095108ca03745f5d17295970ee631f"}, + {file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7be57db08bd366a37db3aa3a6187941ee21196e8b14975db337ddc7d1490649d"}, + {file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9a0ac0f56860398d2628ce389826ce83fb3a557d0c9a2351e8a2eac6eb869983"}, + {file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a9045186ff13bc826fef16be53736a85029aae3c6adfe52e666cad00d7ca623b"}, + {file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6c4c56b6eab9004e92ad8a48bb54913fdd71fc8a748ae42a27b9e26041646f8b"}, + {file = "pyinstrument-4.6.1-cp39-cp39-win32.whl", hash = "sha256:37e989c44b51839d0c97466fa2b623638b9470d56d79e329f359f0e8fa6d83db"}, + {file = "pyinstrument-4.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:5494c5a84fee4309d7d973366ca6b8b9f8ba1d6b254e93b7c506264ef74f2cef"}, + {file = "pyinstrument-4.6.1.tar.gz", hash = "sha256:f4731b27121350f5a983d358d2272fe3df2f538aed058f57217eef7801a89288"}, +] + +[package.extras] +bin = ["click", "nox"] +docs = ["furo (==2021.6.18b36)", "myst-parser (==0.15.1)", "sphinx (==4.2.0)", "sphinxcontrib-programoutput (==0.17)"] +examples = ["django", "numpy"] +test = ["flaky", "greenlet (>=3.0.0a1)", "ipython", "pytest", "pytest-asyncio (==0.12.0)", "sphinx-autobuild (==2021.3.14)", "trio"] +types = ["typing-extensions"] + [[package]] name = "pylibsrtp" version = "0.8.0" @@ -4143,4 +4219,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe" +content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a" diff --git a/server/pyproject.toml b/server/pyproject.toml index 7681af39..2a901918 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -41,6 +41,7 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"} [tool.poetry.group.dev.dependencies] black = "^23.7.0" stamina = "^23.1.0" +pyinstrument = "^4.6.1" [tool.poetry.group.tests.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index c2e3bf7e..5bfffeca 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -41,7 +41,6 @@ if settings.SENTRY_DSN: else: logger.info("Sentry disabled") - # build app app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -102,6 +101,23 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None: use_route_names_as_operation_ids(app) +if settings.PROFILING: + from fastapi import Request + from fastapi.responses import HTMLResponse + from pyinstrument import Profiler + + @app.middleware("http") + async def profile_request(request: Request, call_next): + profiling = request.query_params.get("profile", False) + if profiling: + profiler = Profiler(async_mode="enabled") + profiler.start() + await call_next(request) + profiler.stop() + return HTMLResponse(profiler.output_html()) + else: + return await call_next(request) + if __name__ == "__main__": import uvicorn diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 4b91423a..c563f587 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -11,7 +11,6 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform transcripts = sqlalchemy.Table( "transcript", @@ -86,6 +85,14 @@ class TranscriptFinalTitle(BaseModel): title: str +class TranscriptDuration(BaseModel): + duration: float + + +class TranscriptWaveform(BaseModel): + waveform: list[float] + + class TranscriptEvent(BaseModel): event: str data: dict @@ -126,22 +133,6 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - def unlink(self): self.data_path.unlink(missing_ok=True) diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py deleted file mode 100644 index e0384770..00000000 --- a/server/reflector/llm/llm_banana.py +++ /dev/null @@ -1,54 +0,0 @@ -import httpx - -from reflector.llm.base import LLM -from reflector.settings import settings -from reflector.utils.retry import retry - - -class BananaLLM(LLM): - def __init__(self): - super().__init__() - self.timeout = settings.LLM_TIMEOUT - self.headers = { - "X-Banana-API-Key": settings.LLM_BANANA_API_KEY, - "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, - } - - async def _generate( - self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs - ): - json_payload = {"prompt": prompt} - if gen_schema: - json_payload["gen_schema"] = gen_schema - if gen_cfg: - json_payload["gen_cfg"] = gen_cfg - async with httpx.AsyncClient() as client: - response = await retry(client.post)( - settings.LLM_URL, - headers=self.headers, - json=json_payload, - timeout=self.timeout, - retry_timeout=300, # as per their sdk - ) - response.raise_for_status() - text = response.json()["text"] - return text - - -LLM.register("banana", BananaLLM) - -if __name__ == "__main__": - from reflector.logger import logger - - async def main(): - llm = BananaLLM() - prompt = llm.create_prompt( - instruct="Complete the following task", - text="Tell me a joke about programming.", - ) - result = await llm.generate(prompt=prompt, logger=logger) - print(result) - - import asyncio - - asyncio.run(main()) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index b2bc51ea..3a9d1868 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -21,11 +21,13 @@ from pydantic import BaseModel from reflector.app import app from reflector.db.transcripts import ( Transcript, + TranscriptDuration, TranscriptFinalLongSummary, TranscriptFinalShortSummary, TranscriptFinalTitle, TranscriptText, TranscriptTopic, + TranscriptWaveform, transcripts_controller, ) from reflector.logger import logger @@ -45,6 +47,7 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) +from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import ( TitleSummaryWithId as TitleSummaryWithIdProcessorType, @@ -230,6 +233,33 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) + @broadcast_to_sockets + async def on_duration(self, data): + async with self.transaction(): + duration = TranscriptDuration(duration=data) + + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "duration": duration.duration, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, event="DURATION", data=duration + ) + + @broadcast_to_sockets + async def on_waveform(self, data): + async with self.transaction(): + waveform = TranscriptWaveform(waveform=data) + + transcript = await self.get_transcript() + + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) + class PipelineMainLive(PipelineMainBase): audio_filename: Path | None = None @@ -243,7 +273,10 @@ class PipelineMainLive(PipelineMainBase): transcript = await self.get_transcript() processors = [ - AudioFileWriterProcessor(path=transcript.audio_mp3_filename), + AudioFileWriterProcessor( + path=transcript.audio_mp3_filename, + on_duration=self.on_duration, + ), AudioChunkerProcessor(), AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(), @@ -253,6 +286,11 @@ class PipelineMainLive(PipelineMainBase): BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + AudioWaveformProcessor.as_threaded( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), ] ), ] @@ -285,8 +323,13 @@ class PipelineMainDiarization(PipelineMainBase): # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [ - AudioDiarizationAutoProcessor(callback=self.on_topic), + processors = [] + if settings.DIARIZATION_ENABLED: + processors += [ + AudioDiarizationAutoProcessor(callback=self.on_topic), + ] + + processors += [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py index d34dc3f0..36ee4263 100644 --- a/server/reflector/processors/audio_file_writer.py +++ b/server/reflector/processors/audio_file_writer.py @@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor): INPUT_TYPE = av.AudioFrame OUTPUT_TYPE = av.AudioFrame - def __init__(self, path: Path | str): - super().__init__() + def __init__(self, path: Path | str, **kwargs): + super().__init__(**kwargs) if isinstance(path, str): path = Path(path) if path.suffix not in (".mp3", ".wav"): @@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor): self.path = path self.out_container = None self.out_stream = None + self.last_packet = None async def _push(self, data: av.AudioFrame): if not self.out_container: @@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor): raise ValueError("Only mp3 and wav files are supported") for packet in self.out_stream.encode(data): self.out_container.mux(packet) + self.last_packet = packet await self.emit(data) async def _flush(self): if self.out_container: for packet in self.out_stream.encode(): self.out_container.mux(packet) + self.last_packet = packet + try: + if self.last_packet is not None: + duration = round( + float( + (self.last_packet.pts * self.last_packet.duration) + * self.last_packet.time_base + ), + 2, + ) + except Exception: + self.logger.exception("Failed to get duration") + duration = 0 + self.out_container.close() self.out_container = None self.out_stream = None + + if duration > 0: + await self.emit(duration, name="duration") diff --git a/server/reflector/processors/audio_transcript_banana.py b/server/reflector/processors/audio_transcript_banana.py deleted file mode 100644 index fe339eea..00000000 --- a/server/reflector/processors/audio_transcript_banana.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Implementation using the GPU service from banana. - -API will be a POST request to TRANSCRIPT_URL: - -```json -{ - "audio_url": "https://...", - "audio_ext": "wav", - "timestamp": 123.456 - "language": "en" -} -``` - -""" - -from pathlib import Path - -import httpx -from reflector.processors.audio_transcript import AudioTranscriptProcessor -from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor -from reflector.processors.types import AudioFile, Transcript, Word -from reflector.settings import settings -from reflector.storage import Storage -from reflector.utils.retry import retry - - -class AudioTranscriptBananaProcessor(AudioTranscriptProcessor): - def __init__(self, banana_api_key: str, banana_model_key: str): - super().__init__() - self.transcript_url = settings.TRANSCRIPT_URL - self.timeout = settings.TRANSCRIPT_TIMEOUT - self.storage = Storage.get_instance( - settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_" - ) - self.headers = { - "X-Banana-API-Key": banana_api_key, - "X-Banana-Model-Key": banana_model_key, - } - - async def _transcript(self, data: AudioFile): - async with httpx.AsyncClient() as client: - print(f"Uploading audio {data.path.name} to S3") - url = await self._upload_file(data.path) - - print(f"Try to transcribe audio {data.path.name}") - request_data = { - "audio_url": url, - "audio_ext": data.path.suffix[1:], - "timestamp": float(round(data.timestamp, 2)), - } - response = await retry(client.post)( - self.transcript_url, - json=request_data, - headers=self.headers, - timeout=self.timeout, - ) - - print(f"Transcript response: {response.status_code} {response.content}") - response.raise_for_status() - result = response.json() - transcript = Transcript( - text=result["text"], - words=[ - Word(text=word["text"], start=word["start"], end=word["end"]) - for word in result["words"] - ], - ) - - # remove audio file from S3 - await self._delete_file(data.path) - - return transcript - - @retry - async def _upload_file(self, path: Path) -> str: - upload_result = await self.storage.put_file(path.name, open(path, "rb")) - return upload_result.url - - @retry - async def _delete_file(self, path: Path): - await self.storage.delete_file(path.name) - return True - - -AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor) diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py new file mode 100644 index 00000000..f1a24ffd --- /dev/null +++ b/server/reflector/processors/audio_waveform_processor.py @@ -0,0 +1,36 @@ +import json +from pathlib import Path + +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary +from reflector.utils.audio_waveform import get_audio_waveform + + +class AudioWaveformProcessor(Processor): + """ + Write the waveform for the final audio + """ + + INPUT_TYPE = TitleSummary + + def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs): + super().__init__(**kwargs) + if isinstance(audio_path, str): + audio_path = Path(audio_path) + if audio_path.suffix not in (".mp3", ".wav"): + raise ValueError("Only mp3 and wav files are supported") + self.audio_path = audio_path + self.waveform_path = waveform_path + + async def _flush(self): + self.waveform_path.parent.mkdir(parents=True, exist_ok=True) + self.logger.info("Waveform Processing Started") + waveform = get_audio_waveform(path=self.audio_path, segments_count=255) + + with open(self.waveform_path, "w") as fd: + json.dump(waveform, fd) + self.logger.info("Waveform Processing Finished") + await self.emit(waveform, name="waveform") + + async def _push(_self, _data): + return diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 46bfb4a5..00f0223b 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -14,7 +14,42 @@ class PipelineEvent(BaseModel): data: Any -class Processor: +class Emitter: + def __init__(self, **kwargs): + self._callbacks = {} + + # register callbacks from kwargs (on_*) + for key, value in kwargs.items(): + if key.startswith("on_"): + self.on(value, name=key[3:]) + + def on(self, callback, name="default"): + """ + Register a callback to be called when data is emitted + """ + # ensure callback is asynchronous + if not asyncio.iscoroutinefunction(callback): + raise ValueError("Callback must be a coroutine function") + if name not in self._callbacks: + self._callbacks[name] = [] + self._callbacks[name].append(callback) + + def off(self, callback, name="default"): + """ + Unregister a callback to be called when data is emitted + """ + if name not in self._callbacks: + return + self._callbacks[name].remove(callback) + + async def emit(self, data, name="default"): + if name not in self._callbacks: + return + for callback in self._callbacks[name]: + await callback(data) + + +class Processor(Emitter): INPUT_TYPE: type = None OUTPUT_TYPE: type = None @@ -59,7 +94,8 @@ class Processor: ["processor"], ) - def __init__(self, callback=None, custom_logger=None): + def __init__(self, callback=None, custom_logger=None, **kwargs): + super().__init__(**kwargs) self.name = name = self.__class__.__name__ self.m_processor = self.m_processor.labels(name) self.m_processor_call = self.m_processor_call.labels(name) @@ -70,9 +106,11 @@ class Processor: self.m_processor_flush_success = self.m_processor_flush_success.labels(name) self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name) self._processors = [] - self._callbacks = [] + + # register callbacks if callback: self.on(callback) + self.uid = uuid4().hex self.flushed = False self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) @@ -100,21 +138,6 @@ class Processor: """ self._processors.remove(processor) - def on(self, callback): - """ - Register a callback to be called when data is emitted - """ - # ensure callback is asynchronous - if not asyncio.iscoroutinefunction(callback): - raise ValueError("Callback must be a coroutine function") - self._callbacks.append(callback) - - def off(self, callback): - """ - Unregister a callback to be called when data is emitted - """ - self._callbacks.remove(callback) - def get_pref(self, key: str, default: Any = None): """ Get a preference from the pipeline prefs @@ -123,15 +146,16 @@ class Processor: return self.pipeline.get_pref(key, default) return default - async def emit(self, data): - if self.pipeline: - await self.pipeline.emit( - PipelineEvent(processor=self.name, uid=self.uid, data=data) - ) - for callback in self._callbacks: - await callback(data) - for processor in self._processors: - await processor.push(data) + async def emit(self, data, name="default"): + if name == "default": + if self.pipeline: + await self.pipeline.emit( + PipelineEvent(processor=self.name, uid=self.uid, data=data) + ) + await super().emit(data, name=name) + if name == "default": + for processor in self._processors: + await processor.push(data) async def push(self, data): """ @@ -254,11 +278,11 @@ class ThreadedProcessor(Processor): def disconnect(self, processor: Processor): self.processor.disconnect(processor) - def on(self, callback): - self.processor.on(callback) + def on(self, callback, name="default"): + self.processor.on(callback, name=name) - def off(self, callback): - self.processor.off(callback) + def off(self, callback, name="default"): + self.processor.off(callback, name=name) def describe(self, level=0): super().describe(level) @@ -305,13 +329,13 @@ class BroadcastProcessor(Processor): for processor in self.processors: processor.disconnect(processor) - def on(self, callback): + def on(self, callback, name="default"): for processor in self.processors: - processor.on(callback) + processor.on(callback, name=name) - def off(self, callback): + def off(self, callback, name="default"): for processor in self.processors: - processor.off(callback) + processor.off(callback, name=name) def describe(self, level=0): super().describe(level) diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py index 905ea423..fbb07164 100644 --- a/server/reflector/processors/transcript_translator.py +++ b/server/reflector/processors/transcript_translator.py @@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor): def __init__(self, **kwargs): super().__init__(**kwargs) + self.transcript = None self.translate_url = settings.TRANSLATE_URL self.timeout = settings.TRANSLATE_TIMEOUT self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"} diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 312f5433..93e565df 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -5,6 +5,7 @@ from pathlib import Path from profanityfilter import ProfanityFilter from pydantic import BaseModel, PrivateAttr +from reflector.redis_cache import redis_cache PUNC_RE = re.compile(r"[.;:?!…]") @@ -68,10 +69,14 @@ class Transcript(BaseModel): # Uncensored text return "".join([word.text for word in self.words]) + @redis_cache(prefix="profanity", duration=3600 * 24 * 7) + def _get_censored_text(self, text: str): + return profanity_filter.censor(text).strip() + @property def text(self): # Censored text - return profanity_filter.censor(self.raw_text).strip() + return self._get_censored_text(self.raw_text) @property def human_timestamp(self): diff --git a/server/reflector/redis_cache.py b/server/reflector/redis_cache.py new file mode 100644 index 00000000..c31471cf --- /dev/null +++ b/server/reflector/redis_cache.py @@ -0,0 +1,50 @@ +import functools +import json + +import redis +from reflector.settings import settings + +redis_clients = {} + + +def get_redis_client(db=0): + """ + Get a Redis client for the specified database. + """ + if db not in redis_clients: + redis_clients[db] = redis.StrictRedis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=db, + ) + return redis_clients[db] + + +def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1): + """ + Cache the result of a function in Redis. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Check if the first argument is a string + if len(args) < (argidx + 1) or not isinstance(args[argidx], str): + return func(*args, **kwargs) + + # Compute the cache key based on the arguments and prefix + cache_key = prefix + ":" + args[argidx] + redis_client = get_redis_client(db=db) + cached_result = redis_client.get(cache_key) + + if cached_result: + return json.loads(cached_result.decode("utf-8")) + + # If the result is not cached, call the original function + result = func(*args, **kwargs) + redis_client.setex(cache_key, duration, json.dumps(result)) + return result + + return wrapper + + return decorator diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 021d509f..65412310 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -41,7 +41,7 @@ class Settings(BaseSettings): AUDIO_BUFFER_SIZE: int = 256 * 960 # Audio Transcription - # backends: whisper, banana, modal + # backends: whisper, modal TRANSCRIPT_BACKEND: str = "whisper" TRANSCRIPT_URL: str | None = None TRANSCRIPT_TIMEOUT: int = 90 @@ -50,10 +50,6 @@ class Settings(BaseSettings): TRANSLATE_URL: str | None = None TRANSLATE_TIMEOUT: int = 90 - # Audio transcription banana.dev configuration - TRANSCRIPT_BANANA_API_KEY: str | None = None - TRANSCRIPT_BANANA_MODEL_KEY: str | None = None - # Audio transcription modal.com configuration TRANSCRIPT_MODAL_API_KEY: str | None = None @@ -61,13 +57,16 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_BACKEND: str = "aws" # Storage configuration for AWS - TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks" + TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1" TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None + # Transcript MP3 storage + TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" + # LLM - # available backend: openai, banana, modal, oobabooga + # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga" # LLM common configuration @@ -82,14 +81,11 @@ class Settings(BaseSettings): LLM_TEMPERATURE: float = 0.7 ZEPHYR_LLM_URL: str | None = None - # LLM Banana configuration - LLM_BANANA_API_KEY: str | None = None - LLM_BANANA_MODEL_KEY: str | None = None - # LLM Modal configuration LLM_MODAL_API_KEY: str | None = None # Diarization + DIARIZATION_ENABLED: bool = True DIARIZATION_BACKEND: str = "modal" DIARIZATION_URL: str | None = None @@ -124,6 +120,7 @@ class Settings(BaseSettings): # Redis REDIS_HOST: str = "localhost" REDIS_PORT: int = 6379 + REDIS_CACHE_DB: int = 2 # Secret key SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21" @@ -131,5 +128,8 @@ class Settings(BaseSettings): # Current hosting/domain BASE_URL: str = "http://localhost:1250" + # Profiling + PROFILING: bool = False + settings = Settings() diff --git a/server/reflector/views/_range_requests_response.py b/server/reflector/views/_range_requests_response.py index f0c628e9..2fac632d 100644 --- a/server/reflector/views/_range_requests_response.py +++ b/server/reflector/views/_range_requests_response.py @@ -1,7 +1,7 @@ import os from typing import BinaryIO -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request, Response, status from fastapi.responses import StreamingResponse @@ -57,6 +57,9 @@ def range_requests_response( ), } + if request.method == "HEAD": + return Response(headers=headers) + if content_disposition: headers["Content-Disposition"] = content_disposition diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 5f1d7831..88351880 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -23,7 +23,6 @@ from reflector.db.transcripts import ( from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings from reflector.ws_manager import get_ws_manager -from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import RtcOffer, rtc_offer_base @@ -53,7 +52,7 @@ class GetTranscript(BaseModel): name: str status: str locked: bool - duration: int + duration: float title: str | None short_summary: str | None long_summary: str | None @@ -222,6 +221,7 @@ async def transcript_delete( @router.get("/transcripts/{transcript_id}/audio/mp3") +@router.head("/transcripts/{transcript_id}/audio/mp3") async def transcript_get_audio_mp3( request: Request, transcript_id: str, @@ -272,8 +272,6 @@ async def transcript_get_audio_waveform( if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=500, detail="Audio not found") - await run_in_threadpool(transcript.convert_audio_to_waveform) - return transcript.audio_waveform diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 79cb25bf..28f83fff 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty assert response.status_code == 200 assert response.headers["content-type"] == content_type + # test get 404 + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url_suffix,content_type", + [ + ["/mp3", "audio/mpeg"], + ], +) +async def test_transcript_audio_download_head( + fake_transcript, url_suffix, content_type +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}") + assert response.status_code == 200 + assert response.headers["content-type"] == content_type + + # test head 404 + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") + assert response.status_code == 404 + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -90,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek( assert response.status_code == 206 assert response.headers["content-type"] == content_type assert response.headers["content-range"].startswith("bytes 100-") - - -@pytest.mark.asyncio -async def test_transcript_audio_download_waveform(fake_transcript): - from reflector.app import app - - ac = AsyncClient(app=app, base_url="http://test/v1") - response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - assert isinstance(response.json()["data"], list) - assert len(response.json()["data"]) >= 255 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 413c8b24..b33b1db5 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -182,6 +182,16 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("FINAL_TITLE")] assert ev["data"]["title"] == "LLM TITLE" + assert "WAVEFORM" in eventnames + ev = events[eventnames.index("WAVEFORM")] + assert isinstance(ev["data"]["waveform"], list) + assert len(ev["data"]["waveform"]) >= 250 + waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform") + assert waveform_resp.status_code == 200 + assert waveform_resp.headers["content-type"] == "application/json" + assert isinstance(waveform_resp.json()["data"], list) + assert len(waveform_resp.json()["data"]) >= 250 + # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] assert statuses.index("recording") < statuses.index("processing") @@ -191,10 +201,14 @@ async def test_transcript_rtc_and_websocket( assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" + # check on the latest response that the audio duration is > 0 + assert resp.json()["duration"] > 0 + assert "DURATION" in eventnames + # check that audio/mp3 is available - resp = await ac.get(f"/transcripts/{tid}/audio/mp3") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mpeg" + audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert audio_resp.status_code == 200 + assert audio_resp.headers["Content-Type"] == "audio/mpeg" @pytest.mark.usefixtures("celery_session_app") diff --git a/www/app/(errors)/errorContext.tsx b/www/app/(errors)/errorContext.tsx index d8a80c04..d541c6f0 100644 --- a/www/app/(errors)/errorContext.tsx +++ b/www/app/(errors)/errorContext.tsx @@ -3,7 +3,8 @@ import React, { createContext, useContext, useState } from "react"; interface ErrorContextProps { error: Error | null; - setError: React.Dispatch>; + humanMessage?: string; + setError: (error: Error, humanMessage?: string) => void; } const ErrorContext = createContext(undefined); @@ -22,9 +23,16 @@ interface ErrorProviderProps { export const ErrorProvider: React.FC = ({ children }) => { const [error, setError] = useState(null); + const [humanMessage, setHumanMessage] = useState(); + const declareError = (error, humanMessage?) => { + setError(error); + setHumanMessage(humanMessage); + }; return ( - + {children} ); diff --git a/www/app/(errors)/errorMessage.tsx b/www/app/(errors)/errorMessage.tsx index 8b410c4c..6d198650 100644 --- a/www/app/(errors)/errorMessage.tsx +++ b/www/app/(errors)/errorMessage.tsx @@ -4,29 +4,51 @@ import { useEffect, useState } from "react"; import * as Sentry from "@sentry/react"; const ErrorMessage: React.FC = () => { - const { error, setError } = useError(); + const { error, setError, humanMessage } = useError(); const [isVisible, setIsVisible] = useState(false); + // Setup Shortcuts + useEffect(() => { + const handleKeyPress = (event: KeyboardEvent) => { + switch (event.key) { + case "^": + throw new Error("Unhandled Exception thrown by '^' shortcut"); + case "$": + setError( + new Error("Unhandled Exception thrown by '$' shortcut"), + "You did this to yourself", + ); + } + }; + + document.addEventListener("keydown", handleKeyPress); + return () => document.removeEventListener("keydown", handleKeyPress); + }, []); + useEffect(() => { if (error) { - setIsVisible(true); - Sentry.captureException(error); - console.error("Error", error.message, error); + if (humanMessage) { + setIsVisible(true); + Sentry.captureException(Error(humanMessage, { cause: error })); + } else { + Sentry.captureException(error); + } + + console.error("Error", error); } }, [error]); - if (!isVisible || !error) return null; + if (!isVisible || !humanMessage) return null; return ( ); }; diff --git a/www/app/[domain]/browse/pagination.tsx b/www/app/[domain]/browse/pagination.tsx index 27ff5f47..e10d5321 100644 --- a/www/app/[domain]/browse/pagination.tsx +++ b/www/app/[domain]/browse/pagination.tsx @@ -40,7 +40,7 @@ export default function Pagination(props: PaginationProps) { return (
+
- {transcript?.response?.longSummary && ( + {transcript.response.longSummary ? ( + ) : ( +
+ {transcript.response.status == "processing" ? ( +

Loading Transcript

+ ) : ( +

+ There was an error generating the final summary, please + come back later +

+ )} +
)}
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 41a2d053..2c5b73e0 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -8,12 +8,15 @@ import { useWebSockets } from "../../useWebSockets"; import useAudioDevice from "../../useAudioDevice"; import "../../../../styles/button.css"; import { Topic } from "../../webSocketTypes"; -import getApi from "../../../../lib/getApi"; import LiveTrancription from "../../liveTranscription"; import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; +import { useRouter } from "next/navigation"; +import Player from "../../player"; +import useMp3, { Mp3Response } from "../../useMp3"; +import WaveformLoading from "../../waveformLoading"; type TranscriptDetails = { params: { @@ -42,8 +45,12 @@ const TranscriptRecord = (details: TranscriptDetails) => { const { audioDevices, getAudioStream } = useAudioDevice(); - const [hasRecorded, setHasRecorded] = useState(false); + const [recordedTime, setRecordedTime] = useState(0); + const [startTime, setStartTime] = useState(0); const [transcriptStarted, setTranscriptStarted] = useState(false); + let mp3 = useMp3(details.params.transcriptId, true); + + const router = useRouter(); useEffect(() => { if (!transcriptStarted && webSockets.transcriptText.length !== 0) @@ -51,15 +58,27 @@ const TranscriptRecord = (details: TranscriptDetails) => { }, [webSockets.transcriptText]); useEffect(() => { - if (transcript?.response?.longSummary) { - const newUrl = `/transcripts/${transcript.response.id}`; + const statusToRedirect = ["ended", "error"]; + + //TODO if has no topic and is error, get back to new + if ( + statusToRedirect.includes(transcript.response?.status) || + statusToRedirect.includes(webSockets.status.value) + ) { + const newUrl = "/transcripts/" + details.params.transcriptId; // Shallow redirection does not work on NextJS 13 // https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/49540 - // router.push(newUrl, undefined, { shallow: true }); - history.replaceState({}, "", newUrl); + router.replace(newUrl); + // history.replaceState({}, "", newUrl); + } // history.replaceState({}, "", newUrl); + }, [webSockets.status.value, transcript.response?.status]); + + useEffect(() => { + if (webSockets.duration) { + mp3.getNow(); } - }); + }, [webSockets.duration]); useEffect(() => { lockWakeState(); @@ -70,19 +89,31 @@ const TranscriptRecord = (details: TranscriptDetails) => { return ( <> - { - setStream(null); - setHasRecorded(true); - webRTC?.send(JSON.stringify({ cmd: "STOP" })); - }} - topics={webSockets.topics} - getAudioStream={getAudioStream} - useActiveTopic={useActiveTopic} - isPastMeeting={false} - audioDevices={audioDevices} - /> + {webSockets.waveform && webSockets.duration && mp3?.media ? ( + + ) : recordedTime ? ( + + ) : ( + { + setStream(null); + setRecordedTime(Date.now() - startTime); + webRTC?.send(JSON.stringify({ cmd: "STOP" })); + }} + onRecord={() => { + setStartTime(Date.now()); + }} + getAudioStream={getAudioStream} + audioDevices={audioDevices} + /> + )}
{
- {!hasRecorded ? ( + {!recordedTime ? ( <> {transcriptStarted && (

Transcription

@@ -128,6 +159,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { couple of minutes. Please do not navigate away from the page during this time.

+ {/* NTH If login required remove last sentence */}
)} diff --git a/www/app/[domain]/transcripts/createTranscript.ts b/www/app/[domain]/transcripts/createTranscript.ts index 31a034f4..0d96b8db 100644 --- a/www/app/[domain]/transcripts/createTranscript.ts +++ b/www/app/[domain]/transcripts/createTranscript.ts @@ -45,7 +45,10 @@ const useCreateTranscript = (): CreateTranscript => { console.debug("New transcript created:", result); }) .catch((err) => { - setError(err); + setError( + err, + "There was an issue creating a transcript, please try again.", + ); setErrorState(err); setLoading(false); }); diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx index 463f6100..e0d0f1c9 100644 --- a/www/app/[domain]/transcripts/finalSummary.tsx +++ b/www/app/[domain]/transcripts/finalSummary.tsx @@ -87,7 +87,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx new file mode 100644 index 00000000..02151a68 --- /dev/null +++ b/www/app/[domain]/transcripts/player.tsx @@ -0,0 +1,166 @@ +import React, { useRef, useEffect, useState } from "react"; + +import WaveSurfer from "wavesurfer.js"; +import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; + +import { formatTime } from "../../lib/time"; +import { Topic } from "./webSocketTypes"; +import { AudioWaveform } from "../../api"; +import { waveSurferStyles } from "../../styles/recorder"; + +type PlayerProps = { + topics: Topic[]; + useActiveTopic: [ + Topic | null, + React.Dispatch>, + ]; + waveform: AudioWaveform["data"]; + media: HTMLMediaElement; + mediaDuration: number; +}; + +export default function Player(props: PlayerProps) { + const waveformRef = useRef(null); + const [wavesurfer, setWavesurfer] = useState(null); + const [isPlaying, setIsPlaying] = useState(false); + const [currentTime, setCurrentTime] = useState(0); + const [waveRegions, setWaveRegions] = useState( + null, + ); + const [activeTopic, setActiveTopic] = props.useActiveTopic; + const topicsRef = useRef(props.topics); + // Waveform setup + useEffect(() => { + if (waveformRef.current) { + // XXX duration is required to prevent recomputing peaks from audio + // However, the current waveform returns only the peaks, and no duration + // And the backend does not save duration properly. + // So at the moment, we deduct the duration from the topics. + // This is not ideal, but it works for now. + const _wavesurfer = WaveSurfer.create({ + container: waveformRef.current, + peaks: props.waveform, + hideScrollbar: true, + autoCenter: true, + barWidth: 2, + height: "auto", + duration: props.mediaDuration, + + ...waveSurferStyles.player, + }); + + // styling + const wsWrapper = _wavesurfer.getWrapper(); + wsWrapper.style.cursor = waveSurferStyles.playerStyle.cursor; + wsWrapper.style.backgroundColor = + waveSurferStyles.playerStyle.backgroundColor; + wsWrapper.style.borderRadius = waveSurferStyles.playerStyle.borderRadius; + + _wavesurfer.on("play", () => { + setIsPlaying(true); + }); + _wavesurfer.on("pause", () => { + setIsPlaying(false); + }); + _wavesurfer.on("timeupdate", setCurrentTime); + + setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); + + _wavesurfer.toggleInteraction(true); + + _wavesurfer.setMediaElement(props.media); + + setWavesurfer(_wavesurfer); + + return () => { + _wavesurfer.destroy(); + setIsPlaying(false); + setCurrentTime(0); + }; + } + }, []); + + useEffect(() => { + if (!wavesurfer) return; + if (!props.media) return; + wavesurfer.setMediaElement(props.media); + }, [props.media, wavesurfer]); + + useEffect(() => { + topicsRef.current = props.topics; + renderMarkers(); + }, [props.topics, waveRegions]); + + const renderMarkers = () => { + if (!waveRegions) return; + + waveRegions.clearRegions(); + + for (let topic of topicsRef.current) { + const content = document.createElement("div"); + content.setAttribute("style", waveSurferStyles.marker); + content.onmouseover = () => { + content.style.backgroundColor = + waveSurferStyles.markerHover.backgroundColor; + content.style.zIndex = "999"; + content.style.width = "300px"; + }; + content.onmouseout = () => { + content.setAttribute("style", waveSurferStyles.marker); + }; + content.textContent = topic.title; + + const region = waveRegions.addRegion({ + start: topic.timestamp, + content, + color: "f00", + drag: false, + }); + region.on("click", (e) => { + e.stopPropagation(); + setActiveTopic(topic); + wavesurfer?.setTime(region.start); + }); + } + }; + + useEffect(() => { + if (activeTopic) { + wavesurfer?.setTime(activeTopic.timestamp); + } + }, [activeTopic]); + + const handlePlayClick = () => { + wavesurfer?.playPause(); + }; + + const timeLabel = () => { + if (props.mediaDuration) + return `${formatTime(currentTime)}/${formatTime(props.mediaDuration)}`; + return ""; + }; + + return ( +
+
+
+
{timeLabel()}
+
+ + +
+ ); +} diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index 765d8f09..e7c016a7 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -6,31 +6,19 @@ import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faMicrophone } from "@fortawesome/free-solid-svg-icons"; -import { faDownload } from "@fortawesome/free-solid-svg-icons"; import { formatTime } from "../../lib/time"; -import { Topic } from "./webSocketTypes"; -import { AudioWaveform } from "../../api"; import AudioInputsDropdown from "./audioInputsDropdown"; import { Option } from "react-dropdown"; -import { useError } from "../../(errors)/errorContext"; import { waveSurferStyles } from "../../styles/recorder"; -import useMp3 from "./useMp3"; +import { useError } from "../../(errors)/errorContext"; type RecorderProps = { - setStream?: React.Dispatch>; - onStop?: () => void; - topics: Topic[]; - getAudioStream?: (deviceId) => Promise; - audioDevices?: Option[]; - useActiveTopic: [ - Topic | null, - React.Dispatch>, - ]; - waveform?: AudioWaveform | null; - isPastMeeting: boolean; - transcriptId?: string | null; - mp3Blob?: Blob | null; + setStream: React.Dispatch>; + onStop: () => void; + onRecord?: () => void; + getAudioStream: (deviceId) => Promise; + audioDevices: Option[]; }; export default function Recorder(props: RecorderProps) { @@ -38,7 +26,7 @@ export default function Recorder(props: RecorderProps) { const [wavesurfer, setWavesurfer] = useState(null); const [record, setRecord] = useState(null); const [isRecording, setIsRecording] = useState(false); - const [hasRecorded, setHasRecorded] = useState(props.isPastMeeting); + const [hasRecorded, setHasRecorded] = useState(false); const [isPlaying, setIsPlaying] = useState(false); const [currentTime, setCurrentTime] = useState(0); const [timeInterval, setTimeInterval] = useState(null); @@ -48,8 +36,6 @@ export default function Recorder(props: RecorderProps) { ); const [deviceId, setDeviceId] = useState(null); const [recordStarted, setRecordStarted] = useState(false); - const [activeTopic, setActiveTopic] = props.useActiveTopic; - const topicsRef = useRef(props.topics); const [showDevices, setShowDevices] = useState(false); const { setError } = useError(); @@ -73,11 +59,6 @@ export default function Recorder(props: RecorderProps) { if (!record.isRecording()) return; handleRecClick(); break; - case "%": - setError(new Error("Error triggered by '%' shortcut")); - break; - case "^": - throw new Error("Unhandled Exception thrown by '^' shortcut"); case "(": location.href = "/login"; break; @@ -109,7 +90,6 @@ export default function Recorder(props: RecorderProps) { if (waveformRef.current) { const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - peaks: props.waveform?.data, hideScrollbar: true, autoCenter: true, barWidth: 2, @@ -118,10 +98,8 @@ export default function Recorder(props: RecorderProps) { ...waveSurferStyles.player, }); - if (!props.transcriptId) { - const _wshack: any = _wavesurfer; - _wshack.renderer.renderSingleCanvas = () => {}; - } + const _wshack: any = _wavesurfer; + _wshack.renderer.renderSingleCanvas = () => {}; // styling const wsWrapper = _wavesurfer.getWrapper(); @@ -141,12 +119,6 @@ export default function Recorder(props: RecorderProps) { setRecord(_wavesurfer.registerPlugin(RecordPlugin.create())); setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); - if (props.isPastMeeting) _wavesurfer.toggleInteraction(true); - - if (props.mp3Blob) { - _wavesurfer.loadBlob(props.mp3Blob); - } - setWavesurfer(_wavesurfer); return () => { @@ -158,58 +130,6 @@ export default function Recorder(props: RecorderProps) { } }, []); - useEffect(() => { - if (!wavesurfer) return; - if (!props.mp3Blob) return; - wavesurfer.loadBlob(props.mp3Blob); - }, [props.mp3Blob]); - - useEffect(() => { - topicsRef.current = props.topics; - if (!isRecording) renderMarkers(); - }, [props.topics, waveRegions]); - - const renderMarkers = () => { - if (!waveRegions) return; - - waveRegions.clearRegions(); - - for (let topic of topicsRef.current) { - const content = document.createElement("div"); - content.setAttribute("style", waveSurferStyles.marker); - content.onmouseover = () => { - content.style.backgroundColor = - waveSurferStyles.markerHover.backgroundColor; - content.style.zIndex = "999"; - content.style.width = "300px"; - }; - content.onmouseout = () => { - content.setAttribute("style", waveSurferStyles.marker); - }; - content.textContent = topic.title; - - const region = waveRegions.addRegion({ - start: topic.timestamp, - content, - color: "f00", - drag: false, - }); - region.on("click", (e) => { - e.stopPropagation(); - setActiveTopic(topic); - wavesurfer?.setTime(region.start); - }); - } - }; - - useEffect(() => { - if (!record) return; - - return record.on("stopRecording", () => { - renderMarkers(); - }); - }, [record]); - useEffect(() => { if (isRecording) { const interval = window.setInterval(() => { @@ -226,25 +146,24 @@ export default function Recorder(props: RecorderProps) { } }, [isRecording]); - useEffect(() => { - if (activeTopic) { - wavesurfer?.setTime(activeTopic.timestamp); - } - }, [activeTopic]); - const handleRecClick = async () => { if (!record) return console.log("no record"); if (record.isRecording()) { if (props.onStop) props.onStop(); record.stopRecording(); + if (screenMediaStream) { + screenMediaStream.getTracks().forEach((t) => t.stop()); + } setIsRecording(false); setHasRecorded(true); + setScreenMediaStream(null); + setDestinationStream(null); } else { + if (props.onRecord) props.onRecord(); const stream = await getCurrentStream(); if (props.setStream) props.setStream(stream); - waveRegions?.clearRegions(); if (stream) { await record.startRecording(stream); setIsRecording(true); @@ -252,6 +171,76 @@ export default function Recorder(props: RecorderProps) { } }; + const [screenMediaStream, setScreenMediaStream] = + useState(null); + + const handleRecordTabClick = async () => { + if (!record) return console.log("no record"); + const stream: MediaStream = await navigator.mediaDevices.getDisplayMedia({ + video: true, + audio: { + echoCancellation: true, + noiseSuppression: true, + sampleRate: 44100, + }, + }); + + if (stream.getAudioTracks().length == 0) { + setError(new Error("No audio track found in screen recording.")); + return; + } + setScreenMediaStream(stream); + }; + + const [destinationStream, setDestinationStream] = + useState(null); + + const startTabRecording = async () => { + if (!screenMediaStream) return; + if (!record) return; + if (destinationStream !== null) return console.log("already recording"); + + // connect mic audio (microphone) + const micStream = await getCurrentStream(); + if (!micStream) { + console.log("no microphone audio"); + return; + } + + // Create MediaStreamSource nodes for the microphone and tab + const audioContext = new AudioContext(); + const micSource = audioContext.createMediaStreamSource(micStream); + const tabSource = audioContext.createMediaStreamSource(screenMediaStream); + + // Merge channels + // XXX If the length is not the same, we do not receive audio in WebRTC. + // So for now, merge the channels to have only one stereo source + const channelMerger = audioContext.createChannelMerger(1); + micSource.connect(channelMerger, 0, 0); + tabSource.connect(channelMerger, 0, 0); + + // Create a MediaStreamDestination node + const destination = audioContext.createMediaStreamDestination(); + channelMerger.connect(destination); + + // Use the destination's stream for the WebRTC connection + setDestinationStream(destination.stream); + }; + + useEffect(() => { + if (!record) return; + if (!destinationStream) return; + if (props.setStream) props.setStream(destinationStream); + if (destinationStream) { + record.startRecording(destinationStream); + setIsRecording(true); + } + }, [record, destinationStream]); + + useEffect(() => { + startTabRecording(); + }, [record, screenMediaStream]); + const handlePlayClick = () => { wavesurfer?.playPause(); }; @@ -300,23 +289,9 @@ export default function Recorder(props: RecorderProps) { } text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`} id="play-btn" onClick={handlePlayClick} - disabled={isRecording} > {isPlaying ? "Pause" : "Play"} - - {props.transcriptId && ( - - - - )} )} {!hasRecorded && ( @@ -332,6 +307,19 @@ export default function Recorder(props: RecorderProps) { > {isRecording ? "Stop" : "Record"} + {!isRecording && ( + + )} {props.audioDevices && props.audioDevices?.length > 0 && deviceId && ( <>