Feature additions (#210)

* initial

* add LLM features

* update LLM logic

* update llm functions: change control flow

* add generation config

* update return types

* update processors and tests

* update rtc_offer

* revert new title processor change

* fix unit tests

* add comments and fix HTTP 500

* adjust prompt

* test with reflector app

* revert new event for final title

* update

* move onus onto processors

* move onus onto processors

* stash

* add provision for gen config

* dynamically pack the LLM input using context length

* tune final summary params

* update consolidated class structures

* update consolidated class structures

* update precommit

* add broadcast processors

* working baseline

* Organize LLMParams

* minor fixes

* minor fixes

* minor fixes

* fix unit tests

* fix unit tests

* fix unit tests

* update tests

* update tests

* edit pipeline response events

* update summary return types

* configure tests

* alembic db migration

* change LLM response flow

* edit main llm functions

* edit main llm functions

* change llm name and gen cf

* Update transcript_topic_detector.py

* PR review comments

* checkpoint before db event migration

* update DB migration of past events

* update DB migration of past events

* edit LLM classes

* Delete unwanted file

* remove List typing

* remove List typing

* update oobabooga API call

* topic enhancements

* update UI event handling

* move ensure_casing to llm base

* update tests

* update tests
This commit is contained in:
projects-g
2023-09-13 11:26:08 +05:30
committed by GitHub
parent 762d7bfc3c
commit 9fe261406c
33 changed files with 1334 additions and 202 deletions

View File

@@ -18,12 +18,6 @@ repos:
exclude: ^server/trials
- id: detect-private-key
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.282
hooks:
- id: ruff
files: ^server/(reflector|tests)/
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
@@ -36,4 +30,10 @@ repos:
- id: isort
name: isort (python)
files: ^server/(gpu|evaluate|reflector)/
args: ["--profile", "black", "--filter-files"]
args: [ "--profile", "black", "--filter-files" ]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.282
hooks:
- id: ruff
files: ^server/(reflector|tests)/

View File

@@ -94,6 +94,12 @@
#LLM_URL=http://localhost:4891/v1/completions
#LLM_OPENAI_MODEL="GPT4All Falcon"
## Default LLM MODEL NAME
DEFAULT_LLM=lmsys/vicuna-13b-v1.5
## Cache directory to store models
CACHE_DIR=data
## =======================================================
## Sentry
## =======================================================

View File

@@ -55,7 +55,7 @@ llm_image = (
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4",
"huggingface_hub==0.16.4"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
@@ -73,8 +73,7 @@ llm_image = (
class LLM:
def __enter__(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
@@ -84,10 +83,11 @@ class LLM:
cache_dir=IMAGE_MODEL_DIR
)
# generation configuration
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
# JSONFormer doesn't yet support generation configs, but keeping for future usage
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
@@ -106,6 +106,7 @@ class LLM:
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
def __exit__(self, *args):
print("Exit llm")
@@ -116,34 +117,44 @@ class LLM:
return {"status": "ok"}
@method()
def generate(self, prompt: str, schema: str = None):
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
# If a schema is given, conform to schema
if schema:
print(f"Schema {schema=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
if gen_schema:
import jsonformer
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(schema),
prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens)
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens
)
response = jsonformer_llm()
else:
# If no schema, perform prompt only generation
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
response = response[len(prompt):]
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@@ -160,7 +171,7 @@ class LLM:
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, Field
from pydantic import BaseModel
llmstub = LLM()
@@ -177,16 +188,16 @@ def web():
class LLMRequest(BaseModel):
prompt: str
schema_: Optional[dict] = Field(None, alias="schema")
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm(
req: LLMRequest,
):
if req.schema_:
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_))
else:
func = llmstub.generate.spawn(prompt=req.prompt)
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
result = func.get()
return result

View File

