mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing
This commit is contained in:
@@ -51,17 +51,6 @@
|
||||
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
|
||||
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||
|
||||
## Using serverless banana.dev (require reflector-gpu-banana deployed)
|
||||
## XXX this service is buggy do not use at the moment
|
||||
## XXX it also require the audio to be saved to S3
|
||||
#TRANSCRIPT_BACKEND=banana
|
||||
#TRANSCRIPT_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
|
||||
#TRANSCRIPT_BANANA_API_KEY=xxx
|
||||
#TRANSCRIPT_BANANA_MODEL_KEY=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY=xxx
|
||||
#TRANSCRIPT_STORAGE_AWS_BUCKET_NAME="reflector-bucket/chunks"
|
||||
|
||||
## =======================================================
|
||||
## LLM backend
|
||||
##
|
||||
@@ -78,13 +67,6 @@
|
||||
#LLM_URL=https://xxxxxx--reflector-llm-web.modal.run
|
||||
#LLM_MODAL_API_KEY=xxx
|
||||
|
||||
## Using serverless banana.dev (require reflector-gpu-banana deployed)
|
||||
## XXX this service is buggy do not use at the moment
|
||||
#LLM_BACKEND=banana
|
||||
#LLM_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
|
||||
#LLM_BANANA_API_KEY=xxxxx
|
||||
#LLM_BANANA_MODEL_KEY=xxxxx
|
||||
|
||||
## Using OpenAI
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_OPENAI_KEY=xxx
|
||||
|
||||
@@ -81,7 +81,8 @@ class LLM:
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
@@ -96,7 +97,8 @@ class LLM:
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
|
||||
@@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
||||
|
||||
stub = Stub(name="reflector-llm-zephyr")
|
||||
|
||||
@@ -81,7 +81,8 @@ class LLM:
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
@@ -96,7 +97,8 @@ class LLM:
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
@@ -95,7 +95,8 @@ class Transcriber:
|
||||
device=self.device,
|
||||
compute_type=WHISPER_COMPUTE_TYPE,
|
||||
num_workers=WHISPER_NUM_WORKERS,
|
||||
download_root=WHISPER_MODEL_DIR
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
@method()
|
||||
|
||||
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
64
server/migrations/versions/4814901632bc_fix_duration.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""fix duration
|
||||
|
||||
Revision ID: 4814901632bc
|
||||
Revises: 38a927dcb099
|
||||
Create Date: 2023-11-10 18:12:17.886522
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "4814901632bc"
|
||||
down_revision: Union[str, None] = "38a927dcb099"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# for all the transcripts, calculate the duration from the mp3
|
||||
# and update the duration column
|
||||
from pathlib import Path
|
||||
from reflector.settings import settings
|
||||
import av
|
||||
|
||||
bind = op.get_bind()
|
||||
transcript = table(
|
||||
"transcript", column("id", sa.String), column("duration", sa.Float)
|
||||
)
|
||||
|
||||
# select only the one with duration = 0
|
||||
results = bind.execute(
|
||||
select([transcript.c.id, transcript.c.duration]).where(
|
||||
transcript.c.duration == 0
|
||||
)
|
||||
)
|
||||
|
||||
data_dir = Path(settings.DATA_DIR)
|
||||
for row in results:
|
||||
audio_path = data_dir / row["id"] / "audio.mp3"
|
||||
if not audio_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"Processing {audio_path}")
|
||||
container = av.open(audio_path.as_posix())
|
||||
print(container.duration)
|
||||
duration = round(float(container.duration / av.time_base), 2)
|
||||
print(f"Duration: {duration}")
|
||||
bind.execute(
|
||||
transcript.update()
|
||||
.where(transcript.c.id == row["id"])
|
||||
.values(duration=duration)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to process {audio_path}: {e}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
78
server/poetry.lock
generated
78
server/poetry.lock
generated
@@ -2557,6 +2557,82 @@ typing-extensions = "*"
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"]
|
||||
|
||||
[[package]]
|
||||
name = "pyinstrument"
|
||||
version = "4.6.1"
|
||||
description = "Call stack profiler for Python. Shows you why your code is slow!"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:73476e4bc6e467ac1b2c3c0dd1f0b71c9061d4de14626676adfdfbb14aa342b4"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4d1da8efd974cf9df52ee03edaee2d3875105ddd00de35aa542760f7c612bdf7"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:507be1ee2f2b0c9fba74d622a272640dd6d1b0c9ec3388b2cdeb97ad1e77125f"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95cee6de08eb45754ef4f602ce52b640d1c535d934a6a8733a974daa095def37"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7873e8cec92321251fdf894a72b3c78f4c5c20afdd1fef0baf9042ec843bb04"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a242f6cac40bc83e1f3002b6b53681846dfba007f366971db0bf21e02dbb1903"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:97c9660cdb4bd2a43cf4f3ab52cffd22f3ac9a748d913b750178fb34e5e39e64"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e304cd0723e2b18ada5e63c187abf6d777949454c734f5974d64a0865859f0f4"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-win32.whl", hash = "sha256:cee21a2d78187dd8a80f72f5d0f1ddb767b2d9800f8bb4d94b6d11f217c22cdb"},
|
||||
{file = "pyinstrument-4.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:2000712f71d693fed2f8a1c1638d37b7919124f367b37976d07128d49f1445eb"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a366c6f3dfb11f1739bdc1dee75a01c1563ad0bf4047071e5e77598087df457f"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6be327be65d934796558aa9cb0f75ce62ebd207d49ad1854610c97b0579ad47"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e160d9c5d20d3e4ef82269e4e8b246ff09bdf37af5fb8cb8ccca97936d95ad6"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ffbf56605ef21c2fcb60de2fa74ff81f417d8be0c5002a407e414d6ef6dee43"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c92cc4924596d6e8f30a16182bbe90893b1572d847ae12652f72b34a9a17c24a"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f4b48a94d938cae981f6948d9ec603bab2087b178d2095d042d5a48aabaecaab"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e7a386392275bdef4a1849712dc5b74f0023483fca14ef93d0ca27d453548982"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:871b131b83e9b1122f2325061c68ed1e861eebcb568c934d2fb193652f077f77"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-win32.whl", hash = "sha256:8d8515156dd91f5652d13b5fcc87e634f8fe1c07b68d1d0840348cdd50bf5ace"},
|
||||
{file = "pyinstrument-4.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb868fbe089036e9f32525a249f4c78b8dc46967612393f204b8234f439c9cc4"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a18cd234cce4f230f1733807f17a134e64a1f1acabf74a14d27f583cf2b183df"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:574cfca69150be4ce4461fb224712fbc0722a49b0dc02fa204d02807adf6b5a0"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e02cf505e932eb8ccf561b7527550a67ec14fcae1fe0e25319b09c9c166e914"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832fb2acef9d53701c1ab546564c45fb70a8770c816374f8dd11420d399103c9"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13cb57e9607545623ebe462345b3d0c4caee0125d2d02267043ece8aca8f4ea0"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9be89e7419bcfe8dd6abb0d959d6d9c439c613a4a873514c43d16b48dae697c9"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:476785cfbc44e8e1b1ad447398aa3deae81a8df4d37eb2d8bbb0c404eff979cd"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e9cebd90128a3d2fee36d3ccb665c1b9dce75261061b2046203e45c4a8012d54"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-win32.whl", hash = "sha256:1d0b76683df2ad5c40eff73607dc5c13828c92fbca36aff1ddf869a3c5a55fa6"},
|
||||
{file = "pyinstrument-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:c4b7af1d9d6a523cfbfedebcb69202242d5bd0cb89c4e094cc73d5d6e38279bd"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:79ae152f8c6a680a188fb3be5e0f360ac05db5bbf410169a6c40851dfaebcce9"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cad2745964c174c65aa75f1bf68a4394d1b4d28f33894837cfd315d1e836f0"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb81f66f7f94045d723069cf317453d42375de9ff3c69089cf6466b078ac1db4"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab30ae75969da99e9a529e21ff497c18fdf958e822753db4ae7ed1e67094040"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f36cb5b644762fb3c86289324bbef17e95f91cd710603ac19444a47f638e8e96"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8b45075d9dbbc977dbc7007fb22bb0054c6990fbe91bf48dd80c0b96c6307ba7"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:475ac31477f6302e092463896d6a2055f3e6abcd293bad16ff94fc9185308a88"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-win32.whl", hash = "sha256:29172ab3d8609fdf821c3f2562dc61e14f1a8ff5306607c32ca743582d3a760e"},
|
||||
{file = "pyinstrument-4.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:bd176f297c99035127b264369d2bb97a65255f65f8d4e843836baf55ebb3cee4"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:23e9b4526978432e9999021da9a545992cf2ac3df5ee82db7beb6908fc4c978c"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2dbcaccc9f456ef95557ec501caeb292119c24446d768cb4fb43578b0f3d572c"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2097f63c66c2bc9678c826b9ff0c25acde3ed455590d9dcac21220673fe74fbf"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205ac2e76bd65d61b9611a9ce03d5f6393e34ec5b41dd38808f25d54e6b3e067"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f414ddf1161976a40fc0a333000e6a4ad612719eac0b8c9bb73f47153187148"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:65e62ebfa2cd8fb57eda90006f4505ac4c70da00fc2f05b6d8337d776ea76d41"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d96309df4df10be7b4885797c5f69bb3a89414680ebaec0722d8156fde5268c3"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f3d1ad3bc8ebb4db925afa706aa865c4bfb40d52509f143491ac0df2440ee5d2"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-win32.whl", hash = "sha256:dc37cb988c8854eb42bda2e438aaf553536566657d157c4473cc8aad5692a779"},
|
||||
{file = "pyinstrument-4.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:2cd4ce750c34a0318fc2d6c727cc255e9658d12a5cf3f2d0473f1c27157bdaeb"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6ca95b21f022e995e062b371d1f42d901452bcbedd2c02f036de677119503355"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ac1e1d7e1f1b64054c4eb04eb4869a7a5eef2261440e73943cc1b1bc3c828c18"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0711845e953fce6ab781221aacffa2a66dbc3289f8343e5babd7b2ea34da6c90"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b7d28582017de35cb64eb4e4fa603e753095108ca03745f5d17295970ee631f"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7be57db08bd366a37db3aa3a6187941ee21196e8b14975db337ddc7d1490649d"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9a0ac0f56860398d2628ce389826ce83fb3a557d0c9a2351e8a2eac6eb869983"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a9045186ff13bc826fef16be53736a85029aae3c6adfe52e666cad00d7ca623b"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6c4c56b6eab9004e92ad8a48bb54913fdd71fc8a748ae42a27b9e26041646f8b"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-win32.whl", hash = "sha256:37e989c44b51839d0c97466fa2b623638b9470d56d79e329f359f0e8fa6d83db"},
|
||||
{file = "pyinstrument-4.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:5494c5a84fee4309d7d973366ca6b8b9f8ba1d6b254e93b7c506264ef74f2cef"},
|
||||
{file = "pyinstrument-4.6.1.tar.gz", hash = "sha256:f4731b27121350f5a983d358d2272fe3df2f538aed058f57217eef7801a89288"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
bin = ["click", "nox"]
|
||||
docs = ["furo (==2021.6.18b36)", "myst-parser (==0.15.1)", "sphinx (==4.2.0)", "sphinxcontrib-programoutput (==0.17)"]
|
||||
examples = ["django", "numpy"]
|
||||
test = ["flaky", "greenlet (>=3.0.0a1)", "ipython", "pytest", "pytest-asyncio (==0.12.0)", "sphinx-autobuild (==2021.3.14)", "trio"]
|
||||
types = ["typing-extensions"]
|
||||
|
||||
[[package]]
|
||||
name = "pylibsrtp"
|
||||
version = "0.8.0"
|
||||
@@ -4143,4 +4219,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe"
|
||||
content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a"
|
||||
|
||||
@@ -41,6 +41,7 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
stamina = "^23.1.0"
|
||||
pyinstrument = "^4.6.1"
|
||||
|
||||
|
||||
[tool.poetry.group.tests.dependencies]
|
||||
|
||||
@@ -41,7 +41,6 @@ if settings.SENTRY_DSN:
|
||||
else:
|
||||
logger.info("Sentry disabled")
|
||||
|
||||
|
||||
# build app
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
@@ -102,6 +101,23 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
|
||||
use_route_names_as_operation_ids(app)
|
||||
|
||||
if settings.PROFILING:
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pyinstrument import Profiler
|
||||
|
||||
@app.middleware("http")
|
||||
async def profile_request(request: Request, call_next):
|
||||
profiling = request.query_params.get("profile", False)
|
||||
if profiling:
|
||||
profiler = Profiler(async_mode="enabled")
|
||||
profiler.start()
|
||||
await call_next(request)
|
||||
profiler.stop()
|
||||
return HTMLResponse(profiler.output_html())
|
||||
else:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
@@ -86,6 +85,14 @@ class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptDuration(BaseModel):
|
||||
duration: float
|
||||
|
||||
|
||||
class TranscriptWaveform(BaseModel):
|
||||
waveform: list[float]
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
@@ -126,22 +133,6 @@ class Transcript(BaseModel):
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=300, # as per their sdk
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -21,11 +21,13 @@ from pydantic import BaseModel
|
||||
from reflector.app import app
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptDuration,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
@@ -45,6 +47,7 @@ from reflector.processors import (
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
from reflector.processors.types import (
|
||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||
@@ -230,6 +233,33 @@ class PipelineMainBase(PipelineRunner):
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
audio_filename: Path | None = None
|
||||
@@ -243,7 +273,10 @@ class PipelineMainLive(PipelineMainBase):
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||
AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
@@ -253,6 +286,11 @@ class PipelineMainLive(PipelineMainBase):
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=transcript.audio_mp3_filename,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
@@ -285,8 +323,13 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
processors = []
|
||||
if settings.DIARIZATION_ENABLED:
|
||||
processors += [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
]
|
||||
|
||||
processors += [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
|
||||
@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
super().__init__()
|
||||
def __init__(self, path: Path | str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path.suffix not in (".mp3", ".wav"):
|
||||
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
|
||||
self.path = path
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
self.last_packet = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.out_container:
|
||||
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
for packet in self.out_stream.encode(data):
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
await self.emit(data)
|
||||
|
||||
async def _flush(self):
|
||||
if self.out_container:
|
||||
for packet in self.out_stream.encode():
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
try:
|
||||
if self.last_packet is not None:
|
||||
duration = round(
|
||||
float(
|
||||
(self.last_packet.pts * self.last_packet.duration)
|
||||
* self.last_packet.time_base
|
||||
),
|
||||
2,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to get duration")
|
||||
duration = 0
|
||||
|
||||
self.out_container.close()
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
|
||||
if duration > 0:
|
||||
await self.emit(duration, name="duration")
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
Implementation using the GPU service from banana.
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://...",
|
||||
"audio_ext": "wav",
|
||||
"timestamp": 123.456
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, banana_api_key: str, banana_model_key: str):
|
||||
super().__init__()
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.storage = Storage.get_instance(
|
||||
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
|
||||
)
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": banana_api_key,
|
||||
"X-Banana-Model-Key": banana_model_key,
|
||||
}
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Uploading audio {data.path.name} to S3")
|
||||
url = await self._upload_file(data.path)
|
||||
|
||||
print(f"Try to transcribe audio {data.path.name}")
|
||||
request_data = {
|
||||
"audio_url": url,
|
||||
"audio_ext": data.path.suffix[1:],
|
||||
"timestamp": float(round(data.timestamp, 2)),
|
||||
}
|
||||
response = await retry(client.post)(
|
||||
self.transcript_url,
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
print(f"Transcript response: {response.status_code} {response.content}")
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
transcript = Transcript(
|
||||
text=result["text"],
|
||||
words=[
|
||||
Word(text=word["text"], start=word["start"], end=word["end"])
|
||||
for word in result["words"]
|
||||
],
|
||||
)
|
||||
|
||||
# remove audio file from S3
|
||||
await self._delete_file(data.path)
|
||||
|
||||
return transcript
|
||||
|
||||
@retry
|
||||
async def _upload_file(self, path: Path) -> str:
|
||||
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
|
||||
return upload_result.url
|
||||
|
||||
@retry
|
||||
async def _delete_file(self, path: Path):
|
||||
await self.storage.delete_file(path.name)
|
||||
return True
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)
|
||||
36
server/reflector/processors/audio_waveform_processor.py
Normal file
36
server/reflector/processors/audio_waveform_processor.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
|
||||
class AudioWaveformProcessor(Processor):
|
||||
"""
|
||||
Write the waveform for the final audio
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(audio_path, str):
|
||||
audio_path = Path(audio_path)
|
||||
if audio_path.suffix not in (".mp3", ".wav"):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
self.audio_path = audio_path
|
||||
self.waveform_path = waveform_path
|
||||
|
||||
async def _flush(self):
|
||||
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info("Waveform Processing Started")
|
||||
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
|
||||
|
||||
with open(self.waveform_path, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
self.logger.info("Waveform Processing Finished")
|
||||
await self.emit(waveform, name="waveform")
|
||||
|
||||
async def _push(_self, _data):
|
||||
return
|
||||
@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class Processor:
|
||||
class Emitter:
|
||||
def __init__(self, **kwargs):
|
||||
self._callbacks = {}
|
||||
|
||||
# register callbacks from kwargs (on_*)
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("on_"):
|
||||
self.on(value, name=key[3:])
|
||||
|
||||
def on(self, callback, name="default"):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
if name not in self._callbacks:
|
||||
self._callbacks[name] = []
|
||||
self._callbacks[name].append(callback)
|
||||
|
||||
def off(self, callback, name="default"):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
self._callbacks[name].remove(callback)
|
||||
|
||||
async def emit(self, data, name="default"):
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
for callback in self._callbacks[name]:
|
||||
await callback(data)
|
||||
|
||||
|
||||
class Processor(Emitter):
|
||||
INPUT_TYPE: type = None
|
||||
OUTPUT_TYPE: type = None
|
||||
|
||||
@@ -59,7 +94,8 @@ class Processor:
|
||||
["processor"],
|
||||
)
|
||||
|
||||
def __init__(self, callback=None, custom_logger=None):
|
||||
def __init__(self, callback=None, custom_logger=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = name = self.__class__.__name__
|
||||
self.m_processor = self.m_processor.labels(name)
|
||||
self.m_processor_call = self.m_processor_call.labels(name)
|
||||
@@ -70,9 +106,11 @@ class Processor:
|
||||
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
|
||||
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
|
||||
self._processors = []
|
||||
self._callbacks = []
|
||||
|
||||
# register callbacks
|
||||
if callback:
|
||||
self.on(callback)
|
||||
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
@@ -100,21 +138,6 @@ class Processor:
|
||||
"""
|
||||
self._processors.remove(processor)
|
||||
|
||||
def on(self, callback):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def off(self, callback):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def get_pref(self, key: str, default: Any = None):
|
||||
"""
|
||||
Get a preference from the pipeline prefs
|
||||
@@ -123,15 +146,16 @@ class Processor:
|
||||
return self.pipeline.get_pref(key, default)
|
||||
return default
|
||||
|
||||
async def emit(self, data):
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
for callback in self._callbacks:
|
||||
await callback(data)
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
async def emit(self, data, name="default"):
|
||||
if name == "default":
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
await super().emit(data, name=name)
|
||||
if name == "default":
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
|
||||
async def push(self, data):
|
||||
"""
|
||||
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
def on(self, callback, name="default"):
|
||||
self.processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
def off(self, callback, name="default"):
|
||||
self.processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
|
||||
for processor in self.processors:
|
||||
processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
def on(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.on(callback)
|
||||
processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
def off(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.off(callback)
|
||||
processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
|
||||
@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
@@ -68,10 +69,14 @@ class Transcript(BaseModel):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return profanity_filter.censor(self.raw_text).strip()
|
||||
return self._get_censored_text(self.raw_text)
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
|
||||
50
server/reflector/redis_cache.py
Normal file
50
server/reflector/redis_cache.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import json
|
||||
|
||||
import redis
|
||||
from reflector.settings import settings
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
|
||||
def get_redis_client(db=0):
|
||||
"""
|
||||
Get a Redis client for the specified database.
|
||||
"""
|
||||
if db not in redis_clients:
|
||||
redis_clients[db] = redis.StrictRedis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=db,
|
||||
)
|
||||
return redis_clients[db]
|
||||
|
||||
|
||||
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
|
||||
"""
|
||||
Cache the result of a function in Redis.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Check if the first argument is a string
|
||||
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Compute the cache key based on the arguments and prefix
|
||||
cache_key = prefix + ":" + args[argidx]
|
||||
redis_client = get_redis_client(db=db)
|
||||
cached_result = redis_client.get(cache_key)
|
||||
|
||||
if cached_result:
|
||||
return json.loads(cached_result.decode("utf-8"))
|
||||
|
||||
# If the result is not cached, call the original function
|
||||
result = func(*args, **kwargs)
|
||||
redis_client.setex(cache_key, duration, json.dumps(result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -41,7 +41,7 @@ class Settings(BaseSettings):
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, banana, modal
|
||||
# backends: whisper, modal
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
@@ -50,10 +50,6 @@ class Settings(BaseSettings):
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription banana.dev configuration
|
||||
TRANSCRIPT_BANANA_API_KEY: str | None = None
|
||||
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# Audio transcription modal.com configuration
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
@@ -61,13 +57,16 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# Storage configuration for AWS
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
||||
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Transcript MP3 storage
|
||||
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# LLM
|
||||
# available backend: openai, banana, modal, oobabooga
|
||||
# available backend: openai, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
@@ -82,14 +81,11 @@ class Settings(BaseSettings):
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
ZEPHYR_LLM_URL: str | None = None
|
||||
|
||||
# LLM Banana configuration
|
||||
LLM_BANANA_API_KEY: str | None = None
|
||||
LLM_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
@@ -124,6 +120,7 @@ class Settings(BaseSettings):
|
||||
# Redis
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_CACHE_DB: int = 2
|
||||
|
||||
# Secret key
|
||||
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
|
||||
# Current hosting/domain
|
||||
BASE_URL: str = "http://localhost:1250"
|
||||
|
||||
# Profiling
|
||||
PROFILING: bool = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
@@ -57,6 +57,9 @@ def range_requests_response(
|
||||
),
|
||||
}
|
||||
|
||||
if request.method == "HEAD":
|
||||
return Response(headers=headers)
|
||||
|
||||
if content_disposition:
|
||||
headers["Content-Disposition"] = content_disposition
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from reflector.db.transcripts import (
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
@@ -53,7 +52,7 @@ class GetTranscript(BaseModel):
|
||||
name: str
|
||||
status: str
|
||||
locked: bool
|
||||
duration: int
|
||||
duration: float
|
||||
title: str | None
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
@@ -222,6 +221,7 @@ async def transcript_delete(
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
@router.head("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
@@ -272,8 +272,6 @@ async def transcript_get_audio_waveform(
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == content_type
|
||||
|
||||
# test get 404
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"url_suffix,content_type",
|
||||
[
|
||||
["/mp3", "audio/mpeg"],
|
||||
],
|
||||
)
|
||||
async def test_transcript_audio_download_head(
|
||||
fake_transcript, url_suffix, content_type
|
||||
):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == content_type
|
||||
|
||||
# test head 404
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
@@ -90,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
|
||||
assert response.status_code == 206
|
||||
assert response.headers["content-type"] == content_type
|
||||
assert response.headers["content-range"].startswith("bytes 100-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_audio_download_waveform(fake_transcript):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
assert isinstance(response.json()["data"], list)
|
||||
assert len(response.json()["data"]) >= 255
|
||||
|
||||
@@ -182,6 +182,16 @@ async def test_transcript_rtc_and_websocket(
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
|
||||
assert "WAVEFORM" in eventnames
|
||||
ev = events[eventnames.index("WAVEFORM")]
|
||||
assert isinstance(ev["data"]["waveform"], list)
|
||||
assert len(ev["data"]["waveform"]) >= 250
|
||||
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
|
||||
assert waveform_resp.status_code == 200
|
||||
assert waveform_resp.headers["content-type"] == "application/json"
|
||||
assert isinstance(waveform_resp.json()["data"], list)
|
||||
assert len(waveform_resp.json()["data"]) >= 250
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses.index("recording") < statuses.index("processing")
|
||||
@@ -191,10 +201,14 @@ async def test_transcript_rtc_and_websocket(
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# check on the latest response that the audio duration is > 0
|
||||
assert resp.json()["duration"] > 0
|
||||
assert "DURATION" in eventnames
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mpeg"
|
||||
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert audio_resp.status_code == 200
|
||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
|
||||
Reference in New Issue
Block a user