diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77be7317..14e0554b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,8 +32,8 @@ repos: files: ^server/(gpu|evaluate|reflector)/ args: [ "--profile", "black", "--filter-files" ] - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.282 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.5 hooks: - id: ruff files: ^server/(reflector|tests)/ diff --git a/server/gpu/modal_deployments/reflector_vllm_hermes3.py b/server/gpu/modal_deployments/reflector_vllm_hermes3.py new file mode 100644 index 00000000..d1c86be7 --- /dev/null +++ b/server/gpu/modal_deployments/reflector_vllm_hermes3.py @@ -0,0 +1,171 @@ +# # Run an OpenAI-Compatible vLLM Server + +import modal + +MODELS_DIR = "/llamas" +MODEL_NAME = "NousResearch/Hermes-3-Llama-3.1-8B" +N_GPU = 1 + + +def download_llm(): + from huggingface_hub import snapshot_download + + print("Downloading LLM model") + snapshot_download( + MODEL_NAME, + local_dir=f"{MODELS_DIR}/{MODEL_NAME}", + ignore_patterns=[ + "*.pt", + "*.bin", + "*.pth", + "original/*", + ], # Ensure safetensors + ) + print("LLM model downloaded") + + +def move_cache(): + from transformers.utils import move_cache as transformers_move_cache + + transformers_move_cache() + + +vllm_image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install("vllm==0.5.3post1") + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + .pip_install( + # "accelerate==0.34.2", + "einops==0.8.0", + "hf-transfer~=0.1", + ) + .run_function(download_llm) + .run_function(move_cache) + .pip_install( + "bitsandbytes>=0.42.9", + ) +) + +app = modal.App("reflector-vllm-hermes3") + + +@app.function( + image=vllm_image, + gpu=modal.gpu.A100(count=N_GPU, size="40GB"), + timeout=60 * 5, + container_idle_timeout=60 * 5, + allow_concurrent_inputs=100, + secrets=[ + modal.Secret.from_name("reflector-gpu"), + ], +) +@modal.asgi_app() +def serve(): + import os + + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.usage.usage_lib import UsageContext + + TOKEN = os.environ["REFLECTOR_GPU_APIKEY"] + + # create a fastAPI app that uses vLLM's OpenAI-compatible router + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + + # wrap vllm's router in auth router + router.include_router(api_server.router) + # add authed vllm to our fastAPI app + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODELS_DIR + "/" + MODEL_NAME, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + # max_model_len=8096, + enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) + # --- 4 bits load + # quantization="bitsandbytes", + # load_format="bitsandbytes", + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + model_config = get_model_config(engine) + + request_logger = RequestLogger(max_log_len=2048) + + api_server.openai_serving_chat = OpenAIServingChat( + engine, + model_config=model_config, + served_model_names=[MODEL_NAME], + chat_template=None, + response_role="assistant", + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + api_server.openai_serving_completion = OpenAIServingCompletion( + engine, + model_config=model_config, + served_model_names=[MODEL_NAME], + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + return web_app + + +def get_model_config(engine): + import asyncio + + try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + return model_config diff --git a/server/poetry.lock b/server/poetry.lock index 1d9882ea..1d6cabeb 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1733,6 +1733,41 @@ files = [ {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, ] +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + [[package]] name = "jwcrypto" version = "1.5.0" @@ -3094,6 +3129,21 @@ async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2 hiredis = ["hiredis (>=1.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" version = "2023.10.3" @@ -3212,6 +3262,118 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + [[package]] name = "rsa" version = "4.9" @@ -4321,4 +4483,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "8b64f7a5e8282cedf8f508c9f85ed233222045bdccc49d7c4ea96cf4bf8f902b" +content-hash = "3ba5402ab9fbec271f255c345c89c67385c722735b1d79a959e887c7c34a4047" diff --git a/server/pyproject.toml b/server/pyproject.toml index 6cdda4cb..caa72483 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -38,6 +38,7 @@ python-multipart = "^0.0.6" faster-whisper = "^0.10.0" transformers = "^4.36.2" black = "24.1.1" +jsonschema = "^4.23.0" [tool.poetry.group.dev.dependencies] @@ -78,4 +79,5 @@ addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" testpaths = ["tests"] asyncio_mode = "auto" - +[tool.ruff.lint.per-file-ignores] +"reflector/processors/summary/summary_builder.py" = ["E501"] diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index a1bb3ba8..a934b6f0 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -156,6 +156,27 @@ class LLM: return result + async def completion( + self, messages: list, logger: reflector_logger, **kwargs + ) -> dict: + """ + Use /v1/chat/completion Open-AI compatible endpoint from the URL + It's up to the user to validate anything or transform the result + """ + logger.info("LLM completions", messages=messages) + + try: + with self.m_generate.time(): + result = await retry(self._completion)(messages=messages, **kwargs) + self.m_generate_success.inc() + except Exception: + logger.exception("Failed to call llm after retrying") + self.m_generate_failure.inc() + raise + + logger.debug("LLM completion result", result=repr(result)) + return result + def ensure_casing(self, title: str) -> str: """ LLM takes care of word casing, but in rare cases this @@ -234,6 +255,11 @@ class LLM: ) -> str: raise NotImplementedError + async def _completion( + self, messages: list, logger: reflector_logger, **kwargs + ) -> dict: + raise NotImplementedError + def _parse_json(self, result: str) -> dict: result = result.strip() # try detecting code block if exist diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 4b81c5a0..63eb4db4 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -23,7 +23,11 @@ class ModalLLM(LLM): """ # TODO: Query the specific GPU platform # Replace this with a HTTP call - return ["lmsys/vicuna-13b-v1.5", "HuggingFaceH4/zephyr-7b-alpha"] + return [ + "lmsys/vicuna-13b-v1.5", + "HuggingFaceH4/zephyr-7b-alpha", + "NousResearch/Hermes-3-Llama-3.1-8B", + ] async def _generate( self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs @@ -53,6 +57,31 @@ class ModalLLM(LLM): text = response.json()["text"] return text + async def _completion(self, messages: list, **kwargs) -> dict: + kwargs.setdefault("temperature", 0.3) + kwargs.setdefault("max_tokens", 2048) + kwargs.setdefault("stream", False) + kwargs.setdefault("repetition_penalty", 1) + kwargs.setdefault("top_p", 1) + kwargs.setdefault("top_k", -1) + kwargs.setdefault("min_p", 0.05) + data = {"messages": messages, "model": self.model_name, **kwargs} + + if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B": + self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions" + + async with httpx.AsyncClient() as client: + response = await retry(client.post)( + self.llm_url, + headers=self.headers, + json=data, + timeout=self.timeout, + retry_timeout=60 * 5, + follow_redirects=True, + ) + response.raise_for_status() + return response.json() + def _set_model_name(self, model_name: str) -> bool: """ Set the model name diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 2ae79c2f..7377e7a4 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -38,10 +38,8 @@ from reflector.processors import ( AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, - BroadcastProcessor, Pipeline, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, + TranscriptFinalSummaryProcessor, TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, @@ -424,21 +422,14 @@ class PipelineMainFromTopics(PipelineMainBase): return pipeline -class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): +class PipelineMainTitle(PipelineMainFromTopics): """ Generate title from the topics """ def get_processors(self) -> list: return [ - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), - ] - ) + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), ] @@ -449,15 +440,10 @@ class PipelineMainFinalSummaries(PipelineMainFromTopics): def get_processors(self) -> list: return [ - BroadcastProcessor( - processors=[ - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), - ] + TranscriptFinalSummaryProcessor.as_threaded( + transcript=self._transcript, + callback=self.on_long_summary, + on_short_summary=self.on_short_summary, ), ] @@ -552,11 +538,11 @@ async def pipeline_diarization(transcript: Transcript, logger: Logger): @get_transcript -async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): - logger.info("Starting title and short summary") - runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) +async def pipeline_title(transcript: Transcript, logger: Logger): + logger.info("Starting title") + runner = PipelineMainTitle(transcript_id=transcript.id) await runner.run() - logger.info("Title and short summary done") + logger.info("Title done") @get_transcript @@ -632,8 +618,8 @@ async def task_pipeline_diarization(*, transcript_id: str): @shared_task @asynctask -async def task_pipeline_title_and_short_summary(*, transcript_id: str): - await pipeline_title_and_short_summary(transcript_id=transcript_id) +async def task_pipeline_title(*, transcript_id: str): + await pipeline_title(transcript_id=transcript_id) @shared_task @@ -659,9 +645,7 @@ def pipeline_post(*, transcript_id: str): | task_pipeline_remove_upload.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) - chain_title_preview = task_pipeline_title_and_short_summary.si( - transcript_id=transcript_id - ) + chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id) chain_final_summaries = task_pipeline_final_summaries.si( transcript_id=transcript_id ) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index ed3118ae..5f0e6864 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -39,7 +39,7 @@ class PipelineRunner(BaseModel): runner_cls=self.__class__.__name__, ) - def create(self) -> Pipeline: + async def create(self) -> Pipeline: """ Create the pipeline if not specified earlier. Should be implemented in a subclass diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 1c88d6c5..0de73350 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -11,12 +11,7 @@ from .base import ( # noqa: F401 Processor, ThreadedProcessor, ) -from .transcript_final_long_summary import ( # noqa: F401 - TranscriptFinalLongSummaryProcessor, -) -from .transcript_final_short_summary import ( # noqa: F401 - TranscriptFinalShortSummaryProcessor, -) +from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401 from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 diff --git a/server/reflector/processors/summary/summary_builder.py b/server/reflector/processors/summary/summary_builder.py new file mode 100644 index 00000000..987cf53b --- /dev/null +++ b/server/reflector/processors/summary/summary_builder.py @@ -0,0 +1,851 @@ +""" +# Summary meeting notes + +This script is used to generate a summary of a meeting notes transcript. +""" + +import asyncio +import json +import re +import sys +from datetime import datetime +from enum import Enum +from functools import partial + +import jsonschema +import structlog +from reflector.llm.base import LLM +from transformers import AutoTokenizer + +JSON_SCHEMA_LIST_STRING = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "items": {"type": "string"}, +} + +JSON_SCHEMA_ACTION_ITEMS = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "items": { + "type": "object", + "properties": { + "content": {"type": "string"}, + "assigned_to": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + }, + }, + "required": ["content"], + }, +} + +JSON_SCHEMA_DECISIONS_OR_OPEN_QUESTIONS = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "items": { + "type": "object", + "properties": {"content": {"type": "string"}}, + "required": ["content"], + }, +} + +JSON_SCHEMA_TRANSCRIPTION_TYPE = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "transcription_type": {"type": "string", "enum": ["meeting", "podcast"]}, + }, + "required": ["transcription_type"], +} + + +class ItemType(Enum): + ACTION_ITEM = "action-item" + DECISION = "decision" + OPEN_QUESTION = "open-question" + + +class TranscriptionType(Enum): + MEETING = "meeting" + PODCAST = "podcast" + + +class Messages: + """ + Manage the LLM context for conversational messages, with roles (system, user, assistant). + """ + + def __init__(self, messages=None, model_name=None, tokenizer=None, logger=None): + self.messages = messages or [] + self.model_name = model_name + self.tokenizer = tokenizer + self.logger = logger + + def set_model(self, model): + self.model_name = model + + def set_logger(self, logger): + self.logger = logger + + def copy(self): + m = Messages( + self.messages[:], + model_name=self.model_name, + tokenizer=self.tokenizer, + logger=self.logger, + ) + return m + + def add_system(self, content: str): + self.add("system", content) + self.print_content("SYSTEM", content) + + def add_user(self, content: str): + self.add("user", content) + self.print_content("USER", content) + + def add_assistant(self, content: str): + self.add("assistant", content) + self.print_content("ASSISTANT", content) + + def add(self, role: str, content: str): + self.messages.append({"role": role, "content": content}) + + def get_tokenizer(self): + if not self.tokenizer: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + return self.tokenizer + + def count_tokens(self): + tokenizer = self.get_tokenizer() + total_tokens = 0 + for message in self.messages: + total_tokens += len(tokenizer.tokenize(message["content"])) + return total_tokens + + def get_tokens_count(self, message): + tokenizer = self.get_tokenizer() + return len(tokenizer.tokenize(message)) + + def print_content(self, role, content): + if not self.logger: + return + for line in content.split("\n"): + self.logger.info(f">> {role}: {line}") + + +class SummaryBuilder: + def __init__(self, llm, filename: str | None = None, logger=None): + self.transcript: str | None = None + self.recap: str | None = None + self.summaries: list[dict] = [] + self.subjects: list[str] = [] + self.items_action: list = [] + self.items_decision: list = [] + self.items_question: list = [] + self.transcription_type: TranscriptionType | None = None + self.llm_instance: LLM = llm + self.model_name: str = llm.model_name + self.logger = logger or structlog.get_logger() + self.m = Messages(model_name=self.model_name, logger=self.logger) + if filename: + self.read_transcript_from_file(filename) + + def read_transcript_from_file(self, filename): + """ + Load a transcript from a text file. + Must be formatted as: + + speaker: message + speaker2: message2 + + """ + with open(filename, "r", encoding="utf-8") as f: + self.transcript = f.read().strip() + + def set_transcript(self, transcript: str): + assert isinstance(transcript, str) + self.transcript = transcript + + def set_llm_instance(self, llm): + self.llm_instance = llm + + # ---------------------------------------------------------------------------- + # Participants + # ---------------------------------------------------------------------------- + + async def identify_participants(self): + """ + From a transcript, try identify the participants. + This might not give the best result without good diarization, but it's a start. + They are appended at the end of the transcript, providing more context for the assistant. + """ + + self.logger.debug("--- identify_participants") + + m = Messages(model_name=self.model_name) + m.add_system( + "You are an advanced note-taking assistant." + "You'll be given a transcript, and identify the participants." + ) + m.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + "---\n\n" + "Please identify the participants in the conversation." + "Each participant should only be listed once, even if they are mentionned multiple times in the conversation." + "Participants are real people who are part of the conversation and the speakers." + "You can put participants that are mentioned by name." + "Do not put company name." + "Ensure that no duplicate names are included." + "Output the result in JSON format following the schema: " + f"\n```json-schema\n{JSON_SCHEMA_LIST_STRING}\n```" + ) + result = await self.llm( + m, + [ + self.validate_json, + partial(self.validate_json_schema, JSON_SCHEMA_LIST_STRING), + ], + ) + + # augment the transcript with the participants. + participants = self.format_list_md(result) + self.transcript += f"\n\n# Participants\n\n{participants}" + + # ---------------------------------------------------------------------------- + # Transcription identification + # ---------------------------------------------------------------------------- + + async def identify_transcription_type(self): + """ + Identify the type of transcription: meeting or podcast. + """ + + self.logger.debug("--- identify transcription type") + + m = Messages(model_name=self.model_name, logger=self.logger) + m.add_system( + "You are an advanced assistant specialize to recognize the type of an audio transcription." + "It could be a meeting or a podcast." + ) + m.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + "---\n\n" + "Please identify the type of transcription (meeting or podcast). " + "Output the result in JSON format following the schema:" + f"\n```json-schema\n{JSON_SCHEMA_TRANSCRIPTION_TYPE}\n```" + ) + result = await self.llm( + m, + [ + self.validate_json, + partial(self.validate_json_schema, JSON_SCHEMA_TRANSCRIPTION_TYPE), + ], + ) + + transcription_type = result["transcription_type"] + self.transcription_type = TranscriptionType(transcription_type) + + # ---------------------------------------------------------------------------- + # Items + # ---------------------------------------------------------------------------- + + async def generate_items( + self, + search_action=False, + search_decision=False, + search_open_question=False, + ): + """ + Build a list of item about action, decision or question + """ + # require key subjects + if not self.subjects or not self.summaries: + await self.generate_summary() + + self.logger.debug("--- items") + + self.items_action = [] + self.items_decision = [] + self.items_question = [] + + item_types = [] + if search_action: + item_types.append(ItemType.ACTION_ITEM) + if search_decision: + item_types.append(ItemType.DECISION) + if search_open_question: + item_types.append(ItemType.OPEN_QUESTION) + + ## Version asking everything in one go + for item_type in item_types: + if item_type == ItemType.ACTION_ITEM: + json_schema = JSON_SCHEMA_ACTION_ITEMS + items = self.items_action + prompt_definition = ( + "An action item is a specific, actionable task designed to achieve a concrete outcome;" + "An action item scope is narrow, focused on short-term execution; " + "An action item is generally assigned to a specific person or team. " + "An action item is NOT a decision, a question, or a general topic. " + "For example: 'Gary, please send the report by Friday.' is an action item." + "But: 'I though Gary was here today. Anyway, somebody need to do an analysis.' is not an action item." + "The field assigned_to must contain a valid participant or person mentionned in the transcript." + ) + + elif item_type == ItemType.DECISION: + json_schema = JSON_SCHEMA_DECISIONS_OR_OPEN_QUESTIONS + items = self.items_decision + prompt_definition = ( + "A decision defines a broad or strategic direction or course of action;" + "It's more about setting the framework, high-level goals, or vision for what needs to happen;" + "A decision scope often affect multiple areas of the organization, and it's more about long-term impact." + ) + + elif item_type == ItemType.OPEN_QUESTION: + json_schema = JSON_SCHEMA_DECISIONS_OR_OPEN_QUESTIONS + items = self.items_question + prompt_definition = "" + + await self.build_items_type( + items, item_type, json_schema, prompt_definition + ) + + async def build_items_type( + self, + items: list, + item_type: ItemType, + json_schema: dict, + prompt_definition: str, + ): + m = Messages(model_name=self.model_name, logger=self.logger) + m.add_system( + "You are an advanced note-taking assistant." + f"You'll be given a transcript, and identify {item_type}." + + prompt_definition + ) + + if item_type in (ItemType.ACTION_ITEM, ItemType.DECISION): + # for both action_items and decision, break down per subject + for subject in self.subjects: + # find the summary of the subject + summary = "" + for entry in self.summaries: + if entry["subject"] == subject: + summary = entry["summary"] + break + + m2 = m.copy() + m2.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + f"# Main subjects\n\n{self.format_list_md(self.subjects)}\n\n" + f"# Summary of {subject}\n\n{summary}\n\n" + "---\n\n" + f'What are the {item_type.value} only related to the main subject "{subject}" ? ' + f"Make sure the {item_type.value} do not overlap with other subjects. " + "To recall: " + + prompt_definition + + "If there are none, just return an empty list. " + "The result must be a list following this format: " + f"\n```json-schema\n{json_schema}\n```" + ) + result = await self.llm( + m2, + [ + self.validate_json, + partial(self.validate_json_schema, json_schema), + ], + ) + if not result: + self.logger.error( + f"Error: unable to identify {item_type.value} for {subject}" + ) + continue + else: + items.extend(result) + + # and for action-items and decision, we try do deduplicate + items = await self.deduplicate_items(item_type, items) + + elif item_type == ItemType.OPEN_QUESTION: + m2 = m.copy() + m2.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + "---\n\n" + f"Identify the open questions unanswered during the meeting." + "If there are none, just return an empty list. " + "The result must be a list following this format:" + f"\n```json-schema\n{json_schema}\n```" + ) + result = await self.llm( + m2, + [ + self.validate_json, + partial(self.validate_json_schema, json_schema), + ], + ) + if not result: + self.logger.error("Error: unable to identify open questions") + else: + items.extend(result) + + async def deduplicate_items(self, item_type: ItemType, items: list): + """ + Deduplicate items based on the transcript and the list of items gathered for all subjects + """ + m = Messages(model_name=self.model_name, logger=self.logger) + if item_type == ItemType.ACTION_ITEM: + json_schema = JSON_SCHEMA_ACTION_ITEMS + else: + json_schema = JSON_SCHEMA_DECISIONS_OR_OPEN_QUESTIONS + + title = item_type.value.replace("_", " ") + + m.add_system( + "You are an advanced assistant that correlate and consolidate information. " + f"Another agent found a list of {title}. However the list may be redundant. " + f"Your first step will be to give information about how theses {title} overlap. " + "In a second time, the user will ask you to consolidate according to your finding. " + f"Keep in mind that the same {title} can be written in different ways. " + ) + + md_items = [] + for item in items: + assigned_to = ", ".join(item.get("assigned_to", [])) + content = item["content"] + if assigned_to: + text = f"- **{assigned_to}**: {content}" + else: + text = f"- {content}" + md_items.append(text) + + md_text = "\n".join(md_items) + + m.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + f"# {title}\n\n{md_text}\n\n--\n\n" + f"Here is a list of {title} identified by another agent. " + f"Some of the {title} seem to overlap or be redundant. " + "How can you effectively group or merge them into more consise list?" + ) + + await self.llm(m) + + m.add_user( + f"Consolidate the list of {title} according to your finding. " + f"The list must be shorter or equal than the original list. " + "Give the result using the following JSON schema:" + f"\n```json-schema\n{json_schema}\n```" + ) + + result = await self.llm( + m, + [ + self.validate_json, + partial(self.validate_json_schema, json_schema), + ], + ) + return result + + # ---------------------------------------------------------------------------- + # Summary + # ---------------------------------------------------------------------------- + + async def generate_summary(self, only_subjects=False): + """ + This is the main function to build the summary. + + It actually share the context with the different steps (subjects, quick recap) + which make it more sense to keep it in one function. + + The process is: + - Extract the main subjects + - Generate a summary for all the main subjects + - Generate a quick recap + """ + self.logger.debug("--- extract main subjects") + + m = Messages(model_name=self.model_name, logger=self.logger) + m.add_system( + ( + "You are an advanced transcription summarization assistant." + "Your task is to summarize discussions by focusing only on the main ideas contributed by participants." + # Prevent generating another transcription + "Exclude direct quotes and unnecessary details." + # Do not mention others participants just because they didn't contributed + "Only include participant names if they actively contribute to the subject." + # Prevent generation of summary with "no others participants contributed" etc + "Keep summaries concise and focused on main subjects without adding conclusions such as 'no other participant contributed'. " + # Avoid: In the discussion, they talked about... + "Do not include contextual preface. " + # Prevention to have too long summary + "Summary should fit in a single paragraph. " + # Using other pronouns that the participants or the group + 'Mention the participants or the group using "they".' + # Avoid finishing the summary with "No conclusions were added by the summarizer" + "Do not mention conclusion if there is no conclusion" + ) + ) + m.add_user( + f"# Transcript\n\n{self.transcript}\n\n" + + ( + "\n\n---\n\n" + "What are the main/key subjects discussed in this transcript ? " + "Do not include direct quotes or unnecessary details. " + "Be concise and focused on the main ideas. " + "A subject briefly mentionned should not be included. " + f"The result must follow the JSON schema: {JSON_SCHEMA_LIST_STRING}. " + ), + ) + + # Note: Asking the model the key subject sometimes generate a lot of subjects + # We need to consolidate them to avoid redundancy when it happen. + m2 = m.copy() + + subjects = await self.llm( + m2, + [ + self.validate_json, + partial(self.validate_json_schema, JSON_SCHEMA_LIST_STRING), + ], + ) + if subjects: + self.subjects = subjects + + if len(self.subjects) > 6: + # the model may bugged and generate a lot of subjects + m.add_user( + "No that may be too much. " + "Consolidate the subjects and remove any redundancy. " + "Keep the most importants. " + "Remember that the same subject can be written in different ways. " + "Do not consolidate subjects if they are worth keeping separate due to their importance or sensitivity. " + f"The result must follow the JSON schema: {JSON_SCHEMA_LIST_STRING}. " + ) + subjects = await self.llm( + m2, + [ + self.validate_json, + partial(self.validate_json_schema, JSON_SCHEMA_LIST_STRING), + ], + ) + if subjects: + self.subjects = subjects + + # Write manually the assistants response to remove the redundancy if case somethign happen + m.add_assistant(self.format_list_md(self.subjects)) + + if only_subjects: + return + + summaries = [] + + # ---------------------------------------------------------------------------- + # Summarize per subject + # ---------------------------------------------------------------------------- + + m2 = m.copy() + for subject in subjects: + m2 = m # .copy() + prompt = ( + f"Summarize the main subject: '{subject}'. " + "Include only the main ideas contributed by participants. " + "Do not include direct quotes or unnecessary details. " + "Avoid introducing or restating the subject. " + "Focus on the core arguments without minor details. " + "Summarize in few sentences. " + ) + m2.add_user(prompt) + + summary = await self.llm(m2) + summaries.append( + { + "subject": subject, + "summary": summary, + } + ) + + self.summaries = summaries + + # ---------------------------------------------------------------------------- + # Quick recap + # ---------------------------------------------------------------------------- + + m3 = m # .copy() + m3.add_user( + "Provide a quick recap of the meeting, that fit into a small to medium paragraph." + ) + recap = await self.llm(m3) + self.recap = recap + + # ---------------------------------------------------------------------------- + # Markdown + # ---------------------------------------------------------------------------- + + def as_markdown(self): + lines = [] + if self.recap: + lines.append("# Quick recap") + lines.append("") + lines.append(self.recap) + lines.append("") + + if self.items_action: + lines.append("# Actions") + lines.append("") + for action in self.items_action: + assigned_to = ", ".join(action.get("assigned_to", [])) + content = action.get("content", "") + line = "-" + if assigned_to: + line += f" **{assigned_to}**:" + line += f" {content}" + lines.append(line) + lines.append("") + + if self.items_decision: + lines.append("") + lines.append("# Decisions") + for decision in self.items_decision: + content = decision.get("content", "") + lines.append(f"- {content}") + lines.append("") + + if self.items_question: + lines.append("") + lines.append("# Open questions") + for question in self.items_question: + content = question.get("content", "") + lines.append(f"- {content}") + lines.append("") + + if self.summaries: + lines.append("# Summary") + lines.append("") + for summary in self.summaries: + lines.append(f"**{summary['subject']}**") + lines.append(summary["summary"]) + lines.append("") + lines.append("") + + return "\n".join(lines) + + # ---------------------------------------------------------------------------- + # Validation API + # ---------------------------------------------------------------------------- + + def validate_list(self, result: str): + # does the list match 1. xxx\n2. xxx... ? + lines = result.split("\n") + firstline = lines[0].strip() + + if re.match(r"1\.\s.+", firstline): + # strip the numbers of the list + lines = [re.sub(r"^\d+\.\s", "", line).strip() for line in lines] + return lines + + if re.match(r"- ", firstline): + # strip the list markers + lines = [re.sub(r"^- ", "", line).strip() for line in lines] + return lines + + return result.split("\n") + + def validate_next_steps(self, result: str): + if result.lower().startswith("no"): + return None + + return result + + def validate_json(self, result): + # if result startswith ```json, strip begin/end + result = result.strip() + + # grab the json between ```json and ``` using regex if exist + match = re.search(r"```json(.*?)```", result, re.DOTALL) + if match: + result = match.group(1).strip() + + # try parsing json + return json.loads(result) + + def validate_json_schema(self, schema, result): + try: + jsonschema.validate(instance=result, schema=schema) + except Exception as e: + self.logger.exception(e) + raise + return result + + # ---------------------------------------------------------------------------- + # LLM API + # ---------------------------------------------------------------------------- + + async def llm( + self, + messages: Messages, + validate_func=None, + auto_append=True, + max_retries=3, + ): + """ + Perform a completion using the LLM model. + Automatically validate the result and retry maximum `max_retries` times if an error occurs. + Append the result to the message context if `auto_append` is True. + """ + + self.logger.debug( + f"--- messages ({len(messages.messages)} messages, " + f"{messages.count_tokens()} tokens)" + ) + + if validate_func and not isinstance(validate_func, list): + validate_func = [validate_func] + + while max_retries > 0: + try: + # do the llm completion + result = result_validated = await self.completion( + messages.messages, + logger=self.logger, + ) + self.logger.debug(f"--- result\n{result_validated}") + + # validate the result using the provided functions + if validate_func: + for func in validate_func: + result_validated = func(result_validated) + + self.logger.debug(f"--- validated\n{result_validated}") + + # add the result to the message context as an assistant response + # only if the response was not guided + if auto_append: + messages.add_assistant(result) + return result_validated + except Exception as e: + self.logger.error(f"Error: {e}") + max_retries -= 1 + + async def completion(self, messages: list, **kwargs) -> str: + """ + Complete the messages using the LLM model. + The request assume a /v1/chat/completions compatible endpoint. + `messages` are a list of dict with `role` and `content` keys. + """ + + result = await self.llm_instance.completion(messages=messages, **kwargs) + return result["choices"][0]["message"]["content"] + + def format_list_md(self, data: list): + return "\n".join([f"- {item}" for item in data]) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Generate a summary of a meeting transcript" + ) + + parser.add_argument( + "transcript", + type=str, + nargs="?", + help="The transcript of the meeting", + default="transcript.txt", + ) + + parser.add_argument( + "--transcription-type", + action="store_true", + help="Identify the type of the transcript (meeting, interview, podcast...)", + ) + + parser.add_argument( + "--save", + action="store_true", + help="Save the summary to a file", + ) + + parser.add_argument( + "--summary", + action="store_true", + help="Generate a summary", + ) + + parser.add_argument( + "--items", + help="Generate a list of action items", + action="store_true", + ) + + parser.add_argument( + "--subjects", + help="Generate a list of subjects", + action="store_true", + ) + + parser.add_argument( + "--participants", + help="Generate a list of participants", + action="store_true", + ) + + args = parser.parse_args() + + async def main(): + # build the summary + llm = LLM.get_instance(model_name="NousResearch/Hermes-3-Llama-3.1-8B") + sm = SummaryBuilder(llm=llm, filename=args.transcript) + + if args.subjects: + await sm.generate_summary(only_subjects=True) + print("# Subjects\n") + print("\n".join(sm.subjects)) + sys.exit(0) + + if args.transcription_type: + await sm.identify_transcription_type() + print(sm.transcription_type) + sys.exit(0) + + if args.participants: + await sm.identify_participants() + sys.exit(0) + + # if no summary or items is asked, ask for everything + if not args.summary and not args.items and not args.subjects: + args.summary = True + args.items = True + + await sm.identify_participants() + await sm.identify_transcription_type() + + if args.summary: + await sm.generate_summary() + + if sm.transcription_type == TranscriptionType.MEETING: + if args.items: + await sm.generate_items( + search_action=True, + search_decision=True, + search_open_question=True, + ) + + print("") + print("-" * 80) + print("") + print(sm.as_markdown()) + + if args.save: + # write the summary to a file, on the format summary-.md + filename = f"summary-{datetime.now().isoformat()}.md" + with open(filename, "w", encoding="utf-8") as f: + f.write(sm.as_markdown()) + + print("") + print("-" * 80) + print("") + print("Saved to", filename) + + asyncio.run(main()) diff --git a/server/reflector/processors/transcript_final_long_summary.py b/server/reflector/processors/transcript_final_long_summary.py deleted file mode 100644 index 57e36636..00000000 --- a/server/reflector/processors/transcript_final_long_summary.py +++ /dev/null @@ -1,96 +0,0 @@ -import nltk -from reflector.llm import LLM, LLMTaskParams -from reflector.processors.base import Processor -from reflector.processors.types import FinalLongSummary, TitleSummary - - -class TranscriptFinalLongSummaryProcessor(Processor): - """ - Get the final long summary - """ - - INPUT_TYPE = TitleSummary - OUTPUT_TYPE = FinalLongSummary - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.chunks: list[TitleSummary] = [] - self.llm = LLM.get_instance(model_name="HuggingFaceH4/zephyr-7b-alpha") - - async def _push(self, data: TitleSummary): - self.chunks.append(data) - - async def get_bullet_summary(self, text: str) -> str: - params = LLMTaskParams.get_instance("bullet_summary").task_params - chunks = list(self.llm.split_corpus(corpus=text, task_params=params)) - - bullet_summary = "" - for chunk in chunks: - prompt = self.llm.create_prompt(instruct=params.instruct, text=chunk) - summary_result = await self.llm.generate( - prompt=prompt, - gen_schema=params.gen_schema, - gen_cfg=params.gen_cfg, - logger=self.logger, - ) - bullet_summary += summary_result["long_summary"] - return bullet_summary - - async def get_merged_summary(self, text: str) -> str: - params = LLMTaskParams.get_instance("merged_summary").task_params - chunks = list(self.llm.split_corpus(corpus=text, task_params=params)) - - merged_summary = "" - for chunk in chunks: - prompt = self.llm.create_prompt(instruct=params.instruct, text=chunk) - summary_result = await self.llm.generate( - prompt=prompt, - gen_schema=params.gen_schema, - gen_cfg=params.gen_cfg, - logger=self.logger, - ) - merged_summary += summary_result["long_summary"] - return merged_summary - - async def get_long_summary(self, text: str) -> str: - """ - Generate a long version of the final summary - """ - bullet_summary = await self.get_bullet_summary(text) - merged_summary = await self.get_merged_summary(bullet_summary) - - return merged_summary - - def sentence_tokenize(self, text: str) -> [str]: - return nltk.sent_tokenize(text) - - async def _flush(self): - if not self.chunks: - self.logger.warning("No summary to output") - return - - accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks]) - long_summary = await self.get_long_summary(accumulated_summaries) - - # Format the output as much as possible to be handled - # by front-end for displaying - summary_sentences = [] - for sentence in self.sentence_tokenize(long_summary): - sentence = str(sentence).strip() - if sentence.startswith("- "): - sentence.replace("- ", "* ") - elif not sentence.startswith("*"): - sentence = "* " + sentence - sentence += " \n" - summary_sentences.append(sentence) - - formatted_long_summary = "".join(summary_sentences) - - last_chunk = self.chunks[-1] - duration = last_chunk.timestamp + last_chunk.duration - - final_long_summary = FinalLongSummary( - long_summary=formatted_long_summary, - duration=duration, - ) - await self.emit(final_long_summary) diff --git a/server/reflector/processors/transcript_final_short_summary.py b/server/reflector/processors/transcript_final_short_summary.py deleted file mode 100644 index fe25ebc0..00000000 --- a/server/reflector/processors/transcript_final_short_summary.py +++ /dev/null @@ -1,72 +0,0 @@ -from reflector.llm import LLM, LLMTaskParams -from reflector.processors.base import Processor -from reflector.processors.types import FinalShortSummary, TitleSummary - - -class TranscriptFinalShortSummaryProcessor(Processor): - """ - Get the final summary using a tree summarizer - """ - - INPUT_TYPE = TitleSummary - OUTPUT_TYPE = FinalShortSummary - TASK = "final_short_summary" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.chunks: list[TitleSummary] = [] - self.llm = LLM.get_instance() - self.params = LLMTaskParams.get_instance(self.TASK).task_params - - async def _push(self, data: TitleSummary): - self.chunks.append(data) - - async def get_short_summary(self, text: str) -> dict: - """ - Generata a short summary using tree summarizer - """ - self.logger.info(f"Smoothing out {len(text)} length summary to a short summary") - chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params)) - - if len(chunks) == 1: - chunk = chunks[0] - prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk) - summary_result = await self.llm.generate( - prompt=prompt, - gen_schema=self.params.gen_schema, - gen_cfg=self.params.gen_cfg, - logger=self.logger, - ) - return summary_result - else: - accumulated_summaries = "" - for chunk in chunks: - prompt = self.llm.create_prompt( - instruct=self.params.instruct, text=chunk - ) - summary_result = await self.llm.generate( - prompt=prompt, - gen_schema=self.params.gen_schema, - gen_cfg=self.params.gen_cfg, - logger=self.logger, - ) - accumulated_summaries += summary_result["short_summary"] - - return await self.get_short_summary(accumulated_summaries) - - async def _flush(self): - if not self.chunks: - self.logger.warning("No summary to output") - return - - accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks]) - short_summary_result = await self.get_short_summary(accumulated_summaries) - - last_chunk = self.chunks[-1] - duration = last_chunk.timestamp + last_chunk.duration - - final_summary = FinalShortSummary( - short_summary=short_summary_result["short_summary"], - duration=duration, - ) - await self.emit(final_summary) diff --git a/server/reflector/processors/transcript_final_summary.py b/server/reflector/processors/transcript_final_summary.py new file mode 100644 index 00000000..daa52e56 --- /dev/null +++ b/server/reflector/processors/transcript_final_summary.py @@ -0,0 +1,83 @@ +from reflector.llm import LLM +from reflector.processors.base import Processor +from reflector.processors.summary.summary_builder import SummaryBuilder +from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary + + +class TranscriptFinalSummaryProcessor(Processor): + """ + Get the final (long and short) summary + """ + + INPUT_TYPE = TitleSummary + OUTPUT_TYPE = FinalLongSummary + + def __init__(self, transcript=None, **kwargs): + super().__init__(**kwargs) + self.transcript = transcript + self.chunks: list[TitleSummary] = [] + self.llm = LLM.get_instance(model_name="NousResearch/Hermes-3-Llama-3.1-8B") + self.builder = None + + async def _push(self, data: TitleSummary): + self.chunks.append(data) + + async def get_summary_builder(self, text) -> SummaryBuilder: + builder = SummaryBuilder(self.llm) + builder.set_transcript(text) + await builder.identify_participants() + await builder.generate_summary() + return builder + + async def get_long_summary(self, text) -> str: + if not self.builder: + self.builder = await self.get_summary_builder(text) + return self.builder.as_markdown() + + async def get_short_summary(self, text) -> str | None: + if not self.builder: + self.builder = await self.get_summary_builder(text) + return self.builder.recap + + async def _flush(self): + if not self.chunks: + self.logger.warning("No summary to output") + return + + # build the speakermap from the transcript + speakermap = {} + if self.transcript: + speakermap = { + participant["speaker"]: participant["name"] + for participant in self.transcript.participants + } + + # build the transcript as a single string + # XXX: unsure if the participants name as replaced directly in speaker ? + text_transcript = [] + for topic in self.chunks: + for segment in topic.transcript.as_segments(): + name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}") + text_transcript.append(f"{name}: {segment.text}") + + text_transcript = "\n".join(text_transcript) + + last_chunk = self.chunks[-1] + duration = last_chunk.timestamp + last_chunk.duration + + long_summary = await self.get_long_summary(text_transcript) + short_summary = await self.get_short_summary(text_transcript) + + final_long_summary = FinalLongSummary( + long_summary=long_summary, + duration=duration, + ) + + if short_summary: + final_short_summary = FinalShortSummary( + short_summary=short_summary, + duration=duration, + ) + await self.emit(final_short_summary, name="short_summary") + + await self.emit(final_long_summary) diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 3f571c4d..4b24ba21 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -77,6 +77,7 @@ class Settings(BaseSettings): LLM_MAX_TOKENS: int = 1024 LLM_TEMPERATURE: float = 0.7 ZEPHYR_LLM_URL: str | None = None + HERMES_3_8B_LLM_URL: str | None = None # LLM Modal configuration LLM_MODAL_API_KEY: str | None = None diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index d619040d..e2e352a8 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -8,8 +8,7 @@ from reflector.processors import ( AudioTranscriptAutoProcessor, Pipeline, PipelineEvent, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, + TranscriptFinalSummaryProcessor, TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, @@ -39,8 +38,7 @@ async def process_audio_file( BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(), - TranscriptFinalLongSummaryProcessor.as_threaded(), - TranscriptFinalShortSummaryProcessor.as_threaded(), + TranscriptFinalSummaryProcessor.as_threaded(), ], ), ] diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d25801bf..f161d028 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -25,16 +25,16 @@ def dummy_processors(): ) as mock_topic, patch( "reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title" ) as mock_title, patch( - "reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary" + "reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary" ) as mock_long_summary, patch( - "reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary" + "reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary" ) as mock_short_summary, patch( "reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation" ) as mock_translate: mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"} mock_title.return_value = {"title": "LLM TITLE"} mock_long_summary.return_value = "LLM LONG SUMMARY" - mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} + mock_short_summary.return_value = "LLM SHORT SUMMARY" mock_translate.return_value = "Bonjour le monde" yield ( mock_translate, @@ -142,15 +142,6 @@ def ensure_casing(): yield -@pytest.fixture -def sentence_tokenize(): - with patch( - "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize" - ) as mock_sent_tokenize: - mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"] - yield - - @pytest.fixture(scope="session") def celery_enable_logging(): return True diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 79825b26..44efafc0 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -9,7 +9,6 @@ async def test_basic_process( dummy_llm, dummy_processors, ensure_casing, - sentence_tokenize, ): # goal is to start the server, and send rtc audio to it # validate the events received @@ -38,6 +37,5 @@ async def test_basic_process( assert marks["TranscriptLinerProcessor"] == 4 assert marks["TranscriptTranslatorProcessor"] == 4 assert marks["TranscriptTopicDetectorProcessor"] == 1 - assert marks["TranscriptFinalLongSummaryProcessor"] == 1 - assert marks["TranscriptFinalShortSummaryProcessor"] == 1 + assert marks["TranscriptFinalSummaryProcessor"] == 1 assert marks["TranscriptFinalTitleProcessor"] == 1 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index e95839f0..04d12563 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -72,7 +72,6 @@ async def test_transcript_rtc_and_websocket( ensure_casing, nltk, appserver, - sentence_tokenize, ): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread @@ -176,7 +175,7 @@ async def test_transcript_rtc_and_websocket( assert "FINAL_LONG_SUMMARY" in eventnames ev = events[eventnames.index("FINAL_LONG_SUMMARY")] - assert ev["data"]["long_summary"] == "* LLM LONG SUMMARY \n" + assert ev["data"]["long_summary"] == "LLM LONG SUMMARY" assert "FINAL_SHORT_SUMMARY" in eventnames ev = events[eventnames.index("FINAL_SHORT_SUMMARY")] @@ -230,7 +229,6 @@ async def test_transcript_rtc_and_websocket_and_fr( ensure_casing, nltk, appserver, - sentence_tokenize, ): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread @@ -343,7 +341,7 @@ async def test_transcript_rtc_and_websocket_and_fr( assert "FINAL_LONG_SUMMARY" in eventnames ev = events[eventnames.index("FINAL_LONG_SUMMARY")] - assert ev["data"]["long_summary"] == "* LLM LONG SUMMARY \n" + assert ev["data"]["long_summary"] == "LLM LONG SUMMARY" assert "FINAL_SHORT_SUMMARY" in eventnames ev = events[eventnames.index("FINAL_SHORT_SUMMARY")] diff --git a/www/app/(app)/transcripts/[transcriptId]/finalSummary.tsx b/www/app/(app)/transcripts/[transcriptId]/finalSummary.tsx index 08c20feb..ec552a15 100644 --- a/www/app/(app)/transcripts/[transcriptId]/finalSummary.tsx +++ b/www/app/(app)/transcripts/[transcriptId]/finalSummary.tsx @@ -97,12 +97,19 @@ export default function FinalSummary(props: FinalSummaryProps) { h={"100%"} overflowY={isEditMode ? "hidden" : "auto"} pb={4} + position="relative" > - - Summary - + {isEditMode && ( <> + Summary