@@ -0,0 +1,37 @@
"""add_title/short_and_long_summary_and_remove_summary
Revision ID: 99365b0cd87b
Revises: b3df9681cae9
Create Date: 2023-09-01 20:19:47.216334
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '99365b0cd87b'
down_revision: Union[str, None] = 'b3df9681cae9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute("UPDATE transcript SET events = "
"REPLACE(events, '\"event\": \"SUMMARY\"', '\"event\": \"LONG_SUMMARY\"');")
op.alter_column('transcript', 'summary', new_column_name='long_summary')
op.add_column('transcript', sa.Column('title', sa.String(), nullable=True))
op.add_column('transcript', sa.Column('short_summary', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute("UPDATE transcript SET events = "
"REPLACE(events, '\"event\": \"LONG_SUMMARY\"', '\"event\": \"SUMMARY\"');")
op.alter_column('transcript', 'long_summary', nullable=True, new_column_name='summary')
op.drop_column('transcript', 'title')
op.drop_column('transcript', 'short_summary')
# ### end Alembic commands ###

266
server/poetry.lock generated
View File

@@ -1588,6 +1588,17 @@ files = [
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
]
[[package]]
name = "joblib"
version = "1.3.2"
description = "Lightweight pipelining with Python functions"
optional = false
python-versions = ">=3.7"
files = [
{file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
]
[[package]]
name = "jwcrypto"
version = "1.5.0"
@@ -1934,6 +1945,31 @@ files = [
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
]
[[package]]
name = "nltk"
version = "3.8.1"
description = "Natural Language Toolkit"
optional = false
python-versions = ">=3.7"
files = [
{file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"},
{file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"},
]
[package.dependencies]
click = "*"
joblib = "*"
regex = ">=2021.8.3"
tqdm = "*"
[package.extras]
all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"]
corenlp = ["requests"]
machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"]
plot = ["matplotlib"]
tgrep = ["pyparsing"]
twitter = ["twython"]
[[package]]
name = "numpy"
version = "1.25.2"
@@ -2672,6 +2708,103 @@ files = [
[package.extras]
full = ["numpy"]
[[package]]
name = "regex"
version = "2023.8.8"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.6"
files = [
{file = "regex-2023.8.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:88900f521c645f784260a8d346e12a1590f79e96403971241e64c3a265c8ecdb"},
{file = "regex-2023.8.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3611576aff55918af2697410ff0293d6071b7e00f4b09e005d614686ac4cd57c"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8a0ccc8f2698f120e9e5742f4b38dc944c38744d4bdfc427616f3a163dd9de5"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c662a4cbdd6280ee56f841f14620787215a171c4e2d1744c9528bed8f5816c96"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf0633e4a1b667bfe0bb10b5e53fe0d5f34a6243ea2530eb342491f1adf4f739"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551ad543fa19e94943c5b2cebc54c73353ffff08228ee5f3376bd27b3d5b9800"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54de2619f5ea58474f2ac211ceea6b615af2d7e4306220d4f3fe690c91988a61"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ec4b3f0aebbbe2fc0134ee30a791af522a92ad9f164858805a77442d7d18570"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ae646c35cb9f820491760ac62c25b6d6b496757fda2d51be429e0e7b67ae0ab"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca339088839582d01654e6f83a637a4b8194d0960477b9769d2ff2cfa0fa36d2"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:d9b6627408021452dcd0d2cdf8da0534e19d93d070bfa8b6b4176f99711e7f90"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:bd3366aceedf274f765a3a4bc95d6cd97b130d1dda524d8f25225d14123c01db"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7aed90a72fc3654fba9bc4b7f851571dcc368120432ad68b226bd593f3f6c0b7"},
{file = "regex-2023.8.8-cp310-cp310-win32.whl", hash = "sha256:80b80b889cb767cc47f31d2b2f3dec2db8126fbcd0cff31b3925b4dc6609dcdb"},
{file = "regex-2023.8.8-cp310-cp310-win_amd64.whl", hash = "sha256:b82edc98d107cbc7357da7a5a695901b47d6eb0420e587256ba3ad24b80b7d0b"},
{file = "regex-2023.8.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1e7d84d64c84ad97bf06f3c8cb5e48941f135ace28f450d86af6b6512f1c9a71"},
{file = "regex-2023.8.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce0f9fbe7d295f9922c0424a3637b88c6c472b75eafeaff6f910494a1fa719ef"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06c57e14ac723b04458df5956cfb7e2d9caa6e9d353c0b4c7d5d54fcb1325c46"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7a9aaa5a1267125eef22cef3b63484c3241aaec6f48949b366d26c7250e0357"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b7408511fca48a82a119d78a77c2f5eb1b22fe88b0d2450ed0756d194fe7a9a"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14dc6f2d88192a67d708341f3085df6a4f5a0c7b03dec08d763ca2cd86e9f559"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48c640b99213643d141550326f34f0502fedb1798adb3c9eb79650b1ecb2f177"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0085da0f6c6393428bf0d9c08d8b1874d805bb55e17cb1dfa5ddb7cfb11140bf"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:964b16dcc10c79a4a2be9f1273fcc2684a9eedb3906439720598029a797b46e6"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7ce606c14bb195b0e5108544b540e2c5faed6843367e4ab3deb5c6aa5e681208"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:40f029d73b10fac448c73d6eb33d57b34607f40116e9f6e9f0d32e9229b147d7"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3b8e6ea6be6d64104d8e9afc34c151926f8182f84e7ac290a93925c0db004bfd"},
{file = "regex-2023.8.8-cp311-cp311-win32.whl", hash = "sha256:942f8b1f3b223638b02df7df79140646c03938d488fbfb771824f3d05fc083a8"},
{file = "regex-2023.8.8-cp311-cp311-win_amd64.whl", hash = "sha256:51d8ea2a3a1a8fe4f67de21b8b93757005213e8ac3917567872f2865185fa7fb"},
{file = "regex-2023.8.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e951d1a8e9963ea51efd7f150450803e3b95db5939f994ad3d5edac2b6f6e2b4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704f63b774218207b8ccc6c47fcef5340741e5d839d11d606f70af93ee78e4d4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22283c769a7b01c8ac355d5be0715bf6929b6267619505e289f792b01304d898"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91129ff1bb0619bc1f4ad19485718cc623a2dc433dff95baadbf89405c7f6b57"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de35342190deb7b866ad6ba5cbcccb2d22c0487ee0cbb251efef0843d705f0d4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b993b6f524d1e274a5062488a43e3f9f8764ee9745ccd8e8193df743dbe5ee61"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3026cbcf11d79095a32d9a13bbc572a458727bd5b1ca332df4a79faecd45281c"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:293352710172239bf579c90a9864d0df57340b6fd21272345222fb6371bf82b3"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d909b5a3fff619dc7e48b6b1bedc2f30ec43033ba7af32f936c10839e81b9217"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3d370ff652323c5307d9c8e4c62efd1956fb08051b0e9210212bc51168b4ff56"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:b076da1ed19dc37788f6a934c60adf97bd02c7eea461b73730513921a85d4235"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e9941a4ada58f6218694f382e43fdd256e97615db9da135e77359da257a7168b"},
{file = "regex-2023.8.8-cp36-cp36m-win32.whl", hash = "sha256:a8c65c17aed7e15a0c824cdc63a6b104dfc530f6fa8cb6ac51c437af52b481c7"},
{file = "regex-2023.8.8-cp36-cp36m-win_amd64.whl", hash = "sha256:aadf28046e77a72f30dcc1ab185639e8de7f4104b8cb5c6dfa5d8ed860e57236"},
{file = "regex-2023.8.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:423adfa872b4908843ac3e7a30f957f5d5282944b81ca0a3b8a7ccbbfaa06103"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ae594c66f4a7e1ea67232a0846649a7c94c188d6c071ac0210c3e86a5f92109"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e51c80c168074faa793685656c38eb7a06cbad7774c8cbc3ea05552d615393d8"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09b7f4c66aa9d1522b06e31a54f15581c37286237208df1345108fcf4e050c18"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e73e5243af12d9cd6a9d6a45a43570dbe2e5b1cdfc862f5ae2b031e44dd95a8"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941460db8fe3bd613db52f05259c9336f5a47ccae7d7def44cc277184030a116"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f0ccf3e01afeb412a1a9993049cb160d0352dba635bbca7762b2dc722aa5742a"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2e9216e0d2cdce7dbc9be48cb3eacb962740a09b011a116fd7af8c832ab116ca"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:5cd9cd7170459b9223c5e592ac036e0704bee765706445c353d96f2890e816c8"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4873ef92e03a4309b3ccd8281454801b291b689f6ad45ef8c3658b6fa761d7ac"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:239c3c2a339d3b3ddd51c2daef10874410917cd2b998f043c13e2084cb191684"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1005c60ed7037be0d9dea1f9c53cc42f836188227366370867222bda4c3c6bd7"},
{file = "regex-2023.8.8-cp37-cp37m-win32.whl", hash = "sha256:e6bd1e9b95bc5614a7a9c9c44fde9539cba1c823b43a9f7bc11266446dd568e3"},
{file = "regex-2023.8.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9a96edd79661e93327cfeac4edec72a4046e14550a1d22aa0dd2e3ca52aec921"},
{file = "regex-2023.8.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f2181c20ef18747d5f4a7ea513e09ea03bdd50884a11ce46066bb90fe4213675"},
{file = "regex-2023.8.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a2ad5add903eb7cdde2b7c64aaca405f3957ab34f16594d2b78d53b8b1a6a7d6"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9233ac249b354c54146e392e8a451e465dd2d967fc773690811d3a8c240ac601"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:920974009fb37b20d32afcdf0227a2e707eb83fe418713f7a8b7de038b870d0b"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2b6c5dfe0929b6c23dde9624483380b170b6e34ed79054ad131b20203a1a63"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96979d753b1dc3b2169003e1854dc67bfc86edf93c01e84757927f810b8c3c93"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ae54a338191e1356253e7883d9d19f8679b6143703086245fb14d1f20196be9"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2162ae2eb8b079622176a81b65d486ba50b888271302190870b8cc488587d280"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c884d1a59e69e03b93cf0dfee8794c63d7de0ee8f7ffb76e5f75be8131b6400a"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf9273e96f3ee2ac89ffcb17627a78f78e7516b08f94dc435844ae72576a276e"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:83215147121e15d5f3a45d99abeed9cf1fe16869d5c233b08c56cdf75f43a504"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f7454aa427b8ab9101f3787eb178057c5250478e39b99540cfc2b889c7d0586"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f0640913d2c1044d97e30d7c41728195fc37e54d190c5385eacb52115127b882"},
{file = "regex-2023.8.8-cp38-cp38-win32.whl", hash = "sha256:0c59122ceccb905a941fb23b087b8eafc5290bf983ebcb14d2301febcbe199c7"},
{file = "regex-2023.8.8-cp38-cp38-win_amd64.whl", hash = "sha256:c12f6f67495ea05c3d542d119d270007090bad5b843f642d418eb601ec0fa7be"},
{file = "regex-2023.8.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:82cd0a69cd28f6cc3789cc6adeb1027f79526b1ab50b1f6062bbc3a0ccb2dbc3"},
{file = "regex-2023.8.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bb34d1605f96a245fc39790a117ac1bac8de84ab7691637b26ab2c5efb8f228c"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:987b9ac04d0b38ef4f89fbc035e84a7efad9cdd5f1e29024f9289182c8d99e09"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dd6082f4e2aec9b6a0927202c85bc1b09dcab113f97265127c1dc20e2e32495"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7eb95fe8222932c10d4436e7a6f7c99991e3fdd9f36c949eff16a69246dee2dc"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7098c524ba9f20717a56a8d551d2ed491ea89cbf37e540759ed3b776a4f8d6eb"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b694430b3f00eb02c594ff5a16db30e054c1b9589a043fe9174584c6efa8033"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2aeab3895d778155054abea5238d0eb9a72e9242bd4b43f42fd911ef9a13470"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:988631b9d78b546e284478c2ec15c8a85960e262e247b35ca5eaf7ee22f6050a"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:67ecd894e56a0c6108ec5ab1d8fa8418ec0cff45844a855966b875d1039a2e34"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:14898830f0a0eb67cae2bbbc787c1a7d6e34ecc06fbd39d3af5fe29a4468e2c9"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f2200e00b62568cfd920127782c61bc1c546062a879cdc741cfcc6976668dfcf"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9691a549c19c22d26a4f3b948071e93517bdf86e41b81d8c6ac8a964bb71e5a6"},
{file = "regex-2023.8.8-cp39-cp39-win32.whl", hash = "sha256:6ab2ed84bf0137927846b37e882745a827458689eb969028af8032b1b3dac78e"},
{file = "regex-2023.8.8-cp39-cp39-win_amd64.whl", hash = "sha256:5543c055d8ec7801901e1193a51570643d6a6ab8751b1f7dd9af71af467538bb"},
{file = "regex-2023.8.8.tar.gz", hash = "sha256:fcbdc5f2b0f1cd0f6a56cdb46fe41d2cce1e644e3b68832f3eeebc5fb0f7712e"},
]
[[package]]
name = "requests"
version = "2.31.0"
@@ -2710,6 +2843,70 @@ botocore = ">=1.12.36,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
[[package]]
name = "safetensors"
version = "0.3.3"
description = "Fast and Safe Tensor serialization"
optional = false
python-versions = "*"
files = [
{file = "safetensors-0.3.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:92e4d0c8b2836120fddd134474c5bda8963f322333941f8b9f643e5b24f041eb"},
{file = "safetensors-0.3.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3dcadb6153c42addc9c625a622ebde9293fabe1973f9ef31ba10fb42c16e8536"},
{file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08f26b61e1b0a14dc959aa9d568776bd038805f611caef1de04a80c468d4a7a4"},
{file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:17f41344d9a075f2f21b289a49a62e98baff54b5754240ba896063bce31626bf"},
{file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:f1045f798e1a16a6ced98d6a42ec72936d367a2eec81dc5fade6ed54638cd7d2"},
{file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:eaf0e4bc91da13f21ac846a39429eb3f3b7ed06295a32321fa3eb1a59b5c70f3"},
{file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a07121f427e646a50d18c1be0fa1a2cbf6398624c31149cd7e6b35486d72189e"},
{file = "safetensors-0.3.3-cp310-cp310-win32.whl", hash = "sha256:a85e29cbfddfea86453cc0f4889b4bcc6b9c155be9a60e27be479a34e199e7ef"},
{file = "safetensors-0.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:e13adad4a3e591378f71068d14e92343e626cf698ff805f61cdb946e684a218e"},
{file = "safetensors-0.3.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:cbc3312f134baf07334dd517341a4b470b2931f090bd9284888acb7dfaf4606f"},
{file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d15030af39d5d30c22bcbc6d180c65405b7ea4c05b7bab14a570eac7d7d43722"},
{file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:f84a74cbe9859b28e3d6d7715ac1dd3097bebf8d772694098f6d42435245860c"},
{file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:10d637423d98ab2e6a4ad96abf4534eb26fcaf8ca3115623e64c00759374e90d"},
{file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:3b46f5de8b44084aff2e480874c550c399c730c84b2e8ad1bddb062c94aa14e9"},
{file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e8fdf7407dba44587ed5e79d5de3533d242648e1f2041760b21474bd5ea5c8c"},
{file = "safetensors-0.3.3-cp311-cp311-win32.whl", hash = "sha256:7d3b744cee8d7a46ffa68db1a2ff1a1a432488e3f7a5a97856fe69e22139d50c"},
{file = "safetensors-0.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f579877d30feec9b6ba409d05fa174633a4fc095675a4a82971d831a8bb60b97"},
{file = "safetensors-0.3.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2fff5b19a1b462c17322998b2f4b8bce43c16fe208968174d2f3a1446284ceed"},
{file = "safetensors-0.3.3-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:41adb1d39e8aad04b16879e3e0cbcb849315999fad73bc992091a01e379cb058"},
{file = "safetensors-0.3.3-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:0f2b404250b3b877b11d34afcc30d80e7035714a1116a3df56acaca6b6c00096"},
{file = "safetensors-0.3.3-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:b43956ef20e9f4f2e648818a9e7b3499edd6b753a0f5526d4f6a6826fbee8446"},
{file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c32ee08f61cea56a5d62bbf94af95df6040c8ab574afffaeb7b44ae5da1e9e3"},
{file = "safetensors-0.3.3-cp37-cp37m-win32.whl", hash = "sha256:351600f367badd59f7bfe86d317bb768dd8c59c1561c6fac43cafbd9c1af7827"},
{file = "safetensors-0.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:034717e297849dae1af0a7027a14b8647bd2e272c24106dced64d83e10d468d1"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8530399666748634bc0b301a6a5523756931b0c2680d188e743d16304afe917a"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:9d741c1f1621e489ba10aa3d135b54202684f6e205df52e219d5eecd673a80c9"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:0c345fd85b4d2093a5109596ff4cd9dfc2e84992e881b4857fbc4a93a3b89ddb"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:69ccee8d05f55cdf76f7e6c87d2bdfb648c16778ef8acfd2ecc495e273e9233e"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:c08a9a4b7a4ca389232fa8d097aebc20bbd4f61e477abc7065b5c18b8202dede"},
{file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:a002868d2e3f49bbe81bee2655a411c24fa1f8e68b703dec6629cb989d6ae42e"},
{file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab43aeeb9eadbb6b460df3568a662e6f1911ecc39387f8752afcb6a7d96c087"},
{file = "safetensors-0.3.3-cp38-cp38-win32.whl", hash = "sha256:f2f59fce31dd3429daca7269a6b06f65e6547a0c248f5116976c3f1e9b73f251"},
{file = "safetensors-0.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:c31ca0d8610f57799925bf08616856b39518ab772c65093ef1516762e796fde4"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:59a596b3225c96d59af412385981f17dd95314e3fffdf359c7e3f5bb97730a19"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:82a16e92210a6221edd75ab17acdd468dd958ef5023d9c6c1289606cc30d1479"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:98a929e763a581f516373ef31983ed1257d2d0da912a8e05d5cd12e9e441c93a"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:12b83f1986cd16ea0454c636c37b11e819d60dd952c26978310a0835133480b7"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:f439175c827c2f1bbd54df42789c5204a10983a30bc4242bc7deaf854a24f3f0"},
{file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:0085be33b8cbcb13079b3a8e131656e05b0bc5e6970530d4c24150f7afd76d70"},
{file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad3cc8006e7a86ee7c88bd2813ec59cd7cc75b03e6fa4af89b9c7b235b438d68"},
{file = "safetensors-0.3.3-cp39-cp39-win32.whl", hash = "sha256:ab29f54c6b8c301ca05fa014728996bd83aac6e21528f893aaf8945c71f42b6d"},
{file = "safetensors-0.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:0fa82004eae1a71e2aa29843ef99de9350e459a0fc2f65fc6ee0da9690933d2d"},
{file = "safetensors-0.3.3.tar.gz", hash = "sha256:edb7072d788c4f929d0f5735d3a2fb51e5a27f833587828583b7f5747af1a2b8"},
]
[package.extras]
all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"]
dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"]
jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)"]
numpy = ["numpy (>=1.21.6)"]
paddlepaddle = ["numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)"]
pinned-tf = ["tensorflow (==2.11.0)"]
quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"]
tensorflow = ["numpy (>=1.21.6)", "tensorflow (>=2.11.0)"]
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"]
torch = ["numpy (>=1.21.6)", "torch (>=1.10)"]
[[package]]
name = "sentry-sdk"
version = "1.29.2"
@@ -3013,6 +3210,75 @@ notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "transformers"
version = "4.32.1"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "transformers-4.32.1-py3-none-any.whl", hash = "sha256:b930d3dbd907a3f300cf49e54d63a56f8a0ab16b01a2c2a61ecff37c6de1da08"},
{file = "transformers-4.32.1.tar.gz", hash = "sha256:1edc8ae1de357d97c3d36b04412aa63d55e6fc0c4b39b419a7d380ed947d2252"},
]
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.15.1,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
safetensors = ">=0.3.1"
tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14"
tqdm = ">=4.27"
[package.extras]
accelerate = ["accelerate (>=0.20.3)"]
agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"]
all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
codecarbon = ["codecarbon (==1.2.0)"]
deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"]
docs-specific = ["hf-doc-builder"]
fairscale = ["fairscale (>0.3)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
integrations = ["optuna", "ray[tune]", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"]
ray = ["ray[tune]"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"]
tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
timm = ["timm"]
tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"]
torch = ["accelerate (>=0.20.3)", "torch (>=1.9,!=1.12.0)"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (<10.0.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.15.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"]
video = ["av (==9.2.0)", "decord (==0.6.0)"]
vision = ["Pillow (<10.0.0)"]
[[package]]
name = "typing-extensions"
version = "4.7.1"

View File

@@ -27,6 +27,8 @@ databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"}
sqlalchemy = "<1.5"
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
alembic = "^1.11.3"
nltk = "^3.8.1"
transformers = "^4.32.1"
prometheus-fastapi-instrumentator = "^6.1.0"

View File

@@ -1,12 +1,13 @@
from contextlib import asynccontextmanager
import reflector.auth # noqa
import reflector.db # noqa
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi_pagination import add_pagination
from prometheus_fastapi_instrumentator import Instrumentator
import reflector.auth # noqa
import reflector.db # noqa
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.logger import logger
from reflector.metrics import metrics_init

View File

@@ -1,5 +1,6 @@
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
@@ -16,7 +17,9 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Integer),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),

View File

@@ -1 +1,2 @@
from .base import LLM # noqa: F401
from .llm_params import LLMTaskParams # noqa: F401

View File

@@ -2,12 +2,19 @@ import importlib
import json
import re
from time import monotonic
from typing import TypeVar
import nltk
from prometheus_client import Counter, Histogram
from transformers import GenerationConfig
from reflector.llm.llm_params import TaskParams
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
T = TypeVar("T", bound="LLM")
class LLM:
_registry = {}
@@ -32,12 +39,25 @@ class LLM:
["backend"],
)
def __enter__(self):
self.ensure_nltk()
@classmethod
def ensure_nltk(cls):
"""
Make sure NLTK package is installed. Searches in the cache and
downloads only if needed.
"""
nltk.download("punkt", download_dir=settings.CACHE_DIR)
# For POS tagging
nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR)
@classmethod
def register(cls, name, klass):
cls._registry[name] = klass
@classmethod
def get_instance(cls, name=None):
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
"""
Return an instance depending on the settings.
Settings used:
@@ -50,7 +70,39 @@ class LLM:
if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name)
return cls._registry[name]()
return cls._registry[name](model_name)
def get_model_name(self) -> str:
"""
Get the currently set model name
"""
return self._get_model_name()
def _get_model_name(self) -> str:
pass
def set_model_name(self, model_name: str) -> bool:
"""
Update the model name with the provided model name
"""
return self._set_model_name(model_name)
def _set_model_name(self, model_name: str) -> bool:
raise NotImplementedError
@property
def template(self) -> str:
"""
Return the LLM Prompt template
"""
return """
### Human:
{instruct}
{text}
### Assistant:
"""
def __init__(self):
name = self.__class__.__name__
@@ -73,21 +125,39 @@ class LLM:
async def _warmup(self, logger: reflector_logger):
pass
@property
def tokenizer(self):
"""
Return the tokenizer instance used by LLM
"""
return self._get_tokenizer()
def _get_tokenizer(self):
pass
async def generate(
self,
prompt: str,
logger: reflector_logger,
schema: dict | None = None,
gen_schema: dict | None = None,
gen_cfg: GenerationConfig | None = None,
**kwargs,
) -> dict:
logger.info("LLM generate", prompt=repr(prompt))
if gen_cfg:
gen_cfg = gen_cfg.to_dict()
self.m_generate_call.inc()
try:
with self.m_generate.time():
result = await retry(self._generate)(
prompt=prompt, schema=schema, **kwargs
prompt=prompt,
gen_schema=gen_schema,
gen_cfg=gen_cfg,
**kwargs,
)
self.m_generate_success.inc()
except Exception:
logger.exception("Failed to call llm after retrying")
self.m_generate_failure.inc()
@@ -100,7 +170,60 @@ class LLM:
return result
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
def ensure_casing(self, title: str) -> str:
"""
LLM takes care of word casing, but in rare cases this
can falter. This is a fallback to ensure the casing of
topics is in a proper format.
We select nouns, verbs and adjectives and check if camel
casing is present and fix it, if not. Will not perform
any other changes.
"""
tokens = nltk.word_tokenize(title)
pos_tags = nltk.pos_tag(tokens)
camel_cased = []
whitelisted_pos_tags = [
"NN",
"NNS",
"NNP",
"NNPS", # Noun POS
"VB",
"VBD",
"VBG",
"VBN",
"VBP",
"VBZ", # Verb POS
"JJ",
"JJR",
"JJS", # Adjective POS
]
# If at all there is an exception, do not block other reflector
# processes. Return the LLM generated title, at the least.
try:
for word, pos in pos_tags:
if pos in whitelisted_pos_tags and word[0].islower():
camel_cased.append(word[0].upper() + word[1:])
else:
camel_cased.append(word)
modified_title = " ".join(camel_cased)
# The result can have words in braces with additional space.
# Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc.
pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])"
title = re.sub(pattern, "", modified_title)
except Exception as e:
reflector_logger.info(
f"Failed to ensure casing on {title=} " f"with exception : {str(e)}"
)
return title
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:
@@ -122,3 +245,62 @@ class LLM:
result = result[:-3]
return json.loads(result.strip())
def text_token_threshold(self, task_params: TaskParams | None) -> int:
"""
Choose the token size to set as the threshold to pack the LLM calls
"""
buffer_token_size = 25
default_output_tokens = 1000
context_window = self.tokenizer.model_max_length
tokens = self.tokenizer.tokenize(
self.create_prompt(instruct=task_params.instruct, text="")
)
threshold = context_window - len(tokens) - buffer_token_size
if task_params.gen_cfg:
threshold -= task_params.gen_cfg.max_new_tokens
else:
threshold -= default_output_tokens
return threshold
def split_corpus(
self,
corpus: str,
task_params: TaskParams,
token_threshold: int | None = None,
) -> list[str]:
"""
Split the input to the LLM due to CUDA memory limitations and LLM context window
restrictions.
Accumulate tokens from full sentences till threshold and yield accumulated
tokens. Reset accumulation when threshold is reached and repeat process.
"""
if not token_threshold:
token_threshold = self.text_token_threshold(task_params=task_params)
accumulated_tokens = []
accumulated_sentences = []
accumulated_token_count = 0
corpus_sentences = nltk.sent_tokenize(corpus)
for sentence in corpus_sentences:
tokens = self.tokenizer.tokenize(sentence)
if accumulated_token_count + len(tokens) <= token_threshold:
accumulated_token_count += len(tokens)
accumulated_tokens.extend(tokens)
accumulated_sentences.append(sentence)
else:
yield "".join(accumulated_sentences)
accumulated_token_count = len(tokens)
accumulated_tokens = tokens
accumulated_sentences = [sentence]
if accumulated_tokens:
yield " ".join(accumulated_sentences)
def create_prompt(self, instruct: str, text: str) -> str:
"""
Create a consumable prompt based on the prompt template
"""
return self.template.format(instruct=instruct, text=text)

