Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing

This commit is contained in:
Sara
2023-11-21 12:11:58 +01:00
47 changed files with 1163 additions and 614 deletions

View File

@@ -51,17 +51,6 @@
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
#TRANSCRIPT_MODAL_API_KEY=xxxxx
## Using serverless banana.dev (require reflector-gpu-banana deployed)
## XXX this service is buggy do not use at the moment
## XXX it also require the audio to be saved to S3
#TRANSCRIPT_BACKEND=banana
#TRANSCRIPT_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
#TRANSCRIPT_BANANA_API_KEY=xxx
#TRANSCRIPT_BANANA_MODEL_KEY=xxx
#TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID=xxx
#TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY=xxx
#TRANSCRIPT_STORAGE_AWS_BUCKET_NAME="reflector-bucket/chunks"
## =======================================================
## LLM backend
##
@@ -78,13 +67,6 @@
#LLM_URL=https://xxxxxx--reflector-llm-web.modal.run
#LLM_MODAL_API_KEY=xxx
## Using serverless banana.dev (require reflector-gpu-banana deployed)
## XXX this service is buggy do not use at the moment
#LLM_BACKEND=banana
#LLM_URL=https://reflector-gpu-banana-xxxxx.run.banana.dev
#LLM_BANANA_API_KEY=xxxxx
#LLM_BANANA_MODEL_KEY=xxxxx
## Using OpenAI
#LLM_BACKEND=openai
#LLM_OPENAI_KEY=xxx

View File

@@ -81,7 +81,8 @@ class LLM:
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
@@ -96,7 +97,8 @@ class LLM:
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# move model to gpu

View File

@@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models"
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
stub = Stub(name="reflector-llm-zephyr")
@@ -81,7 +81,8 @@ class LLM:
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
@@ -96,7 +97,8 @@ class LLM:
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
gen_cfg.pad_token_id = tokenizer.eos_token_id
gen_cfg.eos_token_id = tokenizer.eos_token_id

View File

@@ -95,7 +95,8 @@ class Transcriber:
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
download_root=WHISPER_MODEL_DIR
download_root=WHISPER_MODEL_DIR,
local_files_only=True
)
@method()

View File

