mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Feature additions (#210)
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
This commit is contained in:
@@ -18,12 +18,6 @@ repos:
|
||||
exclude: ^server/trials
|
||||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
rev: v0.0.282
|
||||
hooks:
|
||||
- id: ruff
|
||||
files: ^server/(reflector|tests)/
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
@@ -36,4 +30,10 @@ repos:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
files: ^server/(gpu|evaluate|reflector)/
|
||||
args: ["--profile", "black", "--filter-files"]
|
||||
args: [ "--profile", "black", "--filter-files" ]
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
rev: v0.0.282
|
||||
hooks:
|
||||
- id: ruff
|
||||
files: ^server/(reflector|tests)/
|
||||
|
||||
@@ -94,6 +94,12 @@
|
||||
#LLM_URL=http://localhost:4891/v1/completions
|
||||
#LLM_OPENAI_MODEL="GPT4All Falcon"
|
||||
|
||||
## Default LLM MODEL NAME
|
||||
DEFAULT_LLM=lmsys/vicuna-13b-v1.5
|
||||
|
||||
## Cache directory to store models
|
||||
CACHE_DIR=data
|
||||
|
||||
## =======================================================
|
||||
## Sentry
|
||||
## =======================================================
|
||||
|
||||
@@ -55,7 +55,7 @@ llm_image = (
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
"huggingface_hub==0.16.4"
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
@@ -73,8 +73,7 @@ llm_image = (
|
||||
class LLM:
|
||||
def __enter__(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -84,10 +83,11 @@ class LLM:
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# generation configuration
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
# JSONFormer doesn't yet support generation configs, but keeping for future usage
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
@@ -106,6 +106,7 @@ class LLM:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
|
||||
def __exit__(self, *args):
|
||||
print("Exit llm")
|
||||
@@ -116,34 +117,44 @@ class LLM:
|
||||
return {"status": "ok"}
|
||||
|
||||
@method()
|
||||
def generate(self, prompt: str, schema: str = None):
|
||||
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
# If a schema is given, conform to schema
|
||||
if schema:
|
||||
print(f"Schema {schema=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no schema, perform prompt only generation
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
response = response[len(prompt):]
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
@@ -160,7 +171,7 @@ class LLM:
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
@@ -177,16 +188,16 @@ def web():
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
schema_: Optional[dict] = Field(None, alias="schema")
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
async def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
if req.schema_:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_))
|
||||
else:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt)
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_title/short_and_long_summary_and_remove_summary
|
||||
|
||||
Revision ID: 99365b0cd87b
|
||||
Revises: b3df9681cae9
|
||||
Create Date: 2023-09-01 20:19:47.216334
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '99365b0cd87b'
|
||||
down_revision: Union[str, None] = 'b3df9681cae9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.execute("UPDATE transcript SET events = "
|
||||
"REPLACE(events, '\"event\": \"SUMMARY\"', '\"event\": \"LONG_SUMMARY\"');")
|
||||
op.alter_column('transcript', 'summary', new_column_name='long_summary')
|
||||
op.add_column('transcript', sa.Column('title', sa.String(), nullable=True))
|
||||
op.add_column('transcript', sa.Column('short_summary', sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.execute("UPDATE transcript SET events = "
|
||||
"REPLACE(events, '\"event\": \"LONG_SUMMARY\"', '\"event\": \"SUMMARY\"');")
|
||||
op.alter_column('transcript', 'long_summary', nullable=True, new_column_name='summary')
|
||||
op.drop_column('transcript', 'title')
|
||||
op.drop_column('transcript', 'short_summary')
|
||||
# ### end Alembic commands ###
|
||||
266
server/poetry.lock
generated
266
server/poetry.lock
generated
@@ -1588,6 +1588,17 @@ files = [
|
||||
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.3.2"
|
||||
description = "Lightweight pipelining with Python functions"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
|
||||
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jwcrypto"
|
||||
version = "1.5.0"
|
||||
@@ -1934,6 +1945,31 @@ files = [
|
||||
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.8.1"
|
||||
description = "Natural Language Toolkit"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"},
|
||||
{file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = "*"
|
||||
joblib = "*"
|
||||
regex = ">=2021.8.3"
|
||||
tqdm = "*"
|
||||
|
||||
[package.extras]
|
||||
all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"]
|
||||
corenlp = ["requests"]
|
||||
machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"]
|
||||
plot = ["matplotlib"]
|
||||
tgrep = ["pyparsing"]
|
||||
twitter = ["twython"]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.25.2"
|
||||
@@ -2672,6 +2708,103 @@ files = [
|
||||
[package.extras]
|
||||
full = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2023.8.8"
|
||||
description = "Alternative regular expression module, to replace re."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "regex-2023.8.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:88900f521c645f784260a8d346e12a1590f79e96403971241e64c3a265c8ecdb"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3611576aff55918af2697410ff0293d6071b7e00f4b09e005d614686ac4cd57c"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8a0ccc8f2698f120e9e5742f4b38dc944c38744d4bdfc427616f3a163dd9de5"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c662a4cbdd6280ee56f841f14620787215a171c4e2d1744c9528bed8f5816c96"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf0633e4a1b667bfe0bb10b5e53fe0d5f34a6243ea2530eb342491f1adf4f739"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551ad543fa19e94943c5b2cebc54c73353ffff08228ee5f3376bd27b3d5b9800"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54de2619f5ea58474f2ac211ceea6b615af2d7e4306220d4f3fe690c91988a61"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ec4b3f0aebbbe2fc0134ee30a791af522a92ad9f164858805a77442d7d18570"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ae646c35cb9f820491760ac62c25b6d6b496757fda2d51be429e0e7b67ae0ab"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca339088839582d01654e6f83a637a4b8194d0960477b9769d2ff2cfa0fa36d2"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:d9b6627408021452dcd0d2cdf8da0534e19d93d070bfa8b6b4176f99711e7f90"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:bd3366aceedf274f765a3a4bc95d6cd97b130d1dda524d8f25225d14123c01db"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7aed90a72fc3654fba9bc4b7f851571dcc368120432ad68b226bd593f3f6c0b7"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-win32.whl", hash = "sha256:80b80b889cb767cc47f31d2b2f3dec2db8126fbcd0cff31b3925b4dc6609dcdb"},
|
||||
{file = "regex-2023.8.8-cp310-cp310-win_amd64.whl", hash = "sha256:b82edc98d107cbc7357da7a5a695901b47d6eb0420e587256ba3ad24b80b7d0b"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1e7d84d64c84ad97bf06f3c8cb5e48941f135ace28f450d86af6b6512f1c9a71"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce0f9fbe7d295f9922c0424a3637b88c6c472b75eafeaff6f910494a1fa719ef"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06c57e14ac723b04458df5956cfb7e2d9caa6e9d353c0b4c7d5d54fcb1325c46"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7a9aaa5a1267125eef22cef3b63484c3241aaec6f48949b366d26c7250e0357"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b7408511fca48a82a119d78a77c2f5eb1b22fe88b0d2450ed0756d194fe7a9a"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14dc6f2d88192a67d708341f3085df6a4f5a0c7b03dec08d763ca2cd86e9f559"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48c640b99213643d141550326f34f0502fedb1798adb3c9eb79650b1ecb2f177"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0085da0f6c6393428bf0d9c08d8b1874d805bb55e17cb1dfa5ddb7cfb11140bf"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:964b16dcc10c79a4a2be9f1273fcc2684a9eedb3906439720598029a797b46e6"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7ce606c14bb195b0e5108544b540e2c5faed6843367e4ab3deb5c6aa5e681208"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:40f029d73b10fac448c73d6eb33d57b34607f40116e9f6e9f0d32e9229b147d7"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3b8e6ea6be6d64104d8e9afc34c151926f8182f84e7ac290a93925c0db004bfd"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-win32.whl", hash = "sha256:942f8b1f3b223638b02df7df79140646c03938d488fbfb771824f3d05fc083a8"},
|
||||
{file = "regex-2023.8.8-cp311-cp311-win_amd64.whl", hash = "sha256:51d8ea2a3a1a8fe4f67de21b8b93757005213e8ac3917567872f2865185fa7fb"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e951d1a8e9963ea51efd7f150450803e3b95db5939f994ad3d5edac2b6f6e2b4"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704f63b774218207b8ccc6c47fcef5340741e5d839d11d606f70af93ee78e4d4"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22283c769a7b01c8ac355d5be0715bf6929b6267619505e289f792b01304d898"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91129ff1bb0619bc1f4ad19485718cc623a2dc433dff95baadbf89405c7f6b57"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de35342190deb7b866ad6ba5cbcccb2d22c0487ee0cbb251efef0843d705f0d4"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b993b6f524d1e274a5062488a43e3f9f8764ee9745ccd8e8193df743dbe5ee61"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3026cbcf11d79095a32d9a13bbc572a458727bd5b1ca332df4a79faecd45281c"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:293352710172239bf579c90a9864d0df57340b6fd21272345222fb6371bf82b3"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d909b5a3fff619dc7e48b6b1bedc2f30ec43033ba7af32f936c10839e81b9217"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3d370ff652323c5307d9c8e4c62efd1956fb08051b0e9210212bc51168b4ff56"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:b076da1ed19dc37788f6a934c60adf97bd02c7eea461b73730513921a85d4235"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e9941a4ada58f6218694f382e43fdd256e97615db9da135e77359da257a7168b"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-win32.whl", hash = "sha256:a8c65c17aed7e15a0c824cdc63a6b104dfc530f6fa8cb6ac51c437af52b481c7"},
|
||||
{file = "regex-2023.8.8-cp36-cp36m-win_amd64.whl", hash = "sha256:aadf28046e77a72f30dcc1ab185639e8de7f4104b8cb5c6dfa5d8ed860e57236"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:423adfa872b4908843ac3e7a30f957f5d5282944b81ca0a3b8a7ccbbfaa06103"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ae594c66f4a7e1ea67232a0846649a7c94c188d6c071ac0210c3e86a5f92109"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e51c80c168074faa793685656c38eb7a06cbad7774c8cbc3ea05552d615393d8"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09b7f4c66aa9d1522b06e31a54f15581c37286237208df1345108fcf4e050c18"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e73e5243af12d9cd6a9d6a45a43570dbe2e5b1cdfc862f5ae2b031e44dd95a8"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941460db8fe3bd613db52f05259c9336f5a47ccae7d7def44cc277184030a116"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f0ccf3e01afeb412a1a9993049cb160d0352dba635bbca7762b2dc722aa5742a"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2e9216e0d2cdce7dbc9be48cb3eacb962740a09b011a116fd7af8c832ab116ca"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:5cd9cd7170459b9223c5e592ac036e0704bee765706445c353d96f2890e816c8"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4873ef92e03a4309b3ccd8281454801b291b689f6ad45ef8c3658b6fa761d7ac"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:239c3c2a339d3b3ddd51c2daef10874410917cd2b998f043c13e2084cb191684"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1005c60ed7037be0d9dea1f9c53cc42f836188227366370867222bda4c3c6bd7"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-win32.whl", hash = "sha256:e6bd1e9b95bc5614a7a9c9c44fde9539cba1c823b43a9f7bc11266446dd568e3"},
|
||||
{file = "regex-2023.8.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9a96edd79661e93327cfeac4edec72a4046e14550a1d22aa0dd2e3ca52aec921"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f2181c20ef18747d5f4a7ea513e09ea03bdd50884a11ce46066bb90fe4213675"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a2ad5add903eb7cdde2b7c64aaca405f3957ab34f16594d2b78d53b8b1a6a7d6"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9233ac249b354c54146e392e8a451e465dd2d967fc773690811d3a8c240ac601"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:920974009fb37b20d32afcdf0227a2e707eb83fe418713f7a8b7de038b870d0b"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2b6c5dfe0929b6c23dde9624483380b170b6e34ed79054ad131b20203a1a63"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96979d753b1dc3b2169003e1854dc67bfc86edf93c01e84757927f810b8c3c93"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ae54a338191e1356253e7883d9d19f8679b6143703086245fb14d1f20196be9"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2162ae2eb8b079622176a81b65d486ba50b888271302190870b8cc488587d280"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c884d1a59e69e03b93cf0dfee8794c63d7de0ee8f7ffb76e5f75be8131b6400a"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf9273e96f3ee2ac89ffcb17627a78f78e7516b08f94dc435844ae72576a276e"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:83215147121e15d5f3a45d99abeed9cf1fe16869d5c233b08c56cdf75f43a504"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f7454aa427b8ab9101f3787eb178057c5250478e39b99540cfc2b889c7d0586"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f0640913d2c1044d97e30d7c41728195fc37e54d190c5385eacb52115127b882"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-win32.whl", hash = "sha256:0c59122ceccb905a941fb23b087b8eafc5290bf983ebcb14d2301febcbe199c7"},
|
||||
{file = "regex-2023.8.8-cp38-cp38-win_amd64.whl", hash = "sha256:c12f6f67495ea05c3d542d119d270007090bad5b843f642d418eb601ec0fa7be"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:82cd0a69cd28f6cc3789cc6adeb1027f79526b1ab50b1f6062bbc3a0ccb2dbc3"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bb34d1605f96a245fc39790a117ac1bac8de84ab7691637b26ab2c5efb8f228c"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:987b9ac04d0b38ef4f89fbc035e84a7efad9cdd5f1e29024f9289182c8d99e09"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dd6082f4e2aec9b6a0927202c85bc1b09dcab113f97265127c1dc20e2e32495"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7eb95fe8222932c10d4436e7a6f7c99991e3fdd9f36c949eff16a69246dee2dc"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7098c524ba9f20717a56a8d551d2ed491ea89cbf37e540759ed3b776a4f8d6eb"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b694430b3f00eb02c594ff5a16db30e054c1b9589a043fe9174584c6efa8033"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2aeab3895d778155054abea5238d0eb9a72e9242bd4b43f42fd911ef9a13470"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:988631b9d78b546e284478c2ec15c8a85960e262e247b35ca5eaf7ee22f6050a"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:67ecd894e56a0c6108ec5ab1d8fa8418ec0cff45844a855966b875d1039a2e34"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:14898830f0a0eb67cae2bbbc787c1a7d6e34ecc06fbd39d3af5fe29a4468e2c9"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f2200e00b62568cfd920127782c61bc1c546062a879cdc741cfcc6976668dfcf"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9691a549c19c22d26a4f3b948071e93517bdf86e41b81d8c6ac8a964bb71e5a6"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-win32.whl", hash = "sha256:6ab2ed84bf0137927846b37e882745a827458689eb969028af8032b1b3dac78e"},
|
||||
{file = "regex-2023.8.8-cp39-cp39-win_amd64.whl", hash = "sha256:5543c055d8ec7801901e1193a51570643d6a6ab8751b1f7dd9af71af467538bb"},
|
||||
{file = "regex-2023.8.8.tar.gz", hash = "sha256:fcbdc5f2b0f1cd0f6a56cdb46fe41d2cce1e644e3b68832f3eeebc5fb0f7712e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.31.0"
|
||||
@@ -2710,6 +2843,70 @@ botocore = ">=1.12.36,<2.0a.0"
|
||||
[package.extras]
|
||||
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.3.3"
|
||||
description = "Fast and Safe Tensor serialization"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:92e4d0c8b2836120fddd134474c5bda8963f322333941f8b9f643e5b24f041eb"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3dcadb6153c42addc9c625a622ebde9293fabe1973f9ef31ba10fb42c16e8536"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08f26b61e1b0a14dc959aa9d568776bd038805f611caef1de04a80c468d4a7a4"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:17f41344d9a075f2f21b289a49a62e98baff54b5754240ba896063bce31626bf"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:f1045f798e1a16a6ced98d6a42ec72936d367a2eec81dc5fade6ed54638cd7d2"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:eaf0e4bc91da13f21ac846a39429eb3f3b7ed06295a32321fa3eb1a59b5c70f3"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a07121f427e646a50d18c1be0fa1a2cbf6398624c31149cd7e6b35486d72189e"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-win32.whl", hash = "sha256:a85e29cbfddfea86453cc0f4889b4bcc6b9c155be9a60e27be479a34e199e7ef"},
|
||||
{file = "safetensors-0.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:e13adad4a3e591378f71068d14e92343e626cf698ff805f61cdb946e684a218e"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:cbc3312f134baf07334dd517341a4b470b2931f090bd9284888acb7dfaf4606f"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d15030af39d5d30c22bcbc6d180c65405b7ea4c05b7bab14a570eac7d7d43722"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:f84a74cbe9859b28e3d6d7715ac1dd3097bebf8d772694098f6d42435245860c"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:10d637423d98ab2e6a4ad96abf4534eb26fcaf8ca3115623e64c00759374e90d"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:3b46f5de8b44084aff2e480874c550c399c730c84b2e8ad1bddb062c94aa14e9"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e8fdf7407dba44587ed5e79d5de3533d242648e1f2041760b21474bd5ea5c8c"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-win32.whl", hash = "sha256:7d3b744cee8d7a46ffa68db1a2ff1a1a432488e3f7a5a97856fe69e22139d50c"},
|
||||
{file = "safetensors-0.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f579877d30feec9b6ba409d05fa174633a4fc095675a4a82971d831a8bb60b97"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2fff5b19a1b462c17322998b2f4b8bce43c16fe208968174d2f3a1446284ceed"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:41adb1d39e8aad04b16879e3e0cbcb849315999fad73bc992091a01e379cb058"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:0f2b404250b3b877b11d34afcc30d80e7035714a1116a3df56acaca6b6c00096"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:b43956ef20e9f4f2e648818a9e7b3499edd6b753a0f5526d4f6a6826fbee8446"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c32ee08f61cea56a5d62bbf94af95df6040c8ab574afffaeb7b44ae5da1e9e3"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-win32.whl", hash = "sha256:351600f367badd59f7bfe86d317bb768dd8c59c1561c6fac43cafbd9c1af7827"},
|
||||
{file = "safetensors-0.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:034717e297849dae1af0a7027a14b8647bd2e272c24106dced64d83e10d468d1"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8530399666748634bc0b301a6a5523756931b0c2680d188e743d16304afe917a"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:9d741c1f1621e489ba10aa3d135b54202684f6e205df52e219d5eecd673a80c9"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:0c345fd85b4d2093a5109596ff4cd9dfc2e84992e881b4857fbc4a93a3b89ddb"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:69ccee8d05f55cdf76f7e6c87d2bdfb648c16778ef8acfd2ecc495e273e9233e"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:c08a9a4b7a4ca389232fa8d097aebc20bbd4f61e477abc7065b5c18b8202dede"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:a002868d2e3f49bbe81bee2655a411c24fa1f8e68b703dec6629cb989d6ae42e"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab43aeeb9eadbb6b460df3568a662e6f1911ecc39387f8752afcb6a7d96c087"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-win32.whl", hash = "sha256:f2f59fce31dd3429daca7269a6b06f65e6547a0c248f5116976c3f1e9b73f251"},
|
||||
{file = "safetensors-0.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:c31ca0d8610f57799925bf08616856b39518ab772c65093ef1516762e796fde4"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:59a596b3225c96d59af412385981f17dd95314e3fffdf359c7e3f5bb97730a19"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:82a16e92210a6221edd75ab17acdd468dd958ef5023d9c6c1289606cc30d1479"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:98a929e763a581f516373ef31983ed1257d2d0da912a8e05d5cd12e9e441c93a"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:12b83f1986cd16ea0454c636c37b11e819d60dd952c26978310a0835133480b7"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:f439175c827c2f1bbd54df42789c5204a10983a30bc4242bc7deaf854a24f3f0"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:0085be33b8cbcb13079b3a8e131656e05b0bc5e6970530d4c24150f7afd76d70"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad3cc8006e7a86ee7c88bd2813ec59cd7cc75b03e6fa4af89b9c7b235b438d68"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-win32.whl", hash = "sha256:ab29f54c6b8c301ca05fa014728996bd83aac6e21528f893aaf8945c71f42b6d"},
|
||||
{file = "safetensors-0.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:0fa82004eae1a71e2aa29843ef99de9350e459a0fc2f65fc6ee0da9690933d2d"},
|
||||
{file = "safetensors-0.3.3.tar.gz", hash = "sha256:edb7072d788c4f929d0f5735d3a2fb51e5a27f833587828583b7f5747af1a2b8"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"]
|
||||
dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"]
|
||||
jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)"]
|
||||
numpy = ["numpy (>=1.21.6)"]
|
||||
paddlepaddle = ["numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)"]
|
||||
pinned-tf = ["tensorflow (==2.11.0)"]
|
||||
quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"]
|
||||
tensorflow = ["numpy (>=1.21.6)", "tensorflow (>=2.11.0)"]
|
||||
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"]
|
||||
torch = ["numpy (>=1.21.6)", "torch (>=1.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.29.2"
|
||||
@@ -3013,6 +3210,75 @@ notebook = ["ipywidgets (>=6)"]
|
||||
slack = ["slack-sdk"]
|
||||
telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.32.1"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "transformers-4.32.1-py3-none-any.whl", hash = "sha256:b930d3dbd907a3f300cf49e54d63a56f8a0ab16b01a2c2a61ecff37c6de1da08"},
|
||||
{file = "transformers-4.32.1.tar.gz", hash = "sha256:1edc8ae1de357d97c3d36b04412aa63d55e6fc0c4b39b419a7d380ed947d2252"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.15.1,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
regex = "!=2019.12.17"
|
||||
requests = "*"
|
||||
safetensors = ">=0.3.1"
|
||||
tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14"
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.20.3)"]
|
||||
agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"]
|
||||
all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
codecarbon = ["codecarbon (==1.2.0)"]
|
||||
deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"]
|
||||
docs-specific = ["hf-doc-builder"]
|
||||
fairscale = ["fairscale (>0.3)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune]", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6)"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune]"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
||||
serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"]
|
||||
tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
timm = ["timm"]
|
||||
tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"]
|
||||
torch = ["accelerate (>=0.20.3)", "torch (>=1.9,!=1.12.0)"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (<10.0.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.15.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"]
|
||||
video = ["av (==9.2.0)", "decord (==0.6.0)"]
|
||||
vision = ["Pillow (<10.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.7.1"
|
||||
|
||||
@@ -27,6 +27,8 @@ databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"}
|
||||
sqlalchemy = "<1.5"
|
||||
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
|
||||
alembic = "^1.11.3"
|
||||
nltk = "^3.8.1"
|
||||
transformers = "^4.32.1"
|
||||
prometheus-fastapi-instrumentator = "^6.1.0"
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi_pagination import add_pagination
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.logger import logger
|
||||
from reflector.metrics import metrics_init
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
@@ -16,7 +17,9 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from .llm_params import LLMTaskParams # noqa: F401
|
||||
|
||||
@@ -2,12 +2,19 @@ import importlib
|
||||
import json
|
||||
import re
|
||||
from time import monotonic
|
||||
from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_registry = {}
|
||||
@@ -32,12 +39,25 @@ class LLM:
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.ensure_nltk()
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
nltk.download("punkt", download_dir=settings.CACHE_DIR)
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR)
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name=None):
|
||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
@@ -50,7 +70,39 @@ class LLM:
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
return cls._registry[name]()
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get the currently set model name
|
||||
"""
|
||||
return self._get_model_name()
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
pass
|
||||
|
||||
def set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Update the model name with the provided model name
|
||||
"""
|
||||
return self._set_model_name(model_name)
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
"""
|
||||
Return the LLM Prompt template
|
||||
"""
|
||||
return """
|
||||
### Human:
|
||||
{instruct}
|
||||
|
||||
{text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -73,21 +125,39 @@ class LLM:
|
||||
async def _warmup(self, logger: reflector_logger):
|
||||
pass
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer instance used by LLM
|
||||
"""
|
||||
return self._get_tokenizer()
|
||||
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
schema: dict | None = None,
|
||||
gen_schema: dict | None = None,
|
||||
gen_cfg: GenerationConfig | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
|
||||
if gen_cfg:
|
||||
gen_cfg = gen_cfg.to_dict()
|
||||
self.m_generate_call.inc()
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._generate)(
|
||||
prompt=prompt, schema=schema, **kwargs
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
@@ -100,7 +170,60 @@ class LLM:
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
def ensure_casing(self, title: str) -> str:
|
||||
"""
|
||||
LLM takes care of word casing, but in rare cases this
|
||||
can falter. This is a fallback to ensure the casing of
|
||||
topics is in a proper format.
|
||||
|
||||
We select nouns, verbs and adjectives and check if camel
|
||||
casing is present and fix it, if not. Will not perform
|
||||
any other changes.
|
||||
"""
|
||||
tokens = nltk.word_tokenize(title)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
camel_cased = []
|
||||
|
||||
whitelisted_pos_tags = [
|
||||
"NN",
|
||||
"NNS",
|
||||
"NNP",
|
||||
"NNPS", # Noun POS
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ", # Verb POS
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS", # Adjective POS
|
||||
]
|
||||
|
||||
# If at all there is an exception, do not block other reflector
|
||||
# processes. Return the LLM generated title, at the least.
|
||||
try:
|
||||
for word, pos in pos_tags:
|
||||
if pos in whitelisted_pos_tags and word[0].islower():
|
||||
camel_cased.append(word[0].upper() + word[1:])
|
||||
else:
|
||||
camel_cased.append(word)
|
||||
modified_title = " ".join(camel_cased)
|
||||
|
||||
# The result can have words in braces with additional space.
|
||||
# Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc.
|
||||
pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])"
|
||||
title = re.sub(pattern, "", modified_title)
|
||||
except Exception as e:
|
||||
reflector_logger.info(
|
||||
f"Failed to ensure casing on {title=} " f"with exception : {str(e)}"
|
||||
)
|
||||
|
||||
return title
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
@@ -122,3 +245,62 @@ class LLM:
|
||||
result = result[:-3]
|
||||
|
||||
return json.loads(result.strip())
|
||||
|
||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
||||
"""
|
||||
Choose the token size to set as the threshold to pack the LLM calls
|
||||
"""
|
||||
buffer_token_size = 25
|
||||
default_output_tokens = 1000
|
||||
context_window = self.tokenizer.model_max_length
|
||||
tokens = self.tokenizer.tokenize(
|
||||
self.create_prompt(instruct=task_params.instruct, text="")
|
||||
)
|
||||
threshold = context_window - len(tokens) - buffer_token_size
|
||||
if task_params.gen_cfg:
|
||||
threshold -= task_params.gen_cfg.max_new_tokens
|
||||
else:
|
||||
threshold -= default_output_tokens
|
||||
return threshold
|
||||
|
||||
def split_corpus(
|
||||
self,
|
||||
corpus: str,
|
||||
task_params: TaskParams,
|
||||
token_threshold: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
||||
restrictions.
|
||||
|
||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
||||
"""
|
||||
if not token_threshold:
|
||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
||||
|
||||
accumulated_tokens = []
|
||||
accumulated_sentences = []
|
||||
accumulated_token_count = 0
|
||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
||||
|
||||
for sentence in corpus_sentences:
|
||||
tokens = self.tokenizer.tokenize(sentence)
|
||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
||||
accumulated_token_count += len(tokens)
|
||||
accumulated_tokens.extend(tokens)
|
||||
accumulated_sentences.append(sentence)
|
||||
else:
|
||||
yield "".join(accumulated_sentences)
|
||||
accumulated_token_count = len(tokens)
|
||||
accumulated_tokens = tokens
|
||||
accumulated_sentences = [sentence]
|
||||
|
||||
if accumulated_tokens:
|
||||
yield " ".join(accumulated_sentences)
|
||||
|
||||
def create_prompt(self, instruct: str, text: str) -> str:
|
||||
"""
|
||||
Create a consumable prompt based on the prompt template
|
||||
"""
|
||||
return self.template.format(instruct=instruct, text=text)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
@@ -13,10 +14,14 @@ class BananaLLM(LLM):
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
@@ -27,18 +32,21 @@ class BananaLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
result = await llm.generate("Hello, my name is")
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import httpx
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
def __init__(self):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.llm_url = settings.LLM_URL + "/llm"
|
||||
@@ -13,6 +16,16 @@ class ModalLLM(LLM):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
||||
|
||||
@property
|
||||
def supported_models(self):
|
||||
"""
|
||||
List of currently supported models on this GPU platform
|
||||
"""
|
||||
# TODO: Query the specific GPU platform
|
||||
# Replace this with a HTTP call
|
||||
return ["lmsys/vicuna-13b-v1.5"]
|
||||
|
||||
async def _warmup(self, logger):
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -23,10 +36,14 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
@@ -37,10 +54,43 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Set the model name
|
||||
"""
|
||||
# Abort, if the model is not supported
|
||||
if model_name not in self.supported_models:
|
||||
reflector_logger.info(
|
||||
f"Attempted to change {model_name=}, but is not supported."
|
||||
f"Setting model and tokenizer failed !"
|
||||
)
|
||||
return False
|
||||
# Abort, if the model is already set
|
||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
||||
reflector_logger.info("No change in model. Setting model skipped.")
|
||||
return False
|
||||
# Update model name and tokenizer
|
||||
self.model_name = model_name
|
||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, cache_dir=settings.CACHE_DIR
|
||||
)
|
||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
||||
return True
|
||||
|
||||
def _get_tokenizer(self) -> AutoTokenizer:
|
||||
"""
|
||||
Return the currently used LLM tokenizer
|
||||
"""
|
||||
return self.llm_tokenizer
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""
|
||||
Return the current model name from the instance details
|
||||
"""
|
||||
return self.model_name
|
||||
|
||||
|
||||
LLM.register("modal", ModalLLM)
|
||||
|
||||
@@ -49,15 +99,25 @@ if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
llm = ModalLLM()
|
||||
result = await llm.generate("Hello, my name is", logger=logger)
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
schema = {
|
||||
gen_schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"properties": {"response": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
||||
result = await llm.generate(
|
||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
||||
)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OobaboogaLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload.update(gen_cfg)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import httpx
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, model_name: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = settings.LLM_OPENAI_KEY
|
||||
self.openai_url = settings.LLM_URL
|
||||
@@ -15,7 +17,13 @@ class OpenAILLM(LLM):
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
gen_schema: dict | None,
|
||||
gen_cfg: GenerationConfig | None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
|
||||
150
server/reflector/llm/llm_params.py
Normal file
150
server/reflector/llm/llm_params.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
||||
instruct: str
|
||||
gen_cfg: Optional[GenerationConfig] = None
|
||||
gen_schema: Optional[dict] = None
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LLMTaskParams")
|
||||
|
||||
|
||||
class LLMTaskParams:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task, klass) -> None:
|
||||
cls._registry[task] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, task: str) -> T:
|
||||
return cls._registry[task]()
|
||||
|
||||
@property
|
||||
def task_params(self) -> TaskParams | None:
|
||||
"""
|
||||
Fetch the task related parameters
|
||||
"""
|
||||
return self._get_task_params()
|
||||
|
||||
def _get_task_params(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FinalLongSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"long_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalShortSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"short_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalTitleParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
||||
)
|
||||
self._instruct = """
|
||||
Combine the following individual titles into one single short title that
|
||||
condenses the essence of all titles.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class TopicParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9
|
||||
)
|
||||
self._instruct = """
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
For the title field, generate a very detailed and self-explanatory
|
||||
title for the given text. Let the title be as descriptive as possible.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
LLMTaskParams.register("topic", TopicParams)
|
||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
||||
@@ -4,7 +4,20 @@ from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_long_summary import ( # noqa: F401
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
)
|
||||
from .transcript_final_short_summary import ( # noqa: F401
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
)
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||
from .types import AudioFile, FinalSummary, TitleSummary, Transcript, Word # noqa: F401
|
||||
from .types import ( # noqa: F401
|
||||
AudioFile,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
Word,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ class BroadcastProcessor(Processor):
|
||||
types of input.
|
||||
"""
|
||||
|
||||
def __init__(self, processors: Processor):
|
||||
def __init__(self, processors: list[Processor]):
|
||||
super().__init__()
|
||||
self.processors = processors
|
||||
self.INPUT_TYPE = processors[0].INPUT_TYPE
|
||||
|
||||
59
server/reflector/processors/transcript_final_long_summary.py
Normal file
59
server/reflector/processors/transcript_final_long_summary.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalLongSummary, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalLongSummaryProcessor(Processor):
|
||||
"""
|
||||
Get the final long summary
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalLongSummary
|
||||
TASK = "final_long_summary"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_long_summary(self, text: str) -> str:
|
||||
"""
|
||||
Generate a long version of the final summary
|
||||
"""
|
||||
self.logger.info(f"Smoothing out {len(text)} length summary to a long summary")
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
accumulated_summaries = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_summaries += summary_result["long_summary"]
|
||||
|
||||
return accumulated_summaries
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
|
||||
long_summary = await self.get_long_summary(accumulated_summaries)
|
||||
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
final_long_summary = FinalLongSummary(
|
||||
long_summary=long_summary,
|
||||
duration=duration,
|
||||
)
|
||||
await self.emit(final_long_summary)
|
||||
@@ -0,0 +1,72 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalShortSummary, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalShortSummaryProcessor(Processor):
|
||||
"""
|
||||
Get the final summary using a tree summarizer
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalShortSummary
|
||||
TASK = "final_short_summary"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_short_summary(self, text: str) -> dict:
|
||||
"""
|
||||
Generata a short summary using tree summarizer
|
||||
"""
|
||||
self.logger.info(f"Smoothing out {len(text)} length summary to a short summary")
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return summary_result
|
||||
else:
|
||||
accumulated_summaries = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_summaries += summary_result["short_summary"]
|
||||
|
||||
return await self.get_short_summary(accumulated_summaries)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
|
||||
short_summary_result = await self.get_short_summary(accumulated_summaries)
|
||||
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
final_summary = FinalShortSummary(
|
||||
short_summary=short_summary_result["short_summary"],
|
||||
duration=duration,
|
||||
)
|
||||
await self.emit(final_summary)
|
||||
@@ -1,30 +0,0 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, FinalSummary
|
||||
|
||||
|
||||
class TranscriptFinalSummaryProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
# FIXME improve final summary
|
||||
result = "\n".join([chunk.summary for chunk in self.chunks])
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
await self.emit(FinalSummary(summary=result, duration=duration))
|
||||
65
server/reflector/processors/transcript_final_title.py
Normal file
65
server/reflector/processors/transcript_final_title.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalTitle, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalTitleProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalTitle
|
||||
TASK = "final_title"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_title(self, text: str) -> dict:
|
||||
"""
|
||||
Generate a title for the whole recording
|
||||
"""
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return title_result
|
||||
else:
|
||||
accumulated_titles = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_titles += title_result["summary"]
|
||||
|
||||
return await self.get_title(accumulated_titles)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
||||
title_result = await self.get_title(accumulated_titles)
|
||||
|
||||
final_title = FinalTitle(title=title_result["title"])
|
||||
await self.emit(final_title)
|
||||
@@ -1,7 +1,6 @@
|
||||
from reflector.llm import LLM
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -11,34 +10,14 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
TASK = "topic"
|
||||
|
||||
PROMPT = """
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
|
||||
For the title field, generate a short title for the given text.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
|
||||
{input_text}
|
||||
|
||||
### Assistant:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, min_transcript_length=750, **kwargs):
|
||||
def __init__(self, min_transcript_length: int = 750, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.topic_detector_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _warmup(self):
|
||||
await self.llm.warmup(logger=self.logger)
|
||||
@@ -55,18 +34,30 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
return
|
||||
await self.flush()
|
||||
|
||||
async def get_topic(self, text: str) -> dict:
|
||||
"""
|
||||
Generate a topic and description for a transcription excerpt
|
||||
"""
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
||||
topic_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return topic_result
|
||||
|
||||
async def _flush(self):
|
||||
if not self.transcript:
|
||||
return
|
||||
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
prompt = self.PROMPT.format(input_text=text)
|
||||
result = await retry(self.llm.generate)(
|
||||
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
|
||||
)
|
||||
topic_result = await self.get_topic(text=text)
|
||||
|
||||
summary = TitleSummary(
|
||||
title=result["title"],
|
||||
summary=result["summary"],
|
||||
title=self.llm.ensure_casing(topic_result["title"]),
|
||||
summary=topic_result["summary"],
|
||||
timestamp=self.transcript.timestamp,
|
||||
duration=self.transcript.duration,
|
||||
transcript=self.transcript,
|
||||
|
||||
@@ -103,11 +103,20 @@ class TitleSummary(BaseModel):
|
||||
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
class FinalSummary(BaseModel):
|
||||
summary: str
|
||||
class FinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
duration: float
|
||||
|
||||
|
||||
class FinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
duration: float
|
||||
|
||||
|
||||
class FinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranslationLanguages(BaseModel):
|
||||
language_to_id_mapping: dict = {
|
||||
"Afrikaans": "af",
|
||||
|
||||
@@ -91,5 +91,11 @@ class Settings(BaseSettings):
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
|
||||
# Default LLM model name
|
||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
||||
|
||||
# Cache directory for all model storage
|
||||
CACHE_DIR: str = "data"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
@@ -8,10 +9,13 @@ from reflector.processors import (
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
@@ -31,7 +35,13 @@ async def process_audio_file(
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# transcription output
|
||||
|
||||
@@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
from fastapi import APIRouter, Request
|
||||
from prometheus_client import Gauge
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
@@ -15,14 +16,19 @@ from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
FinalSummary,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
Pipeline,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.types import FinalTitle
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
@@ -72,8 +78,10 @@ class StrValue(BaseModel):
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_SUMMARY = "FINAL_SUMMARY"
|
||||
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
|
||||
FINAL_TITLE = "FINAL_TITLE"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
@@ -124,15 +132,15 @@ async def rtc_offer_base(
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(summary: TitleSummary):
|
||||
async def on_topic(topic: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
ctx.logger.info("Topic", topic=topic)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": summary.title,
|
||||
"timestamp": summary.timestamp,
|
||||
"transcript": summary.transcript.text,
|
||||
"desc": summary.summary,
|
||||
"title": topic.title,
|
||||
"timestamp": topic.timestamp,
|
||||
"transcript": topic.transcript.text,
|
||||
"desc": topic.summary,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,17 +152,17 @@ async def rtc_offer_base(
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
||||
)
|
||||
|
||||
async def on_final_summary(summary: FinalSummary):
|
||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||
async def on_final_short_summary(summary: FinalShortSummary):
|
||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"summary": summary.summary,
|
||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
||||
"summary": summary.short_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
@@ -162,11 +170,47 @@ async def rtc_offer_base(
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_SUMMARY,
|
||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_long_summary(summary: FinalLongSummary):
|
||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
||||
"summary": summary.long_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_title(title: FinalTitle):
|
||||
ctx.logger.info("FinalTitle", final_title=title)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_TITLE,
|
||||
args=event_callback_args,
|
||||
data=title,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
processors = []
|
||||
@@ -178,7 +222,17 @@ async def rtc_offer_base(
|
||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -18,11 +17,13 @@ from fastapi import (
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
@@ -60,8 +61,16 @@ class TranscriptTopic(BaseModel):
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalSummary(BaseModel):
|
||||
summary: str
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
@@ -77,7 +86,9 @@ class Transcript(BaseModel):
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
summary: str | None = None
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
@@ -241,7 +252,9 @@ class GetTranscript(BaseModel):
|
||||
status: str
|
||||
locked: bool
|
||||
duration: int
|
||||
summary: str | None
|
||||
title: str | None
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
created_at: datetime
|
||||
source_language: str
|
||||
target_language: str
|
||||
@@ -256,7 +269,9 @@ class CreateTranscript(BaseModel):
|
||||
class UpdateTranscript(BaseModel):
|
||||
name: Optional[str] = Field(None)
|
||||
locked: Optional[bool] = Field(None)
|
||||
summary: Optional[str] = Field(None)
|
||||
title: Optional[str] = Field(None)
|
||||
short_summary: Optional[str] = Field(None)
|
||||
long_summary: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
@@ -315,20 +330,32 @@ async def transcript_update(
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {}
|
||||
values = {"events": []}
|
||||
if info.name is not None:
|
||||
values["name"] = info.name
|
||||
if info.locked is not None:
|
||||
values["locked"] = info.locked
|
||||
if info.summary is not None:
|
||||
values["summary"] = info.summary
|
||||
# also find FINAL_SUMMARY event and patch it
|
||||
for te in transcript.events:
|
||||
if te["event"] == PipelineEvent.FINAL_SUMMARY:
|
||||
te["summary"] = info.summary
|
||||
if info.long_summary is not None:
|
||||
values["long_summary"] = info.long_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
transcript_event["long_summary"] = info.long_summary
|
||||
break
|
||||
values["events"] = transcript.events
|
||||
|
||||
values["events"].extend(transcript.events)
|
||||
if info.short_summary is not None:
|
||||
values["short_summary"] = info.short_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
transcript_event["short_summary"] = info.short_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.title is not None:
|
||||
values["title"] = info.title
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
||||
transcript_event["title"] = info.title
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
@@ -539,14 +566,38 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
elif event == PipelineEvent.FINAL_TITLE:
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
resp = transcript.add_event(event=event, data=final_title)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"summary": final_summary.summary,
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -14,3 +16,50 @@ async def setup_database():
|
||||
metadata.create_all(bind=engine)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_processors():
|
||||
with patch(
|
||||
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
|
||||
) as mock_topic, patch(
|
||||
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
|
||||
) as mock_title, patch(
|
||||
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
|
||||
) as mock_long_summary, patch(
|
||||
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
|
||||
) as mock_short_summary:
|
||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
||||
mock_title.return_value = {"title": "LLM FINAL TITLE"}
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
|
||||
|
||||
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
def __init__(self):
|
||||
self.model_name = "DUMMY MODEL"
|
||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
mock_nltk.return_value = "NLTK PACKAGE"
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_casing():
|
||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
||||
mock_casing.return_value = "LLM TITLE"
|
||||
yield
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_broadcast():
|
||||
async def test_processor_broadcast(nltk):
|
||||
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
|
||||
|
||||
class TestProcessor(Processor):
|
||||
|
||||
@@ -2,27 +2,19 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_process(event_loop):
|
||||
async def test_basic_process(
|
||||
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
from reflector.tools.process import process_audio_file
|
||||
from reflector.settings import settings
|
||||
from reflector.llm.base import LLM
|
||||
from pathlib import Path
|
||||
|
||||
# use an LLM test backend
|
||||
settings.LLM_BACKEND = "test"
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
class LLMTest(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
return {
|
||||
"title": "TITLE",
|
||||
"summary": "SUMMARY",
|
||||
}
|
||||
|
||||
LLM.register("test", LLMTest)
|
||||
|
||||
# event callback
|
||||
marks = {}
|
||||
|
||||
@@ -39,4 +31,6 @@ async def test_basic_process(event_loop):
|
||||
# validate the events
|
||||
assert marks["TranscriptLinerProcessor"] == 5
|
||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalLongSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalShortSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||
|
||||
@@ -75,21 +75,52 @@ async def test_transcript_get_update_summary():
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] is None
|
||||
assert response.json()["long_summary"] is None
|
||||
assert response.json()["short_summary"] is None
|
||||
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] is None
|
||||
assert response.json()["long_summary"] is None
|
||||
assert response.json()["short_summary"] is None
|
||||
|
||||
response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"})
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{tid}",
|
||||
json={"long_summary": "test_long", "short_summary": "test_short"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] == "test"
|
||||
assert response.json()["long_summary"] == "test_long"
|
||||
assert response.json()["short_summary"] == "test_short"
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] == "test"
|
||||
assert response.json()["long_summary"] == "test_long"
|
||||
assert response.json()["short_summary"] == "test_short"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_get_update_title():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] is None
|
||||
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] is None
|
||||
|
||||
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] == "test_title"
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] == "test_title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -67,21 +67,10 @@ async def dummy_transcript():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket(
|
||||
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
@@ -186,9 +175,17 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
||||
|
||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM FINAL TITLE"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
@@ -218,7 +215,9 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket_and_fr(
|
||||
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
@@ -326,9 +325,17 @@ async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dum
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
||||
|
||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM FINAL TITLE"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
|
||||
@@ -94,6 +94,13 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
break;
|
||||
|
||||
case "FINAL_LONG_SUMMARY":
|
||||
if (message.data) {
|
||||
message.data = { summary: message.data.long_summary };
|
||||
setFinalSummary(message.data);
|
||||
console.debug("FINAL_LONG_SUMMARY event:", message.data);
|
||||
}
|
||||
break;
|
||||
|
||||
case "FINAL_SUMMARY":
|
||||
if (message.data) {
|
||||
setFinalSummary(message.data);
|
||||
|
||||
Reference in New Issue
Block a user