View File

@@ -1,4 +1,5 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
@@ -13,10 +14,14 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
@@ -27,18 +32,21 @@ class BananaLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if not schema:
text = text[len(prompt) :]
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = BananaLLM()
result = await llm.generate("Hello, my name is")
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
import asyncio

View File

@@ -1,11 +1,14 @@
import httpx
from transformers import AutoTokenizer, GenerationConfig
from reflector.llm.base import LLM
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class ModalLLM(LLM):
def __init__(self):
def __init__(self, model_name: str | None = None):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.llm_url = settings.LLM_URL + "/llm"
@@ -13,6 +16,16 @@ class ModalLLM(LLM):
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
@property
def supported_models(self):
"""
List of currently supported models on this GPU platform
"""
# TODO: Query the specific GPU platform
# Replace this with a HTTP call
return ["lmsys/vicuna-13b-v1.5"]
async def _warmup(self, logger):
async with httpx.AsyncClient() as client:
@@ -23,10 +36,14 @@ class ModalLLM(LLM):
)
response.raise_for_status()
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
@@ -37,10 +54,43 @@ class ModalLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if not schema:
text = text[len(prompt) :]
return text
def _set_model_name(self, model_name: str) -> bool:
"""
Set the model name
"""
# Abort, if the model is not supported
if model_name not in self.supported_models:
reflector_logger.info(
f"Attempted to change {model_name=}, but is not supported."
f"Setting model and tokenizer failed !"
)
return False
# Abort, if the model is already set
elif hasattr(self, "model_name") and model_name == self._get_model_name():
reflector_logger.info("No change in model. Setting model skipped.")
return False
# Update model name and tokenizer
self.model_name = model_name
self.llm_tokenizer = AutoTokenizer.from_pretrained(
self.model_name, cache_dir=settings.CACHE_DIR
)
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
return True
def _get_tokenizer(self) -> AutoTokenizer:
"""
Return the currently used LLM tokenizer
"""
return self.llm_tokenizer
def _get_model_name(self) -> str:
"""
Return the current model name from the instance details
"""
return self.model_name
LLM.register("modal", ModalLLM)
@@ -49,15 +99,25 @@ if __name__ == "__main__":
async def main():
llm = ModalLLM()
result = await llm.generate("Hello, my name is", logger=logger)
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
schema = {
gen_schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
"properties": {"response": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
print(result)
gen_cfg = GenerationConfig(max_new_tokens=150)
result = await llm.generate(
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
)
print(result)
import asyncio

View File

@@ -1,13 +1,21 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
def __init__(self, model_name: str | None = None):
super().__init__()
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload.update(gen_cfg)
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,

View File

@@ -1,11 +1,13 @@
import httpx
from transformers import GenerationConfig
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
class OpenAILLM(LLM):
def __init__(self, **kwargs):
def __init__(self, model_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.openai_key = settings.LLM_OPENAI_KEY
self.openai_url = settings.LLM_URL
@@ -15,7 +17,13 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
async def _generate(
self,
prompt: str,
gen_schema: dict | None,
gen_cfg: GenerationConfig | None,
**kwargs,
) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",

View File

@@ -0,0 +1,150 @@
from typing import Optional, TypeVar
from pydantic import BaseModel
from transformers import GenerationConfig
class TaskParams(BaseModel, arbitrary_types_allowed=True):
instruct: str
gen_cfg: Optional[GenerationConfig] = None
gen_schema: Optional[dict] = None
T = TypeVar("T", bound="LLMTaskParams")
class LLMTaskParams:
_registry = {}
@classmethod
def register(cls, task, klass) -> None:
cls._registry[task] = klass
@classmethod
def get_instance(cls, task: str) -> T:
return cls._registry[task]()
@property
def task_params(self) -> TaskParams | None:
"""
Fetch the task related parameters
"""
return self._get_task_params()
def _get_task_params(self) -> None:
pass
class FinalLongSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
)
self._instruct = """
Take the key ideas and takeaways from the text and create a short
summary. Be sure to keep the length of the response to a minimum.
Do not include trivial information in the summary.
"""
self._schema = {
"type": "object",
"properties": {"long_summary": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
class FinalShortSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3
)
self._instruct = """
Take the key ideas and takeaways from the text and create a short
summary. Be sure to keep the length of the response to a minimum.
Do not include trivial information in the summary.
"""
self._schema = {
"type": "object",
"properties": {"short_summary": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
class FinalTitleParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
)
self._instruct = """
Combine the following individual titles into one single short title that
condenses the essence of all titles.
"""
self._schema = {
"type": "object",
"properties": {"title": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
class TopicParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9
)
self._instruct = """
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.
For the title field, generate a very detailed and self-explanatory
title for the given text. Let the title be as descriptive as possible.
For the summary field, summarize the given text in a maximum of
three sentences.
"""
self._schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
LLMTaskParams.register("topic", TopicParams)
LLMTaskParams.register("final_title", FinalTitleParams)
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)

View File

@@ -4,7 +4,20 @@ from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
from .transcript_final_long_summary import ( # noqa: F401
TranscriptFinalLongSummaryProcessor,
)
from .transcript_final_short_summary import ( # noqa: F401
TranscriptFinalShortSummaryProcessor,
)
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
from .types import AudioFile, FinalSummary, TitleSummary, Transcript, Word # noqa: F401
from .types import ( # noqa: F401
AudioFile,
FinalLongSummary,
FinalShortSummary,
TitleSummary,
Transcript,
Word,
)

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
from prometheus_client import Counter, Gauge, Histogram
from pydantic import BaseModel
from reflector.logger import logger
@@ -296,7 +297,7 @@ class BroadcastProcessor(Processor):
types of input.
"""
def __init__(self, processors: Processor):
def __init__(self, processors: list[Processor]):
super().__init__()
self.processors = processors
self.INPUT_TYPE = processors[0].INPUT_TYPE

View File

@@ -0,0 +1,59 @@
from reflector.llm import LLM, LLMTaskParams
from reflector.processors.base import Processor
from reflector.processors.types import FinalLongSummary, TitleSummary
class TranscriptFinalLongSummaryProcessor(Processor):
"""
Get the final long summary
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalLongSummary
TASK = "final_long_summary"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def get_long_summary(self, text: str) -> str:
"""
Generate a long version of the final summary
"""
self.logger.info(f"Smoothing out {len(text)} length summary to a long summary")
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
accumulated_summaries = ""
for chunk in chunks:
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
accumulated_summaries += summary_result["long_summary"]
return accumulated_summaries
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
long_summary = await self.get_long_summary(accumulated_summaries)
last_chunk = self.chunks[-1]
duration = last_chunk.timestamp + last_chunk.duration
final_long_summary = FinalLongSummary(
long_summary=long_summary,
duration=duration,
)
await self.emit(final_long_summary)

View File

@@ -0,0 +1,72 @@
from reflector.llm import LLM, LLMTaskParams
from reflector.processors.base import Processor
from reflector.processors.types import FinalShortSummary, TitleSummary
class TranscriptFinalShortSummaryProcessor(Processor):
"""
Get the final summary using a tree summarizer
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalShortSummary
TASK = "final_short_summary"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def get_short_summary(self, text: str) -> dict:
"""
Generata a short summary using tree summarizer
"""
self.logger.info(f"Smoothing out {len(text)} length summary to a short summary")
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
if len(chunks) == 1:
chunk = chunks[0]
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
return summary_result
else:
accumulated_summaries = ""
for chunk in chunks:
prompt = self.llm.create_prompt(
instruct=self.params.instruct, text=chunk
)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
accumulated_summaries += summary_result["short_summary"]
return await self.get_short_summary(accumulated_summaries)
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
short_summary_result = await self.get_short_summary(accumulated_summaries)
last_chunk = self.chunks[-1]
duration = last_chunk.timestamp + last_chunk.duration
final_summary = FinalShortSummary(
short_summary=short_summary_result["short_summary"],
duration=duration,
)
await self.emit(final_summary)

View File

@@ -1,30 +0,0 @@
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, FinalSummary
class TranscriptFinalSummaryProcessor(Processor):
"""
Assemble all summary into a line-based json
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
# FIXME improve final summary
result = "\n".join([chunk.summary for chunk in self.chunks])
last_chunk = self.chunks[-1]
duration = last_chunk.timestamp + last_chunk.duration
await self.emit(FinalSummary(summary=result, duration=duration))

View File

@@ -0,0 +1,65 @@
from reflector.llm import LLM, LLMTaskParams
from reflector.processors.base import Processor
from reflector.processors.types import FinalTitle, TitleSummary
class TranscriptFinalTitleProcessor(Processor):
"""
Assemble all summary into a line-based json
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalTitle
TASK = "final_title"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def get_title(self, text: str) -> dict:
"""
Generate a title for the whole recording
"""
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
if len(chunks) == 1:
chunk = chunks[0]
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
title_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
return title_result
else:
accumulated_titles = ""
for chunk in chunks:
prompt = self.llm.create_prompt(
instruct=self.params.instruct, text=chunk
)
title_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
accumulated_titles += title_result["summary"]
return await self.get_title(accumulated_titles)
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
title_result = await self.get_title(accumulated_titles)
final_title = FinalTitle(title=title_result["title"])
await self.emit(final_title)

View File

@@ -1,7 +1,6 @@
from reflector.llm import LLM
from reflector.llm import LLM, LLMTaskParams
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.utils.retry import retry
class TranscriptTopicDetectorProcessor(Processor):
@@ -11,34 +10,14 @@ class TranscriptTopicDetectorProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = TitleSummary
TASK = "topic"
PROMPT = """
### Human:
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.
For the title field, generate a short title for the given text.
For the summary field, summarize the given text in a maximum of
three sentences.
{input_text}
### Assistant:
"""
def __init__(self, min_transcript_length=750, **kwargs):
def __init__(self, min_transcript_length: int = 750, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance()
self.topic_detector_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
self.params = LLMTaskParams.get_instance(self.TASK).task_params
async def _warmup(self):
await self.llm.warmup(logger=self.logger)
@@ -55,18 +34,30 @@ class TranscriptTopicDetectorProcessor(Processor):
return
await self.flush()
async def get_topic(self, text: str) -> dict:
"""
Generate a topic and description for a transcription excerpt
"""
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
topic_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
return topic_result
async def _flush(self):
if not self.transcript:
return
text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
)
topic_result = await self.get_topic(text=text)
summary = TitleSummary(
title=result["title"],
summary=result["summary"],
title=self.llm.ensure_casing(topic_result["title"]),
summary=topic_result["summary"],
timestamp=self.transcript.timestamp,
duration=self.transcript.duration,
transcript=self.transcript,

View File

@@ -103,11 +103,20 @@ class TitleSummary(BaseModel):
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
class FinalSummary(BaseModel):
summary: str
class FinalLongSummary(BaseModel):
long_summary: str
duration: float
class FinalShortSummary(BaseModel):
short_summary: str
duration: float
class FinalTitle(BaseModel):
title: str
class TranslationLanguages(BaseModel):
language_to_id_mapping: dict = {
"Afrikaans": "af",

View File

@@ -91,5 +91,11 @@ class Settings(BaseSettings):
# if set, all anonymous record will be public
PUBLIC_MODE: bool = False
# Default LLM model name
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
# Cache directory for all model storage
CACHE_DIR: str = "data"
settings = Settings()

View File

@@ -1,6 +1,7 @@
import asyncio
import av
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
@@ -8,10 +9,13 @@ from reflector.processors import (
AudioTranscriptAutoProcessor,
Pipeline,
PipelineEvent,
TranscriptFinalSummaryProcessor,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.base import BroadcastProcessor
async def process_audio_file(
@@ -31,7 +35,13 @@ async def process_audio_file(
if not only_transcript:
processors += [
TranscriptTopicDetectorProcessor.as_threaded(),
TranscriptFinalSummaryProcessor.as_threaded(),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(),
TranscriptFinalLongSummaryProcessor.as_threaded(),
TranscriptFinalShortSummaryProcessor.as_threaded(),
],
),
]
# transcription output

View File

@@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request
from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
@@ -15,14 +16,19 @@ from reflector.processors import (
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalSummary,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalSummaryProcessor,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
sessions = []
router = APIRouter()
@@ -72,8 +78,10 @@ class StrValue(BaseModel):
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY"
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
STATUS = "STATUS"
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
FINAL_TITLE = "FINAL_TITLE"
async def rtc_offer_base(
@@ -124,15 +132,15 @@ async def rtc_offer_base(
data=transcript,
)
async def on_topic(summary: TitleSummary):
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary)
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": summary.title,
"timestamp": summary.timestamp,
"transcript": summary.transcript.text,
"desc": summary.summary,
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
@@ -144,17 +152,17 @@ async def rtc_offer_base(
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_summary(summary: FinalSummary):
ctx.logger.info("FinalSummary", final_summary=summary)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
@@ -162,11 +170,47 @@ async def rtc_offer_base(
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SUMMARY,
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
@@ -178,7 +222,17 @@ async def rtc_offer_base(
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=on_final_short_summary
),
]
),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)

View File

@@ -7,7 +7,6 @@ from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
@@ -18,11 +17,13 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -60,8 +61,16 @@ class TranscriptTopic(BaseModel):
timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
@@ -77,7 +86,9 @@ class Transcript(BaseModel):
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
summary: str | None = None
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
@@ -241,7 +252,9 @@ class GetTranscript(BaseModel):
status: str
locked: bool
duration: int
summary: str | None
title: str | None
short_summary: str | None
long_summary: str | None
created_at: datetime
source_language: str
target_language: str
@@ -256,7 +269,9 @@ class CreateTranscript(BaseModel):
class UpdateTranscript(BaseModel):
name: Optional[str] = Field(None)
locked: Optional[bool] = Field(None)
summary: Optional[str] = Field(None)
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
class DeletionStatus(BaseModel):
@@ -315,20 +330,32 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {}
values = {"events": []}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.summary is not None:
values["summary"] = info.summary
# also find FINAL_SUMMARY event and patch it
for te in transcript.events:
if te["event"] == PipelineEvent.FINAL_SUMMARY:
te["summary"] = info.summary
if info.long_summary is not None:
values["long_summary"] = info.long_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
transcript_event["long_summary"] = info.long_summary
break
values["events"] = transcript.events
values["events"].extend(transcript.events)
if info.short_summary is not None:
values["short_summary"] = info.short_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
transcript_event["short_summary"] = info.short_summary
break
values["events"].extend(transcript.events)
if info.title is not None:
values["title"] = info.title
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
transcript_event["title"] = info.title
break
values["events"].extend(transcript.events)
await transcripts_controller.update(transcript, values)
return transcript
@@ -539,14 +566,38 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
},
)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
elif event == PipelineEvent.FINAL_TITLE:
final_title = TranscriptFinalTitle(title=data.title)
resp = transcript.add_event(event=event, data=final_title)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"summary": final_summary.summary,
"title": final_title.title,
},
)
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
resp = transcript.add_event(event=event, data=final_long_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"long_summary": final_long_summary.long_summary,
},
)
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
resp = transcript.add_event(event=event, data=final_short_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"short_summary": final_short_summary.short_summary,
},
)

View File

@@ -1,3 +1,5 @@
from unittest.mock import patch
import pytest
@@ -14,3 +16,50 @@ async def setup_database():
metadata.create_all(bind=engine)
yield
@pytest.fixture
def dummy_processors():
with patch(
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
) as mock_topic, patch(
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
) as mock_title, patch(
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
) as mock_long_summary, patch(
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
) as mock_short_summary:
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM FINAL TITLE"}
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
mock_nltk.return_value = "NLTK PACKAGE"
yield
@pytest.fixture
def ensure_casing():
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
mock_casing.return_value = "LLM TITLE"
yield

View File

@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
async def test_processor_broadcast():
async def test_processor_broadcast(nltk):
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
class TestProcessor(Processor):

View File

@@ -2,27 +2,19 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(event_loop):
async def test_basic_process(
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing
):
# goal is to start the server, and send rtc audio to it
# validate the events received
from reflector.tools.process import process_audio_file
from reflector.settings import settings
from reflector.llm.base import LLM
from pathlib import Path
# use an LLM test backend
settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
class LLMTest(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
return {
"title": "TITLE",
"summary": "SUMMARY",
}
LLM.register("test", LLMTest)
# event callback
marks = {}
@@ -39,4 +31,6 @@ async def test_basic_process(event_loop):
# validate the events
assert marks["TranscriptLinerProcessor"] == 5
assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalSummaryProcessor"] == 1
assert marks["TranscriptFinalLongSummaryProcessor"] == 1
assert marks["TranscriptFinalShortSummaryProcessor"] == 1
assert marks["TranscriptFinalTitleProcessor"] == 1

View File

@@ -75,21 +75,52 @@ async def test_transcript_get_update_summary():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["summary"] is None
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] is None
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"})
response = await ac.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["summary"] == "test"
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] == "test"
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
@pytest.mark.asyncio
async def test_transcript_get_update_title():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
@pytest.mark.asyncio

View File

@@ -67,21 +67,10 @@ async def dummy_transcript():
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
@@ -186,9 +175,17 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM FINAL TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
@@ -218,7 +215,9 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
@@ -326,9 +325,17 @@ async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dum
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM FINAL TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]

View File

@@ -94,6 +94,13 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
break;
case "FINAL_LONG_SUMMARY":
if (message.data) {
message.data = { summary: message.data.long_summary };
setFinalSummary(message.data);
console.debug("FINAL_LONG_SUMMARY event:", message.data);
}
break;
case "FINAL_SUMMARY":
if (message.data) {
setFinalSummary(message.data);