mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
feat: use llamaindex everywhere (#525)
* feat: use llamaindex for transcript final title too * refactor: removed llm backend, replaced with one single class+llamaindex * refactor: self-review * fix: typing * fix: tests * refactor: extract clean_title and add tests * test: fix * test: remove ensure_casing/nltk * fix: tiny mistake
This commit is contained in:
@@ -46,38 +46,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
## llm backend implementation
|
||||
## =======================================================
|
||||
|
||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||
LLM_BACKEND=modal
|
||||
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
|
||||
LLM_MODAL_API_KEY=
|
||||
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
|
||||
|
||||
|
||||
## Using OpenAI
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_OPENAI_KEY=xxx
|
||||
#LLM_OPENAI_MODEL=gpt-3.5-turbo
|
||||
|
||||
## Using GPT4ALL
|
||||
#LLM_BACKEND=openai
|
||||
#LLM_URL=http://localhost:4891/v1/completions
|
||||
#LLM_OPENAI_MODEL="GPT4All Falcon"
|
||||
|
||||
## Default LLM MODEL NAME
|
||||
#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
|
||||
|
||||
## Cache directory to store models
|
||||
CACHE_DIR=data
|
||||
|
||||
## =======================================================
|
||||
## Summary LLM configuration
|
||||
## =======================================================
|
||||
|
||||
## Context size for summary generation (tokens)
|
||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS=16000
|
||||
SUMMARY_LLM_URL=
|
||||
SUMMARY_LLM_API_KEY=sk-
|
||||
SUMMARY_MODEL=
|
||||
# LLM_MODEL=microsoft/phi-4
|
||||
LLM_CONTEXT_WINDOW=16000
|
||||
LLM_URL=
|
||||
LLM_API_KEY=sk-
|
||||
|
||||
## =======================================================
|
||||
## Diarization
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
This repository hold an API for the GPU implementation of the Reflector API service,
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_llm.py` - LLM API
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
|
||||
app = App(name="reflector-llm")
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
llm_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"transformers",
|
||||
"torch",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"jsonformer==0.12.0",
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
.run_function(migrate_cache_llm)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A100",
|
||||
timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=15,
|
||||
image=llm_image,
|
||||
)
|
||||
class LLM:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
model = model.cuda()
|
||||
|
||||
print("Warmup llm done")
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@exit()
|
||||
def exit():
|
||||
print("Exit llm")
|
||||
|
||||
@method()
|
||||
def generate(
|
||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
||||
) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
with self.lock:
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens,
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(
|
||||
output[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
response = response[len(prompt) :]
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60 * 10,
|
||||
timeout=60 * 5,
|
||||
allow_concurrent_inputs=45,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
app = FastAPI()
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(
|
||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
||||
|
||||
app = App(name="reflector-llm-zephyr")
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
llm_image = (
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"transformers==4.34.0",
|
||||
"torch",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"jsonformer==0.12.0",
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
.run_function(migrate_cache_llm)
|
||||
)
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=10,
|
||||
image=llm_image,
|
||||
)
|
||||
class LLM:
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
||||
)
|
||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
model = model.cuda()
|
||||
|
||||
print("Warmup llm done")
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@exit()
|
||||
def exit():
|
||||
print("Exit llm")
|
||||
|
||||
@method()
|
||||
def generate(
|
||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
||||
) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
with self.lock:
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens,
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(
|
||||
output[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
response = response[len(prompt) :]
|
||||
response = {"long_summary": response}
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60 * 10,
|
||||
timeout=60 * 5,
|
||||
allow_concurrent_inputs=30,
|
||||
secrets=[
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
app = FastAPI()
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(
|
||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
83
server/reflector/llm.py
Normal file
83
server/reflector/llm.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.output_parsers import PydanticOutputParser
|
||||
from llama_index.core.program import LLMTextCompletionProgram
|
||||
from llama_index.core.response_synthesizers import TreeSummarize
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
||||
Based on the following analysis, provide the information in the requested JSON format:
|
||||
|
||||
Analysis:
|
||||
{analysis}
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
|
||||
self.settings_obj = settings
|
||||
self.model_name = settings.LLM_MODEL
|
||||
self.url = settings.LLM_URL
|
||||
self.api_key = settings.LLM_API_KEY
|
||||
self.context_window = settings.LLM_CONTEXT_WINDOW
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Configure llamaindex Settings
|
||||
self._configure_llamaindex()
|
||||
|
||||
def _configure_llamaindex(self):
|
||||
"""Configure llamaindex Settings with OpenAILike LLM"""
|
||||
Settings.llm = OpenAILike(
|
||||
model=self.model_name,
|
||||
api_base=self.url,
|
||||
api_key=self.api_key,
|
||||
context_window=self.context_window,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
async def get_response(
|
||||
self, prompt: str, texts: list[str], tone_name: str | None = None
|
||||
) -> str:
|
||||
"""Get a text response using TreeSummarize for non-function-calling models"""
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||
return str(response).strip()
|
||||
|
||||
async def get_structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
texts: list[str],
|
||||
output_cls: Type[T],
|
||||
tone_name: str | None = None,
|
||||
) -> T:
|
||||
"""Get structured output from LLM for non-function-calling models"""
|
||||
summarizer = TreeSummarize(verbose=True)
|
||||
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
||||
|
||||
output_parser = PydanticOutputParser(output_cls)
|
||||
|
||||
program = LLMTextCompletionProgram.from_defaults(
|
||||
output_parser=output_parser,
|
||||
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
format_instructions = output_parser.format(
|
||||
"Please structure the above information in the following JSON format:"
|
||||
)
|
||||
|
||||
output = await program.acall(
|
||||
analysis=str(response), format_instructions=format_instructions
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1,2 +0,0 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from .llm_params import LLMTaskParams # noqa: F401
|
||||
@@ -1,347 +0,0 @@
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_nltk_downloaded = False
|
||||
_registry = {}
|
||||
model_name: str
|
||||
m_generate = Histogram(
|
||||
"llm_generate",
|
||||
"Time spent in LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_call = Counter(
|
||||
"llm_generate_call",
|
||||
"Number of calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_success = Counter(
|
||||
"llm_generate_success",
|
||||
"Number of successful calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
m_generate_failure = Counter(
|
||||
"llm_generate_failure",
|
||||
"Number of failed calls to LLM.generate",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
if not cls._nltk_downloaded:
|
||||
nltk.download("punkt_tab")
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger_eng")
|
||||
cls._nltk_downloaded = True
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
|
||||
- `LLM_BACKEND`: key of the backend
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
if name is None:
|
||||
name = settings.LLM_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
cls.ensure_nltk()
|
||||
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get the currently set model name
|
||||
"""
|
||||
return self._get_model_name()
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
pass
|
||||
|
||||
def set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Update the model name with the provided model name
|
||||
"""
|
||||
return self._set_model_name(model_name)
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
"""
|
||||
Return the LLM Prompt template
|
||||
"""
|
||||
return """
|
||||
### Human:
|
||||
{instruct}
|
||||
|
||||
{text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
name = self.__class__.__name__
|
||||
self.m_generate = self.m_generate.labels(name)
|
||||
self.m_generate_call = self.m_generate_call.labels(name)
|
||||
self.m_generate_success = self.m_generate_success.labels(name)
|
||||
self.m_generate_failure = self.m_generate_failure.labels(name)
|
||||
self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer instance used by LLM
|
||||
"""
|
||||
return self._get_tokenizer()
|
||||
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
def has_structured_output(self):
|
||||
# whether implementation supports structured output
|
||||
# on the model side (otherwise it's prompt engineering)
|
||||
return False
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
gen_schema: dict | None = None,
|
||||
gen_cfg: GenerationConfig | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
|
||||
if gen_cfg:
|
||||
gen_cfg = gen_cfg.to_dict()
|
||||
self.m_generate_call.inc()
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._generate)(
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
raise
|
||||
|
||||
logger.debug("LLM result [raw]", result=repr(result))
|
||||
if isinstance(result, str):
|
||||
result = self._parse_json(result)
|
||||
logger.debug("LLM result [parsed]", result=repr(result))
|
||||
|
||||
return result
|
||||
|
||||
async def completion(
|
||||
self, messages: list, logger: reflector_logger, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Use /v1/chat/completion Open-AI compatible endpoint from the URL
|
||||
It's up to the user to validate anything or transform the result
|
||||
"""
|
||||
logger.info("LLM completions", messages=messages)
|
||||
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._completion)(
|
||||
messages=messages, **{**kwargs, "logger": logger}
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
raise
|
||||
|
||||
logger.debug("LLM completion result", result=repr(result))
|
||||
return result
|
||||
|
||||
def ensure_casing(self, title: str) -> str:
|
||||
"""
|
||||
LLM takes care of word casing, but in rare cases this
|
||||
can falter. This is a fallback to ensure the casing of
|
||||
topics is in a proper format.
|
||||
|
||||
We select nouns, verbs and adjectives and check if camel
|
||||
casing is present and fix it, if not. Will not perform
|
||||
any other changes.
|
||||
"""
|
||||
tokens = nltk.word_tokenize(title)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
camel_cased = []
|
||||
|
||||
whitelisted_pos_tags = [
|
||||
"NN",
|
||||
"NNS",
|
||||
"NNP",
|
||||
"NNPS", # Noun POS
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ", # Verb POS
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS", # Adjective POS
|
||||
]
|
||||
|
||||
# If at all there is an exception, do not block other reflector
|
||||
# processes. Return the LLM generated title, at the least.
|
||||
try:
|
||||
for word, pos in pos_tags:
|
||||
if pos in whitelisted_pos_tags and word[0].islower():
|
||||
camel_cased.append(word[0].upper() + word[1:])
|
||||
else:
|
||||
camel_cased.append(word)
|
||||
modified_title = self.detokenizer.detokenize(camel_cased)
|
||||
|
||||
# Irrespective of casing changes, the starting letter
|
||||
# of title is always upper-cased
|
||||
title = modified_title[0].upper() + modified_title[1:]
|
||||
except Exception as e:
|
||||
reflector_logger.info(
|
||||
f"Failed to ensure casing on {title=} with exception : {str(e)}"
|
||||
)
|
||||
|
||||
return title
|
||||
|
||||
def trim_title(self, title: str) -> str:
|
||||
"""
|
||||
List of manual trimming to the title.
|
||||
|
||||
Longer titles are prone to run into A prefix of phrases that don't
|
||||
really add any descriptive information and in some cases, this
|
||||
behaviour can be repeated for several consecutive topics. Trim the
|
||||
titles to maintain quality of titles.
|
||||
"""
|
||||
phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
|
||||
try:
|
||||
pattern = (
|
||||
r"\b(?:"
|
||||
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
|
||||
+ r")\b"
|
||||
)
|
||||
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
|
||||
except Exception as e:
|
||||
reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
|
||||
return title
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
result = result.strip()
|
||||
# try detecting code block if exist
|
||||
# starts with ```json\n, ends with ```
|
||||
# or starts with ```\n, ends with ```
|
||||
# or starts with \n```javascript\n, ends with ```
|
||||
|
||||
regex = r"```(json|javascript|)?(.*)```"
|
||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
||||
if matches:
|
||||
result = matches[0][1]
|
||||
|
||||
else:
|
||||
# maybe the prompt has been started with ```json
|
||||
# so if text ends with ```, just remove it and use it as json
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
|
||||
return json.loads(result.strip())
|
||||
|
||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
||||
"""
|
||||
Choose the token size to set as the threshold to pack the LLM calls
|
||||
"""
|
||||
buffer_token_size = 100
|
||||
default_output_tokens = 1000
|
||||
context_window = self.tokenizer.model_max_length
|
||||
tokens = self.tokenizer.tokenize(
|
||||
self.create_prompt(instruct=task_params.instruct, text="")
|
||||
)
|
||||
threshold = context_window - len(tokens) - buffer_token_size
|
||||
if task_params.gen_cfg:
|
||||
threshold -= task_params.gen_cfg.max_new_tokens
|
||||
else:
|
||||
threshold -= default_output_tokens
|
||||
return threshold
|
||||
|
||||
def split_corpus(
|
||||
self,
|
||||
corpus: str,
|
||||
task_params: TaskParams,
|
||||
token_threshold: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
||||
restrictions.
|
||||
|
||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
||||
"""
|
||||
if not token_threshold:
|
||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
||||
|
||||
accumulated_tokens = []
|
||||
accumulated_sentences = []
|
||||
accumulated_token_count = 0
|
||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
||||
|
||||
for sentence in corpus_sentences:
|
||||
tokens = self.tokenizer.tokenize(sentence)
|
||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
||||
accumulated_token_count += len(tokens)
|
||||
accumulated_tokens.extend(tokens)
|
||||
accumulated_sentences.append(sentence)
|
||||
else:
|
||||
yield "".join(accumulated_sentences)
|
||||
accumulated_token_count = len(tokens)
|
||||
accumulated_tokens = tokens
|
||||
accumulated_sentences = [sentence]
|
||||
|
||||
if accumulated_tokens:
|
||||
yield " ".join(accumulated_sentences)
|
||||
|
||||
def create_prompt(self, instruct: str, text: str) -> str:
|
||||
"""
|
||||
Create a consumable prompt based on the prompt template
|
||||
"""
|
||||
return self.template.format(instruct=instruct, text=text)
|
||||
@@ -1,155 +0,0 @@
|
||||
import httpx
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.llm_url = settings.LLM_URL + "/llm"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
||||
|
||||
@property
|
||||
def supported_models(self):
|
||||
"""
|
||||
List of currently supported models on this GPU platform
|
||||
"""
|
||||
# TODO: Query the specific GPU platform
|
||||
# Replace this with a HTTP call
|
||||
return [
|
||||
"lmsys/vicuna-13b-v1.5",
|
||||
"HuggingFaceH4/zephyr-7b-alpha",
|
||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
]
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
|
||||
# Handing over generation of the final summary to Zephyr model
|
||||
# but replacing the Vicuna model will happen after more testing
|
||||
# TODO: Create a mapping of model names and cloud deployments
|
||||
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
|
||||
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
logger=kwargs.get("logger", reflector_logger),
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
||||
# returns full api response
|
||||
kwargs.setdefault("temperature", 0.3)
|
||||
kwargs.setdefault("max_tokens", 2048)
|
||||
kwargs.setdefault("stream", False)
|
||||
kwargs.setdefault("repetition_penalty", 1)
|
||||
kwargs.setdefault("top_p", 1)
|
||||
kwargs.setdefault("top_k", -1)
|
||||
kwargs.setdefault("min_p", 0.05)
|
||||
data = {"messages": messages, "model": self.model_name, **kwargs}
|
||||
|
||||
if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
|
||||
self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
logger=kwargs.get("logger", reflector_logger),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Set the model name
|
||||
"""
|
||||
# Abort, if the model is not supported
|
||||
if model_name not in self.supported_models:
|
||||
reflector_logger.info(
|
||||
f"Attempted to change {model_name=}, but is not supported."
|
||||
f"Setting model and tokenizer failed !"
|
||||
)
|
||||
return False
|
||||
# Abort, if the model is already set
|
||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
||||
reflector_logger.info("No change in model. Setting model skipped.")
|
||||
return False
|
||||
# Update model name and tokenizer
|
||||
self.model_name = model_name
|
||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, cache_dir=settings.CACHE_DIR
|
||||
)
|
||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
||||
return True
|
||||
|
||||
def _get_tokenizer(self) -> AutoTokenizer:
|
||||
"""
|
||||
Return the currently used LLM tokenizer
|
||||
"""
|
||||
return self.llm_tokenizer
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""
|
||||
Return the current model name from the instance details
|
||||
"""
|
||||
return self.model_name
|
||||
|
||||
|
||||
LLM.register("modal", ModalLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = ModalLLM()
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_schema = {
|
||||
"type": "object",
|
||||
"properties": {"response": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
||||
result = await llm.generate(
|
||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
||||
)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,48 +0,0 @@
|
||||
import httpx
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
def __init__(self, model_name: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = settings.LLM_OPENAI_KEY
|
||||
self.openai_url = settings.LLM_URL
|
||||
self.openai_model = settings.LLM_OPENAI_MODEL
|
||||
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
gen_schema: dict | None,
|
||||
gen_cfg: GenerationConfig | None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.openai_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"model": self.openai_model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.openai_temperature,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
|
||||
LLM.register("openai", OpenAILLM)
|
||||
@@ -1,219 +0,0 @@
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
||||
instruct: str
|
||||
gen_cfg: Optional[GenerationConfig] = None
|
||||
gen_schema: Optional[dict] = None
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LLMTaskParams")
|
||||
|
||||
|
||||
class LLMTaskParams:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task, klass) -> None:
|
||||
cls._registry[task] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, task: str) -> T:
|
||||
return cls._registry[task]()
|
||||
|
||||
@property
|
||||
def task_params(self) -> TaskParams | None:
|
||||
"""
|
||||
Fetch the task related parameters
|
||||
"""
|
||||
return self._get_task_params()
|
||||
|
||||
def _get_task_params(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FinalLongSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"long_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalShortSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"short_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalTitleParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
||||
)
|
||||
self._instruct = """
|
||||
Combine the following individual titles into one single short title that
|
||||
condenses the essence of all titles.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class TopicParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
|
||||
)
|
||||
self._instruct = """
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
For the title field, generate a very detailed and self-explanatory
|
||||
title for the given text. Let the title be as descriptive as possible.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
two sentences.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class BulletedSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
early_stopping=True,
|
||||
)
|
||||
self._instruct = """
|
||||
Given a meeting transcript, extract the key things discussed in the
|
||||
form of a list.
|
||||
|
||||
While generating the response, follow the constraints mentioned below.
|
||||
|
||||
Summary constraints:
|
||||
i) Do not add new content, except to fix spelling or punctuation.
|
||||
ii) Do not add any prefixes or numbering in the response.
|
||||
iii) The summarization should be as information dense as possible.
|
||||
iv) Do not add any additional sections like Note, Conclusion, etc. in
|
||||
the response.
|
||||
|
||||
Response format:
|
||||
i) The response should be in the form of a bulleted list.
|
||||
ii) Iteratively merge all the relevant paragraphs together to keep the
|
||||
number of paragraphs to a minimum.
|
||||
iii) Remove any unfinished sentences from the final response.
|
||||
iv) Do not include narrative or reporting clauses.
|
||||
v) Use "*" as the bullet icon.
|
||||
"""
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class MergedSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=600,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
early_stopping=True,
|
||||
)
|
||||
self._instruct = """
|
||||
Given the key points of a meeting, summarize the points to describe the
|
||||
meeting in the form of paragraphs.
|
||||
"""
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
LLMTaskParams.register("topic", TopicParams)
|
||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
||||
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
|
||||
LLMTaskParams.register("merged_summary", MergedSummaryParams)
|
||||
@@ -1,118 +0,0 @@
|
||||
import httpx
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
def apply_gen_config(payload: dict, gen_cfg) -> None:
|
||||
"""Apply generation config overrides to the payload."""
|
||||
config_mapping = {
|
||||
"temperature": "temperature",
|
||||
"max_new_tokens": "max_tokens",
|
||||
"max_tokens": "max_tokens",
|
||||
"top_p": "top_p",
|
||||
"frequency_penalty": "frequency_penalty",
|
||||
"presence_penalty": "presence_penalty",
|
||||
}
|
||||
|
||||
for cfg_attr, payload_key in config_mapping.items():
|
||||
value = getattr(gen_cfg, cfg_attr, None)
|
||||
if value is not None:
|
||||
payload[payload_key] = value
|
||||
if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
|
||||
break
|
||||
|
||||
|
||||
class OpenAILLM:
|
||||
def __init__(self, config_prefix: str, settings):
|
||||
self.config_prefix = config_prefix
|
||||
self.settings_obj = settings
|
||||
self.model_name = getattr(settings, f"{config_prefix}_MODEL")
|
||||
self.url = getattr(settings, f"{config_prefix}_LLM_URL")
|
||||
self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
|
||||
|
||||
timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
|
||||
self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
|
||||
self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
|
||||
self.client = httpx.AsyncClient(timeout=timeout)
|
||||
|
||||
# Use a tokenizer that approximates OpenAI token counting
|
||||
tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
|
||||
) -> str:
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI LLM generate",
|
||||
prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
result = await self.completion(
|
||||
messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def completion(
|
||||
self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
|
||||
) -> dict:
|
||||
if logger:
|
||||
logger.info("OpenAI LLM completion", messages_count=len(messages))
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
# Apply generation config overrides
|
||||
if gen_cfg:
|
||||
apply_gen_config(payload, gen_cfg)
|
||||
|
||||
# Apply structured output schema
|
||||
if gen_schema:
|
||||
payload["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "response", "schema": gen_schema},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
url = f"{self.url.rstrip('/')}/chat/completions"
|
||||
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI API request", url=url, payload_keys=list(payload.keys())
|
||||
)
|
||||
|
||||
response = await self.client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI API response",
|
||||
status_code=response.status_code,
|
||||
choices_count=len(result.get("choices", [])),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.client.aclose()
|
||||
@@ -12,15 +12,9 @@ from textwrap import dedent
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import structlog
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.output_parsers import PydanticOutputParser
|
||||
from llama_index.core.program import LLMTextCompletionProgram
|
||||
from llama_index.core.response_synthesizers import TreeSummarize
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.llm.openai_llm import OpenAILLM
|
||||
from reflector.llm import LLM
|
||||
from reflector.settings import settings
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
@@ -168,23 +162,12 @@ class SummaryBuilder:
|
||||
self.summaries: list[dict[str, str]] = []
|
||||
self.subjects: list[str] = []
|
||||
self.transcription_type: TranscriptionType | None = None
|
||||
self.llm_instance: LLM = llm
|
||||
self.llm: LLM = llm
|
||||
self.model_name: str = llm.model_name
|
||||
self.logger = logger or structlog.get_logger()
|
||||
if filename:
|
||||
self.read_transcript_from_file(filename)
|
||||
|
||||
Settings.llm = OpenAILike(
|
||||
model=llm.model_name,
|
||||
api_base=llm.url,
|
||||
api_key=llm.api_key,
|
||||
context_window=settings.SUMMARY_LLM_CONTEXT_SIZE_TOKENS,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=llm.has_structured_output,
|
||||
temperature=llm.temperature,
|
||||
max_tokens=llm.max_tokens,
|
||||
)
|
||||
|
||||
def read_transcript_from_file(self, filename: str) -> None:
|
||||
"""
|
||||
Load a transcript from a text file.
|
||||
@@ -202,40 +185,16 @@ class SummaryBuilder:
|
||||
self.transcript = transcript
|
||||
|
||||
def set_llm_instance(self, llm: LLM) -> None:
|
||||
self.llm_instance = llm
|
||||
self.llm = llm
|
||||
|
||||
async def _get_structured_response(
|
||||
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
|
||||
) -> Type[T]:
|
||||
) -> T:
|
||||
"""Generic function to get structured output from LLM for non-function-calling models."""
|
||||
# First, use TreeSummarize to get the response
|
||||
summarizer = TreeSummarize(verbose=True)
|
||||
|
||||
response = await summarizer.aget_response(
|
||||
prompt, [self.transcript], tone_name=tone_name
|
||||
return await self.llm.get_structured_response(
|
||||
prompt, [self.transcript], output_cls, tone_name=tone_name
|
||||
)
|
||||
|
||||
# Then, use PydanticOutputParser to structure the response
|
||||
output_parser = PydanticOutputParser(output_cls)
|
||||
|
||||
prompt_template_str = STRUCTURED_RESPONSE_PROMPT_TEMPLATE
|
||||
|
||||
program = LLMTextCompletionProgram.from_defaults(
|
||||
output_parser=output_parser,
|
||||
prompt_template_str=prompt_template_str,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
format_instructions = output_parser.format(
|
||||
"Please structure the above information in the following JSON format:"
|
||||
)
|
||||
|
||||
output = await program.acall(
|
||||
analysis=str(response), format_instructions=format_instructions
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Participants
|
||||
# ----------------------------------------------------------------------------
|
||||
@@ -354,19 +313,18 @@ class SummaryBuilder:
|
||||
async def generate_subject_summaries(self) -> None:
|
||||
"""Generate detailed summaries for each extracted subject."""
|
||||
assert self.transcript is not None
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
summaries = []
|
||||
|
||||
for subject in self.subjects:
|
||||
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
|
||||
|
||||
detailed_response = await summarizer.aget_response(
|
||||
detailed_response = await self.llm.get_response(
|
||||
detailed_prompt, [self.transcript], tone_name="Topic assistant"
|
||||
)
|
||||
|
||||
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
|
||||
|
||||
paragraph_response = await summarizer.aget_response(
|
||||
paragraph_response = await self.llm.get_response(
|
||||
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
|
||||
)
|
||||
|
||||
@@ -377,7 +335,6 @@ class SummaryBuilder:
|
||||
|
||||
async def generate_recap(self) -> None:
|
||||
"""Generate a quick recap from the subject summaries."""
|
||||
summarizer = TreeSummarize(verbose=True)
|
||||
|
||||
summaries_text = "\n\n".join(
|
||||
[
|
||||
@@ -388,7 +345,7 @@ class SummaryBuilder:
|
||||
|
||||
recap_prompt = RECAP_PROMPT
|
||||
|
||||
recap_response = await summarizer.aget_response(
|
||||
recap_response = await self.llm.get_response(
|
||||
recap_prompt, [summaries_text], tone_name="Recap summarizer"
|
||||
)
|
||||
|
||||
@@ -483,7 +440,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# build the summary
|
||||
|
||||
llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
||||
llm = LLM(settings=settings)
|
||||
sm = SummaryBuilder(llm=llm, filename=args.transcript)
|
||||
|
||||
if args.subjects:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from reflector.llm.openai_llm import OpenAILLM
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
||||
@@ -17,7 +17,7 @@ class TranscriptFinalSummaryProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = transcript
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
||||
self.llm = LLM(settings=settings)
|
||||
self.builder = None
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
|
||||
@@ -1,67 +1,72 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from textwrap import dedent
|
||||
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalTitle, TitleSummary
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
TITLE_PROMPT = dedent(
|
||||
"""
|
||||
Generate a concise title for this meeting based on the following topic titles.
|
||||
Ignore casual conversation, greetings, or administrative matters.
|
||||
|
||||
The title must:
|
||||
- Be maximum 10 words
|
||||
- Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
|
||||
- Avoid generic terms like "Team Meeting" or "Discussion"
|
||||
|
||||
If multiple unrelated topics were discussed, prioritize the most significant one.
|
||||
or create a compound title (e.g., "Product Launch and Budget Planning").
|
||||
|
||||
<topics_discussed>
|
||||
{titles}
|
||||
</topics_discussed>
|
||||
|
||||
Do not explain, just output the meeting title as a single line.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
|
||||
class TranscriptFinalTitleProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
Generate a final title from topic titles using LlamaIndex
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalTitle
|
||||
TASK = "final_title"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_title(self, text: str) -> dict:
|
||||
async def get_title(self, accumulated_titles: str) -> str:
|
||||
"""
|
||||
Generate a title for the whole recording
|
||||
Generate a title for the whole recording using LLM
|
||||
"""
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
prompt = TITLE_PROMPT.format(titles=accumulated_titles)
|
||||
response = await self.llm.get_response(
|
||||
prompt,
|
||||
[accumulated_titles],
|
||||
tone_name="Title generator",
|
||||
)
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return title_result
|
||||
else:
|
||||
accumulated_titles = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_titles += title_result["title"]
|
||||
self.logger.info(f"Generated title response: {response}")
|
||||
|
||||
return await self.get_title(accumulated_titles)
|
||||
return response
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
||||
title_result = await self.get_title(accumulated_titles)
|
||||
final_title = self.llm.trim_title(title_result["title"])
|
||||
final_title = self.llm.ensure_casing(final_title)
|
||||
accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
|
||||
title = await self.get_title(accumulated_titles)
|
||||
title = clean_title(title)
|
||||
|
||||
final_title = FinalTitle(title=final_title)
|
||||
final_title = FinalTitle(title=title)
|
||||
await self.emit(final_title)
|
||||
|
||||
@@ -1,7 +1,41 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from textwrap import dedent
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
TOPIC_PROMPT = dedent(
|
||||
"""
|
||||
Analyze the following transcript segment and extract the main topic being discussed.
|
||||
Focus on the substantive content and ignore small talk or administrative chatter.
|
||||
|
||||
Create a title that:
|
||||
- Captures the specific subject matter being discussed
|
||||
- Is descriptive and self-explanatory
|
||||
- Uses professional language
|
||||
- Is specific rather than generic
|
||||
|
||||
For the summary:
|
||||
- Summarize the key points in maximum two sentences
|
||||
- Focus on what was discussed, decided, or accomplished
|
||||
- Be concise but informative
|
||||
|
||||
<transcript>
|
||||
{text}
|
||||
</transcript>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
|
||||
class TopicResponse(BaseModel):
|
||||
"""Structured response for topic detection"""
|
||||
|
||||
title: str = Field(description="A descriptive title for the topic being discussed")
|
||||
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
TASK = "topic"
|
||||
|
||||
def __init__(
|
||||
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
|
||||
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
if self.transcript is None:
|
||||
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
return
|
||||
await self.flush()
|
||||
|
||||
async def get_topic(self, text: str) -> dict:
|
||||
async def get_topic(self, text: str) -> TopicResponse:
|
||||
"""
|
||||
Generate a topic and description for a transcription excerpt
|
||||
Generate a topic and description for a transcription excerpt using LLM
|
||||
"""
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
||||
topic_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
prompt = TOPIC_PROMPT.format(text=text)
|
||||
response = await self.llm.get_structured_response(
|
||||
prompt, [text], TopicResponse, tone_name="Topic analyzer"
|
||||
)
|
||||
return topic_result
|
||||
return response
|
||||
|
||||
async def _flush(self):
|
||||
if not self.transcript:
|
||||
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
|
||||
topic_result = await self.get_topic(text=text)
|
||||
title = self.llm.trim_title(topic_result["title"])
|
||||
title = self.llm.ensure_casing(title)
|
||||
title = clean_title(topic_result.title)
|
||||
|
||||
summary = TitleSummary(
|
||||
title=title,
|
||||
summary=topic_result["summary"],
|
||||
summary=topic_result.summary,
|
||||
timestamp=self.transcript.timestamp,
|
||||
duration=self.transcript.duration,
|
||||
transcript=self.transcript,
|
||||
|
||||
@@ -13,14 +13,13 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = Transcript
|
||||
TASK = "translate"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
||||
self.headers = {"Authorization": f"Bearer {settings.TRANSCRIPT_MODAL_API_KEY}"}
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
self.transcript = data
|
||||
|
||||
@@ -9,13 +9,14 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# CORS
|
||||
UI_BASE_URL: str = "http://localhost:3000"
|
||||
CORS_ORIGIN: str = "*"
|
||||
CORS_ALLOW_CREDENTIALS: bool = False
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||
|
||||
# local data directory (audio for no)
|
||||
# local data directory
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# Audio Transcription
|
||||
@@ -24,10 +25,6 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
|
||||
# Translate into the target language
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription modal.com configuration
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
@@ -40,31 +37,15 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Translate into the target language
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# LLM
|
||||
# available backend: openai, modal
|
||||
LLM_BACKEND: str = "modal"
|
||||
|
||||
# LLM common configuration
|
||||
LLM_MODEL: str = "microsoft/phi-4"
|
||||
LLM_URL: str | None = None
|
||||
LLM_HOST: str = "localhost"
|
||||
LLM_PORT: int = 7860
|
||||
LLM_OPENAI_KEY: str | None = None
|
||||
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
||||
LLM_OPENAI_TEMPERATURE: float = 0.7
|
||||
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
|
||||
LLM_MAX_TOKENS: int = 1024
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
ZEPHYR_LLM_URL: str | None = None
|
||||
HERMES_3_8B_LLM_URL: str | None = None
|
||||
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# per-task cases
|
||||
SUMMARY_MODEL: str = "monadical/private/smart"
|
||||
SUMMARY_LLM_URL: str | None = None
|
||||
SUMMARY_LLM_API_KEY: str | None = None
|
||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
|
||||
LLM_API_KEY: str | None = None
|
||||
LLM_CONTEXT_WINDOW: int = 16000
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
@@ -86,12 +67,6 @@ class Settings(BaseSettings):
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
|
||||
# Default LLM model name
|
||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
||||
|
||||
# Cache directory for all model storage
|
||||
CACHE_DIR: str = "./data"
|
||||
|
||||
# Min transcript length to generate topic + summary
|
||||
MIN_TRANSCRIPT_LENGTH: int = 750
|
||||
|
||||
@@ -116,24 +91,20 @@ class Settings(BaseSettings):
|
||||
# Healthcheck
|
||||
HEALTHCHECK_URL: str | None = None
|
||||
|
||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||
|
||||
# Whereby integration
|
||||
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
|
||||
|
||||
WHEREBY_API_KEY: str | None = None
|
||||
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
AWS_WHEREBY_S3_BUCKET: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
|
||||
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
|
||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||
|
||||
# Zulip integration
|
||||
ZULIP_REALM: str | None = None
|
||||
ZULIP_API_KEY: str | None = None
|
||||
ZULIP_BOT_EMAIL: str | None = None
|
||||
|
||||
UI_BASE_URL: str = "http://localhost:3000"
|
||||
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
33
server/reflector/utils/text.py
Normal file
33
server/reflector/utils/text.py
Normal file
@@ -0,0 +1,33 @@
|
||||
def clean_title(title: str) -> str:
|
||||
"""
|
||||
Clean and format a title string for consistent capitalization.
|
||||
|
||||
Rules:
|
||||
- Strip surrounding quotes (single or double)
|
||||
- Capitalize the first word
|
||||
- Capitalize words longer than 3 characters
|
||||
- Keep words with 3 or fewer characters lowercase (except first word)
|
||||
|
||||
Args:
|
||||
title: The title string to clean
|
||||
|
||||
Returns:
|
||||
The cleaned title with consistent capitalization
|
||||
|
||||
Examples:
|
||||
>>> clean_title("hello world")
|
||||
"Hello World"
|
||||
>>> clean_title("meeting with the team")
|
||||
"Meeting With the Team"
|
||||
>>> clean_title("'Title with quotes'")
|
||||
"Title With Quotes"
|
||||
"""
|
||||
title = title.strip("\"'")
|
||||
words = title.split()
|
||||
if words:
|
||||
words = [
|
||||
word.capitalize() if i == 0 or len(word) > 3 else word.lower()
|
||||
for i, word in enumerate(words)
|
||||
]
|
||||
title = " ".join(words)
|
||||
return title
|
||||
@@ -37,8 +37,12 @@ def dummy_processors():
|
||||
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
|
||||
) as mock_translate,
|
||||
):
|
||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
||||
mock_title.return_value = {"title": "LLM TITLE"}
|
||||
from reflector.processors.transcript_topic_detector import TopicResponse
|
||||
|
||||
mock_topic.return_value = TopicResponse(
|
||||
title="LLM TITLE", summary="LLM SUMMARY"
|
||||
)
|
||||
mock_title.return_value = "LLM Title"
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
||||
mock_translate.return_value = "Bonjour le monde"
|
||||
@@ -103,14 +107,15 @@ async def dummy_diarization():
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.llm import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
def __init__(self):
|
||||
self.model_name = "DUMMY MODEL"
|
||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
# LLM doesn't have get_instance anymore, mocking constructor instead
|
||||
with patch("reflector.llm.LLM") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
@@ -129,22 +134,19 @@ async def dummy_storage():
|
||||
async def _get_file_url(self, *args, **kwargs):
|
||||
return "http://fake_server/audio.mp3"
|
||||
|
||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
||||
mock_storage.return_value = DummyStorage()
|
||||
yield
|
||||
async def _get_file(self, *args, **kwargs):
|
||||
from pathlib import Path
|
||||
|
||||
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
return test_mp3.read_bytes()
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
mock_nltk.return_value = "NLTK PACKAGE"
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_casing():
|
||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
||||
mock_casing.return_value = "LLM TITLE"
|
||||
dummy = DummyStorage()
|
||||
with (
|
||||
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
||||
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
||||
):
|
||||
mock_storage.return_value = dummy
|
||||
mock_get_transcripts.return_value = dummy
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_broadcast(nltk):
|
||||
async def test_processor_broadcast():
|
||||
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
||||
|
||||
class TestProcessor(Processor):
|
||||
|
||||
@@ -3,11 +3,9 @@ import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_process(
|
||||
nltk,
|
||||
dummy_transcript,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
ensure_casing,
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
@@ -16,8 +14,8 @@ async def test_basic_process(
|
||||
from reflector.settings import settings
|
||||
from reflector.tools.process import process_audio_file
|
||||
|
||||
# use an LLM test backend
|
||||
settings.LLM_BACKEND = "test"
|
||||
# LLM_BACKEND no longer exists in settings
|
||||
# settings.LLM_BACKEND = "test"
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
# event callback
|
||||
|
||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_process(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
@@ -69,7 +68,7 @@ async def test_transcript_process(
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
assert transcript["title"] == "Llm Title"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
|
||||
@@ -69,8 +69,6 @@ async def test_transcript_rtc_and_websocket(
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
@@ -185,7 +183,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
assert ev["data"]["title"] == "Llm Title"
|
||||
|
||||
assert "WAVEFORM" in eventnames
|
||||
ev = events[eventnames.index("WAVEFORM")]
|
||||
@@ -228,8 +226,6 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
@@ -353,7 +349,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
assert ev["data"]["title"] == "Llm Title"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
|
||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_upload_file(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
assert transcript["title"] == "Llm Title"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
|
||||
21
server/tests/test_utils_text.py
Normal file
21
server/tests/test_utils_text.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_title,expected",
|
||||
[
|
||||
("hello world", "Hello World"),
|
||||
("HELLO WORLD", "Hello World"),
|
||||
("hello WORLD", "Hello World"),
|
||||
("the quick brown fox", "The Quick Brown fox"),
|
||||
("discussion about API design", "Discussion About api Design"),
|
||||
("Q1 2024 budget review", "Q1 2024 Budget Review"),
|
||||
("'Title with quotes'", "Title With Quotes"),
|
||||
("'title with quotes'", "Title With Quotes"),
|
||||
("MiXeD CaSe WoRdS", "Mixed Case Words"),
|
||||
],
|
||||
)
|
||||
def test_clean_title(input_title, expected):
|
||||
assert clean_title(input_title) == expected
|
||||
Reference in New Issue
Block a user