mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: use llamaindex everywhere (#525)
* feat: use llamaindex for transcript final title too * refactor: removed llm backend, replaced with one single class+llamaindex * refactor: self-review * fix: typing * fix: tests * refactor: extract clean_title and add tests * test: fix * test: remove ensure_casing/nltk * fix: tiny mistake
This commit is contained in:
@@ -46,38 +46,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
|||||||
## llm backend implementation
|
## llm backend implementation
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
|
||||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
|
||||||
LLM_BACKEND=modal
|
|
||||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
|
||||||
LLM_MODAL_API_KEY=
|
|
||||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
|
||||||
|
|
||||||
|
|
||||||
## Using OpenAI
|
|
||||||
#LLM_BACKEND=openai
|
|
||||||
#LLM_OPENAI_KEY=xxx
|
|
||||||
#LLM_OPENAI_MODEL=gpt-3.5-turbo
|
|
||||||
|
|
||||||
## Using GPT4ALL
|
|
||||||
#LLM_BACKEND=openai
|
|
||||||
#LLM_URL=http://localhost:4891/v1/completions
|
|
||||||
#LLM_OPENAI_MODEL="GPT4All Falcon"
|
|
||||||
|
|
||||||
## Default LLM MODEL NAME
|
|
||||||
#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
|
|
||||||
|
|
||||||
## Cache directory to store models
|
|
||||||
CACHE_DIR=data
|
|
||||||
|
|
||||||
## =======================================================
|
|
||||||
## Summary LLM configuration
|
|
||||||
## =======================================================
|
|
||||||
|
|
||||||
## Context size for summary generation (tokens)
|
## Context size for summary generation (tokens)
|
||||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS=16000
|
# LLM_MODEL=microsoft/phi-4
|
||||||
SUMMARY_LLM_URL=
|
LLM_CONTEXT_WINDOW=16000
|
||||||
SUMMARY_LLM_API_KEY=sk-
|
LLM_URL=
|
||||||
SUMMARY_MODEL=
|
LLM_API_KEY=sk-
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
## Diarization
|
## Diarization
|
||||||
|
|||||||
@@ -3,8 +3,9 @@
|
|||||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||||
and use [Modal.com](https://modal.com)
|
and use [Modal.com](https://modal.com)
|
||||||
|
|
||||||
- `reflector_llm.py` - LLM API
|
- `reflector_diarizer.py` - Diarization API
|
||||||
- `reflector_transcriber.py` - Transcription API
|
- `reflector_transcriber.py` - Transcription API
|
||||||
|
- `reflector_translator.py` - Translation API
|
||||||
|
|
||||||
## Modal.com deployment
|
## Modal.com deployment
|
||||||
|
|
||||||
|
|||||||
@@ -1,213 +0,0 @@
|
|||||||
"""
|
|
||||||
Reflector GPU backend - LLM
|
|
||||||
===========================
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
|
||||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
|
||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
|
||||||
|
|
||||||
app = App(name="reflector-llm")
|
|
||||||
|
|
||||||
|
|
||||||
def download_llm():
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
print("Downloading LLM model")
|
|
||||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM model downloaded")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
llm_image = (
|
|
||||||
Image.debian_slim(python_version="3.10.8")
|
|
||||||
.apt_install("git")
|
|
||||||
.pip_install(
|
|
||||||
"transformers",
|
|
||||||
"torch",
|
|
||||||
"sentencepiece",
|
|
||||||
"protobuf",
|
|
||||||
"jsonformer==0.12.0",
|
|
||||||
"accelerate==0.21.0",
|
|
||||||
"einops==0.6.1",
|
|
||||||
"hf-transfer~=0.1",
|
|
||||||
"huggingface_hub==0.16.4",
|
|
||||||
)
|
|
||||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
|
||||||
.run_function(download_llm)
|
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A100",
|
|
||||||
timeout=60 * 5,
|
|
||||||
scaledown_window=60 * 5,
|
|
||||||
allow_concurrent_inputs=15,
|
|
||||||
image=llm_image,
|
|
||||||
)
|
|
||||||
class LLM:
|
|
||||||
@enter()
|
|
||||||
def enter(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
print("Instance llm model")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
LLM_MODEL,
|
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
|
||||||
cache_dir=IMAGE_MODEL_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
|
||||||
print("Instance llm generation config")
|
|
||||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# generation configuration
|
|
||||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
|
||||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# load tokenizer
|
|
||||||
print("Instance llm tokenizer")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# move model to gpu
|
|
||||||
print("Move llm model to GPU")
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
print("Warmup llm done")
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.gen_cfg = gen_cfg
|
|
||||||
self.GenerationConfig = GenerationConfig
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
@exit()
|
|
||||||
def exit():
|
|
||||||
print("Exit llm")
|
|
||||||
|
|
||||||
@method()
|
|
||||||
def generate(
|
|
||||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Perform a generation action using the LLM
|
|
||||||
"""
|
|
||||||
print(f"Generate {prompt=}")
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
|
||||||
else:
|
|
||||||
gen_cfg = self.gen_cfg
|
|
||||||
|
|
||||||
# If a gen_schema is given, conform to gen_schema
|
|
||||||
with self.lock:
|
|
||||||
if gen_schema:
|
|
||||||
import jsonformer
|
|
||||||
|
|
||||||
print(f"Schema {gen_schema=}")
|
|
||||||
jsonformer_llm = jsonformer.Jsonformer(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
json_schema=json.loads(gen_schema),
|
|
||||||
prompt=prompt,
|
|
||||||
max_string_token_length=gen_cfg.max_new_tokens,
|
|
||||||
)
|
|
||||||
response = jsonformer_llm()
|
|
||||||
else:
|
|
||||||
# If no gen_schema, perform prompt only generation
|
|
||||||
|
|
||||||
# tokenize prompt
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
|
||||||
self.model.device
|
|
||||||
)
|
|
||||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
|
||||||
|
|
||||||
# decode output
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
output[0].cpu(), skip_special_tokens=True
|
|
||||||
)
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
print(f"Generated {response=}")
|
|
||||||
return {"text": response}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
# Web API
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60 * 10,
|
|
||||||
timeout=60 * 5,
|
|
||||||
allow_concurrent_inputs=45,
|
|
||||||
secrets=[
|
|
||||||
Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
llmstub = LLM()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class LLMRequest(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
gen_cfg: Optional[dict] = None
|
|
||||||
|
|
||||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
|
||||||
def llm(
|
|
||||||
req: LLMRequest,
|
|
||||||
):
|
|
||||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
|
||||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
|
||||||
func = llmstub.generate.spawn(
|
|
||||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
|
|
||||||
return app
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
"""
|
|
||||||
Reflector GPU backend - LLM
|
|
||||||
===========================
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
|
|
||||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
|
||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
|
||||||
|
|
||||||
app = App(name="reflector-llm-zephyr")
|
|
||||||
|
|
||||||
|
|
||||||
def download_llm():
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
print("Downloading LLM model")
|
|
||||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM model downloaded")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
llm_image = (
|
|
||||||
Image.debian_slim(python_version="3.10.8")
|
|
||||||
.apt_install("git")
|
|
||||||
.pip_install(
|
|
||||||
"transformers==4.34.0",
|
|
||||||
"torch",
|
|
||||||
"sentencepiece",
|
|
||||||
"protobuf",
|
|
||||||
"jsonformer==0.12.0",
|
|
||||||
"accelerate==0.21.0",
|
|
||||||
"einops==0.6.1",
|
|
||||||
"hf-transfer~=0.1",
|
|
||||||
"huggingface_hub==0.16.4",
|
|
||||||
)
|
|
||||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
|
||||||
.run_function(download_llm)
|
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A10G",
|
|
||||||
timeout=60 * 5,
|
|
||||||
scaledown_window=60 * 5,
|
|
||||||
allow_concurrent_inputs=10,
|
|
||||||
image=llm_image,
|
|
||||||
)
|
|
||||||
class LLM:
|
|
||||||
@enter()
|
|
||||||
def enter(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
print("Instance llm model")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
LLM_MODEL,
|
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
|
||||||
cache_dir=IMAGE_MODEL_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
|
||||||
print("Instance llm generation config")
|
|
||||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# generation configuration
|
|
||||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
|
||||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
|
||||||
|
|
||||||
# load tokenizer
|
|
||||||
print("Instance llm tokenizer")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
|
||||||
)
|
|
||||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
|
||||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
model.config.pad_token_id = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
# move model to gpu
|
|
||||||
print("Move llm model to GPU")
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
print("Warmup llm done")
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.gen_cfg = gen_cfg
|
|
||||||
self.GenerationConfig = GenerationConfig
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
@exit()
|
|
||||||
def exit():
|
|
||||||
print("Exit llm")
|
|
||||||
|
|
||||||
@method()
|
|
||||||
def generate(
|
|
||||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Perform a generation action using the LLM
|
|
||||||
"""
|
|
||||||
print(f"Generate {prompt=}")
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
|
||||||
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
|
|
||||||
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
gen_cfg = self.gen_cfg
|
|
||||||
|
|
||||||
# If a gen_schema is given, conform to gen_schema
|
|
||||||
with self.lock:
|
|
||||||
if gen_schema:
|
|
||||||
import jsonformer
|
|
||||||
|
|
||||||
print(f"Schema {gen_schema=}")
|
|
||||||
jsonformer_llm = jsonformer.Jsonformer(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
json_schema=json.loads(gen_schema),
|
|
||||||
prompt=prompt,
|
|
||||||
max_string_token_length=gen_cfg.max_new_tokens,
|
|
||||||
)
|
|
||||||
response = jsonformer_llm()
|
|
||||||
else:
|
|
||||||
# If no gen_schema, perform prompt only generation
|
|
||||||
|
|
||||||
# tokenize prompt
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
|
||||||
self.model.device
|
|
||||||
)
|
|
||||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
|
||||||
|
|
||||||
# decode output
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
output[0].cpu(), skip_special_tokens=True
|
|
||||||
)
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
response = {"long_summary": response}
|
|
||||||
print(f"Generated {response=}")
|
|
||||||
return {"text": response}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
# Web API
|
|
||||||
# -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60 * 10,
|
|
||||||
timeout=60 * 5,
|
|
||||||
allow_concurrent_inputs=30,
|
|
||||||
secrets=[
|
|
||||||
Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@asgi_app()
|
|
||||||
def web():
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
llmstub = LLM()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class LLMRequest(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
gen_cfg: Optional[dict] = None
|
|
||||||
|
|
||||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
|
||||||
def llm(
|
|
||||||
req: LLMRequest,
|
|
||||||
):
|
|
||||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
|
||||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
|
||||||
func = llmstub.generate.spawn(
|
|
||||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
|
|
||||||
return app
|
|
||||||
83
server/reflector/llm.py
Normal file
83
server/reflector/llm.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
from llama_index.core import Settings
|
||||||
|
from llama_index.core.output_parsers import PydanticOutputParser
|
||||||
|
from llama_index.core.program import LLMTextCompletionProgram
|
||||||
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
||||||
|
Based on the following analysis, provide the information in the requested JSON format:
|
||||||
|
|
||||||
|
Analysis:
|
||||||
|
{analysis}
|
||||||
|
|
||||||
|
{format_instructions}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
|
||||||
|
self.settings_obj = settings
|
||||||
|
self.model_name = settings.LLM_MODEL
|
||||||
|
self.url = settings.LLM_URL
|
||||||
|
self.api_key = settings.LLM_API_KEY
|
||||||
|
self.context_window = settings.LLM_CONTEXT_WINDOW
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# Configure llamaindex Settings
|
||||||
|
self._configure_llamaindex()
|
||||||
|
|
||||||
|
def _configure_llamaindex(self):
|
||||||
|
"""Configure llamaindex Settings with OpenAILike LLM"""
|
||||||
|
Settings.llm = OpenAILike(
|
||||||
|
model=self.model_name,
|
||||||
|
api_base=self.url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
context_window=self.context_window,
|
||||||
|
is_chat_model=True,
|
||||||
|
is_function_calling_model=False,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_response(
|
||||||
|
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Get a text response using TreeSummarize for non-function-calling models"""
|
||||||
|
summarizer = TreeSummarize(verbose=False)
|
||||||
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||||
|
return str(response).strip()
|
||||||
|
|
||||||
|
async def get_structured_response(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
texts: list[str],
|
||||||
|
output_cls: Type[T],
|
||||||
|
tone_name: str | None = None,
|
||||||
|
) -> T:
|
||||||
|
"""Get structured output from LLM for non-function-calling models"""
|
||||||
|
summarizer = TreeSummarize(verbose=True)
|
||||||
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||||
|
|
||||||
|
output_parser = PydanticOutputParser(output_cls)
|
||||||
|
|
||||||
|
program = LLMTextCompletionProgram.from_defaults(
|
||||||
|
output_parser=output_parser,
|
||||||
|
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
format_instructions = output_parser.format(
|
||||||
|
"Please structure the above information in the following JSON format:"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = await program.acall(
|
||||||
|
analysis=str(response), format_instructions=format_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .base import LLM # noqa: F401
|
|
||||||
from .llm_params import LLMTaskParams # noqa: F401
|
|
||||||
@@ -1,347 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import nltk
|
|
||||||
from prometheus_client import Counter, Histogram
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.llm_params import TaskParams
|
|
||||||
from reflector.logger import logger as reflector_logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="LLM")
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
|
||||||
_nltk_downloaded = False
|
|
||||||
_registry = {}
|
|
||||||
model_name: str
|
|
||||||
m_generate = Histogram(
|
|
||||||
"llm_generate",
|
|
||||||
"Time spent in LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_call = Counter(
|
|
||||||
"llm_generate_call",
|
|
||||||
"Number of calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_success = Counter(
|
|
||||||
"llm_generate_success",
|
|
||||||
"Number of successful calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
m_generate_failure = Counter(
|
|
||||||
"llm_generate_failure",
|
|
||||||
"Number of failed calls to LLM.generate",
|
|
||||||
["backend"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ensure_nltk(cls):
|
|
||||||
"""
|
|
||||||
Make sure NLTK package is installed. Searches in the cache and
|
|
||||||
downloads only if needed.
|
|
||||||
"""
|
|
||||||
if not cls._nltk_downloaded:
|
|
||||||
nltk.download("punkt_tab")
|
|
||||||
# For POS tagging
|
|
||||||
nltk.download("averaged_perceptron_tagger_eng")
|
|
||||||
cls._nltk_downloaded = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, klass):
|
|
||||||
cls._registry[name] = klass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
|
||||||
"""
|
|
||||||
Return an instance depending on the settings.
|
|
||||||
Settings used:
|
|
||||||
|
|
||||||
- `LLM_BACKEND`: key of the backend
|
|
||||||
- `LLM_URL`: url of the backend
|
|
||||||
"""
|
|
||||||
if name is None:
|
|
||||||
name = settings.LLM_BACKEND
|
|
||||||
if name not in cls._registry:
|
|
||||||
module_name = f"reflector.llm.llm_{name}"
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
cls.ensure_nltk()
|
|
||||||
|
|
||||||
return cls._registry[name](model_name)
|
|
||||||
|
|
||||||
def get_model_name(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the currently set model name
|
|
||||||
"""
|
|
||||||
return self._get_model_name()
|
|
||||||
|
|
||||||
def _get_model_name(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_model_name(self, model_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Update the model name with the provided model name
|
|
||||||
"""
|
|
||||||
return self._set_model_name(model_name)
|
|
||||||
|
|
||||||
def _set_model_name(self, model_name: str) -> bool:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def template(self) -> str:
|
|
||||||
"""
|
|
||||||
Return the LLM Prompt template
|
|
||||||
"""
|
|
||||||
return """
|
|
||||||
### Human:
|
|
||||||
{instruct}
|
|
||||||
|
|
||||||
{text}
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
name = self.__class__.__name__
|
|
||||||
self.m_generate = self.m_generate.labels(name)
|
|
||||||
self.m_generate_call = self.m_generate_call.labels(name)
|
|
||||||
self.m_generate_success = self.m_generate_success.labels(name)
|
|
||||||
self.m_generate_failure = self.m_generate_failure.labels(name)
|
|
||||||
self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer(self):
|
|
||||||
"""
|
|
||||||
Return the tokenizer instance used by LLM
|
|
||||||
"""
|
|
||||||
return self._get_tokenizer()
|
|
||||||
|
|
||||||
def _get_tokenizer(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def has_structured_output(self):
|
|
||||||
# whether implementation supports structured output
|
|
||||||
# on the model side (otherwise it's prompt engineering)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
logger: reflector_logger,
|
|
||||||
gen_schema: dict | None = None,
|
|
||||||
gen_cfg: GenerationConfig | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> dict:
|
|
||||||
logger.info("LLM generate", prompt=repr(prompt))
|
|
||||||
|
|
||||||
if gen_cfg:
|
|
||||||
gen_cfg = gen_cfg.to_dict()
|
|
||||||
self.m_generate_call.inc()
|
|
||||||
try:
|
|
||||||
with self.m_generate.time():
|
|
||||||
result = await retry(self._generate)(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=gen_schema,
|
|
||||||
gen_cfg=gen_cfg,
|
|
||||||
logger=logger,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.m_generate_success.inc()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to call llm after retrying")
|
|
||||||
self.m_generate_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug("LLM result [raw]", result=repr(result))
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = self._parse_json(result)
|
|
||||||
logger.debug("LLM result [parsed]", result=repr(result))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def completion(
|
|
||||||
self, messages: list, logger: reflector_logger, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Use /v1/chat/completion Open-AI compatible endpoint from the URL
|
|
||||||
It's up to the user to validate anything or transform the result
|
|
||||||
"""
|
|
||||||
logger.info("LLM completions", messages=messages)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.m_generate.time():
|
|
||||||
result = await retry(self._completion)(
|
|
||||||
messages=messages, **{**kwargs, "logger": logger}
|
|
||||||
)
|
|
||||||
self.m_generate_success.inc()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to call llm after retrying")
|
|
||||||
self.m_generate_failure.inc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug("LLM completion result", result=repr(result))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def ensure_casing(self, title: str) -> str:
|
|
||||||
"""
|
|
||||||
LLM takes care of word casing, but in rare cases this
|
|
||||||
can falter. This is a fallback to ensure the casing of
|
|
||||||
topics is in a proper format.
|
|
||||||
|
|
||||||
We select nouns, verbs and adjectives and check if camel
|
|
||||||
casing is present and fix it, if not. Will not perform
|
|
||||||
any other changes.
|
|
||||||
"""
|
|
||||||
tokens = nltk.word_tokenize(title)
|
|
||||||
pos_tags = nltk.pos_tag(tokens)
|
|
||||||
camel_cased = []
|
|
||||||
|
|
||||||
whitelisted_pos_tags = [
|
|
||||||
"NN",
|
|
||||||
"NNS",
|
|
||||||
"NNP",
|
|
||||||
"NNPS", # Noun POS
|
|
||||||
"VB",
|
|
||||||
"VBD",
|
|
||||||
"VBG",
|
|
||||||
"VBN",
|
|
||||||
"VBP",
|
|
||||||
"VBZ", # Verb POS
|
|
||||||
"JJ",
|
|
||||||
"JJR",
|
|
||||||
"JJS", # Adjective POS
|
|
||||||
]
|
|
||||||
|
|
||||||
# If at all there is an exception, do not block other reflector
|
|
||||||
# processes. Return the LLM generated title, at the least.
|
|
||||||
try:
|
|
||||||
for word, pos in pos_tags:
|
|
||||||
if pos in whitelisted_pos_tags and word[0].islower():
|
|
||||||
camel_cased.append(word[0].upper() + word[1:])
|
|
||||||
else:
|
|
||||||
camel_cased.append(word)
|
|
||||||
modified_title = self.detokenizer.detokenize(camel_cased)
|
|
||||||
|
|
||||||
# Irrespective of casing changes, the starting letter
|
|
||||||
# of title is always upper-cased
|
|
||||||
title = modified_title[0].upper() + modified_title[1:]
|
|
||||||
except Exception as e:
|
|
||||||
reflector_logger.info(
|
|
||||||
f"Failed to ensure casing on {title=} with exception : {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return title
|
|
||||||
|
|
||||||
def trim_title(self, title: str) -> str:
|
|
||||||
"""
|
|
||||||
List of manual trimming to the title.
|
|
||||||
|
|
||||||
Longer titles are prone to run into A prefix of phrases that don't
|
|
||||||
really add any descriptive information and in some cases, this
|
|
||||||
behaviour can be repeated for several consecutive topics. Trim the
|
|
||||||
titles to maintain quality of titles.
|
|
||||||
"""
|
|
||||||
phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
|
|
||||||
try:
|
|
||||||
pattern = (
|
|
||||||
r"\b(?:"
|
|
||||||
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
|
|
||||||
+ r")\b"
|
|
||||||
)
|
|
||||||
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
|
|
||||||
except Exception as e:
|
|
||||||
reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
|
|
||||||
return title
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _parse_json(self, result: str) -> dict:
|
|
||||||
result = result.strip()
|
|
||||||
# try detecting code block if exist
|
|
||||||
# starts with ```json\n, ends with ```
|
|
||||||
# or starts with ```\n, ends with ```
|
|
||||||
# or starts with \n```javascript\n, ends with ```
|
|
||||||
|
|
||||||
regex = r"```(json|javascript|)?(.*)```"
|
|
||||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
|
||||||
if matches:
|
|
||||||
result = matches[0][1]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# maybe the prompt has been started with ```json
|
|
||||||
# so if text ends with ```, just remove it and use it as json
|
|
||||||
if result.endswith("```"):
|
|
||||||
result = result[:-3]
|
|
||||||
|
|
||||||
return json.loads(result.strip())
|
|
||||||
|
|
||||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
|
||||||
"""
|
|
||||||
Choose the token size to set as the threshold to pack the LLM calls
|
|
||||||
"""
|
|
||||||
buffer_token_size = 100
|
|
||||||
default_output_tokens = 1000
|
|
||||||
context_window = self.tokenizer.model_max_length
|
|
||||||
tokens = self.tokenizer.tokenize(
|
|
||||||
self.create_prompt(instruct=task_params.instruct, text="")
|
|
||||||
)
|
|
||||||
threshold = context_window - len(tokens) - buffer_token_size
|
|
||||||
if task_params.gen_cfg:
|
|
||||||
threshold -= task_params.gen_cfg.max_new_tokens
|
|
||||||
else:
|
|
||||||
threshold -= default_output_tokens
|
|
||||||
return threshold
|
|
||||||
|
|
||||||
def split_corpus(
|
|
||||||
self,
|
|
||||||
corpus: str,
|
|
||||||
task_params: TaskParams,
|
|
||||||
token_threshold: int | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
|
||||||
restrictions.
|
|
||||||
|
|
||||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
|
||||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
|
||||||
"""
|
|
||||||
if not token_threshold:
|
|
||||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
|
||||||
|
|
||||||
accumulated_tokens = []
|
|
||||||
accumulated_sentences = []
|
|
||||||
accumulated_token_count = 0
|
|
||||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
|
||||||
|
|
||||||
for sentence in corpus_sentences:
|
|
||||||
tokens = self.tokenizer.tokenize(sentence)
|
|
||||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
|
||||||
accumulated_token_count += len(tokens)
|
|
||||||
accumulated_tokens.extend(tokens)
|
|
||||||
accumulated_sentences.append(sentence)
|
|
||||||
else:
|
|
||||||
yield "".join(accumulated_sentences)
|
|
||||||
accumulated_token_count = len(tokens)
|
|
||||||
accumulated_tokens = tokens
|
|
||||||
accumulated_sentences = [sentence]
|
|
||||||
|
|
||||||
if accumulated_tokens:
|
|
||||||
yield " ".join(accumulated_sentences)
|
|
||||||
|
|
||||||
def create_prompt(self, instruct: str, text: str) -> str:
|
|
||||||
"""
|
|
||||||
Create a consumable prompt based on the prompt template
|
|
||||||
"""
|
|
||||||
return self.template.format(instruct=instruct, text=text)
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import AutoTokenizer, GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
|
||||||
from reflector.logger import logger as reflector_logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
|
|
||||||
|
|
||||||
class ModalLLM(LLM):
|
|
||||||
def __init__(self, model_name: str | None = None):
|
|
||||||
super().__init__()
|
|
||||||
self.timeout = settings.LLM_TIMEOUT
|
|
||||||
self.llm_url = settings.LLM_URL + "/llm"
|
|
||||||
self.headers = {
|
|
||||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
|
||||||
}
|
|
||||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_models(self):
|
|
||||||
"""
|
|
||||||
List of currently supported models on this GPU platform
|
|
||||||
"""
|
|
||||||
# TODO: Query the specific GPU platform
|
|
||||||
# Replace this with a HTTP call
|
|
||||||
return [
|
|
||||||
"lmsys/vicuna-13b-v1.5",
|
|
||||||
"HuggingFaceH4/zephyr-7b-alpha",
|
|
||||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
json_payload = {"prompt": prompt}
|
|
||||||
if gen_schema:
|
|
||||||
json_payload["gen_schema"] = gen_schema
|
|
||||||
if gen_cfg:
|
|
||||||
json_payload["gen_cfg"] = gen_cfg
|
|
||||||
|
|
||||||
# Handing over generation of the final summary to Zephyr model
|
|
||||||
# but replacing the Vicuna model will happen after more testing
|
|
||||||
# TODO: Create a mapping of model names and cloud deployments
|
|
||||||
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
|
|
||||||
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await retry(client.post)(
|
|
||||||
self.llm_url,
|
|
||||||
headers=self.headers,
|
|
||||||
json=json_payload,
|
|
||||||
timeout=self.timeout,
|
|
||||||
retry_timeout=60 * 5,
|
|
||||||
follow_redirects=True,
|
|
||||||
logger=kwargs.get("logger", reflector_logger),
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
text = response.json()["text"]
|
|
||||||
return text
|
|
||||||
|
|
||||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
|
||||||
# returns full api response
|
|
||||||
kwargs.setdefault("temperature", 0.3)
|
|
||||||
kwargs.setdefault("max_tokens", 2048)
|
|
||||||
kwargs.setdefault("stream", False)
|
|
||||||
kwargs.setdefault("repetition_penalty", 1)
|
|
||||||
kwargs.setdefault("top_p", 1)
|
|
||||||
kwargs.setdefault("top_k", -1)
|
|
||||||
kwargs.setdefault("min_p", 0.05)
|
|
||||||
data = {"messages": messages, "model": self.model_name, **kwargs}
|
|
||||||
|
|
||||||
if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
|
|
||||||
self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await retry(client.post)(
|
|
||||||
self.llm_url,
|
|
||||||
headers=self.headers,
|
|
||||||
json=data,
|
|
||||||
timeout=self.timeout,
|
|
||||||
retry_timeout=60 * 5,
|
|
||||||
follow_redirects=True,
|
|
||||||
logger=kwargs.get("logger", reflector_logger),
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def _set_model_name(self, model_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Set the model name
|
|
||||||
"""
|
|
||||||
# Abort, if the model is not supported
|
|
||||||
if model_name not in self.supported_models:
|
|
||||||
reflector_logger.info(
|
|
||||||
f"Attempted to change {model_name=}, but is not supported."
|
|
||||||
f"Setting model and tokenizer failed !"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
# Abort, if the model is already set
|
|
||||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
|
||||||
reflector_logger.info("No change in model. Setting model skipped.")
|
|
||||||
return False
|
|
||||||
# Update model name and tokenizer
|
|
||||||
self.model_name = model_name
|
|
||||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name, cache_dir=settings.CACHE_DIR
|
|
||||||
)
|
|
||||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_tokenizer(self) -> AutoTokenizer:
|
|
||||||
"""
|
|
||||||
Return the currently used LLM tokenizer
|
|
||||||
"""
|
|
||||||
return self.llm_tokenizer
|
|
||||||
|
|
||||||
def _get_model_name(self) -> str:
|
|
||||||
"""
|
|
||||||
Return the current model name from the instance details
|
|
||||||
"""
|
|
||||||
return self.model_name
|
|
||||||
|
|
||||||
|
|
||||||
LLM.register("modal", ModalLLM)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
llm = ModalLLM()
|
|
||||||
prompt = llm.create_prompt(
|
|
||||||
instruct="Complete the following task",
|
|
||||||
text="Tell me a joke about programming.",
|
|
||||||
)
|
|
||||||
result = await llm.generate(prompt=prompt, logger=logger)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
gen_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"response": {"type": "string"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
|
||||||
result = await llm.generate(
|
|
||||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
|
||||||
)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLM):
|
|
||||||
def __init__(self, model_name: str | None = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.openai_key = settings.LLM_OPENAI_KEY
|
|
||||||
self.openai_url = settings.LLM_URL
|
|
||||||
self.openai_model = settings.LLM_OPENAI_MODEL
|
|
||||||
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
|
||||||
self.timeout = settings.LLM_TIMEOUT
|
|
||||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
|
||||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
|
||||||
|
|
||||||
async def _generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
gen_schema: dict | None,
|
|
||||||
gen_cfg: GenerationConfig | None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.openai_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.openai_url,
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"model": self.openai_model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"temperature": self.openai_temperature,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
return result["choices"][0]["text"]
|
|
||||||
|
|
||||||
|
|
||||||
LLM.register("openai", OpenAILLM)
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
from typing import Optional, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
|
||||||
instruct: str
|
|
||||||
gen_cfg: Optional[GenerationConfig] = None
|
|
||||||
gen_schema: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="LLMTaskParams")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMTaskParams:
|
|
||||||
_registry = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, task, klass) -> None:
|
|
||||||
cls._registry[task] = klass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls, task: str) -> T:
|
|
||||||
return cls._registry[task]()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def task_params(self) -> TaskParams | None:
|
|
||||||
"""
|
|
||||||
Fetch the task related parameters
|
|
||||||
"""
|
|
||||||
return self._get_task_params()
|
|
||||||
|
|
||||||
def _get_task_params(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLongSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Take the key ideas and takeaways from the text and create a short
|
|
||||||
summary. Be sure to keep the length of the response to a minimum.
|
|
||||||
Do not include trivial information in the summary.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"long_summary": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class FinalShortSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Take the key ideas and takeaways from the text and create a short
|
|
||||||
summary. Be sure to keep the length of the response to a minimum.
|
|
||||||
Do not include trivial information in the summary.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"short_summary": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class FinalTitleParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Combine the following individual titles into one single short title that
|
|
||||||
condenses the essence of all titles.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"title": {"type": "string"}},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class TopicParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Create a JSON object as response.The JSON object must have 2 fields:
|
|
||||||
i) title and ii) summary.
|
|
||||||
For the title field, generate a very detailed and self-explanatory
|
|
||||||
title for the given text. Let the title be as descriptive as possible.
|
|
||||||
For the summary field, summarize the given text in a maximum of
|
|
||||||
two sentences.
|
|
||||||
"""
|
|
||||||
self._schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"title": {"type": "string"},
|
|
||||||
"summary": {"type": "string"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class BulletedSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=800,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=0.2,
|
|
||||||
early_stopping=True,
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Given a meeting transcript, extract the key things discussed in the
|
|
||||||
form of a list.
|
|
||||||
|
|
||||||
While generating the response, follow the constraints mentioned below.
|
|
||||||
|
|
||||||
Summary constraints:
|
|
||||||
i) Do not add new content, except to fix spelling or punctuation.
|
|
||||||
ii) Do not add any prefixes or numbering in the response.
|
|
||||||
iii) The summarization should be as information dense as possible.
|
|
||||||
iv) Do not add any additional sections like Note, Conclusion, etc. in
|
|
||||||
the response.
|
|
||||||
|
|
||||||
Response format:
|
|
||||||
i) The response should be in the form of a bulleted list.
|
|
||||||
ii) Iteratively merge all the relevant paragraphs together to keep the
|
|
||||||
number of paragraphs to a minimum.
|
|
||||||
iii) Remove any unfinished sentences from the final response.
|
|
||||||
iv) Do not include narrative or reporting clauses.
|
|
||||||
v) Use "*" as the bullet icon.
|
|
||||||
"""
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
class MergedSummaryParams(LLMTaskParams):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._gen_cfg = GenerationConfig(
|
|
||||||
max_new_tokens=600,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=0.2,
|
|
||||||
early_stopping=True,
|
|
||||||
)
|
|
||||||
self._instruct = """
|
|
||||||
Given the key points of a meeting, summarize the points to describe the
|
|
||||||
meeting in the form of paragraphs.
|
|
||||||
"""
|
|
||||||
self._task_params = TaskParams(
|
|
||||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_task_params(self) -> TaskParams:
|
|
||||||
"""gen_schema
|
|
||||||
Return the parameters associated with a specific LLM task
|
|
||||||
"""
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
|
|
||||||
LLMTaskParams.register("topic", TopicParams)
|
|
||||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
|
||||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
|
||||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
|
||||||
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
|
|
||||||
LLMTaskParams.register("merged_summary", MergedSummaryParams)
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
def apply_gen_config(payload: dict, gen_cfg) -> None:
|
|
||||||
"""Apply generation config overrides to the payload."""
|
|
||||||
config_mapping = {
|
|
||||||
"temperature": "temperature",
|
|
||||||
"max_new_tokens": "max_tokens",
|
|
||||||
"max_tokens": "max_tokens",
|
|
||||||
"top_p": "top_p",
|
|
||||||
"frequency_penalty": "frequency_penalty",
|
|
||||||
"presence_penalty": "presence_penalty",
|
|
||||||
}
|
|
||||||
|
|
||||||
for cfg_attr, payload_key in config_mapping.items():
|
|
||||||
value = getattr(gen_cfg, cfg_attr, None)
|
|
||||||
if value is not None:
|
|
||||||
payload[payload_key] = value
|
|
||||||
if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM:
|
|
||||||
def __init__(self, config_prefix: str, settings):
|
|
||||||
self.config_prefix = config_prefix
|
|
||||||
self.settings_obj = settings
|
|
||||||
self.model_name = getattr(settings, f"{config_prefix}_MODEL")
|
|
||||||
self.url = getattr(settings, f"{config_prefix}_LLM_URL")
|
|
||||||
self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
|
|
||||||
|
|
||||||
timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
|
|
||||||
self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
|
|
||||||
self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
|
|
||||||
self.client = httpx.AsyncClient(timeout=timeout)
|
|
||||||
|
|
||||||
# Use a tokenizer that approximates OpenAI token counting
|
|
||||||
tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
|
|
||||||
)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
|
|
||||||
) -> str:
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI LLM generate",
|
|
||||||
prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
result = await self.completion(
|
|
||||||
messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
|
|
||||||
)
|
|
||||||
return result["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
async def completion(
|
|
||||||
self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
if logger:
|
|
||||||
logger.info("OpenAI LLM completion", messages_count=len(messages))
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Apply generation config overrides
|
|
||||||
if gen_cfg:
|
|
||||||
apply_gen_config(payload, gen_cfg)
|
|
||||||
|
|
||||||
# Apply structured output schema
|
|
||||||
if gen_schema:
|
|
||||||
payload["response_format"] = {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {"name": "response", "schema": gen_schema},
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI API request", url=url, payload_keys=list(payload.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.client.post(url, json=payload, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if logger:
|
|
||||||
logger.debug(
|
|
||||||
"OpenAI API response",
|
|
||||||
status_code=response.status_code,
|
|
||||||
choices_count=len(result.get("choices", [])),
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
await self.client.aclose()
|
|
||||||
@@ -12,15 +12,9 @@ from textwrap import dedent
|
|||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from llama_index.core import Settings
|
|
||||||
from llama_index.core.output_parsers import PydanticOutputParser
|
|
||||||
from llama_index.core.program import LLMTextCompletionProgram
|
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
from llama_index.llms.openai_like import OpenAILike
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm import LLM
|
||||||
from reflector.llm.openai_llm import OpenAILLM
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
@@ -168,23 +162,12 @@ class SummaryBuilder:
|
|||||||
self.summaries: list[dict[str, str]] = []
|
self.summaries: list[dict[str, str]] = []
|
||||||
self.subjects: list[str] = []
|
self.subjects: list[str] = []
|
||||||
self.transcription_type: TranscriptionType | None = None
|
self.transcription_type: TranscriptionType | None = None
|
||||||
self.llm_instance: LLM = llm
|
self.llm: LLM = llm
|
||||||
self.model_name: str = llm.model_name
|
self.model_name: str = llm.model_name
|
||||||
self.logger = logger or structlog.get_logger()
|
self.logger = logger or structlog.get_logger()
|
||||||
if filename:
|
if filename:
|
||||||
self.read_transcript_from_file(filename)
|
self.read_transcript_from_file(filename)
|
||||||
|
|
||||||
Settings.llm = OpenAILike(
|
|
||||||
model=llm.model_name,
|
|
||||||
api_base=llm.url,
|
|
||||||
api_key=llm.api_key,
|
|
||||||
context_window=settings.SUMMARY_LLM_CONTEXT_SIZE_TOKENS,
|
|
||||||
is_chat_model=True,
|
|
||||||
is_function_calling_model=llm.has_structured_output,
|
|
||||||
temperature=llm.temperature,
|
|
||||||
max_tokens=llm.max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_transcript_from_file(self, filename: str) -> None:
|
def read_transcript_from_file(self, filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Load a transcript from a text file.
|
Load a transcript from a text file.
|
||||||
@@ -202,40 +185,16 @@ class SummaryBuilder:
|
|||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
|
|
||||||
def set_llm_instance(self, llm: LLM) -> None:
|
def set_llm_instance(self, llm: LLM) -> None:
|
||||||
self.llm_instance = llm
|
self.llm = llm
|
||||||
|
|
||||||
async def _get_structured_response(
|
async def _get_structured_response(
|
||||||
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
|
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
|
||||||
) -> Type[T]:
|
) -> T:
|
||||||
"""Generic function to get structured output from LLM for non-function-calling models."""
|
"""Generic function to get structured output from LLM for non-function-calling models."""
|
||||||
# First, use TreeSummarize to get the response
|
return await self.llm.get_structured_response(
|
||||||
summarizer = TreeSummarize(verbose=True)
|
prompt, [self.transcript], output_cls, tone_name=tone_name
|
||||||
|
|
||||||
response = await summarizer.aget_response(
|
|
||||||
prompt, [self.transcript], tone_name=tone_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then, use PydanticOutputParser to structure the response
|
|
||||||
output_parser = PydanticOutputParser(output_cls)
|
|
||||||
|
|
||||||
prompt_template_str = STRUCTURED_RESPONSE_PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
program = LLMTextCompletionProgram.from_defaults(
|
|
||||||
output_parser=output_parser,
|
|
||||||
prompt_template_str=prompt_template_str,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
format_instructions = output_parser.format(
|
|
||||||
"Please structure the above information in the following JSON format:"
|
|
||||||
)
|
|
||||||
|
|
||||||
output = await program.acall(
|
|
||||||
analysis=str(response), format_instructions=format_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
# Participants
|
# Participants
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
@@ -354,19 +313,18 @@ class SummaryBuilder:
|
|||||||
async def generate_subject_summaries(self) -> None:
|
async def generate_subject_summaries(self) -> None:
|
||||||
"""Generate detailed summaries for each extracted subject."""
|
"""Generate detailed summaries for each extracted subject."""
|
||||||
assert self.transcript is not None
|
assert self.transcript is not None
|
||||||
summarizer = TreeSummarize(verbose=False)
|
|
||||||
summaries = []
|
summaries = []
|
||||||
|
|
||||||
for subject in self.subjects:
|
for subject in self.subjects:
|
||||||
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
|
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
|
||||||
|
|
||||||
detailed_response = await summarizer.aget_response(
|
detailed_response = await self.llm.get_response(
|
||||||
detailed_prompt, [self.transcript], tone_name="Topic assistant"
|
detailed_prompt, [self.transcript], tone_name="Topic assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
|
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
|
||||||
|
|
||||||
paragraph_response = await summarizer.aget_response(
|
paragraph_response = await self.llm.get_response(
|
||||||
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
|
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -377,7 +335,6 @@ class SummaryBuilder:
|
|||||||
|
|
||||||
async def generate_recap(self) -> None:
|
async def generate_recap(self) -> None:
|
||||||
"""Generate a quick recap from the subject summaries."""
|
"""Generate a quick recap from the subject summaries."""
|
||||||
summarizer = TreeSummarize(verbose=True)
|
|
||||||
|
|
||||||
summaries_text = "\n\n".join(
|
summaries_text = "\n\n".join(
|
||||||
[
|
[
|
||||||
@@ -388,7 +345,7 @@ class SummaryBuilder:
|
|||||||
|
|
||||||
recap_prompt = RECAP_PROMPT
|
recap_prompt = RECAP_PROMPT
|
||||||
|
|
||||||
recap_response = await summarizer.aget_response(
|
recap_response = await self.llm.get_response(
|
||||||
recap_prompt, [summaries_text], tone_name="Recap summarizer"
|
recap_prompt, [summaries_text], tone_name="Recap summarizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -483,7 +440,7 @@ if __name__ == "__main__":
|
|||||||
async def main():
|
async def main():
|
||||||
# build the summary
|
# build the summary
|
||||||
|
|
||||||
llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
llm = LLM(settings=settings)
|
||||||
sm = SummaryBuilder(llm=llm, filename=args.transcript)
|
sm = SummaryBuilder(llm=llm, filename=args.transcript)
|
||||||
|
|
||||||
if args.subjects:
|
if args.subjects:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from reflector.llm.openai_llm import OpenAILLM
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.summary.summary_builder import SummaryBuilder
|
from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
||||||
@@ -17,7 +17,7 @@ class TranscriptFinalSummaryProcessor(Processor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
self.chunks: list[TitleSummary] = []
|
self.chunks: list[TitleSummary] = []
|
||||||
self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
self.llm = LLM(settings=settings)
|
||||||
self.builder = None
|
self.builder = None
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
async def _push(self, data: TitleSummary):
|
||||||
|
|||||||
@@ -1,67 +1,72 @@
|
|||||||
from reflector.llm import LLM, LLMTaskParams
|
from textwrap import dedent
|
||||||
|
|
||||||
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import FinalTitle, TitleSummary
|
from reflector.processors.types import FinalTitle, TitleSummary
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
TITLE_PROMPT = dedent(
|
||||||
|
"""
|
||||||
|
Generate a concise title for this meeting based on the following topic titles.
|
||||||
|
Ignore casual conversation, greetings, or administrative matters.
|
||||||
|
|
||||||
|
The title must:
|
||||||
|
- Be maximum 10 words
|
||||||
|
- Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
|
||||||
|
- Avoid generic terms like "Team Meeting" or "Discussion"
|
||||||
|
|
||||||
|
If multiple unrelated topics were discussed, prioritize the most significant one.
|
||||||
|
or create a compound title (e.g., "Product Launch and Budget Planning").
|
||||||
|
|
||||||
|
<topics_discussed>
|
||||||
|
{titles}
|
||||||
|
</topics_discussed>
|
||||||
|
|
||||||
|
Do not explain, just output the meeting title as a single line.
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
class TranscriptFinalTitleProcessor(Processor):
|
class TranscriptFinalTitleProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Assemble all summary into a line-based json
|
Generate a final title from topic titles using LlamaIndex
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = TitleSummary
|
INPUT_TYPE = TitleSummary
|
||||||
OUTPUT_TYPE = FinalTitle
|
OUTPUT_TYPE = FinalTitle
|
||||||
TASK = "final_title"
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.chunks: list[TitleSummary] = []
|
self.chunks: list[TitleSummary] = []
|
||||||
self.llm = LLM.get_instance()
|
self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
|
||||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
async def _push(self, data: TitleSummary):
|
||||||
self.chunks.append(data)
|
self.chunks.append(data)
|
||||||
|
|
||||||
async def get_title(self, text: str) -> dict:
|
async def get_title(self, accumulated_titles: str) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a title for the whole recording
|
Generate a title for the whole recording using LLM
|
||||||
"""
|
"""
|
||||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
prompt = TITLE_PROMPT.format(titles=accumulated_titles)
|
||||||
|
response = await self.llm.get_response(
|
||||||
|
prompt,
|
||||||
|
[accumulated_titles],
|
||||||
|
tone_name="Title generator",
|
||||||
|
)
|
||||||
|
|
||||||
if len(chunks) == 1:
|
self.logger.info(f"Generated title response: {response}")
|
||||||
chunk = chunks[0]
|
|
||||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
|
||||||
title_result = await self.llm.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
return title_result
|
|
||||||
else:
|
|
||||||
accumulated_titles = ""
|
|
||||||
for chunk in chunks:
|
|
||||||
prompt = self.llm.create_prompt(
|
|
||||||
instruct=self.params.instruct, text=chunk
|
|
||||||
)
|
|
||||||
title_result = await self.llm.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
accumulated_titles += title_result["title"]
|
|
||||||
|
|
||||||
return await self.get_title(accumulated_titles)
|
return response
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
self.logger.warning("No summary to output")
|
self.logger.warning("No summary to output")
|
||||||
return
|
return
|
||||||
|
|
||||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
|
||||||
title_result = await self.get_title(accumulated_titles)
|
title = await self.get_title(accumulated_titles)
|
||||||
final_title = self.llm.trim_title(title_result["title"])
|
title = clean_title(title)
|
||||||
final_title = self.llm.ensure_casing(final_title)
|
|
||||||
|
|
||||||
final_title = FinalTitle(title=final_title)
|
final_title = FinalTitle(title=title)
|
||||||
await self.emit(final_title)
|
await self.emit(final_title)
|
||||||
|
|||||||
@@ -1,7 +1,41 @@
|
|||||||
from reflector.llm import LLM, LLMTaskParams
|
from textwrap import dedent
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from reflector.llm import LLM
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import TitleSummary, Transcript
|
from reflector.processors.types import TitleSummary, Transcript
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
TOPIC_PROMPT = dedent(
|
||||||
|
"""
|
||||||
|
Analyze the following transcript segment and extract the main topic being discussed.
|
||||||
|
Focus on the substantive content and ignore small talk or administrative chatter.
|
||||||
|
|
||||||
|
Create a title that:
|
||||||
|
- Captures the specific subject matter being discussed
|
||||||
|
- Is descriptive and self-explanatory
|
||||||
|
- Uses professional language
|
||||||
|
- Is specific rather than generic
|
||||||
|
|
||||||
|
For the summary:
|
||||||
|
- Summarize the key points in maximum two sentences
|
||||||
|
- Focus on what was discussed, decided, or accomplished
|
||||||
|
- Be concise but informative
|
||||||
|
|
||||||
|
<transcript>
|
||||||
|
{text}
|
||||||
|
</transcript>
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
|
class TopicResponse(BaseModel):
|
||||||
|
"""Structured response for topic detection"""
|
||||||
|
|
||||||
|
title: str = Field(description="A descriptive title for the topic being discussed")
|
||||||
|
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopicDetectorProcessor(Processor):
|
class TranscriptTopicDetectorProcessor(Processor):
|
||||||
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
|
|
||||||
INPUT_TYPE = Transcript
|
INPUT_TYPE = Transcript
|
||||||
OUTPUT_TYPE = TitleSummary
|
OUTPUT_TYPE = TitleSummary
|
||||||
TASK = "topic"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
||||||
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.min_transcript_length = min_transcript_length
|
self.min_transcript_length = min_transcript_length
|
||||||
self.llm = LLM.get_instance()
|
self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
|
||||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
|
||||||
|
|
||||||
async def _push(self, data: Transcript):
|
async def _push(self, data: Transcript):
|
||||||
if self.transcript is None:
|
if self.transcript is None:
|
||||||
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
return
|
return
|
||||||
await self.flush()
|
await self.flush()
|
||||||
|
|
||||||
async def get_topic(self, text: str) -> dict:
|
async def get_topic(self, text: str) -> TopicResponse:
|
||||||
"""
|
"""
|
||||||
Generate a topic and description for a transcription excerpt
|
Generate a topic and description for a transcription excerpt using LLM
|
||||||
"""
|
"""
|
||||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
prompt = TOPIC_PROMPT.format(text=text)
|
||||||
topic_result = await self.llm.generate(
|
response = await self.llm.get_structured_response(
|
||||||
prompt=prompt,
|
prompt, [text], TopicResponse, tone_name="Topic analyzer"
|
||||||
gen_schema=self.params.gen_schema,
|
|
||||||
gen_cfg=self.params.gen_cfg,
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
)
|
||||||
return topic_result
|
return response
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
if not self.transcript:
|
if not self.transcript:
|
||||||
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
|
|
||||||
text = self.transcript.text
|
text = self.transcript.text
|
||||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||||
|
|
||||||
topic_result = await self.get_topic(text=text)
|
topic_result = await self.get_topic(text=text)
|
||||||
title = self.llm.trim_title(topic_result["title"])
|
title = clean_title(topic_result.title)
|
||||||
title = self.llm.ensure_casing(title)
|
|
||||||
|
|
||||||
summary = TitleSummary(
|
summary = TitleSummary(
|
||||||
title=title,
|
title=title,
|
||||||
summary=topic_result["summary"],
|
summary=topic_result.summary,
|
||||||
timestamp=self.transcript.timestamp,
|
timestamp=self.transcript.timestamp,
|
||||||
duration=self.transcript.duration,
|
duration=self.transcript.duration,
|
||||||
transcript=self.transcript,
|
transcript=self.transcript,
|
||||||
|
|||||||
@@ -13,14 +13,13 @@ class TranscriptTranslatorProcessor(Processor):
|
|||||||
|
|
||||||
INPUT_TYPE = Transcript
|
INPUT_TYPE = Transcript
|
||||||
OUTPUT_TYPE = Transcript
|
OUTPUT_TYPE = Transcript
|
||||||
TASK = "translate"
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.translate_url = settings.TRANSLATE_URL
|
self.translate_url = settings.TRANSLATE_URL
|
||||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
self.headers = {"Authorization": f"Bearer {settings.TRANSCRIPT_MODAL_API_KEY}"}
|
||||||
|
|
||||||
async def _push(self, data: Transcript):
|
async def _push(self, data: Transcript):
|
||||||
self.transcript = data
|
self.transcript = data
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
# CORS
|
||||||
|
UI_BASE_URL: str = "http://localhost:3000"
|
||||||
CORS_ORIGIN: str = "*"
|
CORS_ORIGIN: str = "*"
|
||||||
CORS_ALLOW_CREDENTIALS: bool = False
|
CORS_ALLOW_CREDENTIALS: bool = False
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||||
|
|
||||||
# local data directory (audio for no)
|
# local data directory
|
||||||
DATA_DIR: str = "./data"
|
DATA_DIR: str = "./data"
|
||||||
|
|
||||||
# Audio Transcription
|
# Audio Transcription
|
||||||
@@ -24,10 +25,6 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_URL: str | None = None
|
TRANSCRIPT_URL: str | None = None
|
||||||
TRANSCRIPT_TIMEOUT: int = 90
|
TRANSCRIPT_TIMEOUT: int = 90
|
||||||
|
|
||||||
# Translate into the target language
|
|
||||||
TRANSLATE_URL: str | None = None
|
|
||||||
TRANSLATE_TIMEOUT: int = 90
|
|
||||||
|
|
||||||
# Audio transcription modal.com configuration
|
# Audio transcription modal.com configuration
|
||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
@@ -40,31 +37,15 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||||
|
|
||||||
|
# Translate into the target language
|
||||||
|
TRANSLATE_URL: str | None = None
|
||||||
|
TRANSLATE_TIMEOUT: int = 90
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
# available backend: openai, modal
|
LLM_MODEL: str = "microsoft/phi-4"
|
||||||
LLM_BACKEND: str = "modal"
|
|
||||||
|
|
||||||
# LLM common configuration
|
|
||||||
LLM_URL: str | None = None
|
LLM_URL: str | None = None
|
||||||
LLM_HOST: str = "localhost"
|
LLM_API_KEY: str | None = None
|
||||||
LLM_PORT: int = 7860
|
LLM_CONTEXT_WINDOW: int = 16000
|
||||||
LLM_OPENAI_KEY: str | None = None
|
|
||||||
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
|
||||||
LLM_OPENAI_TEMPERATURE: float = 0.7
|
|
||||||
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
|
|
||||||
LLM_MAX_TOKENS: int = 1024
|
|
||||||
LLM_TEMPERATURE: float = 0.7
|
|
||||||
ZEPHYR_LLM_URL: str | None = None
|
|
||||||
HERMES_3_8B_LLM_URL: str | None = None
|
|
||||||
|
|
||||||
# LLM Modal configuration
|
|
||||||
LLM_MODAL_API_KEY: str | None = None
|
|
||||||
|
|
||||||
# per-task cases
|
|
||||||
SUMMARY_MODEL: str = "monadical/private/smart"
|
|
||||||
SUMMARY_LLM_URL: str | None = None
|
|
||||||
SUMMARY_LLM_API_KEY: str | None = None
|
|
||||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
|
|
||||||
|
|
||||||
# Diarization
|
# Diarization
|
||||||
DIARIZATION_ENABLED: bool = True
|
DIARIZATION_ENABLED: bool = True
|
||||||
@@ -86,12 +67,6 @@ class Settings(BaseSettings):
|
|||||||
# if set, all anonymous record will be public
|
# if set, all anonymous record will be public
|
||||||
PUBLIC_MODE: bool = False
|
PUBLIC_MODE: bool = False
|
||||||
|
|
||||||
# Default LLM model name
|
|
||||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
|
||||||
|
|
||||||
# Cache directory for all model storage
|
|
||||||
CACHE_DIR: str = "./data"
|
|
||||||
|
|
||||||
# Min transcript length to generate topic + summary
|
# Min transcript length to generate topic + summary
|
||||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||||
|
|
||||||
@@ -116,24 +91,20 @@ class Settings(BaseSettings):
|
|||||||
# Healthcheck
|
# Healthcheck
|
||||||
HEALTHCHECK_URL: str | None = None
|
HEALTHCHECK_URL: str | None = None
|
||||||
|
|
||||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
# Whereby integration
|
||||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
|
||||||
|
|
||||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||||
|
|
||||||
WHEREBY_API_KEY: str | None = None
|
WHEREBY_API_KEY: str | None = None
|
||||||
|
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||||
AWS_WHEREBY_S3_BUCKET: str | None = None
|
AWS_WHEREBY_S3_BUCKET: str | None = None
|
||||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||||
|
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||||
|
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||||
|
|
||||||
|
# Zulip integration
|
||||||
ZULIP_REALM: str | None = None
|
ZULIP_REALM: str | None = None
|
||||||
ZULIP_API_KEY: str | None = None
|
ZULIP_API_KEY: str | None = None
|
||||||
ZULIP_BOT_EMAIL: str | None = None
|
ZULIP_BOT_EMAIL: str | None = None
|
||||||
|
|
||||||
UI_BASE_URL: str = "http://localhost:3000"
|
|
||||||
|
|
||||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
33
server/reflector/utils/text.py
Normal file
33
server/reflector/utils/text.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
def clean_title(title: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and format a title string for consistent capitalization.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Strip surrounding quotes (single or double)
|
||||||
|
- Capitalize the first word
|
||||||
|
- Capitalize words longer than 3 characters
|
||||||
|
- Keep words with 3 or fewer characters lowercase (except first word)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: The title string to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cleaned title with consistent capitalization
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> clean_title("hello world")
|
||||||
|
"Hello World"
|
||||||
|
>>> clean_title("meeting with the team")
|
||||||
|
"Meeting With the Team"
|
||||||
|
>>> clean_title("'Title with quotes'")
|
||||||
|
"Title With Quotes"
|
||||||
|
"""
|
||||||
|
title = title.strip("\"'")
|
||||||
|
words = title.split()
|
||||||
|
if words:
|
||||||
|
words = [
|
||||||
|
word.capitalize() if i == 0 or len(word) > 3 else word.lower()
|
||||||
|
for i, word in enumerate(words)
|
||||||
|
]
|
||||||
|
title = " ".join(words)
|
||||||
|
return title
|
||||||
@@ -37,8 +37,12 @@ def dummy_processors():
|
|||||||
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
|
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
|
||||||
) as mock_translate,
|
) as mock_translate,
|
||||||
):
|
):
|
||||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
from reflector.processors.transcript_topic_detector import TopicResponse
|
||||||
mock_title.return_value = {"title": "LLM TITLE"}
|
|
||||||
|
mock_topic.return_value = TopicResponse(
|
||||||
|
title="LLM TITLE", summary="LLM SUMMARY"
|
||||||
|
)
|
||||||
|
mock_title.return_value = "LLM Title"
|
||||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
||||||
mock_translate.return_value = "Bonjour le monde"
|
mock_translate.return_value = "Bonjour le monde"
|
||||||
@@ -103,14 +107,15 @@ async def dummy_diarization():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_llm():
|
async def dummy_llm():
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm import LLM
|
||||||
|
|
||||||
class TestLLM(LLM):
|
class TestLLM(LLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_name = "DUMMY MODEL"
|
self.model_name = "DUMMY MODEL"
|
||||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||||
|
|
||||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
# LLM doesn't have get_instance anymore, mocking constructor instead
|
||||||
|
with patch("reflector.llm.LLM") as mock_llm:
|
||||||
mock_llm.return_value = TestLLM()
|
mock_llm.return_value = TestLLM()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -129,22 +134,19 @@ async def dummy_storage():
|
|||||||
async def _get_file_url(self, *args, **kwargs):
|
async def _get_file_url(self, *args, **kwargs):
|
||||||
return "http://fake_server/audio.mp3"
|
return "http://fake_server/audio.mp3"
|
||||||
|
|
||||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
async def _get_file(self, *args, **kwargs):
|
||||||
mock_storage.return_value = DummyStorage()
|
from pathlib import Path
|
||||||
yield
|
|
||||||
|
|
||||||
|
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||||
|
return test_mp3.read_bytes()
|
||||||
|
|
||||||
@pytest.fixture
|
dummy = DummyStorage()
|
||||||
def nltk():
|
with (
|
||||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
||||||
mock_nltk.return_value = "NLTK PACKAGE"
|
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
||||||
yield
|
):
|
||||||
|
mock_storage.return_value = dummy
|
||||||
|
mock_get_transcripts.return_value = dummy
|
||||||
@pytest.fixture
|
|
||||||
def ensure_casing():
|
|
||||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
|
||||||
mock_casing.return_value = "LLM TITLE"
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_processor_broadcast(nltk):
|
async def test_processor_broadcast():
|
||||||
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
||||||
|
|
||||||
class TestProcessor(Processor):
|
class TestProcessor(Processor):
|
||||||
|
|||||||
@@ -3,11 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_process(
|
async def test_basic_process(
|
||||||
nltk,
|
|
||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
ensure_casing,
|
|
||||||
):
|
):
|
||||||
# goal is to start the server, and send rtc audio to it
|
# goal is to start the server, and send rtc audio to it
|
||||||
# validate the events received
|
# validate the events received
|
||||||
@@ -16,8 +14,8 @@ async def test_basic_process(
|
|||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.tools.process import process_audio_file
|
from reflector.tools.process import process_audio_file
|
||||||
|
|
||||||
# use an LLM test backend
|
# LLM_BACKEND no longer exists in settings
|
||||||
settings.LLM_BACKEND = "test"
|
# settings.LLM_BACKEND = "test"
|
||||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||||
|
|
||||||
# event callback
|
# event callback
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_process(
|
async def test_transcript_process(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
ensure_casing,
|
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
@@ -69,7 +68,7 @@ async def test_transcript_process(
|
|||||||
transcript = resp.json()
|
transcript = resp.json()
|
||||||
assert transcript["status"] == "ended"
|
assert transcript["status"] == "ended"
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||||
assert transcript["title"] == "LLM TITLE"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||||
|
|||||||
@@ -69,8 +69,6 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
|
||||||
nltk,
|
|
||||||
appserver,
|
appserver,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
@@ -185,7 +183,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
assert "FINAL_TITLE" in eventnames
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
assert ev["data"]["title"] == "LLM TITLE"
|
assert ev["data"]["title"] == "Llm Title"
|
||||||
|
|
||||||
assert "WAVEFORM" in eventnames
|
assert "WAVEFORM" in eventnames
|
||||||
ev = events[eventnames.index("WAVEFORM")]
|
ev = events[eventnames.index("WAVEFORM")]
|
||||||
@@ -228,8 +226,6 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
|
||||||
nltk,
|
|
||||||
appserver,
|
appserver,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
@@ -353,7 +349,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
assert "FINAL_TITLE" in eventnames
|
assert "FINAL_TITLE" in eventnames
|
||||||
ev = events[eventnames.index("FINAL_TITLE")]
|
ev = events[eventnames.index("FINAL_TITLE")]
|
||||||
assert ev["data"]["title"] == "LLM TITLE"
|
assert ev["data"]["title"] == "Llm Title"
|
||||||
|
|
||||||
# check status order
|
# check status order
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_upload_file(
|
async def test_transcript_upload_file(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
ensure_casing,
|
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
|
|||||||
transcript = resp.json()
|
transcript = resp.json()
|
||||||
assert transcript["status"] == "ended"
|
assert transcript["status"] == "ended"
|
||||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||||
assert transcript["title"] == "LLM TITLE"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||||
|
|||||||
21
server/tests/test_utils_text.py
Normal file
21
server/tests/test_utils_text.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.utils.text import clean_title
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_title,expected",
|
||||||
|
[
|
||||||
|
("hello world", "Hello World"),
|
||||||
|
("HELLO WORLD", "Hello World"),
|
||||||
|
("hello WORLD", "Hello World"),
|
||||||
|
("the quick brown fox", "The Quick Brown fox"),
|
||||||
|
("discussion about API design", "Discussion About api Design"),
|
||||||
|
("Q1 2024 budget review", "Q1 2024 Budget Review"),
|
||||||
|
("'Title with quotes'", "Title With Quotes"),
|
||||||
|
("'title with quotes'", "Title With Quotes"),
|
||||||
|
("MiXeD CaSe WoRdS", "Mixed Case Words"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clean_title(input_title, expected):
|
||||||
|
assert clean_title(input_title) == expected
|
||||||
Reference in New Issue
Block a user