diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3dcbe202..19d82472 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,12 +18,6 @@ repos: exclude: ^server/trials - id: detect-private-key - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.282 - hooks: - - id: ruff - files: ^server/(reflector|tests)/ - - repo: https://github.com/psf/black rev: 23.1.0 hooks: @@ -36,4 +30,10 @@ repos: - id: isort name: isort (python) files: ^server/(gpu|evaluate|reflector)/ - args: ["--profile", "black", "--filter-files"] + args: [ "--profile", "black", "--filter-files" ] + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.282 + hooks: + - id: ruff + files: ^server/(reflector|tests)/ diff --git a/server/env.example b/server/env.example index 5c91b9d2..4e9b7311 100644 --- a/server/env.example +++ b/server/env.example @@ -94,6 +94,12 @@ #LLM_URL=http://localhost:4891/v1/completions #LLM_OPENAI_MODEL="GPT4All Falcon" +## Default LLM MODEL NAME +DEFAULT_LLM=lmsys/vicuna-13b-v1.5 + +## Cache directory to store models +CACHE_DIR=data + ## ======================================================= ## Sentry ## ======================================================= diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 9e20ff00..a4e88aae 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -55,7 +55,7 @@ llm_image = ( "accelerate==0.21.0", "einops==0.6.1", "hf-transfer~=0.1", - "huggingface_hub==0.16.4", + "huggingface_hub==0.16.4" ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) .run_function(download_llm) @@ -73,8 +73,7 @@ llm_image = ( class LLM: def __enter__(self): import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - from transformers.generation import GenerationConfig + from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig print("Instance llm model") model = AutoModelForCausalLM.from_pretrained( @@ -84,10 +83,11 @@ class LLM: cache_dir=IMAGE_MODEL_DIR ) - # generation configuration + # JSONFormer doesn't yet support generation configs print("Instance llm generation config") - # JSONFormer doesn't yet support generation configs, but keeping for future usage model.config.max_new_tokens = LLM_MAX_NEW_TOKENS + + # generation configuration gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS @@ -106,6 +106,7 @@ class LLM: self.model = model self.tokenizer = tokenizer self.gen_cfg = gen_cfg + self.GenerationConfig = GenerationConfig def __exit__(self, *args): print("Exit llm") @@ -116,34 +117,44 @@ class LLM: return {"status": "ok"} @method() - def generate(self, prompt: str, schema: str = None): + def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict: + """ + Perform a generation action using the LLM + """ print(f"Generate {prompt=}") - # If a schema is given, conform to schema - if schema: - print(f"Schema {schema=}") + if gen_cfg: + gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg)) + else: + gen_cfg = self.gen_cfg + + # If a gen_schema is given, conform to gen_schema + if gen_schema: import jsonformer - jsonformer_llm = jsonformer.Jsonformer(model=self.model, - tokenizer=self.tokenizer, - json_schema=json.loads(schema), - prompt=prompt, - max_string_token_length=self.gen_cfg.max_new_tokens) + print(f"Schema {gen_schema=}") + jsonformer_llm = jsonformer.Jsonformer( + model=self.model, + tokenizer=self.tokenizer, + json_schema=json.loads(gen_schema), + prompt=prompt, + max_string_token_length=gen_cfg.max_new_tokens + ) response = jsonformer_llm() else: - # If no schema, perform prompt only generation + # If no gen_schema, perform prompt only generation # tokenize prompt input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( self.model.device ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + output = self.model.generate(input_ids, generation_config=gen_cfg) # decode output response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + response = response[len(prompt):] print(f"Generated {response=}") return {"text": response} - # ------------------------------------------------------------------- # Web API # ------------------------------------------------------------------- @@ -160,7 +171,7 @@ class LLM: def web(): from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer - from pydantic import BaseModel, Field + from pydantic import BaseModel llmstub = LLM() @@ -177,16 +188,16 @@ def web(): class LLMRequest(BaseModel): prompt: str - schema_: Optional[dict] = Field(None, alias="schema") + gen_schema: Optional[dict] = None + gen_cfg: Optional[dict] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - if req.schema_: - func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_)) - else: - func = llmstub.generate.spawn(prompt=req.prompt) + gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None + gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None + func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg) result = func.get() return result diff --git a/server/migrations/versions/99365b0cd87b_add_title_short_and_long_summary_and_.py b/server/migrations/versions/99365b0cd87b_add_title_short_and_long_summary_and_.py new file mode 100644 index 00000000..5d7dc857 --- /dev/null +++ b/server/migrations/versions/99365b0cd87b_add_title_short_and_long_summary_and_.py @@ -0,0 +1,37 @@ +"""add_title/short_and_long_summary_and_remove_summary + +Revision ID: 99365b0cd87b +Revises: b3df9681cae9 +Create Date: 2023-09-01 20:19:47.216334 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '99365b0cd87b' +down_revision: Union[str, None] = 'b3df9681cae9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE transcript SET events = " + "REPLACE(events, '\"event\": \"SUMMARY\"', '\"event\": \"LONG_SUMMARY\"');") + op.alter_column('transcript', 'summary', new_column_name='long_summary') + op.add_column('transcript', sa.Column('title', sa.String(), nullable=True)) + op.add_column('transcript', sa.Column('short_summary', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE transcript SET events = " + "REPLACE(events, '\"event\": \"LONG_SUMMARY\"', '\"event\": \"SUMMARY\"');") + op.alter_column('transcript', 'long_summary', nullable=True, new_column_name='summary') + op.drop_column('transcript', 'title') + op.drop_column('transcript', 'short_summary') + # ### end Alembic commands ### diff --git a/server/poetry.lock b/server/poetry.lock index 24823ed7..d58ce0fb 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1588,6 +1588,17 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + [[package]] name = "jwcrypto" version = "1.5.0" @@ -1934,6 +1945,31 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nltk" +version = "3.8.1" +description = "Natural Language Toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, + {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, +] + +[package.dependencies] +click = "*" +joblib = "*" +regex = ">=2021.8.3" +tqdm = "*" + +[package.extras] +all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"] +corenlp = ["requests"] +machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"] +plot = ["matplotlib"] +tgrep = ["pyparsing"] +twitter = ["twython"] + [[package]] name = "numpy" version = "1.25.2" @@ -2672,6 +2708,103 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "regex" +version = "2023.8.8" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.6" +files = [ + {file = "regex-2023.8.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:88900f521c645f784260a8d346e12a1590f79e96403971241e64c3a265c8ecdb"}, + {file = "regex-2023.8.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3611576aff55918af2697410ff0293d6071b7e00f4b09e005d614686ac4cd57c"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8a0ccc8f2698f120e9e5742f4b38dc944c38744d4bdfc427616f3a163dd9de5"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c662a4cbdd6280ee56f841f14620787215a171c4e2d1744c9528bed8f5816c96"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf0633e4a1b667bfe0bb10b5e53fe0d5f34a6243ea2530eb342491f1adf4f739"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551ad543fa19e94943c5b2cebc54c73353ffff08228ee5f3376bd27b3d5b9800"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54de2619f5ea58474f2ac211ceea6b615af2d7e4306220d4f3fe690c91988a61"}, + {file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ec4b3f0aebbbe2fc0134ee30a791af522a92ad9f164858805a77442d7d18570"}, + {file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ae646c35cb9f820491760ac62c25b6d6b496757fda2d51be429e0e7b67ae0ab"}, + {file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca339088839582d01654e6f83a637a4b8194d0960477b9769d2ff2cfa0fa36d2"}, + {file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:d9b6627408021452dcd0d2cdf8da0534e19d93d070bfa8b6b4176f99711e7f90"}, + {file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:bd3366aceedf274f765a3a4bc95d6cd97b130d1dda524d8f25225d14123c01db"}, + {file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7aed90a72fc3654fba9bc4b7f851571dcc368120432ad68b226bd593f3f6c0b7"}, + {file = "regex-2023.8.8-cp310-cp310-win32.whl", hash = "sha256:80b80b889cb767cc47f31d2b2f3dec2db8126fbcd0cff31b3925b4dc6609dcdb"}, + {file = "regex-2023.8.8-cp310-cp310-win_amd64.whl", hash = "sha256:b82edc98d107cbc7357da7a5a695901b47d6eb0420e587256ba3ad24b80b7d0b"}, + {file = "regex-2023.8.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1e7d84d64c84ad97bf06f3c8cb5e48941f135ace28f450d86af6b6512f1c9a71"}, + {file = "regex-2023.8.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce0f9fbe7d295f9922c0424a3637b88c6c472b75eafeaff6f910494a1fa719ef"}, + {file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06c57e14ac723b04458df5956cfb7e2d9caa6e9d353c0b4c7d5d54fcb1325c46"}, + {file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7a9aaa5a1267125eef22cef3b63484c3241aaec6f48949b366d26c7250e0357"}, + {file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b7408511fca48a82a119d78a77c2f5eb1b22fe88b0d2450ed0756d194fe7a9a"}, + {file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14dc6f2d88192a67d708341f3085df6a4f5a0c7b03dec08d763ca2cd86e9f559"}, + {file = "regex-2023.8.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48c640b99213643d141550326f34f0502fedb1798adb3c9eb79650b1ecb2f177"}, + {file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0085da0f6c6393428bf0d9c08d8b1874d805bb55e17cb1dfa5ddb7cfb11140bf"}, + {file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:964b16dcc10c79a4a2be9f1273fcc2684a9eedb3906439720598029a797b46e6"}, + {file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7ce606c14bb195b0e5108544b540e2c5faed6843367e4ab3deb5c6aa5e681208"}, + {file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:40f029d73b10fac448c73d6eb33d57b34607f40116e9f6e9f0d32e9229b147d7"}, + {file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3b8e6ea6be6d64104d8e9afc34c151926f8182f84e7ac290a93925c0db004bfd"}, + {file = "regex-2023.8.8-cp311-cp311-win32.whl", hash = "sha256:942f8b1f3b223638b02df7df79140646c03938d488fbfb771824f3d05fc083a8"}, + {file = "regex-2023.8.8-cp311-cp311-win_amd64.whl", hash = "sha256:51d8ea2a3a1a8fe4f67de21b8b93757005213e8ac3917567872f2865185fa7fb"}, + {file = "regex-2023.8.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e951d1a8e9963ea51efd7f150450803e3b95db5939f994ad3d5edac2b6f6e2b4"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704f63b774218207b8ccc6c47fcef5340741e5d839d11d606f70af93ee78e4d4"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22283c769a7b01c8ac355d5be0715bf6929b6267619505e289f792b01304d898"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91129ff1bb0619bc1f4ad19485718cc623a2dc433dff95baadbf89405c7f6b57"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de35342190deb7b866ad6ba5cbcccb2d22c0487ee0cbb251efef0843d705f0d4"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b993b6f524d1e274a5062488a43e3f9f8764ee9745ccd8e8193df743dbe5ee61"}, + {file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3026cbcf11d79095a32d9a13bbc572a458727bd5b1ca332df4a79faecd45281c"}, + {file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:293352710172239bf579c90a9864d0df57340b6fd21272345222fb6371bf82b3"}, + {file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d909b5a3fff619dc7e48b6b1bedc2f30ec43033ba7af32f936c10839e81b9217"}, + {file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3d370ff652323c5307d9c8e4c62efd1956fb08051b0e9210212bc51168b4ff56"}, + {file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:b076da1ed19dc37788f6a934c60adf97bd02c7eea461b73730513921a85d4235"}, + {file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e9941a4ada58f6218694f382e43fdd256e97615db9da135e77359da257a7168b"}, + {file = "regex-2023.8.8-cp36-cp36m-win32.whl", hash = "sha256:a8c65c17aed7e15a0c824cdc63a6b104dfc530f6fa8cb6ac51c437af52b481c7"}, + {file = "regex-2023.8.8-cp36-cp36m-win_amd64.whl", hash = "sha256:aadf28046e77a72f30dcc1ab185639e8de7f4104b8cb5c6dfa5d8ed860e57236"}, + {file = "regex-2023.8.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:423adfa872b4908843ac3e7a30f957f5d5282944b81ca0a3b8a7ccbbfaa06103"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ae594c66f4a7e1ea67232a0846649a7c94c188d6c071ac0210c3e86a5f92109"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e51c80c168074faa793685656c38eb7a06cbad7774c8cbc3ea05552d615393d8"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09b7f4c66aa9d1522b06e31a54f15581c37286237208df1345108fcf4e050c18"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e73e5243af12d9cd6a9d6a45a43570dbe2e5b1cdfc862f5ae2b031e44dd95a8"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941460db8fe3bd613db52f05259c9336f5a47ccae7d7def44cc277184030a116"}, + {file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f0ccf3e01afeb412a1a9993049cb160d0352dba635bbca7762b2dc722aa5742a"}, + {file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2e9216e0d2cdce7dbc9be48cb3eacb962740a09b011a116fd7af8c832ab116ca"}, + {file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:5cd9cd7170459b9223c5e592ac036e0704bee765706445c353d96f2890e816c8"}, + {file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4873ef92e03a4309b3ccd8281454801b291b689f6ad45ef8c3658b6fa761d7ac"}, + {file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:239c3c2a339d3b3ddd51c2daef10874410917cd2b998f043c13e2084cb191684"}, + {file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1005c60ed7037be0d9dea1f9c53cc42f836188227366370867222bda4c3c6bd7"}, + {file = "regex-2023.8.8-cp37-cp37m-win32.whl", hash = "sha256:e6bd1e9b95bc5614a7a9c9c44fde9539cba1c823b43a9f7bc11266446dd568e3"}, + {file = "regex-2023.8.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9a96edd79661e93327cfeac4edec72a4046e14550a1d22aa0dd2e3ca52aec921"}, + {file = "regex-2023.8.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f2181c20ef18747d5f4a7ea513e09ea03bdd50884a11ce46066bb90fe4213675"}, + {file = "regex-2023.8.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a2ad5add903eb7cdde2b7c64aaca405f3957ab34f16594d2b78d53b8b1a6a7d6"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9233ac249b354c54146e392e8a451e465dd2d967fc773690811d3a8c240ac601"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:920974009fb37b20d32afcdf0227a2e707eb83fe418713f7a8b7de038b870d0b"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2b6c5dfe0929b6c23dde9624483380b170b6e34ed79054ad131b20203a1a63"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96979d753b1dc3b2169003e1854dc67bfc86edf93c01e84757927f810b8c3c93"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ae54a338191e1356253e7883d9d19f8679b6143703086245fb14d1f20196be9"}, + {file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2162ae2eb8b079622176a81b65d486ba50b888271302190870b8cc488587d280"}, + {file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c884d1a59e69e03b93cf0dfee8794c63d7de0ee8f7ffb76e5f75be8131b6400a"}, + {file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf9273e96f3ee2ac89ffcb17627a78f78e7516b08f94dc435844ae72576a276e"}, + {file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:83215147121e15d5f3a45d99abeed9cf1fe16869d5c233b08c56cdf75f43a504"}, + {file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f7454aa427b8ab9101f3787eb178057c5250478e39b99540cfc2b889c7d0586"}, + {file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f0640913d2c1044d97e30d7c41728195fc37e54d190c5385eacb52115127b882"}, + {file = "regex-2023.8.8-cp38-cp38-win32.whl", hash = "sha256:0c59122ceccb905a941fb23b087b8eafc5290bf983ebcb14d2301febcbe199c7"}, + {file = "regex-2023.8.8-cp38-cp38-win_amd64.whl", hash = "sha256:c12f6f67495ea05c3d542d119d270007090bad5b843f642d418eb601ec0fa7be"}, + {file = "regex-2023.8.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:82cd0a69cd28f6cc3789cc6adeb1027f79526b1ab50b1f6062bbc3a0ccb2dbc3"}, + {file = "regex-2023.8.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bb34d1605f96a245fc39790a117ac1bac8de84ab7691637b26ab2c5efb8f228c"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:987b9ac04d0b38ef4f89fbc035e84a7efad9cdd5f1e29024f9289182c8d99e09"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dd6082f4e2aec9b6a0927202c85bc1b09dcab113f97265127c1dc20e2e32495"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7eb95fe8222932c10d4436e7a6f7c99991e3fdd9f36c949eff16a69246dee2dc"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7098c524ba9f20717a56a8d551d2ed491ea89cbf37e540759ed3b776a4f8d6eb"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b694430b3f00eb02c594ff5a16db30e054c1b9589a043fe9174584c6efa8033"}, + {file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2aeab3895d778155054abea5238d0eb9a72e9242bd4b43f42fd911ef9a13470"}, + {file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:988631b9d78b546e284478c2ec15c8a85960e262e247b35ca5eaf7ee22f6050a"}, + {file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:67ecd894e56a0c6108ec5ab1d8fa8418ec0cff45844a855966b875d1039a2e34"}, + {file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:14898830f0a0eb67cae2bbbc787c1a7d6e34ecc06fbd39d3af5fe29a4468e2c9"}, + {file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f2200e00b62568cfd920127782c61bc1c546062a879cdc741cfcc6976668dfcf"}, + {file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9691a549c19c22d26a4f3b948071e93517bdf86e41b81d8c6ac8a964bb71e5a6"}, + {file = "regex-2023.8.8-cp39-cp39-win32.whl", hash = "sha256:6ab2ed84bf0137927846b37e882745a827458689eb969028af8032b1b3dac78e"}, + {file = "regex-2023.8.8-cp39-cp39-win_amd64.whl", hash = "sha256:5543c055d8ec7801901e1193a51570643d6a6ab8751b1f7dd9af71af467538bb"}, + {file = "regex-2023.8.8.tar.gz", hash = "sha256:fcbdc5f2b0f1cd0f6a56cdb46fe41d2cce1e644e3b68832f3eeebc5fb0f7712e"}, +] + [[package]] name = "requests" version = "2.31.0" @@ -2710,6 +2843,70 @@ botocore = ">=1.12.36,<2.0a.0" [package.extras] crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] +[[package]] +name = "safetensors" +version = "0.3.3" +description = "Fast and Safe Tensor serialization" +optional = false +python-versions = "*" +files = [ + {file = "safetensors-0.3.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:92e4d0c8b2836120fddd134474c5bda8963f322333941f8b9f643e5b24f041eb"}, + {file = "safetensors-0.3.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3dcadb6153c42addc9c625a622ebde9293fabe1973f9ef31ba10fb42c16e8536"}, + {file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08f26b61e1b0a14dc959aa9d568776bd038805f611caef1de04a80c468d4a7a4"}, + {file = "safetensors-0.3.3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:17f41344d9a075f2f21b289a49a62e98baff54b5754240ba896063bce31626bf"}, + {file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:f1045f798e1a16a6ced98d6a42ec72936d367a2eec81dc5fade6ed54638cd7d2"}, + {file = "safetensors-0.3.3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:eaf0e4bc91da13f21ac846a39429eb3f3b7ed06295a32321fa3eb1a59b5c70f3"}, + {file = "safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a07121f427e646a50d18c1be0fa1a2cbf6398624c31149cd7e6b35486d72189e"}, + {file = "safetensors-0.3.3-cp310-cp310-win32.whl", hash = "sha256:a85e29cbfddfea86453cc0f4889b4bcc6b9c155be9a60e27be479a34e199e7ef"}, + {file = "safetensors-0.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:e13adad4a3e591378f71068d14e92343e626cf698ff805f61cdb946e684a218e"}, + {file = "safetensors-0.3.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:cbc3312f134baf07334dd517341a4b470b2931f090bd9284888acb7dfaf4606f"}, + {file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d15030af39d5d30c22bcbc6d180c65405b7ea4c05b7bab14a570eac7d7d43722"}, + {file = "safetensors-0.3.3-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:f84a74cbe9859b28e3d6d7715ac1dd3097bebf8d772694098f6d42435245860c"}, + {file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:10d637423d98ab2e6a4ad96abf4534eb26fcaf8ca3115623e64c00759374e90d"}, + {file = "safetensors-0.3.3-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:3b46f5de8b44084aff2e480874c550c399c730c84b2e8ad1bddb062c94aa14e9"}, + {file = "safetensors-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e8fdf7407dba44587ed5e79d5de3533d242648e1f2041760b21474bd5ea5c8c"}, + {file = "safetensors-0.3.3-cp311-cp311-win32.whl", hash = "sha256:7d3b744cee8d7a46ffa68db1a2ff1a1a432488e3f7a5a97856fe69e22139d50c"}, + {file = "safetensors-0.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f579877d30feec9b6ba409d05fa174633a4fc095675a4a82971d831a8bb60b97"}, + {file = "safetensors-0.3.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2fff5b19a1b462c17322998b2f4b8bce43c16fe208968174d2f3a1446284ceed"}, + {file = "safetensors-0.3.3-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:41adb1d39e8aad04b16879e3e0cbcb849315999fad73bc992091a01e379cb058"}, + {file = "safetensors-0.3.3-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:0f2b404250b3b877b11d34afcc30d80e7035714a1116a3df56acaca6b6c00096"}, + {file = "safetensors-0.3.3-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:b43956ef20e9f4f2e648818a9e7b3499edd6b753a0f5526d4f6a6826fbee8446"}, + {file = "safetensors-0.3.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c32ee08f61cea56a5d62bbf94af95df6040c8ab574afffaeb7b44ae5da1e9e3"}, + {file = "safetensors-0.3.3-cp37-cp37m-win32.whl", hash = "sha256:351600f367badd59f7bfe86d317bb768dd8c59c1561c6fac43cafbd9c1af7827"}, + {file = "safetensors-0.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:034717e297849dae1af0a7027a14b8647bd2e272c24106dced64d83e10d468d1"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8530399666748634bc0b301a6a5523756931b0c2680d188e743d16304afe917a"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:9d741c1f1621e489ba10aa3d135b54202684f6e205df52e219d5eecd673a80c9"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:0c345fd85b4d2093a5109596ff4cd9dfc2e84992e881b4857fbc4a93a3b89ddb"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:69ccee8d05f55cdf76f7e6c87d2bdfb648c16778ef8acfd2ecc495e273e9233e"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:c08a9a4b7a4ca389232fa8d097aebc20bbd4f61e477abc7065b5c18b8202dede"}, + {file = "safetensors-0.3.3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:a002868d2e3f49bbe81bee2655a411c24fa1f8e68b703dec6629cb989d6ae42e"}, + {file = "safetensors-0.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab43aeeb9eadbb6b460df3568a662e6f1911ecc39387f8752afcb6a7d96c087"}, + {file = "safetensors-0.3.3-cp38-cp38-win32.whl", hash = "sha256:f2f59fce31dd3429daca7269a6b06f65e6547a0c248f5116976c3f1e9b73f251"}, + {file = "safetensors-0.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:c31ca0d8610f57799925bf08616856b39518ab772c65093ef1516762e796fde4"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:59a596b3225c96d59af412385981f17dd95314e3fffdf359c7e3f5bb97730a19"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:82a16e92210a6221edd75ab17acdd468dd958ef5023d9c6c1289606cc30d1479"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:98a929e763a581f516373ef31983ed1257d2d0da912a8e05d5cd12e9e441c93a"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:12b83f1986cd16ea0454c636c37b11e819d60dd952c26978310a0835133480b7"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:f439175c827c2f1bbd54df42789c5204a10983a30bc4242bc7deaf854a24f3f0"}, + {file = "safetensors-0.3.3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:0085be33b8cbcb13079b3a8e131656e05b0bc5e6970530d4c24150f7afd76d70"}, + {file = "safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad3cc8006e7a86ee7c88bd2813ec59cd7cc75b03e6fa4af89b9c7b235b438d68"}, + {file = "safetensors-0.3.3-cp39-cp39-win32.whl", hash = "sha256:ab29f54c6b8c301ca05fa014728996bd83aac6e21528f893aaf8945c71f42b6d"}, + {file = "safetensors-0.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:0fa82004eae1a71e2aa29843ef99de9350e459a0fc2f65fc6ee0da9690933d2d"}, + {file = "safetensors-0.3.3.tar.gz", hash = "sha256:edb7072d788c4f929d0f5735d3a2fb51e5a27f833587828583b7f5747af1a2b8"}, +] + +[package.extras] +all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"] +dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (==2.11.0)", "torch (>=1.10)"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)"] +pinned-tf = ["tensorflow (==2.11.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["numpy (>=1.21.6)", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] +torch = ["numpy (>=1.21.6)", "torch (>=1.10)"] + [[package]] name = "sentry-sdk" version = "1.29.2" @@ -3013,6 +3210,75 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "transformers" +version = "4.32.1" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "transformers-4.32.1-py3-none-any.whl", hash = "sha256:b930d3dbd907a3f300cf49e54d63a56f8a0ab16b01a2c2a61ecff37c6de1da08"}, + {file = "transformers-4.32.1.tar.gz", hash = "sha256:1edc8ae1de357d97c3d36b04412aa63d55e6fc0c4b39b419a7d380ed947d2252"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.15.1,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.3.1" +tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.20.3)"] +agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] +all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] +docs-specific = ["hf-doc-builder"] +fairscale = ["fairscale (>0.3)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune]", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] +ray = ["ray[tune]"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.14)", "tensorflow-text (<2.14)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] +torch = ["accelerate (>=0.20.3)", "torch (>=1.9,!=1.12.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (<10.0.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.15.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (<10.0.0)"] + [[package]] name = "typing-extensions" version = "4.7.1" diff --git a/server/pyproject.toml b/server/pyproject.toml index edced410..bdae1c4d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -27,6 +27,8 @@ databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"} sqlalchemy = "<1.5" fief-client = {extras = ["fastapi"], version = "^0.17.0"} alembic = "^1.11.3" +nltk = "^3.8.1" +transformers = "^4.32.1" prometheus-fastapi-instrumentator = "^6.1.0" diff --git a/server/reflector/app.py b/server/reflector/app.py index 136199be..b091f579 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,12 +1,13 @@ from contextlib import asynccontextmanager -import reflector.auth # noqa -import reflector.db # noqa from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute from fastapi_pagination import add_pagination from prometheus_fastapi_instrumentator import Instrumentator + +import reflector.auth # noqa +import reflector.db # noqa from reflector.events import subscribers_shutdown, subscribers_startup from reflector.logger import logger from reflector.metrics import metrics_init diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index b445e907..b68dfe20 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,5 +1,6 @@ import databases import sqlalchemy + from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings @@ -16,7 +17,9 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("locked", sqlalchemy.Boolean), sqlalchemy.Column("duration", sqlalchemy.Integer), sqlalchemy.Column("created_at", sqlalchemy.DateTime), - sqlalchemy.Column("summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("title", sqlalchemy.String, nullable=True), + sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), diff --git a/server/reflector/llm/__init__.py b/server/reflector/llm/__init__.py index f0dda3b6..446a20d6 100644 --- a/server/reflector/llm/__init__.py +++ b/server/reflector/llm/__init__.py @@ -1 +1,2 @@ from .base import LLM # noqa: F401 +from .llm_params import LLMTaskParams # noqa: F401 diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index e729148e..950a1a07 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -2,12 +2,19 @@ import importlib import json import re from time import monotonic +from typing import TypeVar +import nltk from prometheus_client import Counter, Histogram +from transformers import GenerationConfig + +from reflector.llm.llm_params import TaskParams from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry +T = TypeVar("T", bound="LLM") + class LLM: _registry = {} @@ -32,12 +39,25 @@ class LLM: ["backend"], ) + def __enter__(self): + self.ensure_nltk() + + @classmethod + def ensure_nltk(cls): + """ + Make sure NLTK package is installed. Searches in the cache and + downloads only if needed. + """ + nltk.download("punkt", download_dir=settings.CACHE_DIR) + # For POS tagging + nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR) + @classmethod def register(cls, name, klass): cls._registry[name] = klass @classmethod - def get_instance(cls, name=None): + def get_instance(cls, model_name: str | None = None, name: str = None) -> T: """ Return an instance depending on the settings. Settings used: @@ -50,7 +70,39 @@ class LLM: if name not in cls._registry: module_name = f"reflector.llm.llm_{name}" importlib.import_module(module_name) - return cls._registry[name]() + return cls._registry[name](model_name) + + def get_model_name(self) -> str: + """ + Get the currently set model name + """ + return self._get_model_name() + + def _get_model_name(self) -> str: + pass + + def set_model_name(self, model_name: str) -> bool: + """ + Update the model name with the provided model name + """ + return self._set_model_name(model_name) + + def _set_model_name(self, model_name: str) -> bool: + raise NotImplementedError + + @property + def template(self) -> str: + """ + Return the LLM Prompt template + """ + return """ + ### Human: + {instruct} + + {text} + + ### Assistant: + """ def __init__(self): name = self.__class__.__name__ @@ -73,21 +125,39 @@ class LLM: async def _warmup(self, logger: reflector_logger): pass + @property + def tokenizer(self): + """ + Return the tokenizer instance used by LLM + """ + return self._get_tokenizer() + + def _get_tokenizer(self): + pass + async def generate( self, prompt: str, logger: reflector_logger, - schema: dict | None = None, + gen_schema: dict | None = None, + gen_cfg: GenerationConfig | None = None, **kwargs, ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) + + if gen_cfg: + gen_cfg = gen_cfg.to_dict() self.m_generate_call.inc() try: with self.m_generate.time(): result = await retry(self._generate)( - prompt=prompt, schema=schema, **kwargs + prompt=prompt, + gen_schema=gen_schema, + gen_cfg=gen_cfg, + **kwargs, ) self.m_generate_success.inc() + except Exception: logger.exception("Failed to call llm after retrying") self.m_generate_failure.inc() @@ -100,7 +170,60 @@ class LLM: return result - async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: + def ensure_casing(self, title: str) -> str: + """ + LLM takes care of word casing, but in rare cases this + can falter. This is a fallback to ensure the casing of + topics is in a proper format. + + We select nouns, verbs and adjectives and check if camel + casing is present and fix it, if not. Will not perform + any other changes. + """ + tokens = nltk.word_tokenize(title) + pos_tags = nltk.pos_tag(tokens) + camel_cased = [] + + whitelisted_pos_tags = [ + "NN", + "NNS", + "NNP", + "NNPS", # Noun POS + "VB", + "VBD", + "VBG", + "VBN", + "VBP", + "VBZ", # Verb POS + "JJ", + "JJR", + "JJS", # Adjective POS + ] + + # If at all there is an exception, do not block other reflector + # processes. Return the LLM generated title, at the least. + try: + for word, pos in pos_tags: + if pos in whitelisted_pos_tags and word[0].islower(): + camel_cased.append(word[0].upper() + word[1:]) + else: + camel_cased.append(word) + modified_title = " ".join(camel_cased) + + # The result can have words in braces with additional space. + # Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc. + pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])" + title = re.sub(pattern, "", modified_title) + except Exception as e: + reflector_logger.info( + f"Failed to ensure casing on {title=} " f"with exception : {str(e)}" + ) + + return title + + async def _generate( + self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs + ) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: @@ -122,3 +245,62 @@ class LLM: result = result[:-3] return json.loads(result.strip()) + + def text_token_threshold(self, task_params: TaskParams | None) -> int: + """ + Choose the token size to set as the threshold to pack the LLM calls + """ + buffer_token_size = 25 + default_output_tokens = 1000 + context_window = self.tokenizer.model_max_length + tokens = self.tokenizer.tokenize( + self.create_prompt(instruct=task_params.instruct, text="") + ) + threshold = context_window - len(tokens) - buffer_token_size + if task_params.gen_cfg: + threshold -= task_params.gen_cfg.max_new_tokens + else: + threshold -= default_output_tokens + return threshold + + def split_corpus( + self, + corpus: str, + task_params: TaskParams, + token_threshold: int | None = None, + ) -> list[str]: + """ + Split the input to the LLM due to CUDA memory limitations and LLM context window + restrictions. + + Accumulate tokens from full sentences till threshold and yield accumulated + tokens. Reset accumulation when threshold is reached and repeat process. + """ + if not token_threshold: + token_threshold = self.text_token_threshold(task_params=task_params) + + accumulated_tokens = [] + accumulated_sentences = [] + accumulated_token_count = 0 + corpus_sentences = nltk.sent_tokenize(corpus) + + for sentence in corpus_sentences: + tokens = self.tokenizer.tokenize(sentence) + if accumulated_token_count + len(tokens) <= token_threshold: + accumulated_token_count += len(tokens) + accumulated_tokens.extend(tokens) + accumulated_sentences.append(sentence) + else: + yield "".join(accumulated_sentences) + accumulated_token_count = len(tokens) + accumulated_tokens = tokens + accumulated_sentences = [sentence] + + if accumulated_tokens: + yield " ".join(accumulated_sentences) + + def create_prompt(self, instruct: str, text: str) -> str: + """ + Create a consumable prompt based on the prompt template + """ + return self.template.format(instruct=instruct, text=text) diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 56fc0e69..e0384770 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,4 +1,5 @@ import httpx + from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry @@ -13,10 +14,14 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, schema: dict | None, **kwargs): + async def _generate( + self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs + ): json_payload = {"prompt": prompt} - if schema: - json_payload["schema"] = schema + if gen_schema: + json_payload["gen_schema"] = gen_schema + if gen_cfg: + json_payload["gen_cfg"] = gen_cfg async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, @@ -27,18 +32,21 @@ class BananaLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if not schema: - text = text[len(prompt) :] return text LLM.register("banana", BananaLLM) if __name__ == "__main__": + from reflector.logger import logger async def main(): llm = BananaLLM() - result = await llm.generate("Hello, my name is") + prompt = llm.create_prompt( + instruct="Complete the following task", + text="Tell me a joke about programming.", + ) + result = await llm.generate(prompt=prompt, logger=logger) print(result) import asyncio diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index ce0de02a..b427833b 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,11 +1,14 @@ import httpx +from transformers import AutoTokenizer, GenerationConfig + from reflector.llm.base import LLM +from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry class ModalLLM(LLM): - def __init__(self): + def __init__(self, model_name: str | None = None): super().__init__() self.timeout = settings.LLM_TIMEOUT self.llm_url = settings.LLM_URL + "/llm" @@ -13,6 +16,16 @@ class ModalLLM(LLM): self.headers = { "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", } + self._set_model_name(model_name if model_name else settings.DEFAULT_LLM) + + @property + def supported_models(self): + """ + List of currently supported models on this GPU platform + """ + # TODO: Query the specific GPU platform + # Replace this with a HTTP call + return ["lmsys/vicuna-13b-v1.5"] async def _warmup(self, logger): async with httpx.AsyncClient() as client: @@ -23,10 +36,14 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, schema: dict | None, **kwargs): + async def _generate( + self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs + ): json_payload = {"prompt": prompt} - if schema: - json_payload["schema"] = schema + if gen_schema: + json_payload["gen_schema"] = gen_schema + if gen_cfg: + json_payload["gen_cfg"] = gen_cfg async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, @@ -37,10 +54,43 @@ class ModalLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if not schema: - text = text[len(prompt) :] return text + def _set_model_name(self, model_name: str) -> bool: + """ + Set the model name + """ + # Abort, if the model is not supported + if model_name not in self.supported_models: + reflector_logger.info( + f"Attempted to change {model_name=}, but is not supported." + f"Setting model and tokenizer failed !" + ) + return False + # Abort, if the model is already set + elif hasattr(self, "model_name") and model_name == self._get_model_name(): + reflector_logger.info("No change in model. Setting model skipped.") + return False + # Update model name and tokenizer + self.model_name = model_name + self.llm_tokenizer = AutoTokenizer.from_pretrained( + self.model_name, cache_dir=settings.CACHE_DIR + ) + reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.") + return True + + def _get_tokenizer(self) -> AutoTokenizer: + """ + Return the currently used LLM tokenizer + """ + return self.llm_tokenizer + + def _get_model_name(self) -> str: + """ + Return the current model name from the instance details + """ + return self.model_name + LLM.register("modal", ModalLLM) @@ -49,15 +99,25 @@ if __name__ == "__main__": async def main(): llm = ModalLLM() - result = await llm.generate("Hello, my name is", logger=logger) + prompt = llm.create_prompt( + instruct="Complete the following task", + text="Tell me a joke about programming.", + ) + result = await llm.generate(prompt=prompt, logger=logger) print(result) - schema = { + gen_schema = { "type": "object", - "properties": {"name": {"type": "string"}}, + "properties": {"response": {"type": "string"}}, } - result = await llm.generate("Hello, my name is", schema=schema, logger=logger) + result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger) + print(result) + + gen_cfg = GenerationConfig(max_new_tokens=150) + result = await llm.generate( + prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger + ) print(result) import asyncio diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 411014c5..36d3480b 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,13 +1,21 @@ import httpx + from reflector.llm.base import LLM from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, schema: dict | None, **kwargs): + def __init__(self, model_name: str | None = None): + super().__init__() + + async def _generate( + self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs + ): json_payload = {"prompt": prompt} - if schema: - json_payload["schema"] = schema + if gen_schema: + json_payload["gen_schema"] = gen_schema + if gen_cfg: + json_payload.update(gen_cfg) async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 7ed532b7..e28211ef 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,11 +1,13 @@ import httpx +from transformers import GenerationConfig + from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings class OpenAILLM(LLM): - def __init__(self, **kwargs): + def __init__(self, model_name: str | None = None, **kwargs): super().__init__(**kwargs) self.openai_key = settings.LLM_OPENAI_KEY self.openai_url = settings.LLM_URL @@ -15,7 +17,13 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: + async def _generate( + self, + prompt: str, + gen_schema: dict | None, + gen_cfg: GenerationConfig | None, + **kwargs, + ) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/reflector/llm/llm_params.py b/server/reflector/llm/llm_params.py new file mode 100644 index 00000000..59eea7c1 --- /dev/null +++ b/server/reflector/llm/llm_params.py @@ -0,0 +1,150 @@ +from typing import Optional, TypeVar + +from pydantic import BaseModel +from transformers import GenerationConfig + + +class TaskParams(BaseModel, arbitrary_types_allowed=True): + instruct: str + gen_cfg: Optional[GenerationConfig] = None + gen_schema: Optional[dict] = None + + +T = TypeVar("T", bound="LLMTaskParams") + + +class LLMTaskParams: + _registry = {} + + @classmethod + def register(cls, task, klass) -> None: + cls._registry[task] = klass + + @classmethod + def get_instance(cls, task: str) -> T: + return cls._registry[task]() + + @property + def task_params(self) -> TaskParams | None: + """ + Fetch the task related parameters + """ + return self._get_task_params() + + def _get_task_params(self) -> None: + pass + + +class FinalLongSummaryParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3 + ) + self._instruct = """ + Take the key ideas and takeaways from the text and create a short + summary. Be sure to keep the length of the response to a minimum. + Do not include trivial information in the summary. + """ + self._schema = { + "type": "object", + "properties": {"long_summary": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """gen_schema + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class FinalShortSummaryParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3 + ) + self._instruct = """ + Take the key ideas and takeaways from the text and create a short + summary. Be sure to keep the length of the response to a minimum. + Do not include trivial information in the summary. + """ + self._schema = { + "type": "object", + "properties": {"short_summary": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class FinalTitleParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5 + ) + self._instruct = """ + Combine the following individual titles into one single short title that + condenses the essence of all titles. + """ + self._schema = { + "type": "object", + "properties": {"title": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class TopicParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9 + ) + self._instruct = """ + Create a JSON object as response.The JSON object must have 2 fields: + i) title and ii) summary. + For the title field, generate a very detailed and self-explanatory + title for the given text. Let the title be as descriptive as possible. + For the summary field, summarize the given text in a maximum of + three sentences. + """ + self._schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +LLMTaskParams.register("topic", TopicParams) +LLMTaskParams.register("final_title", FinalTitleParams) +LLMTaskParams.register("final_short_summary", FinalShortSummaryParams) +LLMTaskParams.register("final_long_summary", FinalLongSummaryParams) diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 52e46c34..349a41a9 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -4,7 +4,20 @@ from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401 -from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401 +from .transcript_final_long_summary import ( # noqa: F401 + TranscriptFinalLongSummaryProcessor, +) +from .transcript_final_short_summary import ( # noqa: F401 + TranscriptFinalShortSummaryProcessor, +) +from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 -from .types import AudioFile, FinalSummary, TitleSummary, Transcript, Word # noqa: F401 +from .types import ( # noqa: F401 + AudioFile, + FinalLongSummary, + FinalShortSummary, + TitleSummary, + Transcript, + Word, +) diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 4c9757a0..646a1846 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -5,6 +5,7 @@ from uuid import uuid4 from prometheus_client import Counter, Gauge, Histogram from pydantic import BaseModel + from reflector.logger import logger @@ -296,7 +297,7 @@ class BroadcastProcessor(Processor): types of input. """ - def __init__(self, processors: Processor): + def __init__(self, processors: list[Processor]): super().__init__() self.processors = processors self.INPUT_TYPE = processors[0].INPUT_TYPE diff --git a/server/reflector/processors/transcript_final_long_summary.py b/server/reflector/processors/transcript_final_long_summary.py new file mode 100644 index 00000000..477e65c9 --- /dev/null +++ b/server/reflector/processors/transcript_final_long_summary.py @@ -0,0 +1,59 @@ +from reflector.llm import LLM, LLMTaskParams +from reflector.processors.base import Processor +from reflector.processors.types import FinalLongSummary, TitleSummary + + +class TranscriptFinalLongSummaryProcessor(Processor): + """ + Get the final long summary + """ + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = FinalLongSummary + TASK = "final_long_summary" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.chunks: list[TitleSummary] = [] + self.llm = LLM.get_instance() + self.params = LLMTaskParams.get_instance(self.TASK).task_params + + async def _push(self, data: TitleSummary): + self.chunks.append(data) + + async def get_long_summary(self, text: str) -> str: + """ + Generate a long version of the final summary + """ + self.logger.info(f"Smoothing out {len(text)} length summary to a long summary") + chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params)) + + accumulated_summaries = "" + for chunk in chunks: + prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk) + summary_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + accumulated_summaries += summary_result["long_summary"] + + return accumulated_summaries + + async def _flush(self): + if not self.chunks: + self.logger.warning("No summary to output") + return + + accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks]) + long_summary = await self.get_long_summary(accumulated_summaries) + + last_chunk = self.chunks[-1] + duration = last_chunk.timestamp + last_chunk.duration + + final_long_summary = FinalLongSummary( + long_summary=long_summary, + duration=duration, + ) + await self.emit(final_long_summary) diff --git a/server/reflector/processors/transcript_final_short_summary.py b/server/reflector/processors/transcript_final_short_summary.py new file mode 100644 index 00000000..fe25ebc0 --- /dev/null +++ b/server/reflector/processors/transcript_final_short_summary.py @@ -0,0 +1,72 @@ +from reflector.llm import LLM, LLMTaskParams +from reflector.processors.base import Processor +from reflector.processors.types import FinalShortSummary, TitleSummary + + +class TranscriptFinalShortSummaryProcessor(Processor): + """ + Get the final summary using a tree summarizer + """ + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = FinalShortSummary + TASK = "final_short_summary" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.chunks: list[TitleSummary] = [] + self.llm = LLM.get_instance() + self.params = LLMTaskParams.get_instance(self.TASK).task_params + + async def _push(self, data: TitleSummary): + self.chunks.append(data) + + async def get_short_summary(self, text: str) -> dict: + """ + Generata a short summary using tree summarizer + """ + self.logger.info(f"Smoothing out {len(text)} length summary to a short summary") + chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params)) + + if len(chunks) == 1: + chunk = chunks[0] + prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk) + summary_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + return summary_result + else: + accumulated_summaries = "" + for chunk in chunks: + prompt = self.llm.create_prompt( + instruct=self.params.instruct, text=chunk + ) + summary_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + accumulated_summaries += summary_result["short_summary"] + + return await self.get_short_summary(accumulated_summaries) + + async def _flush(self): + if not self.chunks: + self.logger.warning("No summary to output") + return + + accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks]) + short_summary_result = await self.get_short_summary(accumulated_summaries) + + last_chunk = self.chunks[-1] + duration = last_chunk.timestamp + last_chunk.duration + + final_summary = FinalShortSummary( + short_summary=short_summary_result["short_summary"], + duration=duration, + ) + await self.emit(final_summary) diff --git a/server/reflector/processors/transcript_final_summary.py b/server/reflector/processors/transcript_final_summary.py deleted file mode 100644 index 208548f5..00000000 --- a/server/reflector/processors/transcript_final_summary.py +++ /dev/null @@ -1,30 +0,0 @@ -from reflector.processors.base import Processor -from reflector.processors.types import TitleSummary, FinalSummary - - -class TranscriptFinalSummaryProcessor(Processor): - """ - Assemble all summary into a line-based json - """ - - INPUT_TYPE = TitleSummary - OUTPUT_TYPE = FinalSummary - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.chunks: list[TitleSummary] = [] - - async def _push(self, data: TitleSummary): - self.chunks.append(data) - - async def _flush(self): - if not self.chunks: - self.logger.warning("No summary to output") - return - - # FIXME improve final summary - result = "\n".join([chunk.summary for chunk in self.chunks]) - last_chunk = self.chunks[-1] - duration = last_chunk.timestamp + last_chunk.duration - - await self.emit(FinalSummary(summary=result, duration=duration)) diff --git a/server/reflector/processors/transcript_final_title.py b/server/reflector/processors/transcript_final_title.py new file mode 100644 index 00000000..a3360d17 --- /dev/null +++ b/server/reflector/processors/transcript_final_title.py @@ -0,0 +1,65 @@ +from reflector.llm import LLM, LLMTaskParams +from reflector.processors.base import Processor +from reflector.processors.types import FinalTitle, TitleSummary + + +class TranscriptFinalTitleProcessor(Processor): + """ + Assemble all summary into a line-based json + """ + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = FinalTitle + TASK = "final_title" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.chunks: list[TitleSummary] = [] + self.llm = LLM.get_instance() + self.params = LLMTaskParams.get_instance(self.TASK).task_params + + async def _push(self, data: TitleSummary): + self.chunks.append(data) + + async def get_title(self, text: str) -> dict: + """ + Generate a title for the whole recording + """ + chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params)) + + if len(chunks) == 1: + chunk = chunks[0] + prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk) + title_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + return title_result + else: + accumulated_titles = "" + for chunk in chunks: + prompt = self.llm.create_prompt( + instruct=self.params.instruct, text=chunk + ) + title_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + accumulated_titles += title_result["summary"] + + return await self.get_title(accumulated_titles) + + async def _flush(self): + if not self.chunks: + self.logger.warning("No summary to output") + return + + accumulated_titles = ".".join([chunk.title for chunk in self.chunks]) + title_result = await self.get_title(accumulated_titles) + + final_title = FinalTitle(title=title_result["title"]) + await self.emit(final_title) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 3d8e3965..dfd2a432 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -1,7 +1,6 @@ -from reflector.llm import LLM +from reflector.llm import LLM, LLMTaskParams from reflector.processors.base import Processor from reflector.processors.types import TitleSummary, Transcript -from reflector.utils.retry import retry class TranscriptTopicDetectorProcessor(Processor): @@ -11,34 +10,14 @@ class TranscriptTopicDetectorProcessor(Processor): INPUT_TYPE = Transcript OUTPUT_TYPE = TitleSummary + TASK = "topic" - PROMPT = """ - ### Human: - Create a JSON object as response.The JSON object must have 2 fields: - i) title and ii) summary. - - For the title field, generate a short title for the given text. - For the summary field, summarize the given text in a maximum of - three sentences. - - {input_text} - - ### Assistant: - - """ - - def __init__(self, min_transcript_length=750, **kwargs): + def __init__(self, min_transcript_length: int = 750, **kwargs): super().__init__(**kwargs) self.transcript = None self.min_transcript_length = min_transcript_length self.llm = LLM.get_instance() - self.topic_detector_schema = { - "type": "object", - "properties": { - "title": {"type": "string"}, - "summary": {"type": "string"}, - }, - } + self.params = LLMTaskParams.get_instance(self.TASK).task_params async def _warmup(self): await self.llm.warmup(logger=self.logger) @@ -55,18 +34,30 @@ class TranscriptTopicDetectorProcessor(Processor): return await self.flush() + async def get_topic(self, text: str) -> dict: + """ + Generate a topic and description for a transcription excerpt + """ + prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text) + topic_result = await self.llm.generate( + prompt=prompt, + gen_schema=self.params.gen_schema, + gen_cfg=self.params.gen_cfg, + logger=self.logger, + ) + return topic_result + async def _flush(self): if not self.transcript: return + text = self.transcript.text self.logger.info(f"Topic detector got {len(text)} length transcript") - prompt = self.PROMPT.format(input_text=text) - result = await retry(self.llm.generate)( - prompt=prompt, schema=self.topic_detector_schema, logger=self.logger - ) + topic_result = await self.get_topic(text=text) + summary = TitleSummary( - title=result["title"], - summary=result["summary"], + title=self.llm.ensure_casing(topic_result["title"]), + summary=topic_result["summary"], timestamp=self.transcript.timestamp, duration=self.transcript.duration, transcript=self.transcript, diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 8aab2a0d..4d0b3504 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -103,11 +103,20 @@ class TitleSummary(BaseModel): return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" -class FinalSummary(BaseModel): - summary: str +class FinalLongSummary(BaseModel): + long_summary: str duration: float +class FinalShortSummary(BaseModel): + short_summary: str + duration: float + + +class FinalTitle(BaseModel): + title: str + + class TranslationLanguages(BaseModel): language_to_id_mapping: dict = { "Afrikaans": "af", diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 396ce7a3..3fc45819 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -91,5 +91,11 @@ class Settings(BaseSettings): # if set, all anonymous record will be public PUBLIC_MODE: bool = False + # Default LLM model name + DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5" + + # Cache directory for all model storage + CACHE_DIR: str = "data" + settings = Settings() diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index ae60d4a1..add1b104 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -1,6 +1,7 @@ import asyncio import av + from reflector.logger import logger from reflector.processors import ( AudioChunkerProcessor, @@ -8,10 +9,13 @@ from reflector.processors import ( AudioTranscriptAutoProcessor, Pipeline, PipelineEvent, - TranscriptFinalSummaryProcessor, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, ) +from reflector.processors.base import BroadcastProcessor async def process_audio_file( @@ -31,7 +35,13 @@ async def process_audio_file( if not only_transcript: processors += [ TranscriptTopicDetectorProcessor.as_threaded(), - TranscriptFinalSummaryProcessor.as_threaded(), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(), + TranscriptFinalLongSummaryProcessor.as_threaded(), + TranscriptFinalShortSummaryProcessor.as_threaded(), + ], + ), ] # transcription output diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 2a5d2143..792ce244 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from fastapi import APIRouter, Request from prometheus_client import Gauge from pydantic import BaseModel + from reflector.events import subscribers_shutdown from reflector.logger import logger from reflector.processors import ( @@ -15,14 +16,19 @@ from reflector.processors import ( AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, - FinalSummary, + FinalLongSummary, + FinalShortSummary, Pipeline, TitleSummary, Transcript, - TranscriptFinalSummaryProcessor, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, ) +from reflector.processors.base import BroadcastProcessor +from reflector.processors.types import FinalTitle sessions = [] router = APIRouter() @@ -72,8 +78,10 @@ class StrValue(BaseModel): class PipelineEvent(StrEnum): TRANSCRIPT = "TRANSCRIPT" TOPIC = "TOPIC" - FINAL_SUMMARY = "FINAL_SUMMARY" + FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY" STATUS = "STATUS" + FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY" + FINAL_TITLE = "FINAL_TITLE" async def rtc_offer_base( @@ -124,15 +132,15 @@ async def rtc_offer_base( data=transcript, ) - async def on_topic(summary: TitleSummary): + async def on_topic(topic: TitleSummary): # FIXME: make it incremental with the frontend, not send everything - ctx.logger.info("Summary", summary=summary) + ctx.logger.info("Topic", topic=topic) ctx.topics.append( { - "title": summary.title, - "timestamp": summary.timestamp, - "transcript": summary.transcript.text, - "desc": summary.summary, + "title": topic.title, + "timestamp": topic.timestamp, + "transcript": topic.transcript.text, + "desc": topic.summary, } ) @@ -144,17 +152,17 @@ async def rtc_offer_base( # send to callback (eg. websocket) if event_callback: await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=summary + event=PipelineEvent.TOPIC, args=event_callback_args, data=topic ) - async def on_final_summary(summary: FinalSummary): - ctx.logger.info("FinalSummary", final_summary=summary) + async def on_final_short_summary(summary: FinalShortSummary): + ctx.logger.info("FinalShortSummary", final_short_summary=summary) # send to RTC if ctx.data_channel.readyState == "open": result = { - "cmd": "DISPLAY_FINAL_SUMMARY", - "summary": summary.summary, + "cmd": "DISPLAY_FINAL_SHORT_SUMMARY", + "summary": summary.short_summary, "duration": summary.duration, } ctx.data_channel.send(dumps(result)) @@ -162,11 +170,47 @@ async def rtc_offer_base( # send to callback (eg. websocket) if event_callback: await event_callback( - event=PipelineEvent.FINAL_SUMMARY, + event=PipelineEvent.FINAL_SHORT_SUMMARY, args=event_callback_args, data=summary, ) + async def on_final_long_summary(summary: FinalLongSummary): + ctx.logger.info("FinalLongSummary", final_summary=summary) + + # send to RTC + if ctx.data_channel.readyState == "open": + result = { + "cmd": "DISPLAY_FINAL_LONG_SUMMARY", + "summary": summary.long_summary, + "duration": summary.duration, + } + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) + if event_callback: + await event_callback( + event=PipelineEvent.FINAL_LONG_SUMMARY, + args=event_callback_args, + data=summary, + ) + + async def on_final_title(title: FinalTitle): + ctx.logger.info("FinalTitle", final_title=title) + + # send to RTC + if ctx.data_channel.readyState == "open": + result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title} + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) + if event_callback: + await event_callback( + event=PipelineEvent.FINAL_TITLE, + args=event_callback_args, + data=title, + ) + # create a context for the whole rtc transaction # add a customised logger to the context processors = [] @@ -178,7 +222,17 @@ async def rtc_offer_base( AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), TranscriptLinerProcessor(), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), - TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=on_final_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=on_final_short_summary + ), + ] + ), ] ctx.pipeline = Pipeline(*processors) ctx.pipeline.set_pref("audio:source_language", source_language) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 5aed7141..f4611817 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -7,7 +7,6 @@ from typing import Annotated, Optional from uuid import uuid4 import av -import reflector.auth as auth from fastapi import ( APIRouter, Depends, @@ -18,11 +17,13 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field +from starlette.concurrency import run_in_threadpool + +import reflector.auth as auth from reflector.db import database, transcripts from reflector.logger import logger from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform -from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base @@ -60,8 +61,16 @@ class TranscriptTopic(BaseModel): timestamp: float -class TranscriptFinalSummary(BaseModel): - summary: str +class TranscriptFinalShortSummary(BaseModel): + short_summary: str + + +class TranscriptFinalLongSummary(BaseModel): + long_summary: str + + +class TranscriptFinalTitle(BaseModel): + title: str class TranscriptEvent(BaseModel): @@ -77,7 +86,9 @@ class Transcript(BaseModel): locked: bool = False duration: float = 0 created_at: datetime = Field(default_factory=datetime.utcnow) - summary: str | None = None + title: str | None = None + short_summary: str | None = None + long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] source_language: str = "en" @@ -241,7 +252,9 @@ class GetTranscript(BaseModel): status: str locked: bool duration: int - summary: str | None + title: str | None + short_summary: str | None + long_summary: str | None created_at: datetime source_language: str target_language: str @@ -256,7 +269,9 @@ class CreateTranscript(BaseModel): class UpdateTranscript(BaseModel): name: Optional[str] = Field(None) locked: Optional[bool] = Field(None) - summary: Optional[str] = Field(None) + title: Optional[str] = Field(None) + short_summary: Optional[str] = Field(None) + long_summary: Optional[str] = Field(None) class DeletionStatus(BaseModel): @@ -315,20 +330,32 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {} + values = {"events": []} if info.name is not None: values["name"] = info.name if info.locked is not None: values["locked"] = info.locked - if info.summary is not None: - values["summary"] = info.summary - # also find FINAL_SUMMARY event and patch it - for te in transcript.events: - if te["event"] == PipelineEvent.FINAL_SUMMARY: - te["summary"] = info.summary + if info.long_summary is not None: + values["long_summary"] = info.long_summary + for transcript_event in transcript.events: + if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY: + transcript_event["long_summary"] = info.long_summary break - values["events"] = transcript.events - + values["events"].extend(transcript.events) + if info.short_summary is not None: + values["short_summary"] = info.short_summary + for transcript_event in transcript.events: + if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY: + transcript_event["short_summary"] = info.short_summary + break + values["events"].extend(transcript.events) + if info.title is not None: + values["title"] = info.title + for transcript_event in transcript.events: + if transcript_event["event"] == PipelineEvent.FINAL_TITLE: + transcript_event["title"] = info.title + break + values["events"].extend(transcript.events) await transcripts_controller.update(transcript, values) return transcript @@ -539,14 +566,38 @@ async def handle_rtc_event(event: PipelineEvent, args, data): }, ) - elif event == PipelineEvent.FINAL_SUMMARY: - final_summary = TranscriptFinalSummary(summary=data.summary) - resp = transcript.add_event(event=event, data=final_summary) + elif event == PipelineEvent.FINAL_TITLE: + final_title = TranscriptFinalTitle(title=data.title) + resp = transcript.add_event(event=event, data=final_title) await transcripts_controller.update( transcript, { "events": transcript.events_dump(), - "summary": final_summary.summary, + "title": final_title.title, + }, + ) + + elif event == PipelineEvent.FINAL_LONG_SUMMARY: + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + resp = transcript.add_event(event=event, data=final_long_summary) + await transcripts_controller.update( + transcript, + { + "events": transcript.events_dump(), + "long_summary": final_long_summary.long_summary, + }, + ) + + elif event == PipelineEvent.FINAL_SHORT_SUMMARY: + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + resp = transcript.add_event(event=event, data=final_short_summary) + await transcripts_controller.update( + transcript, + { + "events": transcript.events_dump(), + "short_summary": final_short_summary.short_summary, }, ) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d219a282..d0b3a26f 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest @@ -14,3 +16,50 @@ async def setup_database(): metadata.create_all(bind=engine) yield + + +@pytest.fixture +def dummy_processors(): + with patch( + "reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic" + ) as mock_topic, patch( + "reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title" + ) as mock_title, patch( + "reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary" + ) as mock_long_summary, patch( + "reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary" + ) as mock_short_summary: + mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"} + mock_title.return_value = {"title": "LLM FINAL TITLE"} + mock_long_summary.return_value = "LLM LONG SUMMARY" + mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} + + yield mock_topic, mock_title, mock_long_summary, mock_short_summary + + +@pytest.fixture +async def dummy_llm(): + from reflector.llm.base import LLM + + class TestLLM(LLM): + def __init__(self): + self.model_name = "DUMMY MODEL" + self.llm_tokenizer = "DUMMY TOKENIZER" + + with patch("reflector.llm.base.LLM.get_instance") as mock_llm: + mock_llm.return_value = TestLLM() + yield + + +@pytest.fixture +def nltk(): + with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: + mock_nltk.return_value = "NLTK PACKAGE" + yield + + +@pytest.fixture +def ensure_casing(): + with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing: + mock_casing.return_value = "LLM TITLE" + yield diff --git a/server/tests/test_processors_broadcast.py b/server/tests/test_processors_broadcast.py index fcddf31c..79aec14a 100644 --- a/server/tests/test_processors_broadcast.py +++ b/server/tests/test_processors_broadcast.py @@ -2,7 +2,7 @@ import pytest @pytest.mark.asyncio -async def test_processor_broadcast(): +async def test_processor_broadcast(nltk): from reflector.processors.base import Processor, BroadcastProcessor, Pipeline class TestProcessor(Processor): diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 1831e5bd..996c0908 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -2,27 +2,19 @@ import pytest @pytest.mark.asyncio -async def test_basic_process(event_loop): +async def test_basic_process( + event_loop, nltk, dummy_llm, dummy_processors, ensure_casing +): # goal is to start the server, and send rtc audio to it # validate the events received from reflector.tools.process import process_audio_file from reflector.settings import settings - from reflector.llm.base import LLM from pathlib import Path # use an LLM test backend settings.LLM_BACKEND = "test" settings.TRANSCRIPT_BACKEND = "whisper" - class LLMTest(LLM): - async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: - return { - "title": "TITLE", - "summary": "SUMMARY", - } - - LLM.register("test", LLMTest) - # event callback marks = {} @@ -39,4 +31,6 @@ async def test_basic_process(event_loop): # validate the events assert marks["TranscriptLinerProcessor"] == 5 assert marks["TranscriptTopicDetectorProcessor"] == 1 - assert marks["TranscriptFinalSummaryProcessor"] == 1 + assert marks["TranscriptFinalLongSummaryProcessor"] == 1 + assert marks["TranscriptFinalShortSummaryProcessor"] == 1 + assert marks["TranscriptFinalTitleProcessor"] == 1 diff --git a/server/tests/test_transcripts.py b/server/tests/test_transcripts.py index 79c7a802..800d7a5c 100644 --- a/server/tests/test_transcripts.py +++ b/server/tests/test_transcripts.py @@ -75,21 +75,52 @@ async def test_transcript_get_update_summary(): async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "test"}) assert response.status_code == 200 - assert response.json()["summary"] is None + assert response.json()["long_summary"] is None + assert response.json()["short_summary"] is None tid = response.json()["id"] response = await ac.get(f"/transcripts/{tid}") assert response.status_code == 200 - assert response.json()["summary"] is None + assert response.json()["long_summary"] is None + assert response.json()["short_summary"] is None - response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"}) + response = await ac.patch( + f"/transcripts/{tid}", + json={"long_summary": "test_long", "short_summary": "test_short"}, + ) assert response.status_code == 200 - assert response.json()["summary"] == "test" + assert response.json()["long_summary"] == "test_long" + assert response.json()["short_summary"] == "test_short" response = await ac.get(f"/transcripts/{tid}") assert response.status_code == 200 - assert response.json()["summary"] == "test" + assert response.json()["long_summary"] == "test_long" + assert response.json()["short_summary"] == "test_short" + + +@pytest.mark.asyncio +async def test_transcript_get_update_title(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["title"] is None + + tid = response.json()["id"] + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["title"] is None + + response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"}) + assert response.status_code == 200 + assert response.json()["title"] == "test_title" + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["title"] == "test_title" @pytest.mark.asyncio diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index c6adf320..f298e596 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -67,21 +67,10 @@ async def dummy_transcript(): yield -@pytest.fixture -async def dummy_llm(): - from reflector.llm.base import LLM - - class TestLLM(LLM): - async def _generate(self, prompt: str, schema: dict | None, **kwargs): - return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) - - with patch("reflector.llm.base.LLM.get_instance") as mock_llm: - mock_llm.return_value = TestLLM() - yield - - @pytest.mark.asyncio -async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm): +async def test_transcript_rtc_and_websocket( + tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing +): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread # to be able to connect with aiortc @@ -186,9 +175,17 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm) assert ev["data"]["transcript"].startswith("Hello world") assert ev["data"]["timestamp"] == 0.0 - assert "FINAL_SUMMARY" in eventnames - ev = events[eventnames.index("FINAL_SUMMARY")] - assert ev["data"]["summary"] == "LLM SUMMARY" + assert "FINAL_LONG_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_LONG_SUMMARY")] + assert ev["data"]["long_summary"] == "LLM LONG SUMMARY" + + assert "FINAL_SHORT_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_SHORT_SUMMARY")] + assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY" + + assert "FINAL_TITLE" in eventnames + ev = events[eventnames.index("FINAL_TITLE")] + assert ev["data"]["title"] == "LLM FINAL TITLE" # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] @@ -218,7 +215,9 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm) @pytest.mark.asyncio -async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm): +async def test_transcript_rtc_and_websocket_and_fr( + tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing +): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread # to be able to connect with aiortc @@ -326,9 +325,17 @@ async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dum assert ev["data"]["transcript"].startswith("Hello world") assert ev["data"]["timestamp"] == 0.0 - assert "FINAL_SUMMARY" in eventnames - ev = events[eventnames.index("FINAL_SUMMARY")] - assert ev["data"]["summary"] == "LLM SUMMARY" + assert "FINAL_LONG_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_LONG_SUMMARY")] + assert ev["data"]["long_summary"] == "LLM LONG SUMMARY" + + assert "FINAL_SHORT_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_SHORT_SUMMARY")] + assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY" + + assert "FINAL_TITLE" in eventnames + ev = events[eventnames.index("FINAL_TITLE")] + assert ev["data"]["title"] == "LLM FINAL TITLE" # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] diff --git a/www/app/transcripts/useWebSockets.ts b/www/app/transcripts/useWebSockets.ts index cbe257a7..ee31cc5a 100644 --- a/www/app/transcripts/useWebSockets.ts +++ b/www/app/transcripts/useWebSockets.ts @@ -94,6 +94,13 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { break; case "FINAL_LONG_SUMMARY": + if (message.data) { + message.data = { summary: message.data.long_summary }; + setFinalSummary(message.data); + console.debug("FINAL_LONG_SUMMARY event:", message.data); + } + break; + case "FINAL_SUMMARY": if (message.data) { setFinalSummary(message.data);