@@ -0,0 +1,64 @@
"""fix duration
Revision ID: 4814901632bc
Revises: 38a927dcb099
Create Date: 2023-11-10 18:12:17.886522
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column
from sqlalchemy import select
# revision identifiers, used by Alembic.
revision: str = "4814901632bc"
down_revision: Union[str, None] = "38a927dcb099"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# for all the transcripts, calculate the duration from the mp3
# and update the duration column
from pathlib import Path
from reflector.settings import settings
import av
bind = op.get_bind()
transcript = table(
"transcript", column("id", sa.String), column("duration", sa.Float)
)
# select only the one with duration = 0
results = bind.execute(
select([transcript.c.id, transcript.c.duration]).where(
transcript.c.duration == 0
)
)
data_dir = Path(settings.DATA_DIR)
for row in results:
audio_path = data_dir / row["id"] / "audio.mp3"
if not audio_path.exists():
continue
try:
print(f"Processing {audio_path}")
container = av.open(audio_path.as_posix())
print(container.duration)
duration = round(float(container.duration / av.time_base), 2)
print(f"Duration: {duration}")
bind.execute(
transcript.update()
.where(transcript.c.id == row["id"])
.values(duration=duration)
)
except Exception as e:
print(f"Failed to process {audio_path}: {e}")
def downgrade() -> None:
pass

78
server/poetry.lock generated
View File

@@ -2557,6 +2557,82 @@ typing-extensions = "*"
[package.extras]
dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"]
[[package]]
name = "pyinstrument"
version = "4.6.1"
description = "Call stack profiler for Python. Shows you why your code is slow!"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:73476e4bc6e467ac1b2c3c0dd1f0b71c9061d4de14626676adfdfbb14aa342b4"},
{file = "pyinstrument-4.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4d1da8efd974cf9df52ee03edaee2d3875105ddd00de35aa542760f7c612bdf7"},
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:507be1ee2f2b0c9fba74d622a272640dd6d1b0c9ec3388b2cdeb97ad1e77125f"},
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95cee6de08eb45754ef4f602ce52b640d1c535d934a6a8733a974daa095def37"},
{file = "pyinstrument-4.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7873e8cec92321251fdf894a72b3c78f4c5c20afdd1fef0baf9042ec843bb04"},
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a242f6cac40bc83e1f3002b6b53681846dfba007f366971db0bf21e02dbb1903"},
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:97c9660cdb4bd2a43cf4f3ab52cffd22f3ac9a748d913b750178fb34e5e39e64"},
{file = "pyinstrument-4.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e304cd0723e2b18ada5e63c187abf6d777949454c734f5974d64a0865859f0f4"},
{file = "pyinstrument-4.6.1-cp310-cp310-win32.whl", hash = "sha256:cee21a2d78187dd8a80f72f5d0f1ddb767b2d9800f8bb4d94b6d11f217c22cdb"},
{file = "pyinstrument-4.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:2000712f71d693fed2f8a1c1638d37b7919124f367b37976d07128d49f1445eb"},
{file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a366c6f3dfb11f1739bdc1dee75a01c1563ad0bf4047071e5e77598087df457f"},
{file = "pyinstrument-4.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6be327be65d934796558aa9cb0f75ce62ebd207d49ad1854610c97b0579ad47"},
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e160d9c5d20d3e4ef82269e4e8b246ff09bdf37af5fb8cb8ccca97936d95ad6"},
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ffbf56605ef21c2fcb60de2fa74ff81f417d8be0c5002a407e414d6ef6dee43"},
{file = "pyinstrument-4.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c92cc4924596d6e8f30a16182bbe90893b1572d847ae12652f72b34a9a17c24a"},
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f4b48a94d938cae981f6948d9ec603bab2087b178d2095d042d5a48aabaecaab"},
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e7a386392275bdef4a1849712dc5b74f0023483fca14ef93d0ca27d453548982"},
{file = "pyinstrument-4.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:871b131b83e9b1122f2325061c68ed1e861eebcb568c934d2fb193652f077f77"},
{file = "pyinstrument-4.6.1-cp311-cp311-win32.whl", hash = "sha256:8d8515156dd91f5652d13b5fcc87e634f8fe1c07b68d1d0840348cdd50bf5ace"},
{file = "pyinstrument-4.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb868fbe089036e9f32525a249f4c78b8dc46967612393f204b8234f439c9cc4"},
{file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a18cd234cce4f230f1733807f17a134e64a1f1acabf74a14d27f583cf2b183df"},
{file = "pyinstrument-4.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:574cfca69150be4ce4461fb224712fbc0722a49b0dc02fa204d02807adf6b5a0"},
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e02cf505e932eb8ccf561b7527550a67ec14fcae1fe0e25319b09c9c166e914"},
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832fb2acef9d53701c1ab546564c45fb70a8770c816374f8dd11420d399103c9"},
{file = "pyinstrument-4.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13cb57e9607545623ebe462345b3d0c4caee0125d2d02267043ece8aca8f4ea0"},
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9be89e7419bcfe8dd6abb0d959d6d9c439c613a4a873514c43d16b48dae697c9"},
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:476785cfbc44e8e1b1ad447398aa3deae81a8df4d37eb2d8bbb0c404eff979cd"},
{file = "pyinstrument-4.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e9cebd90128a3d2fee36d3ccb665c1b9dce75261061b2046203e45c4a8012d54"},
{file = "pyinstrument-4.6.1-cp312-cp312-win32.whl", hash = "sha256:1d0b76683df2ad5c40eff73607dc5c13828c92fbca36aff1ddf869a3c5a55fa6"},
{file = "pyinstrument-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:c4b7af1d9d6a523cfbfedebcb69202242d5bd0cb89c4e094cc73d5d6e38279bd"},
{file = "pyinstrument-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:79ae152f8c6a680a188fb3be5e0f360ac05db5bbf410169a6c40851dfaebcce9"},
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cad2745964c174c65aa75f1bf68a4394d1b4d28f33894837cfd315d1e836f0"},
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb81f66f7f94045d723069cf317453d42375de9ff3c69089cf6466b078ac1db4"},
{file = "pyinstrument-4.6.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab30ae75969da99e9a529e21ff497c18fdf958e822753db4ae7ed1e67094040"},
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f36cb5b644762fb3c86289324bbef17e95f91cd710603ac19444a47f638e8e96"},
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8b45075d9dbbc977dbc7007fb22bb0054c6990fbe91bf48dd80c0b96c6307ba7"},
{file = "pyinstrument-4.6.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:475ac31477f6302e092463896d6a2055f3e6abcd293bad16ff94fc9185308a88"},
{file = "pyinstrument-4.6.1-cp37-cp37m-win32.whl", hash = "sha256:29172ab3d8609fdf821c3f2562dc61e14f1a8ff5306607c32ca743582d3a760e"},
{file = "pyinstrument-4.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:bd176f297c99035127b264369d2bb97a65255f65f8d4e843836baf55ebb3cee4"},
{file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:23e9b4526978432e9999021da9a545992cf2ac3df5ee82db7beb6908fc4c978c"},
{file = "pyinstrument-4.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2dbcaccc9f456ef95557ec501caeb292119c24446d768cb4fb43578b0f3d572c"},
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2097f63c66c2bc9678c826b9ff0c25acde3ed455590d9dcac21220673fe74fbf"},
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205ac2e76bd65d61b9611a9ce03d5f6393e34ec5b41dd38808f25d54e6b3e067"},
{file = "pyinstrument-4.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f414ddf1161976a40fc0a333000e6a4ad612719eac0b8c9bb73f47153187148"},
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:65e62ebfa2cd8fb57eda90006f4505ac4c70da00fc2f05b6d8337d776ea76d41"},
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d96309df4df10be7b4885797c5f69bb3a89414680ebaec0722d8156fde5268c3"},
{file = "pyinstrument-4.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f3d1ad3bc8ebb4db925afa706aa865c4bfb40d52509f143491ac0df2440ee5d2"},
{file = "pyinstrument-4.6.1-cp38-cp38-win32.whl", hash = "sha256:dc37cb988c8854eb42bda2e438aaf553536566657d157c4473cc8aad5692a779"},
{file = "pyinstrument-4.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:2cd4ce750c34a0318fc2d6c727cc255e9658d12a5cf3f2d0473f1c27157bdaeb"},
{file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6ca95b21f022e995e062b371d1f42d901452bcbedd2c02f036de677119503355"},
{file = "pyinstrument-4.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ac1e1d7e1f1b64054c4eb04eb4869a7a5eef2261440e73943cc1b1bc3c828c18"},
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0711845e953fce6ab781221aacffa2a66dbc3289f8343e5babd7b2ea34da6c90"},
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b7d28582017de35cb64eb4e4fa603e753095108ca03745f5d17295970ee631f"},
{file = "pyinstrument-4.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7be57db08bd366a37db3aa3a6187941ee21196e8b14975db337ddc7d1490649d"},
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9a0ac0f56860398d2628ce389826ce83fb3a557d0c9a2351e8a2eac6eb869983"},
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a9045186ff13bc826fef16be53736a85029aae3c6adfe52e666cad00d7ca623b"},
{file = "pyinstrument-4.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6c4c56b6eab9004e92ad8a48bb54913fdd71fc8a748ae42a27b9e26041646f8b"},
{file = "pyinstrument-4.6.1-cp39-cp39-win32.whl", hash = "sha256:37e989c44b51839d0c97466fa2b623638b9470d56d79e329f359f0e8fa6d83db"},
{file = "pyinstrument-4.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:5494c5a84fee4309d7d973366ca6b8b9f8ba1d6b254e93b7c506264ef74f2cef"},
{file = "pyinstrument-4.6.1.tar.gz", hash = "sha256:f4731b27121350f5a983d358d2272fe3df2f538aed058f57217eef7801a89288"},
]
[package.extras]
bin = ["click", "nox"]
docs = ["furo (==2021.6.18b36)", "myst-parser (==0.15.1)", "sphinx (==4.2.0)", "sphinxcontrib-programoutput (==0.17)"]
examples = ["django", "numpy"]
test = ["flaky", "greenlet (>=3.0.0a1)", "ipython", "pytest", "pytest-asyncio (==0.12.0)", "sphinx-autobuild (==2021.3.14)", "trio"]
types = ["typing-extensions"]
[[package]]
name = "pylibsrtp"
version = "0.8.0"
@@ -4143,4 +4219,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe"
content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a"

View File

@@ -41,6 +41,7 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"}
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
stamina = "^23.1.0"
pyinstrument = "^4.6.1"
[tool.poetry.group.tests.dependencies]

View File

@@ -41,7 +41,6 @@ if settings.SENTRY_DSN:
else:
logger.info("Sentry disabled")
# build app
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -102,6 +101,23 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
use_route_names_as_operation_ids(app)
if settings.PROFILING:
from fastapi import Request
from fastapi.responses import HTMLResponse
from pyinstrument import Profiler
@app.middleware("http")
async def profile_request(request: Request, call_next):
profiling = request.query_params.get("profile", False)
if profiling:
profiler = Profiler(async_mode="enabled")
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
if __name__ == "__main__":
import uvicorn

View File

@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
@@ -86,6 +85,14 @@ class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -126,22 +133,6 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)

View File

@@ -1,54 +0,0 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
class BananaLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.headers = {
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = BananaLLM()
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -21,11 +21,13 @@ from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
@@ -45,6 +47,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -230,6 +233,33 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
@@ -243,7 +273,10 @@ class PipelineMainLive(PipelineMainBase):
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
@@ -253,6 +286,11 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
),
]
@@ -285,8 +323,13 @@ class PipelineMainDiarization(PipelineMainBase):
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
AudioDiarizationAutoProcessor(callback=self.on_topic),
processors = []
if settings.DIARIZATION_ENABLED:
processors += [
AudioDiarizationAutoProcessor(callback=self.on_topic),
]
processors += [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(

View File

@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
def __init__(self, path: Path | str, **kwargs):
super().__init__(**kwargs)
if isinstance(path, str):
path = Path(path)
if path.suffix not in (".mp3", ".wav"):
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
self.path = path
self.out_container = None
self.out_stream = None
self.last_packet = None
async def _push(self, data: av.AudioFrame):
if not self.out_container:
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
raise ValueError("Only mp3 and wav files are supported")
for packet in self.out_stream.encode(data):
self.out_container.mux(packet)
self.last_packet = packet
await self.emit(data)
async def _flush(self):
if self.out_container:
for packet in self.out_stream.encode():
self.out_container.mux(packet)
self.last_packet = packet
try:
if self.last_packet is not None:
duration = round(
float(
(self.last_packet.pts * self.last_packet.duration)
* self.last_packet.time_base
),
2,
)
except Exception:
self.logger.exception("Failed to get duration")
duration = 0
self.out_container.close()
self.out_container = None
self.out_stream = None
if duration > 0:
await self.emit(duration, name="duration")

View File

@@ -1,86 +0,0 @@
"""
Implementation using the GPU service from banana.
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_url": "https://...",
"audio_ext": "wav",
"timestamp": 123.456
"language": "en"
}
```
"""
from pathlib import Path
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings
from reflector.storage import Storage
from reflector.utils.retry import retry
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
def __init__(self, banana_api_key: str, banana_model_key: str):
super().__init__()
self.transcript_url = settings.TRANSCRIPT_URL
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.storage = Storage.get_instance(
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
)
self.headers = {
"X-Banana-API-Key": banana_api_key,
"X-Banana-Model-Key": banana_model_key,
}
async def _transcript(self, data: AudioFile):
async with httpx.AsyncClient() as client:
print(f"Uploading audio {data.path.name} to S3")
url = await self._upload_file(data.path)
print(f"Try to transcribe audio {data.path.name}")
request_data = {
"audio_url": url,
"audio_ext": data.path.suffix[1:],
"timestamp": float(round(data.timestamp, 2)),
}
response = await retry(client.post)(
self.transcript_url,
json=request_data,
headers=self.headers,
timeout=self.timeout,
)
print(f"Transcript response: {response.status_code} {response.content}")
response.raise_for_status()
result = response.json()
transcript = Transcript(
text=result["text"],
words=[
Word(text=word["text"], start=word["start"], end=word["end"])
for word in result["words"]
],
)
# remove audio file from S3
await self._delete_file(data.path)
return transcript
@retry
async def _upload_file(self, path: Path) -> str:
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
return upload_result.url
@retry
async def _delete_file(self, path: Path):
await self.storage.delete_file(path.name)
return True
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)

View File

@@ -0,0 +1,36 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
async def _push(_self, _data):
return

View File

@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
data: Any
class Processor:
class Emitter:
def __init__(self, **kwargs):
self._callbacks = {}
# register callbacks from kwargs (on_*)
for key, value in kwargs.items():
if key.startswith("on_"):
self.on(value, name=key[3:])
def on(self, callback, name="default"):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
if name not in self._callbacks:
self._callbacks[name] = []
self._callbacks[name].append(callback)
def off(self, callback, name="default"):
"""
Unregister a callback to be called when data is emitted
"""
if name not in self._callbacks:
return
self._callbacks[name].remove(callback)
async def emit(self, data, name="default"):
if name not in self._callbacks:
return
for callback in self._callbacks[name]:
await callback(data)
class Processor(Emitter):
INPUT_TYPE: type = None
OUTPUT_TYPE: type = None
@@ -59,7 +94,8 @@ class Processor:
["processor"],
)
def __init__(self, callback=None, custom_logger=None):
def __init__(self, callback=None, custom_logger=None, **kwargs):
super().__init__(**kwargs)
self.name = name = self.__class__.__name__
self.m_processor = self.m_processor.labels(name)
self.m_processor_call = self.m_processor_call.labels(name)
@@ -70,9 +106,11 @@ class Processor:
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
self._processors = []
self._callbacks = []
# register callbacks
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
@@ -100,21 +138,6 @@ class Processor:
"""
self._processors.remove(processor)
def on(self, callback):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
self._callbacks.append(callback)
def off(self, callback):
"""
Unregister a callback to be called when data is emitted
"""
self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
@@ -123,15 +146,16 @@ class Processor:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data):
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
for callback in self._callbacks:
await callback(data)
for processor in self._processors:
await processor.push(data)
async def emit(self, data, name="default"):
if name == "default":
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
await super().emit(data, name=name)
if name == "default":
for processor in self._processors:
await processor.push(data)
async def push(self, data):
"""
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def on(self, callback, name="default"):
self.processor.on(callback, name=name)
def off(self, callback):
self.processor.off(callback)
def off(self, callback, name="default"):
self.processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
for processor in self.processors:
processor.disconnect(processor)
def on(self, callback):
def on(self, callback, name="default"):
for processor in self.processors:
processor.on(callback)
processor.on(callback, name=name)
def off(self, callback):
def off(self, callback, name="default"):
for processor in self.processors:
processor.off(callback)
processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)

View File

@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
from reflector.redis_cache import redis_cache
PUNC_RE = re.compile(r"[.;:?!…]")
@@ -68,10 +69,14 @@ class Transcript(BaseModel):
# Uncensored text
return "".join([word.text for word in self.words])
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
def _get_censored_text(self, text: str):
return profanity_filter.censor(text).strip()
@property
def text(self):
# Censored text
return profanity_filter.censor(self.raw_text).strip()
return self._get_censored_text(self.raw_text)
@property
def human_timestamp(self):

View File

@@ -0,0 +1,50 @@
import functools
import json
import redis
from reflector.settings import settings
redis_clients = {}
def get_redis_client(db=0):
"""
Get a Redis client for the specified database.
"""
if db not in redis_clients:
redis_clients[db] = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=db,
)
return redis_clients[db]
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
"""
Cache the result of a function in Redis.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Check if the first argument is a string
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
return func(*args, **kwargs)
# Compute the cache key based on the arguments and prefix
cache_key = prefix + ":" + args[argidx]
redis_client = get_redis_client(db=db)
cached_result = redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result.decode("utf-8"))
# If the result is not cached, call the original function
result = func(*args, **kwargs)
redis_client.setex(cache_key, duration, json.dumps(result))
return result
return wrapper
return decorator

View File

@@ -41,7 +41,7 @@ class Settings(BaseSettings):
AUDIO_BUFFER_SIZE: int = 256 * 960
# Audio Transcription
# backends: whisper, banana, modal
# backends: whisper, modal
TRANSCRIPT_BACKEND: str = "whisper"
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
@@ -50,10 +50,6 @@ class Settings(BaseSettings):
TRANSLATE_URL: str | None = None
TRANSLATE_TIMEOUT: int = 90
# Audio transcription banana.dev configuration
TRANSCRIPT_BANANA_API_KEY: str | None = None
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
# Audio transcription modal.com configuration
TRANSCRIPT_MODAL_API_KEY: str | None = None
@@ -61,13 +57,16 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM
# available backend: openai, banana, modal, oobabooga
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
# LLM common configuration
@@ -82,14 +81,11 @@ class Settings(BaseSettings):
LLM_TEMPERATURE: float = 0.7
ZEPHYR_LLM_URL: str | None = None
# LLM Banana configuration
LLM_BANANA_API_KEY: str | None = None
LLM_BANANA_MODEL_KEY: str | None = None
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
@@ -124,6 +120,7 @@ class Settings(BaseSettings):
# Redis
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
REDIS_CACHE_DB: int = 2
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
# Profiling
PROFILING: bool = False
settings = Settings()

View File

@@ -1,7 +1,7 @@
import os
from typing import BinaryIO
from fastapi import HTTPException, Request, status
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
@@ -57,6 +57,9 @@ def range_requests_response(
),
}
if request.method == "HEAD":
return Response(headers=headers)
if content_disposition:
headers["Content-Disposition"] = content_disposition

View File

@@ -23,7 +23,6 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -53,7 +52,7 @@ class GetTranscript(BaseModel):
name: str
status: str
locked: bool
duration: int
duration: float
title: str | None
short_summary: str | None
long_summary: str | None
@@ -222,6 +221,7 @@ async def transcript_delete(
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
@@ -272,8 +272,6 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform

View File

@@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test get 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
assert response.status_code == 404
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["/mp3", "audio/mpeg"],
],
)
async def test_transcript_audio_download_head(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test head 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
assert response.status_code == 404
@pytest.mark.asyncio
@pytest.mark.parametrize(
@@ -90,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) >= 255

View File

@@ -182,6 +182,16 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")
@@ -191,10 +201,14 @@ async def test_transcript_rtc_and_websocket(
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")