mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 267b7401ea | |||
| aea9de393c | |||
| dc177af3ff | |||
| 5bd8233657 | |||
| 28ac031ff6 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -13,3 +13,4 @@ restart-dev.sh
|
|||||||
data/
|
data/
|
||||||
www/REFACTOR.md
|
www/REFACTOR.md
|
||||||
www/reload-frontend
|
www/reload-frontend
|
||||||
|
server/test.sqlite
|
||||||
|
|||||||
20
CHANGELOG.md
20
CHANGELOG.md
@@ -1,5 +1,25 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.6.0](https://github.com/Monadical-SAS/reflector/compare/v0.5.0...v0.6.0) (2025-08-05)
|
||||||
|
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* Configuration keys have changed. Update your .env file:
|
||||||
|
- TRANSCRIPT_MODAL_API_KEY → TRANSCRIPT_API_KEY
|
||||||
|
- LLM_MODAL_API_KEY → (removed, use TRANSCRIPT_API_KEY)
|
||||||
|
- Add DIARIZATION_API_KEY and TRANSLATE_API_KEY if using those services
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* implement service-specific Modal API keys with auto processor pattern ([#528](https://github.com/Monadical-SAS/reflector/issues/528)) ([650befb](https://github.com/Monadical-SAS/reflector/commit/650befb291c47a1f49e94a01ab37d8fdfcd2b65d))
|
||||||
|
* use llamaindex everywhere ([#525](https://github.com/Monadical-SAS/reflector/issues/525)) ([3141d17](https://github.com/Monadical-SAS/reflector/commit/3141d172bc4d3b3d533370c8e6e351ea762169bf))
|
||||||
|
|
||||||
|
|
||||||
|
### Miscellaneous Chores
|
||||||
|
|
||||||
|
* **main:** release 0.6.0 ([ecdbf00](https://github.com/Monadical-SAS/reflector/commit/ecdbf003ea2476c3e95fd231adaeb852f2943df0))
|
||||||
|
|
||||||
## [0.5.0](https://github.com/Monadical-SAS/reflector/compare/v0.4.0...v0.5.0) (2025-07-31)
|
## [0.5.0](https://github.com/Monadical-SAS/reflector/compare/v0.4.0...v0.5.0) (2025-07-31)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,9 @@ All endpoints prefixed `/v1/`:
|
|||||||
**Backend** (`server/.env`):
|
**Backend** (`server/.env`):
|
||||||
- `DATABASE_URL` - Database connection string
|
- `DATABASE_URL` - Database connection string
|
||||||
- `REDIS_URL` - Redis broker for Celery
|
- `REDIS_URL` - Redis broker for Celery
|
||||||
- `MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` - Modal.com GPU processing
|
- `TRANSCRIPT_BACKEND=modal` + `TRANSCRIPT_MODAL_API_KEY` - Modal.com transcription
|
||||||
|
- `DIARIZATION_BACKEND=modal` + `DIARIZATION_MODAL_API_KEY` - Modal.com diarization
|
||||||
|
- `TRANSLATION_BACKEND=modal` + `TRANSLATION_MODAL_API_KEY` - Modal.com translation
|
||||||
- `WHEREBY_API_KEY` - Video platform integration
|
- `WHEREBY_API_KEY` - Video platform integration
|
||||||
- `REFLECTOR_AUTH_BACKEND` - Authentication method (none, jwt)
|
- `REFLECTOR_AUTH_BACKEND` - Authentication method (none, jwt)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ AUTH_JWT_AUDIENCE=
|
|||||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||||
#TRANSCRIPT_BACKEND=modal
|
#TRANSCRIPT_BACKEND=modal
|
||||||
#TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-web.modal.run
|
#TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-web.modal.run
|
||||||
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
|
|
||||||
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||||
|
|
||||||
TRANSCRIPT_BACKEND=modal
|
TRANSCRIPT_BACKEND=modal
|
||||||
@@ -32,11 +31,13 @@ TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
|||||||
TRANSCRIPT_MODAL_API_KEY=
|
TRANSCRIPT_MODAL_API_KEY=
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
## Transcription backend
|
## Translation backend
|
||||||
##
|
##
|
||||||
## Only available in modal atm
|
## Only available in modal atm
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
TRANSLATION_BACKEND=modal
|
||||||
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||||
|
#TRANSLATION_MODAL_API_KEY=xxxxx
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
## LLM backend
|
## LLM backend
|
||||||
@@ -46,38 +47,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
|||||||
## llm backend implementation
|
## llm backend implementation
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
|
||||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
|
||||||
LLM_BACKEND=modal
|
|
||||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
|
||||||
LLM_MODAL_API_KEY=
|
|
||||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
|
||||||
|
|
||||||
|
|
||||||
## Using OpenAI
|
|
||||||
#LLM_BACKEND=openai
|
|
||||||
#LLM_OPENAI_KEY=xxx
|
|
||||||
#LLM_OPENAI_MODEL=gpt-3.5-turbo
|
|
||||||
|
|
||||||
## Using GPT4ALL
|
|
||||||
#LLM_BACKEND=openai
|
|
||||||
#LLM_URL=http://localhost:4891/v1/completions
|
|
||||||
#LLM_OPENAI_MODEL="GPT4All Falcon"
|
|
||||||
|
|
||||||
## Default LLM MODEL NAME
|
|
||||||
#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
|
|
||||||
|
|
||||||
## Cache directory to store models
|
|
||||||
CACHE_DIR=data
|
|
||||||
|
|
||||||
## =======================================================
|
|
||||||
## Summary LLM configuration
|
|
||||||
## =======================================================
|
|
||||||
|
|
||||||
## Context size for summary generation (tokens)
|
## Context size for summary generation (tokens)
|
||||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS=16000
|
# LLM_MODEL=microsoft/phi-4
|
||||||
SUMMARY_LLM_URL=
|
LLM_CONTEXT_WINDOW=16000
|
||||||
SUMMARY_LLM_API_KEY=sk-
|
LLM_URL=
|
||||||
SUMMARY_MODEL=
|
LLM_API_KEY=sk-
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
## Diarization
|
## Diarization
|
||||||
@@ -86,7 +60,9 @@ SUMMARY_MODEL=
|
|||||||
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
|
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
|
||||||
## =======================================================
|
## =======================================================
|
||||||
DIARIZATION_ENABLED=false
|
DIARIZATION_ENABLED=false
|
||||||
|
DIARIZATION_BACKEND=modal
|
||||||
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
|
||||||
|
#DIARIZATION_MODAL_API_KEY=xxxxx
|
||||||
|
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
|||||||
@@ -3,8 +3,9 @@
|
|||||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||||
and use [Modal.com](https://modal.com)
|
and use [Modal.com](https://modal.com)
|
||||||
|
|
||||||
- `reflector_llm.py` - LLM API
|
- `reflector_diarizer.py` - Diarization API
|
||||||
- `reflector_transcriber.py` - Transcription API
|
- `reflector_transcriber.py` - Transcription API
|
||||||
|
- `reflector_translator.py` - Translation API
|
||||||
|
|
||||||
## Modal.com deployment
|
## Modal.com deployment
|
||||||
|
|
||||||
@@ -23,16 +24,20 @@ $ modal deploy reflector_llm.py
|
|||||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||||
```
|
```
|
||||||
|
|
||||||
Then in your reflector api configuration `.env`, you can set theses keys:
|
Then in your reflector api configuration `.env`, you can set these keys:
|
||||||
|
|
||||||
```
|
```
|
||||||
TRANSCRIPT_BACKEND=modal
|
TRANSCRIPT_BACKEND=modal
|
||||||
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
|
||||||
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
|
|
||||||
LLM_BACKEND=modal
|
DIARIZATION_BACKEND=modal
|
||||||
LLM_URL=https://xxxx--reflector-llm-web.modal.run
|
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
|
||||||
LLM_MODAL_API_KEY=REFLECTOR_APIKEY
|
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
|
|
||||||
|
TRANSLATION_BACKEND=modal
|
||||||
|
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
|
||||||
|
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
|
||||||
```
|
```
|
||||||
|
|
||||||
## API
|
## API
|
||||||
|
|||||||
@@ -1,213 +0,0 @@
|
|||||||
"""
|
|
||||||
Reflector GPU backend - LLM
|
|
||||||
===========================
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
|
||||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
|
||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
|
||||||
|
|
||||||
app = App(name="reflector-llm")
|
|
||||||
|
|
||||||
|
|
||||||
def download_llm():
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
print("Downloading LLM model")
|
|
||||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM model downloaded")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
llm_image = (
|
|
||||||
Image.debian_slim(python_version="3.10.8")
|
|
||||||
.apt_install("git")
|
|
||||||
.pip_install(
|
|
||||||
"transformers",
|
|
||||||
"torch",
|
|
||||||
"sentencepiece",
|
|
||||||
"protobuf",
|
|
||||||
"jsonformer==0.12.0",
|
|
||||||
"accelerate==0.21.0",
|
|
||||||
"einops==0.6.1",
|
|
||||||
"hf-transfer~=0.1",
|
|
||||||
"huggingface_hub==0.16.4",
|
|
||||||
)
|
|
||||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
|
||||||
.run_function(download_llm)
|
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A100",
|
|
||||||
timeout=60 * 5,
|
|
||||||
scaledown_window=60 * 5,
|
|
||||||
allow_concurrent_inputs=15,
|
|
||||||
image=llm_image,
|
|
||||||
)
|
|
||||||
class LLM:
|
|
||||||
@enter()
|
|
||||||
def enter(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
print("Instance llm model")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
LLM_MODEL,
|
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
|
||||||
cache_dir=IMAGE_MODEL_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
|
||||||
print("Instance llm generation config")
|
|
||||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# generation configuration
|
|
||||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
|
||||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# load tokenizer
|
|
||||||
print("Instance llm tokenizer")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# move model to gpu
|
|
||||||
print("Move llm model to GPU")
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
print("Warmup llm done")
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.gen_cfg = gen_cfg
|
|
||||||
self.GenerationConfig = GenerationConfig
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
@exit()
|
|
||||||
def exit():
|
|
||||||
print("Exit llm")
|
|
||||||
|
|
||||||
@method()
|
|
||||||
def generate(
|
|
||||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Perform a generation action using the LLM
|
|
||||||
"""
|
|
||||||
print(f"Generate {prompt=}")
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
|
||||||
else:
|
|
||||||
gen_cfg = self.gen_cfg
|
|
||||||
|
|
||||||
# If a gen_schema is given, conform to gen_schema
|
|
||||||
with self.lock:
|
|
||||||
if gen_schema:
|
|
||||||
import jsonformer
|
|
||||||
|
|
||||||
print(f"Schema {gen_schema=}")
|
|
||||||
jsonformer_llm = jsonformer.Jsonformer(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
json_schema=json.loads(gen_schema),
|
|
||||||
prompt=prompt,
|
|
||||||
max_string_token_length=gen_cfg.max_new_tokens,
|
|
||||||
)
|
|
||||||
response = jsonformer_llm()
|
|
||||||
else:
|
|
||||||
# If no gen_schema, perform prompt only generation
|
|
||||||
|
|
||||||
# tokenize prompt
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
|
||||||
self.model.device
|
|
||||||
)
|
|
||||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
|
||||||
|
|
||||||
# decode output
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
output[0].cpu(), skip_special_tokens=True
|
|
||||||
)
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
print(f"Generated {response=}")
|
|
||||||
return {"text": response}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
# Web API
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60 * 10,
|
|
||||||
timeout=60 * 5,
|
|
||||||
allow_concurrent_inputs=45,
|
|
||||||
secrets=[
|
|
||||||
Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
llmstub = LLM()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class LLMRequest(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
gen_cfg: Optional[dict] = None
|
|
||||||
|
|
||||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
|
||||||
def llm(
|
|
||||||
req: LLMRequest,
|
|
||||||
):
|
|
||||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
|
||||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
|
||||||
func = llmstub.generate.spawn(
|
|
||||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
|
|
||||||
return app
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
"""
|
|
||||||
Reflector GPU backend - LLM
|
|
||||||
===========================
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
|
|
||||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
|
||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
|
||||||
|
|
||||||
app = App(name="reflector-llm-zephyr")
|
|
||||||
|
|
||||||
|
|
||||||
def download_llm():
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
print("Downloading LLM model")
|
|
||||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM model downloaded")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
llm_image = (
|
|
||||||
Image.debian_slim(python_version="3.10.8")
|
|
||||||
.apt_install("git")
|
|
||||||
.pip_install(
|
|
||||||
"transformers==4.34.0",
|
|
||||||
"torch",
|
|
||||||
"sentencepiece",
|
|
||||||
"protobuf",
|
|
||||||
"jsonformer==0.12.0",
|
|
||||||
"accelerate==0.21.0",
|
|
||||||
"einops==0.6.1",
|
|
||||||
"hf-transfer~=0.1",
|
|
||||||
"huggingface_hub==0.16.4",
|
|
||||||
)
|
|
||||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
|
||||||
.run_function(download_llm)
|
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A10G",
|
|
||||||
timeout=60 * 5,
|
|
||||||
scaledown_window=60 * 5,
|
|
||||||
allow_concurrent_inputs=10,
|
|
||||||
image=llm_image,
|
|
||||||
)
|
|
||||||
class LLM:
|
|
||||||
@enter()
|
|
||||||
def enter(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
print("Instance llm model")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
LLM_MODEL,
|
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
|
||||||
cache_dir=IMAGE_MODEL_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
|
||||||
print("Instance llm generation config")
|
|
||||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# generation configuration
|
|
||||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
|
||||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# load tokenizer
|
|
||||||
print("Instance llm tokenizer")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
|
||||||
)
|
|
||||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
|
||||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
model.config.pad_token_id = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
# move model to gpu
|
|
||||||
print("Move llm model to GPU")
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
print("Warmup llm done")
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.gen_cfg = gen_cfg
|
|
||||||
self.GenerationConfig = GenerationConfig
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
@exit()
|
|
||||||
def exit():
|
|
||||||
print("Exit llm")
|
|
||||||
|
|
||||||
@method()
|
|
||||||
def generate(
|
|
||||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Perform a generation action using the LLM
|
|
||||||
"""
|
|
||||||
print(f"Generate {prompt=}")
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
|
||||||
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
|
|
||||||
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
gen_cfg = self.gen_cfg
|
|
||||||
|
|
||||||
# If a gen_schema is given, conform to gen_schema
|
|
||||||
with self.lock:
|
|
||||||
if gen_schema:
|
|
||||||
import jsonformer
|
|
||||||
|
|
||||||
print(f"Schema {gen_schema=}")
|
|
||||||
jsonformer_llm = jsonformer.Jsonformer(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
json_schema=json.loads(gen_schema),
|
|
||||||
prompt=prompt,
|
|
||||||
max_string_token_length=gen_cfg.max_new_tokens,
|
|
||||||
)
|
|
||||||
response = jsonformer_llm()
|
|
||||||
else:
|
|
||||||
# If no gen_schema, perform prompt only generation
|
|
||||||
|
|
||||||
# tokenize prompt
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
|
||||||
self.model.device
|
|
||||||
)
|
|
||||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
|
||||||
|
|
||||||
# decode output
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
output[0].cpu(), skip_special_tokens=True
|
|
||||||
)
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
response = {"long_summary": response}
|
|
||||||
print(f"Generated {response=}")
|
|
||||||
return {"text": response}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
# Web API
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60 * 10,
|
|
||||||
timeout=60 * 5,
|
|
||||||
allow_concurrent_inputs=30,
|
|
||||||
secrets=[
|
|
||||||
Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
llmstub = LLM()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class LLMRequest(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
gen_cfg: Optional[dict] = None
|
|
||||||
|
|
||||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
|
||||||
def llm(
|
|
||||||
req: LLMRequest,
|
|
||||||
):
|
|
||||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
|
||||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
|
||||||
func = llmstub.generate.spawn(
|
|
||||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
|
|
||||||
return app
|
|
||||||
@@ -40,6 +40,7 @@ dependencies = [
|
|||||||
"psycopg2-binary>=2.9.10",
|
"psycopg2-binary>=2.9.10",
|
||||||
"llama-index>=0.12.52",
|
"llama-index>=0.12.52",
|
||||||
"llama-index-llms-openai-like>=0.4.0",
|
"llama-index-llms-openai-like>=0.4.0",
|
||||||
|
"pytest-env>=1.1.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -83,6 +84,10 @@ packages = ["reflector"]
|
|||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
source = ["reflector"]
|
source = ["reflector"]
|
||||||
|
|
||||||
|
[tool.pytest_env]
|
||||||
|
ENVIRONMENT = "pytest"
|
||||||
|
DATABASE_URL = "sqlite:///test.sqlite"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|||||||
83
server/reflector/llm.py
Normal file
83
server/reflector/llm.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
from llama_index.core import Settings
|
||||||
|
from llama_index.core.output_parsers import PydanticOutputParser
|
||||||
|
from llama_index.core.program import LLMTextCompletionProgram
|
||||||
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
||||||
|
Based on the following analysis, provide the information in the requested JSON format:
|
||||||
|
|
||||||
|
Analysis:
|
||||||
|
{analysis}
|
||||||
|
|
||||||
|
{format_instructions}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
|
||||||
|
self.settings_obj = settings
|
||||||
|
self.model_name = settings.LLM_MODEL
|
||||||
|
self.url = settings.LLM_URL
|
||||||
|
self.api_key = settings.LLM_API_KEY
|
||||||
|
self.context_window = settings.LLM_CONTEXT_WINDOW
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# Configure llamaindex Settings
|
||||||
|
self._configure_llamaindex()
|
||||||
|
|
||||||
|
def _configure_llamaindex(self):
|
||||||
|
"""Configure llamaindex Settings with OpenAILike LLM"""
|
||||||
|
Settings.llm = OpenAILike(
|
||||||
|
model=self.model_name,
|
||||||
|
api_base=self.url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
context_window=self.context_window,
|
||||||
|
is_chat_model=True,
|
||||||
|
is_function_calling_model=False,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_response(
|
||||||
|
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Get a text response using TreeSummarize for non-function-calling models"""
|
||||||
|
summarizer = TreeSummarize(verbose=False)
|
||||||
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||||
|
return str(response).strip()
|
||||||
|
|
||||||
|
async def get_structured_response(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
texts: list[str],
|
||||||
|
output_cls: Type[T],
|
||||||
|
tone_name: str | None = None,
|
||||||
|
) -> T:
|
||||||
|
"""Get structured output from LLM for non-function-calling models"""
|
||||||
|
summarizer = TreeSummarize(verbose=True)
|
||||||
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||||
|
|
||||||
|
output_parser = PydanticOutputParser(output_cls)
|
||||||
|
|
||||||
|
program = LLMTextCompletionProgram.from_defaults(
|
||||||
|
output_parser=output_parser,
|
||||||
|
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
format_instructions = output_parser.format(
|
||||||
|
"Please structure the above information in the following JSON format:"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = await program.acall(
|
||||||
|
analysis=str(response), format_instructions=format_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .base import LLM # noqa: F401
|
|
||||||
from .llm_params import LLMTaskParams # noqa: F401
|
|
||||||
@@ -1,347 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import nltk
|
|
||||||
from prometheus_client import Counter, Histogram
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.llm_params import TaskParams
|
|
||||||
from reflector.logger import logger as reflector_logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="LLM")
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
|
||||||
_nltk_downloaded = False
|
|
||||||
_registry = {}
|
|
||||||
model_name: str
|
|
||||||
m_generate = Histogram(
|
|
||||||
"llm_generate",
|
|
||||||
"Time spent in LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_call = Counter(
|
|
||||||
"llm_generate_call",
|
|
||||||
"Number of calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_success = Counter(
|
|
||||||
"llm_generate_success",
|
|
||||||
"Number of successful calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_failure = Counter(
|
|
||||||
"llm_generate_failure",
|
|
||||||
"Number of failed calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ensure_nltk(cls):
|
|
||||||
"""
|
|
||||||
Make sure NLTK package is installed. Searches in the cache and
|
|
||||||
downloads only if needed.
|
|
||||||
"""
|
|
||||||
if not cls._nltk_downloaded:
|
|
||||||
nltk.download("punkt_tab")
|
|
||||||
# For POS tagging
|
|
||||||
nltk.download("averaged_perceptron_tagger_eng")
|
|
||||||
cls._nltk_downloaded = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, klass):
|
|
||||||
cls._registry[name] = klass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
|
||||||
"""
|
|
||||||
Return an instance depending on the settings.
|
|
||||||
Settings used:
|
|
||||||
|
|
||||||
- `LLM_BACKEND`: key of the backend
|
|
||||||
- `LLM_URL`: url of the backend
|
|
||||||
"""
|
|
||||||
if name is None:
|
|
||||||
name = settings.LLM_BACKEND
|
|
||||||
if name not in cls._registry:
|
|
||||||
module_name = f"reflector.llm.llm_{name}"
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
cls.ensure_nltk()
|
|
||||||
|
|
||||||
return cls._registry[name](model_name)
|
|
||||||
|
|
||||||
def get_model_name(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the currently set model name
|
|
||||||
"""
|
|
||||||
return self._get_model_name()
|
|
||||||
|
|
||||||
def _get_model_name(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_model_name(self, model_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Update the model name with the provided model name
|
|
||||||
"""
|
|
||||||
return self._set_model_name(model_name)
|
|
||||||
|
|
||||||
def _set_model_name(self, model_name: str) -> bool:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def template(self) -> str:
|
|
||||||
"""
|
|
||||||
Return the LLM Prompt template
|
|
||||||
"""
|
|
||||||
return """
|
|
||||||
### Human:
|
|
||||||
{instruct}
|
|
||||||
|
|
||||||
{text}
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
name = self.__class__.__name__
|
|
||||||
self.m_generate = self.m_generate.labels(name)
|
|
||||||
self.m_generate_call = self.m_generate_call.labels(name)
|
|
||||||
self.m_generate_success = self.m_generate_success.labels(name)
|
|
||||||
self.m_generate_failure = self.m_generate_failure.labels(name)
|
|
||||||
self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer(self):
|
|
||||||
"""
|
|
||||||
Return the tokenizer instance used by LLM
|
|
||||||
"""
|
|
||||||
return self._get_tokenizer()
|
|
||||||
|
|
||||||
def _get_tokenizer(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def has_structured_output(self):
|
|
||||||
# whether implementation supports structured output
|
|
||||||
# on the model side (otherwise it's prompt engineering)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
logger: reflector_logger,
|
|
||||||
gen_schema: dict | None = None,
|
|
||||||
gen_cfg: GenerationConfig | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> dict:
|
|
||||||
logger.info("LLM generate", prompt=repr(prompt))
|
|
||||||
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = gen_cfg.to_dict()
|
|
||||||
self.m_generate_call.inc()
|
|
||||||
try:
|
|
||||||
with self.m_generate.time():
|
|
||||||
result = await retry(self._generate)(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=gen_schema,
|
|
||||||
gen_cfg=gen_cfg,
|
|
||||||
logger=logger,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.m_generate_success.inc()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to call llm after retrying")
|
|
||||||
self.m_generate_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug("LLM result [raw]", result=repr(result))
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = self._parse_json(result)
|
|
||||||
logger.debug("LLM result [parsed]", result=repr(result))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def completion(
|
|
||||||
self, messages: list, logger: reflector_logger, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Use /v1/chat/completion Open-AI compatible endpoint from the URL
|
|
||||||
It's up to the user to validate anything or transform the result
|
|
||||||
"""
|
|
||||||
logger.info("LLM completions", messages=messages)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.m_generate.time():
|
|
||||||
result = await retry(self._completion)(
|
|
||||||
messages=messages, **{**kwargs, "logger": logger}
|
|
||||||
)
|
|
||||||
self.m_generate_success.inc()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to call llm after retrying")
|
|
||||||
self.m_generate_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug("LLM completion result", result=repr(result))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def ensure_casing(self, title: str) -> str:
|
|
||||||
"""
|
|
||||||
LLM takes care of word casing, but in rare cases this
|
|
||||||
can falter. This is a fallback to ensure the casing of
|
|
||||||
topics is in a proper format.
|
|
||||||
|
|
||||||
We select nouns, verbs and adjectives and check if camel
|
|
||||||
casing is present and fix it, if not. Will not perform
|
|
||||||
any other changes.
|
|
||||||
"""
|
|
||||||
tokens = nltk.word_tokenize(title)
|
|
||||||
pos_tags = nltk.pos_tag(tokens)
|
|
||||||
camel_cased = []
|
|
||||||
|
|
||||||
whitelisted_pos_tags = [
|
|
||||||
"NN",
|
|
||||||
"NNS",
|
|
||||||
"NNP",
|
|
||||||
"NNPS", # Noun POS
|
|
||||||
"VB",
|
|
||||||
"VBD",
|
|
||||||
"VBG",
|
|
||||||
"VBN",
|
|
||||||
"VBP",
|
|
||||||
"VBZ", # Verb POS
|
|
||||||
"JJ",
|
|
||||||
"JJR",
|
|
||||||
"JJS", # Adjective POS
|
|
||||||
]
|
|
||||||
|
|
||||||
# If at all there is an exception, do not block other reflector
|
|
||||||
# processes. Return the LLM generated title, at the least.
|
|
||||||
try:
|
|
||||||
for word, pos in pos_tags:
|
|
||||||
if pos in whitelisted_pos_tags and word[0].islower():
|
|
||||||
camel_cased.append(word[0].upper() + word[1:])
|
|
||||||
else:
|
|
||||||
camel_cased.append(word)
|
|
||||||
modified_title = self.detokenizer.detokenize(camel_cased)
|
|
||||||
|
|
||||||
# Irrespective of casing changes, the starting letter
|
|
||||||
# of title is always upper-cased
|
|
||||||
title = modified_title[0].upper() + modified_title[1:]
|
|
||||||
except Exception as e:
|
|
||||||
reflector_logger.info(
|
|
||||||
f"Failed to ensure casing on {title=} with exception : {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return title
|
|
||||||
|
|
||||||
def trim_title(self, title: str) -> str:
|
|
||||||
"""
|
|
||||||
List of manual trimming to the title.
|
|
||||||
|
|
||||||
Longer titles are prone to run into A prefix of phrases that don't
|
|
||||||
really add any descriptive information and in some cases, this
|
|
||||||
behaviour can be repeated for several consecutive topics. Trim the
|
|
||||||
titles to maintain quality of titles.
|
|
||||||
"""
|
|
||||||
phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
|
|
||||||
try:
|
|
||||||
pattern = (
|
|
||||||
r"\b(?:"
|
|
||||||
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
|
|
||||||
+ r")\b"
|
|
||||||
)
|
|
||||||
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
|
|
||||||
except Exception as e:
|
|
||||||
reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
|
|
||||||
return title
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _parse_json(self, result: str) -> dict:
|
|
||||||
result = result.strip()
|
|
||||||
# try detecting code block if exist
|
|
||||||
# starts with ```json\n, ends with ```
|
|
||||||
# or starts with ```\n, ends with ```
|
|
||||||
# or starts with \n```javascript\n, ends with ```
|
|
||||||
|
|
||||||
regex = r"```(json|javascript|)?(.*)```"
|
|
||||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
|
||||||
if matches:
|
|
||||||
result = matches[0][1]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# maybe the prompt has been started with ```json
|
|
||||||
# so if text ends with ```, just remove it and use it as json
|
|
||||||
if result.endswith("```"):
|
|
||||||
result = result[:-3]
|
|
||||||
|
|
||||||
return json.loads(result.strip())
|
|
||||||
|
|
||||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
|
||||||
"""
|
|
||||||
Choose the token size to set as the threshold to pack the LLM calls
|
|
||||||
"""
|
|
||||||
buffer_token_size = 100
|
|
||||||
default_output_tokens = 1000
|
|
||||||
context_window = self.tokenizer.model_max_length
|
|
||||||
tokens = self.tokenizer.tokenize(
|
|
||||||
self.create_prompt(instruct=task_params.instruct, text="")
|
|
||||||
)
|
|
||||||
threshold = context_window - len(tokens) - buffer_token_size
|
|
||||||
if task_params.gen_cfg:
|
|
||||||
threshold -= task_params.gen_cfg.max_new_tokens
|
|
||||||
else:
|
|
||||||
threshold -= default_output_tokens
|
|
||||||
return threshold
|
|
||||||
|
|
||||||
def split_corpus(
|
|
||||||
self,
|
|
||||||
corpus: str,
|
|
||||||
task_params: TaskParams,
|
|
||||||
token_threshold: int | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
|
||||||
restrictions.
|
|
||||||
|
|
||||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
|
||||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
|
||||||
"""
|
|
||||||
if not token_threshold:
|
|
||||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
|
||||||
|
|
||||||
accumulated_tokens = []
|
|
||||||
accumulated_sentences = []
|
|
||||||
accumulated_token_count = 0
|
|
||||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
|
||||||
|
|
||||||
for sentence in corpus_sentences:
|
|
||||||
tokens = self.tokenizer.tokenize(sentence)
|
|
||||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
|
||||||
accumulated_token_count += len(tokens)
|
|
||||||
accumulated_tokens.extend(tokens)
|
|
||||||
accumulated_sentences.append(sentence)
|
|
||||||
else:
|
|
||||||
yield "".join(accumulated_sentences)
|
|
||||||
accumulated_token_count = len(tokens)
|
|
||||||
accumulated_tokens = tokens
|
|
||||||
accumulated_sentences = [sentence]
|
|
||||||
|
|
||||||
if accumulated_tokens:
|
|
||||||
yield " ".join(accumulated_sentences)
|
|
||||||
|
|
||||||
def create_prompt(self, instruct: str, text: str) -> str:
|
|
||||||
"""
|
|
||||||
Create a consumable prompt based on the prompt template
|
|
||||||
"""
|
|
||||||
return self.template.format(instruct=instruct, text=text)
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
|
||||||
from reflector.logger import logger as reflector_logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
|
|
||||||
|
|
||||||
class ModalLLM(LLM):
|
|
||||||
def __init__(self, model_name: str | None = None):
|
|
||||||
super().__init__()
|
|
||||||
self.timeout = settings.LLM_TIMEOUT
|
|
||||||
self.llm_url = settings.LLM_URL + "/llm"
|
|
||||||
self.headers = {
|
|
||||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
|
||||||
}
|
|
||||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_models(self):
|
|
||||||
"""
|
|
||||||
List of currently supported models on this GPU platform
|
|
||||||
"""
|
|
||||||
# TODO: Query the specific GPU platform
|
|
||||||
# Replace this with a HTTP call
|
|
||||||
return [
|
|
||||||
"lmsys/vicuna-13b-v1.5",
|
|
||||||
"HuggingFaceH4/zephyr-7b-alpha",
|
|
||||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
json_payload = {"prompt": prompt}
|
|
||||||
if gen_schema:
|
|
||||||
json_payload["gen_schema"] = gen_schema
|
|
||||||
if gen_cfg:
|
|
||||||
json_payload["gen_cfg"] = gen_cfg
|
|
||||||
|
|
||||||
# Handing over generation of the final summary to Zephyr model
|
|
||||||
# but replacing the Vicuna model will happen after more testing
|
|
||||||
# TODO: Create a mapping of model names and cloud deployments
|
|
||||||
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
|
|
||||||
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await retry(client.post)(
|
|
||||||
self.llm_url,
|
|
||||||
headers=self.headers,
|
|
||||||
json=json_payload,
|
|
||||||
timeout=self.timeout,
|
|
||||||
retry_timeout=60 * 5,
|
|
||||||
follow_redirects=True,
|
|
||||||
logger=kwargs.get("logger", reflector_logger),
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
text = response.json()["text"]
|
|
||||||
return text
|
|
||||||
|
|
||||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
|
||||||
# returns full api response
|
|
||||||
kwargs.setdefault("temperature", 0.3)
|
|
||||||
kwargs.setdefault("max_tokens", 2048)
|
|
||||||
kwargs.setdefault("stream", False)
|
|
||||||
kwargs.setdefault("repetition_penalty", 1)
|
|
||||||
kwargs.setdefault("top_p", 1)
|
|
||||||
kwargs.setdefault("top_k", -1)
|
|
||||||
kwargs.setdefault("min_p", 0.05)
|
|
||||||
data = {"messages": messages, "model": self.model_name, **kwargs}
|
|
||||||
|
|
||||||
if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
|
|
||||||
self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await retry(client.post)(
|
|
||||||
self.llm_url,
|
|
||||||
headers=self.headers,
|
|
||||||
json=data,
|
|
||||||
timeout=self.timeout,
|
|
||||||
retry_timeout=60 * 5,
|
|
||||||
follow_redirects=True,
|
|
||||||
logger=kwargs.get("logger", reflector_logger),
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def _set_model_name(self, model_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Set the model name
|
|
||||||
"""
|
|
||||||
# Abort, if the model is not supported
|
|
||||||
if model_name not in self.supported_models:
|
|
||||||
reflector_logger.info(
|
|
||||||
f"Attempted to change {model_name=}, but is not supported."
|
|
||||||
f"Setting model and tokenizer failed !"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
# Abort, if the model is already set
|
|
||||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
|
||||||
reflector_logger.info("No change in model. Setting model skipped.")
|
|
||||||
return False
|
|
||||||
# Update model name and tokenizer
|
|
||||||
self.model_name = model_name
|
|
||||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name, cache_dir=settings.CACHE_DIR
|
|
||||||
)
|
|
||||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_tokenizer(self) -> AutoTokenizer:
|
|
||||||
"""
|
|
||||||
Return the currently used LLM tokenizer
|
|
||||||
"""
|
|
||||||
return self.llm_tokenizer
|
|
||||||
|
|
||||||
def _get_model_name(self) -> str:
|
|
||||||
"""
|
|
||||||
Return the current model name from the instance details
|
|
||||||
"""
|
|
||||||
return self.model_name
|
|
||||||
|
|
||||||
|
|
||||||
LLM.register("modal", ModalLLM)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
llm = ModalLLM()
|
|
||||||
prompt = llm.create_prompt(
|
|
||||||
instruct="Complete the following task",
|
|
||||||
text="Tell me a joke about programming.",
|
|
||||||
)
|
|
||||||
result = await llm.generate(prompt=prompt, logger=logger)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
gen_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"response": {"type": "string"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
|
||||||
result = await llm.generate(
|
|
||||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
|
||||||
)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLM):
|
|
||||||
def __init__(self, model_name: str | None = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.openai_key = settings.LLM_OPENAI_KEY
|
|
||||||
self.openai_url = settings.LLM_URL
|
|
||||||
self.openai_model = settings.LLM_OPENAI_MODEL
|
|
||||||
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
|
||||||
self.timeout = settings.LLM_TIMEOUT
|
|
||||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
|
||||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
gen_schema: dict | None,
|
|
||||||
gen_cfg: GenerationConfig | None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.openai_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.openai_url,
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"model": self.openai_model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"temperature": self.openai_temperature,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
return result["choices"][0]["text"]
|
|
||||||
|
|
||||||
|
|
||||||
LLM.register("openai", OpenAILLM)
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
from typing import Optional, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
|
||||||
instruct: str
|
|
||||||
gen_cfg: Optional[GenerationConfig] = None
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="LLMTaskParams")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMTaskParams:
|
|
||||||
_registry = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, task, klass) -> None:
|
|
||||||
cls._registry[task] = klass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls, task: str) -> T:
|
|
||||||
return cls._registry[task]()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def task_params(self) -> TaskParams | None:
|
|
||||||
"""
|
|
||||||
Fetch the task related parameters
|
|
||||||
"""
|
|
||||||
return self._get_task_params()
|
|
||||||
|
|
||||||
def _get_task_params(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLongSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Take the key ideas and takeaways from the text and create a short
|
|
||||||
summary. Be sure to keep the length of the response to a minimum.
|
|
||||||
Do not include trivial information in the summary.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"long_summary": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class FinalShortSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Take the key ideas and takeaways from the text and create a short
|
|
||||||
summary. Be sure to keep the length of the response to a minimum.
|
|
||||||
Do not include trivial information in the summary.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"short_summary": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class FinalTitleParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Combine the following individual titles into one single short title that
|
|
||||||
condenses the essence of all titles.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"title": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class TopicParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Create a JSON object as response.The JSON object must have 2 fields:
|
|
||||||
i) title and ii) summary.
|
|
||||||
For the title field, generate a very detailed and self-explanatory
|
|
||||||
title for the given text. Let the title be as descriptive as possible.
|
|
||||||
For the summary field, summarize the given text in a maximum of
|
|
||||||
two sentences.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"title": {"type": "string"},
|
|
||||||
"summary": {"type": "string"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class BulletedSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=800,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=0.2,
|
|
||||||
early_stopping=True,
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Given a meeting transcript, extract the key things discussed in the
|
|
||||||
form of a list.
|
|
||||||
|
|
||||||
While generating the response, follow the constraints mentioned below.
|
|
||||||
|
|
||||||
Summary constraints:
|
|
||||||
i) Do not add new content, except to fix spelling or punctuation.
|
|
||||||
ii) Do not add any prefixes or numbering in the response.
|
|
||||||
iii) The summarization should be as information dense as possible.
|
|
||||||
iv) Do not add any additional sections like Note, Conclusion, etc. in
|
|
||||||
the response.
|
|
||||||
|
|
||||||
Response format:
|
|
||||||
i) The response should be in the form of a bulleted list.
|
|
||||||
ii) Iteratively merge all the relevant paragraphs together to keep the
|
|
||||||
number of paragraphs to a minimum.
|
|
||||||
iii) Remove any unfinished sentences from the final response.
|
|
||||||
iv) Do not include narrative or reporting clauses.
|
|
||||||
v) Use "*" as the bullet icon.
|
|
||||||
"""
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class MergedSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=600,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=0.2,
|
|
||||||
early_stopping=True,
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Given the key points of a meeting, summarize the points to describe the
|
|
||||||
meeting in the form of paragraphs.
|
|
||||||
"""
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
LLMTaskParams.register("topic", TopicParams)
|
|
||||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
|
||||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
|
||||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
|
||||||
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
|
|
||||||
LLMTaskParams.register("merged_summary", MergedSummaryParams)
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
def apply_gen_config(payload: dict, gen_cfg) -> None:
|
|
||||||
"""Apply generation config overrides to the payload."""
|
|
||||||
config_mapping = {
|
|
||||||
"temperature": "temperature",
|
|
||||||
"max_new_tokens": "max_tokens",
|
|
||||||
"max_tokens": "max_tokens",
|
|
||||||
"top_p": "top_p",
|
|
||||||
"frequency_penalty": "frequency_penalty",
|
|
||||||
"presence_penalty": "presence_penalty",
|
|
||||||
}
|
|
||||||
|
|
||||||
for cfg_attr, payload_key in config_mapping.items():
|
|
||||||
value = getattr(gen_cfg, cfg_attr, None)
|
|
||||||
if value is not None:
|
|
||||||
payload[payload_key] = value
|
|
||||||
if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM:
|
|
||||||
def __init__(self, config_prefix: str, settings):
|
|
||||||
self.config_prefix = config_prefix
|
|
||||||
self.settings_obj = settings
|
|
||||||
self.model_name = getattr(settings, f"{config_prefix}_MODEL")
|
|
||||||
self.url = getattr(settings, f"{config_prefix}_LLM_URL")
|
|
||||||
self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
|
|
||||||
|
|
||||||
timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
|
|
||||||
self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
|
|
||||||
self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
|
|
||||||
self.client = httpx.AsyncClient(timeout=timeout)
|
|
||||||
|
|
||||||
# Use a tokenizer that approximates OpenAI token counting
|
|
||||||
tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
|
|
||||||
)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
|
|
||||||
) -> str:
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI LLM generate",
|
|
||||||
prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
result = await self.completion(
|
|
||||||
messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
|
|
||||||
)
|
|
||||||
return result["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
async def completion(
|
|
||||||
self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
if logger:
|
|
||||||
logger.info("OpenAI LLM completion", messages_count=len(messages))
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Apply generation config overrides
|
|
||||||
if gen_cfg:
|
|
||||||
apply_gen_config(payload, gen_cfg)
|
|
||||||
|
|
||||||
# Apply structured output schema
|
|
||||||
if gen_schema:
|
|
||||||
payload["response_format"] = {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {"name": "response", "schema": gen_schema},
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI API request", url=url, payload_keys=list(payload.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.client.post(url, json=payload, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI API response",
|
|
||||||
status_code=response.status_code,
|
|
||||||
choices_count=len(result.get("choices", [])),
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
await self.client.aclose()
|
|
||||||
@@ -47,7 +47,7 @@ from reflector.processors import (
|
|||||||
TranscriptFinalTitleProcessor,
|
TranscriptFinalTitleProcessor,
|
||||||
TranscriptLinerProcessor,
|
TranscriptLinerProcessor,
|
||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorProcessor,
|
TranscriptTranslatorAutoProcessor,
|
||||||
)
|
)
|
||||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||||
from reflector.processors.types import AudioDiarizationInput
|
from reflector.processors.types import AudioDiarizationInput
|
||||||
@@ -361,7 +361,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
TranscriptTranslatorAutoProcessor.as_threaded(callback=self.on_transcript),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
]
|
]
|
||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
|||||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||||
from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401
|
from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401
|
||||||
|
from .transcript_translator_auto import TranscriptTranslatorAutoProcessor # noqa: F401
|
||||||
from .types import ( # noqa: F401
|
from .types import ( # noqa: F401
|
||||||
AudioFile,
|
AudioFile,
|
||||||
FinalLongSummary,
|
FinalLongSummary,
|
||||||
|
|||||||
@@ -10,12 +10,17 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
|
|||||||
INPUT_TYPE = AudioDiarizationInput
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
OUTPUT_TYPE = TitleSummary
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
if not settings.DIARIZATION_URL:
|
||||||
|
raise Exception(
|
||||||
|
"DIARIZATION_URL required to use AudioDiarizationModalProcessor"
|
||||||
|
)
|
||||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||||
self.headers = {
|
self.modal_api_key = modal_api_key
|
||||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
self.headers = {}
|
||||||
}
|
if self.modal_api_key:
|
||||||
|
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
async def _diarize(self, data: AudioDiarizationInput):
|
async def _diarize(self, data: AudioDiarizationInput):
|
||||||
# Gather diarization data
|
# Gather diarization data
|
||||||
|
|||||||
@@ -21,16 +21,20 @@ from reflector.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||||
def __init__(self, modal_api_key: str):
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if not settings.TRANSCRIPT_URL:
|
||||||
|
raise Exception(
|
||||||
|
"TRANSCRIPT_URL required to use AudioTranscriptModalProcessor"
|
||||||
|
)
|
||||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||||
self.api_key = settings.TRANSCRIPT_MODAL_API_KEY
|
self.modal_api_key = modal_api_key
|
||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
async with AsyncOpenAI(
|
async with AsyncOpenAI(
|
||||||
base_url=self.transcript_url,
|
base_url=self.transcript_url,
|
||||||
api_key=self.api_key,
|
api_key=self.modal_api_key,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
) as client:
|
) as client:
|
||||||
self.logger.debug(f"Try to transcribe audio {data.name}")
|
self.logger.debug(f"Try to transcribe audio {data.name}")
|
||||||
|
|||||||
@@ -12,15 +12,9 @@ from textwrap import dedent
|
|||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from llama_index.core import Settings
|
|
||||||
from llama_index.core.output_parsers import PydanticOutputParser
|
|
||||||
from llama_index.core.program import LLMTextCompletionProgram
|
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
from llama_index.llms.openai_like import OpenAILike
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm import LLM
|
||||||
from reflector.llm.openai_llm import OpenAILLM
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
@@ -168,23 +162,12 @@ class SummaryBuilder:
|
|||||||
self.summaries: list[dict[str, str]] = []
|
self.summaries: list[dict[str, str]] = []
|
||||||
self.subjects: list[str] = []
|
self.subjects: list[str] = []
|
||||||
self.transcription_type: TranscriptionType | None = None
|
self.transcription_type: TranscriptionType | None = None
|
||||||
self.llm_instance: LLM = llm
|
self.llm: LLM = llm
|
||||||
self.model_name: str = llm.model_name
|
self.model_name: str = llm.model_name
|
||||||
self.logger = logger or structlog.get_logger()
|
self.logger = logger or structlog.get_logger()
|
||||||
if filename:
|
if filename:
|
||||||
self.read_transcript_from_file(filename)
|
self.read_transcript_from_file(filename)
|
||||||
|
|
||||||
Settings.llm = OpenAILike(
|
|
||||||
model=llm.model_name,
|
|
||||||
api_base=llm.url,
|
|
||||||
api_key=llm.api_key,
|
|
||||||
context_window=settings.SUMMARY_LLM_CONTEXT_SIZE_TOKENS,
|
|
||||||
is_chat_model=True,
|
|
||||||
is_function_calling_model=llm.has_structured_output,
|
|
||||||
temperature=llm.temperature,
|
|
||||||
max_tokens=llm.max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_transcript_from_file(self, filename: str) -> None:
|
def read_transcript_from_file(self, filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Load a transcript from a text file.
|
Load a transcript from a text file.
|
||||||
@@ -202,40 +185,16 @@ class SummaryBuilder:
|
|||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
|
|
||||||
def set_llm_instance(self, llm: LLM) -> None:
|
def set_llm_instance(self, llm: LLM) -> None:
|
||||||
self.llm_instance = llm
|
self.llm = llm
|
||||||
|
|
||||||
async def _get_structured_response(
|
async def _get_structured_response(
|
||||||
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
|
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
|
||||||
) -> Type[T]:
|
) -> T:
|
||||||
"""Generic function to get structured output from LLM for non-function-calling models."""
|
"""Generic function to get structured output from LLM for non-function-calling models."""
|
||||||
# First, use TreeSummarize to get the response
|
return await self.llm.get_structured_response(
|
||||||
summarizer = TreeSummarize(verbose=True)
|
prompt, [self.transcript], output_cls, tone_name=tone_name
|
||||||
|
|
||||||
response = await summarizer.aget_response(
|
|
||||||
prompt, [self.transcript], tone_name=tone_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then, use PydanticOutputParser to structure the response
|
|
||||||
output_parser = PydanticOutputParser(output_cls)
|
|
||||||
|
|
||||||
prompt_template_str = STRUCTURED_RESPONSE_PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
program = LLMTextCompletionProgram.from_defaults(
|
|
||||||
output_parser=output_parser,
|
|
||||||
prompt_template_str=prompt_template_str,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
format_instructions = output_parser.format(
|
|
||||||
"Please structure the above information in the following JSON format:"
|
|
||||||
)
|
|
||||||
|
|
||||||
output = await program.acall(
|
|
||||||
analysis=str(response), format_instructions=format_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
# Participants
|
# Participants
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
@@ -354,19 +313,18 @@ class SummaryBuilder:
|
|||||||
async def generate_subject_summaries(self) -> None:
|
async def generate_subject_summaries(self) -> None:
|
||||||
"""Generate detailed summaries for each extracted subject."""
|
"""Generate detailed summaries for each extracted subject."""
|
||||||
assert self.transcript is not None
|
assert self.transcript is not None
|
||||||
summarizer = TreeSummarize(verbose=False)
|
|
||||||
summaries = []
|
summaries = []
|
||||||
|
|
||||||
for subject in self.subjects:
|
for subject in self.subjects:
|
||||||
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
|
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
|
||||||
|
|
||||||
detailed_response = await summarizer.aget_response(
|
detailed_response = await self.llm.get_response(
|
||||||
detailed_prompt, [self.transcript], tone_name="Topic assistant"
|
detailed_prompt, [self.transcript], tone_name="Topic assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
|
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
|
||||||
|
|
||||||
paragraph_response = await summarizer.aget_response(
|
paragraph_response = await self.llm.get_response(
|
||||||
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
|
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -377,7 +335,6 @@ class SummaryBuilder:
|
|||||||
|
|
||||||
async def generate_recap(self) -> None:
|
async def generate_recap(self) -> None:
|
||||||
"""Generate a quick recap from the subject summaries."""
|
"""Generate a quick recap from the subject summaries."""
|
||||||
summarizer = TreeSummarize(verbose=True)
|
|
||||||
|
|
||||||
summaries_text = "\n\n".join(
|
summaries_text = "\n\n".join(
|
||||||
[
|
[
|
||||||
@@ -388,7 +345,7 @@ class SummaryBuilder:
|
|||||||
|
|
||||||
recap_prompt = RECAP_PROMPT
|
recap_prompt = RECAP_PROMPT
|
||||||
|
|
||||||
recap_response = await summarizer.aget_response(
|
recap_response = await self.llm.get_response(
|
||||||
recap_prompt, [summaries_text], tone_name="Recap summarizer"
|
recap_prompt, [summaries_text], tone_name="Recap summarizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -483,7 +440,7 @@ if __name__ == "__main__":
|
|||||||
async def main():
|
async def main():
|
||||||
# build the summary
|
# build the summary
|
||||||
|
|
||||||
llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
llm = LLM(settings=settings)
|
||||||
sm = SummaryBuilder(llm=llm, filename=args.transcript)
|
sm = SummaryBuilder(llm=llm, filename=args.transcript)
|
||||||
|
|
||||||
if args.subjects:
|
if args.subjects:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from reflector.llm.openai_llm import OpenAILLM
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.summary.summary_builder import SummaryBuilder
|
from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
||||||
@@ -17,7 +17,7 @@ class TranscriptFinalSummaryProcessor(Processor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
self.chunks: list[TitleSummary] = []
|
self.chunks: list[TitleSummary] = []
|
||||||
self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
self.llm = LLM(settings=settings)
|
||||||
self.builder = None
|
self.builder = None
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
async def _push(self, data: TitleSummary):
|
||||||
|
|||||||
@@ -1,67 +1,72 @@
|
|||||||
from reflector.llm import LLM, LLMTaskParams
|
from textwrap import dedent
|
||||||
|
|
||||||
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import FinalTitle, TitleSummary
|
from reflector.processors.types import FinalTitle, TitleSummary
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
TITLE_PROMPT = dedent(
|
||||||
|
"""
|
||||||
|
Generate a concise title for this meeting based on the following topic titles.
|
||||||
|
Ignore casual conversation, greetings, or administrative matters.
|
||||||
|
|
||||||
|
The title must:
|
||||||
|
- Be maximum 10 words
|
||||||
|
- Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
|
||||||
|
- Avoid generic terms like "Team Meeting" or "Discussion"
|
||||||
|
|
||||||
|
If multiple unrelated topics were discussed, prioritize the most significant one.
|
||||||
|
or create a compound title (e.g., "Product Launch and Budget Planning").
|
||||||
|
|
||||||
|
<topics_discussed>
|
||||||
|
{titles}
|
||||||
|
</topics_discussed>
|
||||||
|
|
||||||
|
Do not explain, just output the meeting title as a single line.
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
class TranscriptFinalTitleProcessor(Processor):
|
class TranscriptFinalTitleProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Assemble all summary into a line-based json
|
Generate a final title from topic titles using LlamaIndex
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = TitleSummary
|
INPUT_TYPE = TitleSummary
|
||||||
OUTPUT_TYPE = FinalTitle
|
OUTPUT_TYPE = FinalTitle
|
||||||
TASK = "final_title"
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.chunks: list[TitleSummary] = []
|
self.chunks: list[TitleSummary] = []
|
||||||
self.llm = LLM.get_instance()
|
self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
|
||||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
async def _push(self, data: TitleSummary):
|
||||||
self.chunks.append(data)
|
self.chunks.append(data)
|
||||||
|
|
||||||
async def get_title(self, text: str) -> dict:
|
async def get_title(self, accumulated_titles: str) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a title for the whole recording
|
Generate a title for the whole recording using LLM
|
||||||
"""
|
"""
|
||||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
prompt = TITLE_PROMPT.format(titles=accumulated_titles)
|
||||||
|
response = await self.llm.get_response(
|
||||||
|
prompt,
|
||||||
|
[accumulated_titles],
|
||||||
|
tone_name="Title generator",
|
||||||
|
)
|
||||||
|
|
||||||
if len(chunks) == 1:
|
self.logger.info(f"Generated title response: {response}")
|
||||||
chunk = chunks[0]
|
|
||||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
|
||||||
title_result = await self.llm.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
return title_result
|
|
||||||
else:
|
|
||||||
accumulated_titles = ""
|
|
||||||
for chunk in chunks:
|
|
||||||
prompt = self.llm.create_prompt(
|
|
||||||
instruct=self.params.instruct, text=chunk
|
|
||||||
)
|
|
||||||
title_result = await self.llm.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
accumulated_titles += title_result["title"]
|
|
||||||
|
|
||||||
return await self.get_title(accumulated_titles)
|
return response
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
self.logger.warning("No summary to output")
|
self.logger.warning("No summary to output")
|
||||||
return
|
return
|
||||||
|
|
||||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
|
||||||
title_result = await self.get_title(accumulated_titles)
|
title = await self.get_title(accumulated_titles)
|
||||||
final_title = self.llm.trim_title(title_result["title"])
|
title = clean_title(title)
|
||||||
final_title = self.llm.ensure_casing(final_title)
|
|
||||||
|
|
||||||
final_title = FinalTitle(title=final_title)
|
final_title = FinalTitle(title=title)
|
||||||
await self.emit(final_title)
|
await self.emit(final_title)
|
||||||
|
|||||||
@@ -1,7 +1,41 @@
|
|||||||
from reflector.llm import LLM, LLMTaskParams
|
from textwrap import dedent
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import TitleSummary, Transcript
|
from reflector.processors.types import TitleSummary, Transcript
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
TOPIC_PROMPT = dedent(
|
||||||
|
"""
|
||||||
|
Analyze the following transcript segment and extract the main topic being discussed.
|
||||||
|
Focus on the substantive content and ignore small talk or administrative chatter.
|
||||||
|
|
||||||
|
Create a title that:
|
||||||
|
- Captures the specific subject matter being discussed
|
||||||
|
- Is descriptive and self-explanatory
|
||||||
|
- Uses professional language
|
||||||
|
- Is specific rather than generic
|
||||||
|
|
||||||
|
For the summary:
|
||||||
|
- Summarize the key points in maximum two sentences
|
||||||
|
- Focus on what was discussed, decided, or accomplished
|
||||||
|
- Be concise but informative
|
||||||
|
|
||||||
|
<transcript>
|
||||||
|
{text}
|
||||||
|
</transcript>
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
|
class TopicResponse(BaseModel):
|
||||||
|
"""Structured response for topic detection"""
|
||||||
|
|
||||||
|
title: str = Field(description="A descriptive title for the topic being discussed")
|
||||||
|
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopicDetectorProcessor(Processor):
|
class TranscriptTopicDetectorProcessor(Processor):
|
||||||
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
|
|
||||||
INPUT_TYPE = Transcript
|
INPUT_TYPE = Transcript
|
||||||
OUTPUT_TYPE = TitleSummary
|
OUTPUT_TYPE = TitleSummary
|
||||||
TASK = "topic"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
||||||
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.min_transcript_length = min_transcript_length
|
self.min_transcript_length = min_transcript_length
|
||||||
self.llm = LLM.get_instance()
|
self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
|
||||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
|
||||||
|
|
||||||
async def _push(self, data: Transcript):
|
async def _push(self, data: Transcript):
|
||||||
if self.transcript is None:
|
if self.transcript is None:
|
||||||
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
return
|
return
|
||||||
await self.flush()
|
await self.flush()
|
||||||
|
|
||||||
async def get_topic(self, text: str) -> dict:
|
async def get_topic(self, text: str) -> TopicResponse:
|
||||||
"""
|
"""
|
||||||
Generate a topic and description for a transcription excerpt
|
Generate a topic and description for a transcription excerpt using LLM
|
||||||
"""
|
"""
|
||||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
prompt = TOPIC_PROMPT.format(text=text)
|
||||||
topic_result = await self.llm.generate(
|
response = await self.llm.get_structured_response(
|
||||||
prompt=prompt,
|
prompt, [text], TopicResponse, tone_name="Topic analyzer"
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
)
|
||||||
return topic_result
|
return response
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if not self.transcript:
|
if not self.transcript:
|
||||||
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
|
|
||||||
text = self.transcript.text
|
text = self.transcript.text
|
||||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||||
|
|
||||||
topic_result = await self.get_topic(text=text)
|
topic_result = await self.get_topic(text=text)
|
||||||
title = self.llm.trim_title(topic_result["title"])
|
title = clean_title(topic_result.title)
|
||||||
title = self.llm.ensure_casing(title)
|
|
||||||
|
|
||||||
summary = TitleSummary(
|
summary = TitleSummary(
|
||||||
title=title,
|
title=title,
|
||||||
summary=topic_result["summary"],
|
summary=topic_result.summary,
|
||||||
timestamp=self.transcript.timestamp,
|
timestamp=self.transcript.timestamp,
|
||||||
duration=self.transcript.duration,
|
duration=self.transcript.duration,
|
||||||
transcript=self.transcript,
|
transcript=self.transcript,
|
||||||
|
|||||||
@@ -1,9 +1,5 @@
|
|||||||
import httpx
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import Transcript, TranslationLanguages
|
from reflector.processors.types import Transcript
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTranslatorProcessor(Processor):
|
class TranscriptTranslatorProcessor(Processor):
|
||||||
@@ -13,61 +9,27 @@ class TranscriptTranslatorProcessor(Processor):
|
|||||||
|
|
||||||
INPUT_TYPE = Transcript
|
INPUT_TYPE = Transcript
|
||||||
OUTPUT_TYPE = Transcript
|
OUTPUT_TYPE = Transcript
|
||||||
TASK = "translate"
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.translate_url = settings.TRANSLATE_URL
|
|
||||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
|
||||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
|
||||||
|
|
||||||
async def _push(self, data: Transcript):
|
async def _push(self, data: Transcript):
|
||||||
self.transcript = data
|
self.transcript = data
|
||||||
await self.flush()
|
await self.flush()
|
||||||
|
|
||||||
async def get_translation(self, text: str) -> str | None:
|
async def _translate(self, text: str) -> str | None:
|
||||||
# FIXME this should be a processor after, as each user may want
|
raise NotImplementedError
|
||||||
# different languages
|
|
||||||
|
|
||||||
source_language = self.get_pref("audio:source_language", "en")
|
|
||||||
target_language = self.get_pref("audio:target_language", "en")
|
|
||||||
if source_language == target_language:
|
|
||||||
return
|
|
||||||
|
|
||||||
languages = TranslationLanguages()
|
|
||||||
# Only way to set the target should be the UI element like dropdown.
|
|
||||||
# Hence, this assert should never fail.
|
|
||||||
assert languages.is_supported(target_language)
|
|
||||||
self.logger.debug(f"Try to translate {text=}")
|
|
||||||
json_payload = {
|
|
||||||
"text": text,
|
|
||||||
"source_language": source_language,
|
|
||||||
"target_language": target_language,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await retry(client.post)(
|
|
||||||
self.translate_url + "/translate",
|
|
||||||
headers=self.headers,
|
|
||||||
params=json_payload,
|
|
||||||
timeout=self.timeout,
|
|
||||||
follow_redirects=True,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()["text"]
|
|
||||||
|
|
||||||
# Sanity check for translation status in the result
|
|
||||||
if target_language in result:
|
|
||||||
translation = result[target_language]
|
|
||||||
self.logger.debug(f"Translation response: {text=}, {translation=}")
|
|
||||||
return translation
|
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if not self.transcript:
|
if not self.transcript:
|
||||||
return
|
return
|
||||||
self.transcript.translation = await self.get_translation(
|
|
||||||
text=self.transcript.text
|
source_language = self.get_pref("audio:source_language", "en")
|
||||||
)
|
target_language = self.get_pref("audio:target_language", "en")
|
||||||
|
if source_language == target_language:
|
||||||
|
self.transcript.translation = None
|
||||||
|
else:
|
||||||
|
self.transcript.translation = await self._translate(self.transcript.text)
|
||||||
|
|
||||||
await self.emit(self.transcript)
|
await self.emit(self.transcript)
|
||||||
|
|||||||
32
server/reflector/processors/transcript_translator_auto.py
Normal file
32
server/reflector/processors/transcript_translator_auto.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptTranslatorAutoProcessor(TranscriptTranslatorProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.TRANSLATION_BACKEND
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.transcript_translator_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `TRANSLATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "TRANSLATION_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
66
server/reflector/processors/transcript_translator_modal.py
Normal file
66
server/reflector/processors/transcript_translator_modal.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import httpx
|
||||||
|
|
||||||
|
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||||
|
from reflector.processors.transcript_translator_auto import (
|
||||||
|
TranscriptTranslatorAutoProcessor,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import TranslationLanguages
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.retry import retry
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptTranslatorModalProcessor(TranscriptTranslatorProcessor):
|
||||||
|
"""
|
||||||
|
Translate the transcript into the target language using Modal.com
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not settings.TRANSLATE_URL:
|
||||||
|
raise Exception(
|
||||||
|
"TRANSLATE_URL is required for TranscriptTranslatorModalProcessor"
|
||||||
|
)
|
||||||
|
self.translate_url = settings.TRANSLATE_URL
|
||||||
|
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key
|
||||||
|
self.headers = {}
|
||||||
|
if self.modal_api_key:
|
||||||
|
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||||
|
|
||||||
|
async def _translate(self, text: str) -> str | None:
|
||||||
|
source_language = self.get_pref("audio:source_language", "en")
|
||||||
|
target_language = self.get_pref("audio:target_language", "en")
|
||||||
|
|
||||||
|
languages = TranslationLanguages()
|
||||||
|
# Only way to set the target should be the UI element like dropdown.
|
||||||
|
# Hence, this assert should never fail.
|
||||||
|
assert languages.is_supported(target_language)
|
||||||
|
self.logger.debug(f"Try to translate {text=}")
|
||||||
|
json_payload = {
|
||||||
|
"text": text,
|
||||||
|
"source_language": source_language,
|
||||||
|
"target_language": target_language,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await retry(client.post)(
|
||||||
|
self.translate_url + "/translate",
|
||||||
|
headers=self.headers,
|
||||||
|
params=json_payload,
|
||||||
|
timeout=self.timeout,
|
||||||
|
follow_redirects=True,
|
||||||
|
logger=self.logger,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()["text"]
|
||||||
|
|
||||||
|
# Sanity check for translation status in the result
|
||||||
|
if target_language in result:
|
||||||
|
translation = result[target_language]
|
||||||
|
else:
|
||||||
|
translation = None
|
||||||
|
self.logger.debug(f"Translation response: {text=}, {translation=}")
|
||||||
|
return translation
|
||||||
|
|
||||||
|
|
||||||
|
TranscriptTranslatorAutoProcessor.register("modal", TranscriptTranslatorModalProcessor)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||||
|
from reflector.processors.transcript_translator_auto import (
|
||||||
|
TranscriptTranslatorAutoProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptTranslatorPassthroughProcessor(TranscriptTranslatorProcessor):
|
||||||
|
async def _translate(self, text: str) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
TranscriptTranslatorAutoProcessor.register(
|
||||||
|
"passthrough", TranscriptTranslatorPassthroughProcessor
|
||||||
|
)
|
||||||
@@ -9,13 +9,14 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
# CORS
|
||||||
|
UI_BASE_URL: str = "http://localhost:3000"
|
||||||
CORS_ORIGIN: str = "*"
|
CORS_ORIGIN: str = "*"
|
||||||
CORS_ALLOW_CREDENTIALS: bool = False
|
CORS_ALLOW_CREDENTIALS: bool = False
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||||
|
|
||||||
# local data directory (audio for no)
|
# local data directory
|
||||||
DATA_DIR: str = "./data"
|
DATA_DIR: str = "./data"
|
||||||
|
|
||||||
# Audio Transcription
|
# Audio Transcription
|
||||||
@@ -24,11 +25,7 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_URL: str | None = None
|
TRANSCRIPT_URL: str | None = None
|
||||||
TRANSCRIPT_TIMEOUT: int = 90
|
TRANSCRIPT_TIMEOUT: int = 90
|
||||||
|
|
||||||
# Translate into the target language
|
# Audio Transcription: modal backend
|
||||||
TRANSLATE_URL: str | None = None
|
|
||||||
TRANSLATE_TIMEOUT: int = 90
|
|
||||||
|
|
||||||
# Audio transcription modal.com configuration
|
|
||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
# Audio transcription storage
|
# Audio transcription storage
|
||||||
@@ -40,37 +37,28 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||||
|
|
||||||
|
# Translate into the target language
|
||||||
|
TRANSLATION_BACKEND: str = "passthrough"
|
||||||
|
TRANSLATE_URL: str | None = None
|
||||||
|
TRANSLATE_TIMEOUT: int = 90
|
||||||
|
|
||||||
|
# Translation: modal backend
|
||||||
|
TRANSLATE_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
# available backend: openai, modal
|
LLM_MODEL: str = "microsoft/phi-4"
|
||||||
LLM_BACKEND: str = "modal"
|
|
||||||
|
|
||||||
# LLM common configuration
|
|
||||||
LLM_URL: str | None = None
|
LLM_URL: str | None = None
|
||||||
LLM_HOST: str = "localhost"
|
LLM_API_KEY: str | None = None
|
||||||
LLM_PORT: int = 7860
|
LLM_CONTEXT_WINDOW: int = 16000
|
||||||
LLM_OPENAI_KEY: str | None = None
|
|
||||||
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
|
||||||
LLM_OPENAI_TEMPERATURE: float = 0.7
|
|
||||||
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
|
|
||||||
LLM_MAX_TOKENS: int = 1024
|
|
||||||
LLM_TEMPERATURE: float = 0.7
|
|
||||||
ZEPHYR_LLM_URL: str | None = None
|
|
||||||
HERMES_3_8B_LLM_URL: str | None = None
|
|
||||||
|
|
||||||
# LLM Modal configuration
|
|
||||||
LLM_MODAL_API_KEY: str | None = None
|
|
||||||
|
|
||||||
# per-task cases
|
|
||||||
SUMMARY_MODEL: str = "monadical/private/smart"
|
|
||||||
SUMMARY_LLM_URL: str | None = None
|
|
||||||
SUMMARY_LLM_API_KEY: str | None = None
|
|
||||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
|
|
||||||
|
|
||||||
# Diarization
|
# Diarization
|
||||||
DIARIZATION_ENABLED: bool = True
|
DIARIZATION_ENABLED: bool = True
|
||||||
DIARIZATION_BACKEND: str = "modal"
|
DIARIZATION_BACKEND: str = "modal"
|
||||||
DIARIZATION_URL: str | None = None
|
DIARIZATION_URL: str | None = None
|
||||||
|
|
||||||
|
# Diarization: modal backend
|
||||||
|
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
SENTRY_DSN: str | None = None
|
SENTRY_DSN: str | None = None
|
||||||
|
|
||||||
@@ -86,12 +74,6 @@ class Settings(BaseSettings):
|
|||||||
# if set, all anonymous record will be public
|
# if set, all anonymous record will be public
|
||||||
PUBLIC_MODE: bool = False
|
PUBLIC_MODE: bool = False
|
||||||
|
|
||||||
# Default LLM model name
|
|
||||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
|
||||||
|
|
||||||
# Cache directory for all model storage
|
|
||||||
CACHE_DIR: str = "./data"
|
|
||||||
|
|
||||||
# Min transcript length to generate topic + summary
|
# Min transcript length to generate topic + summary
|
||||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||||
|
|
||||||
@@ -116,24 +98,20 @@ class Settings(BaseSettings):
|
|||||||
# Healthcheck
|
# Healthcheck
|
||||||
HEALTHCHECK_URL: str | None = None
|
HEALTHCHECK_URL: str | None = None
|
||||||
|
|
||||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
# Whereby integration
|
||||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
|
||||||
|
|
||||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||||
|
|
||||||
WHEREBY_API_KEY: str | None = None
|
WHEREBY_API_KEY: str | None = None
|
||||||
|
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||||
AWS_WHEREBY_S3_BUCKET: str | None = None
|
AWS_WHEREBY_S3_BUCKET: str | None = None
|
||||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||||
|
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||||
|
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||||
|
|
||||||
|
# Zulip integration
|
||||||
ZULIP_REALM: str | None = None
|
ZULIP_REALM: str | None = None
|
||||||
ZULIP_API_KEY: str | None = None
|
ZULIP_API_KEY: str | None = None
|
||||||
ZULIP_BOT_EMAIL: str | None = None
|
ZULIP_BOT_EMAIL: str | None = None
|
||||||
|
|
||||||
UI_BASE_URL: str = "http://localhost:3000"
|
|
||||||
|
|
||||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from reflector.processors import (
|
|||||||
TranscriptFinalTitleProcessor,
|
TranscriptFinalTitleProcessor,
|
||||||
TranscriptLinerProcessor,
|
TranscriptLinerProcessor,
|
||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorProcessor,
|
TranscriptTranslatorAutoProcessor,
|
||||||
)
|
)
|
||||||
from reflector.processors.base import BroadcastProcessor
|
from reflector.processors.base import BroadcastProcessor
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ async def process_audio_file(
|
|||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(),
|
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||||
]
|
]
|
||||||
if not only_transcript:
|
if not only_transcript:
|
||||||
processors += [
|
processors += [
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from reflector.processors import (
|
|||||||
TranscriptFinalTitleProcessor,
|
TranscriptFinalTitleProcessor,
|
||||||
TranscriptLinerProcessor,
|
TranscriptLinerProcessor,
|
||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorProcessor,
|
TranscriptTranslatorAutoProcessor,
|
||||||
)
|
)
|
||||||
from reflector.processors.base import BroadcastProcessor, Processor
|
from reflector.processors.base import BroadcastProcessor, Processor
|
||||||
from reflector.processors.types import (
|
from reflector.processors.types import (
|
||||||
@@ -103,7 +103,7 @@ async def process_audio_file_with_diarization(
|
|||||||
|
|
||||||
processors += [
|
processors += [
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(),
|
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not only_transcript:
|
if not only_transcript:
|
||||||
|
|||||||
33
server/reflector/utils/text.py
Normal file
33
server/reflector/utils/text.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
def clean_title(title: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and format a title string for consistent capitalization.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Strip surrounding quotes (single or double)
|
||||||
|
- Capitalize the first word
|
||||||
|
- Capitalize words longer than 3 characters
|
||||||
|
- Keep words with 3 or fewer characters lowercase (except first word)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: The title string to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cleaned title with consistent capitalization
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> clean_title("hello world")
|
||||||
|
"Hello World"
|
||||||
|
>>> clean_title("meeting with the team")
|
||||||
|
"Meeting With the Team"
|
||||||
|
>>> clean_title("'Title with quotes'")
|
||||||
|
"Title With Quotes"
|
||||||
|
"""
|
||||||
|
title = title.strip("\"'")
|
||||||
|
words = title.split()
|
||||||
|
if words:
|
||||||
|
words = [
|
||||||
|
word.capitalize() if i == 0 or len(word) > 3 else word.lower()
|
||||||
|
for i, word in enumerate(words)
|
||||||
|
]
|
||||||
|
title = " ".join(words)
|
||||||
|
return title
|
||||||
@@ -7,14 +7,10 @@ import pytest
|
|||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def setup_database():
|
async def setup_database():
|
||||||
from reflector.settings import settings
|
from reflector.db import engine, metadata # noqa
|
||||||
|
|
||||||
with NamedTemporaryFile() as f:
|
|
||||||
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
|
||||||
from reflector.db import engine, metadata
|
|
||||||
|
|
||||||
|
metadata.drop_all(bind=engine)
|
||||||
metadata.create_all(bind=engine)
|
metadata.create_all(bind=engine)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -33,17 +29,16 @@ def dummy_processors():
|
|||||||
patch(
|
patch(
|
||||||
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
|
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
|
||||||
) as mock_short_summary,
|
) as mock_short_summary,
|
||||||
patch(
|
|
||||||
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
|
|
||||||
) as mock_translate,
|
|
||||||
):
|
):
|
||||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
from reflector.processors.transcript_topic_detector import TopicResponse
|
||||||
mock_title.return_value = {"title": "LLM TITLE"}
|
|
||||||
|
mock_topic.return_value = TopicResponse(
|
||||||
|
title="LLM TITLE", summary="LLM SUMMARY"
|
||||||
|
)
|
||||||
|
mock_title.return_value = "LLM Title"
|
||||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
||||||
mock_translate.return_value = "Bonjour le monde"
|
|
||||||
yield (
|
yield (
|
||||||
mock_translate,
|
|
||||||
mock_topic,
|
mock_topic,
|
||||||
mock_title,
|
mock_title,
|
||||||
mock_long_summary,
|
mock_long_summary,
|
||||||
@@ -101,16 +96,38 @@ async def dummy_diarization():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_transcript_translator():
|
||||||
|
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||||
|
|
||||||
|
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
|
||||||
|
async def _translate(self, text: str) -> str:
|
||||||
|
source_language = self.get_pref("audio:source_language", "en")
|
||||||
|
target_language = self.get_pref("audio:target_language", "en")
|
||||||
|
return f"{source_language}:{target_language}:{text}"
|
||||||
|
|
||||||
|
def mock_new(cls, *args, **kwargs):
|
||||||
|
return TestTranscriptTranslatorProcessor(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.transcript_translator_auto"
|
||||||
|
".TranscriptTranslatorAutoProcessor.__new__",
|
||||||
|
mock_new,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_llm():
|
async def dummy_llm():
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm import LLM
|
||||||
|
|
||||||
class TestLLM(LLM):
|
class TestLLM(LLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_name = "DUMMY MODEL"
|
self.model_name = "DUMMY MODEL"
|
||||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||||
|
|
||||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
# LLM doesn't have get_instance anymore, mocking constructor instead
|
||||||
|
with patch("reflector.llm.LLM") as mock_llm:
|
||||||
mock_llm.return_value = TestLLM()
|
mock_llm.return_value = TestLLM()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -129,22 +146,19 @@ async def dummy_storage():
|
|||||||
async def _get_file_url(self, *args, **kwargs):
|
async def _get_file_url(self, *args, **kwargs):
|
||||||
return "http://fake_server/audio.mp3"
|
return "http://fake_server/audio.mp3"
|
||||||
|
|
||||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
async def _get_file(self, *args, **kwargs):
|
||||||
mock_storage.return_value = DummyStorage()
|
from pathlib import Path
|
||||||
yield
|
|
||||||
|
|
||||||
|
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||||
|
return test_mp3.read_bytes()
|
||||||
|
|
||||||
@pytest.fixture
|
dummy = DummyStorage()
|
||||||
def nltk():
|
with (
|
||||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
||||||
mock_nltk.return_value = "NLTK PACKAGE"
|
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
||||||
yield
|
):
|
||||||
|
mock_storage.return_value = dummy
|
||||||
|
mock_get_transcripts.return_value = dummy
|
||||||
@pytest.fixture
|
|
||||||
def ensure_casing():
|
|
||||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
|
||||||
mock_casing.return_value = "LLM TITLE"
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_processor_broadcast(nltk):
|
async def test_processor_broadcast():
|
||||||
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
||||||
|
|
||||||
class TestProcessor(Processor):
|
class TestProcessor(Processor):
|
||||||
|
|||||||
@@ -3,11 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_process(
|
async def test_basic_process(
|
||||||
nltk,
|
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
ensure_casing,
|
|
||||||
):
|
):
|
||||||
# goal is to start the server, and send rtc audio to it
|
# goal is to start the server, and send rtc audio to it
|
||||||
# validate the events received
|
# validate the events received
|
||||||
@@ -16,8 +14,8 @@ async def test_basic_process(
|
|||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.tools.process import process_audio_file
|
from reflector.tools.process import process_audio_file
|
||||||
|
|
||||||
# use an LLM test backend
|
# LLM_BACKEND no longer exists in settings
|
||||||
settings.LLM_BACKEND = "test"
|
# settings.LLM_BACKEND = "test"
|
||||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||||
|
|
||||||
# event callback
|
# event callback
|
||||||
@@ -35,7 +33,7 @@ async def test_basic_process(
|
|||||||
|
|
||||||
# validate the events
|
# validate the events
|
||||||
assert marks["TranscriptLinerProcessor"] == 1
|
assert marks["TranscriptLinerProcessor"] == 1
|
||||||
assert marks["TranscriptTranslatorProcessor"] == 1
|
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
|
||||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_process(
|
async def test_transcript_process(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
ensure_casing,
|
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
@@ -69,7 +68,7 @@ async def test_transcript_process(
|
|||||||
transcript = resp.json()
|
transcript = resp.json()
|
||||||
assert transcript["status"] == "ended"
|
assert transcript["status"] == "ended"
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||||
assert transcript["title"] == "LLM TITLE"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||||
|
|||||||
@@ -67,10 +67,9 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
|
dummy_transcript_translator,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
|
||||||
nltk,
|
|
||||||
appserver,
|
appserver,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
@@ -166,7 +165,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert "TRANSCRIPT" in eventnames
|
assert "TRANSCRIPT" in eventnames
|
||||||
ev = events[eventnames.index("TRANSCRIPT")]
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
assert ev["data"]["text"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
assert ev["data"]["translation"] == "Bonjour le monde"
|
assert ev["data"]["translation"] is None
|
||||||
|
|
||||||
assert "TOPIC" in eventnames
|
assert "TOPIC" in eventnames
|
||||||
ev = events[eventnames.index("TOPIC")]
|
ev = events[eventnames.index("TOPIC")]
|
||||||
@@ -185,7 +184,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
assert "FINAL_TITLE" in eventnames
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
assert ev["data"]["title"] == "LLM TITLE"
|
assert ev["data"]["title"] == "Llm Title"
|
||||||
|
|
||||||
assert "WAVEFORM" in eventnames
|
assert "WAVEFORM" in eventnames
|
||||||
ev = events[eventnames.index("WAVEFORM")]
|
ev = events[eventnames.index("WAVEFORM")]
|
||||||
@@ -226,10 +225,9 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
|
dummy_transcript_translator,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
|
||||||
nltk,
|
|
||||||
appserver,
|
appserver,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
@@ -334,7 +332,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
assert "TRANSCRIPT" in eventnames
|
assert "TRANSCRIPT" in eventnames
|
||||||
ev = events[eventnames.index("TRANSCRIPT")]
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
assert ev["data"]["text"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
assert ev["data"]["translation"] == "Bonjour le monde"
|
assert ev["data"]["translation"] == "en:fr:Hello world."
|
||||||
|
|
||||||
assert "TOPIC" in eventnames
|
assert "TOPIC" in eventnames
|
||||||
ev = events[eventnames.index("TOPIC")]
|
ev = events[eventnames.index("TOPIC")]
|
||||||
@@ -353,7 +351,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
assert "FINAL_TITLE" in eventnames
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
assert ev["data"]["title"] == "LLM TITLE"
|
assert ev["data"]["title"] == "Llm Title"
|
||||||
|
|
||||||
# check status order
|
# check status order
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_upload_file(
|
async def test_transcript_upload_file(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
ensure_casing,
|
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
|
|||||||
transcript = resp.json()
|
transcript = resp.json()
|
||||||
assert transcript["status"] == "ended"
|
assert transcript["status"] == "ended"
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||||
assert transcript["title"] == "LLM TITLE"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||||
|
|||||||
21
server/tests/test_utils_text.py
Normal file
21
server/tests/test_utils_text.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_title,expected",
|
||||||
|
[
|
||||||
|
("hello world", "Hello World"),
|
||||||
|
("HELLO WORLD", "Hello World"),
|
||||||
|
("hello WORLD", "Hello World"),
|
||||||
|
("the quick brown fox", "The Quick Brown fox"),
|
||||||
|
("discussion about API design", "Discussion About api Design"),
|
||||||
|
("Q1 2024 budget review", "Q1 2024 Budget Review"),
|
||||||
|
("'Title with quotes'", "Title With Quotes"),
|
||||||
|
("'title with quotes'", "Title With Quotes"),
|
||||||
|
("MiXeD CaSe WoRdS", "Mixed Case Words"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clean_title(input_title, expected):
|
||||||
|
assert clean_title(input_title) == expected
|
||||||
14
server/uv.lock
generated
14
server/uv.lock
generated
@@ -2428,6 +2428,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/57/79/9dae84c244dabebca6a952e098d6ac9d13719b701fc5323ba6d00abc675a/pytest_docker_tools-3.1.9-py2.py3-none-any.whl", hash = "sha256:36f8e88d56d84ea177df68a175673681243dd991d2807fbf551d90f60341bfdb", size = 29268, upload-time = "2025-03-16T13:48:22.184Z" },
|
{ url = "https://files.pythonhosted.org/packages/57/79/9dae84c244dabebca6a952e098d6ac9d13719b701fc5323ba6d00abc675a/pytest_docker_tools-3.1.9-py2.py3-none-any.whl", hash = "sha256:36f8e88d56d84ea177df68a175673681243dd991d2807fbf551d90f60341bfdb", size = 29268, upload-time = "2025-03-16T13:48:22.184Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-env"
|
||||||
|
version = "1.1.5"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/1f/31/27f28431a16b83cab7a636dce59cf397517807d247caa38ee67d65e71ef8/pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf", size = 8911, upload-time = "2024-09-17T22:39:18.566Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/de/b8/87cfb16045c9d4092cfcf526135d73b88101aac83bc1adcf82dfb5fd3833/pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30", size = 6141, upload-time = "2024-09-17T22:39:16.942Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-httpx"
|
name = "pytest-httpx"
|
||||||
version = "0.34.0"
|
version = "0.34.0"
|
||||||
@@ -2636,6 +2648,7 @@ dependencies = [
|
|||||||
{ name = "protobuf" },
|
{ name = "protobuf" },
|
||||||
{ name = "psycopg2-binary" },
|
{ name = "psycopg2-binary" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
|
{ name = "pytest-env" },
|
||||||
{ name = "python-jose", extra = ["cryptography"] },
|
{ name = "python-jose", extra = ["cryptography"] },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "redis" },
|
{ name = "redis" },
|
||||||
@@ -2699,6 +2712,7 @@ requires-dist = [
|
|||||||
{ name = "protobuf", specifier = ">=4.24.3" },
|
{ name = "protobuf", specifier = ">=4.24.3" },
|
||||||
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
||||||
{ name = "pydantic-settings", specifier = ">=2.0.2" },
|
{ name = "pydantic-settings", specifier = ">=2.0.2" },
|
||||||
|
{ name = "pytest-env", specifier = ">=1.1.5" },
|
||||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||||
{ name = "redis", specifier = ">=5.0.1" },
|
{ name = "redis", specifier = ">=5.0.1" },
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
# Chakra UI v3 Migration - Remaining Tasks
|
|
||||||
|
|
||||||
## Completed
|
|
||||||
|
|
||||||
- ✅ Migrated from Chakra UI v2 to v3 in package.json
|
|
||||||
- ✅ Updated theme.ts with whiteAlpha color palette and semantic tokens
|
|
||||||
- ✅ Added button recipe with fontWeight 600 and hover states
|
|
||||||
- ✅ Moved Poppins font from theme to HTML tag className
|
|
||||||
- ✅ Fixed deprecated props across all files:
|
|
||||||
- ✅ `isDisabled` → `disabled` (all occurrences fixed)
|
|
||||||
- ✅ `isChecked` → `checked` (all occurrences fixed)
|
|
||||||
- ✅ `isLoading` → `loading` (all occurrences fixed)
|
|
||||||
- ✅ `isOpen` → `open` (all occurrences fixed)
|
|
||||||
- ✅ `noOfLines` → `lineClamp` (all occurrences fixed)
|
|
||||||
- ✅ `align` → `alignItems` on Flex/Stack components (all occurrences fixed)
|
|
||||||
- ✅ `justify` → `justifyContent` on Flex/Stack components (all occurrences fixed)
|
|
||||||
|
|
||||||
## Migration Summary
|
|
||||||
|
|
||||||
### Files Modified
|
|
||||||
|
|
||||||
1. **app/(app)/rooms/page.tsx**
|
|
||||||
|
|
||||||
- Fixed: isDisabled, isChecked, align, justify on multiple components
|
|
||||||
- Updated temporary Select component props
|
|
||||||
|
|
||||||
2. **app/(app)/transcripts/fileUploadButton.tsx**
|
|
||||||
|
|
||||||
- Fixed: isDisabled → disabled
|
|
||||||
|
|
||||||
3. **app/(app)/transcripts/shareZulip.tsx**
|
|
||||||
|
|
||||||
- Fixed: isDisabled → disabled
|
|
||||||
|
|
||||||
4. **app/(app)/transcripts/shareAndPrivacy.tsx**
|
|
||||||
|
|
||||||
- Fixed: isLoading → loading, isOpen → open
|
|
||||||
- Updated temporary Select component props
|
|
||||||
|
|
||||||
5. **app/(app)/browse/page.tsx**
|
|
||||||
|
|
||||||
- Fixed: isOpen → open, align → alignItems, justify → justifyContent
|
|
||||||
|
|
||||||
6. **app/(app)/transcripts/transcriptTitle.tsx**
|
|
||||||
|
|
||||||
- Fixed: noOfLines → lineClamp
|
|
||||||
|
|
||||||
7. **app/(app)/transcripts/[transcriptId]/correct/topicHeader.tsx**
|
|
||||||
|
|
||||||
- Fixed: noOfLines → lineClamp
|
|
||||||
|
|
||||||
8. **app/lib/expandableText.tsx**
|
|
||||||
|
|
||||||
- Fixed: noOfLines → lineClamp
|
|
||||||
|
|
||||||
9. **app/[roomName]/page.tsx**
|
|
||||||
|
|
||||||
- Fixed: align → alignItems, justify → justifyContent
|
|
||||||
|
|
||||||
10. **app/lib/WherebyWebinarEmbed.tsx**
|
|
||||||
- Fixed: align → alignItems, justify → justifyContent
|
|
||||||
|
|
||||||
## Other Potential Issues
|
|
||||||
|
|
||||||
1. Check for Modal/Dialog component imports and usage (currently using temporary replacements)
|
|
||||||
2. Review Select component usage (using temporary replacements)
|
|
||||||
3. Test button hover states for whiteAlpha color palette
|
|
||||||
4. Verify all color palettes work correctly with the new semantic tokens
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
After completing migrations:
|
|
||||||
|
|
||||||
1. Run `yarn dev` and check all pages
|
|
||||||
2. Test buttons with different color palettes
|
|
||||||
3. Verify disabled states work correctly
|
|
||||||
4. Check that text alignment and flex layouts are correct
|
|
||||||
5. Test modal/dialog functionality
|
|
||||||
|
|
||||||
## Next Steps
|
|
||||||
|
|
||||||
The Chakra UI v3 migration is now largely complete for deprecated props. The main remaining items are:
|
|
||||||
|
|
||||||
- Replace temporary Modal and Select components with proper Chakra v3 implementations
|
|
||||||
- Thorough testing of all UI components
|
|
||||||
- Performance optimization if needed
|
|
||||||
Reference in New Issue
Block a user