mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Feature additions (#210)
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
This commit is contained in:
@@ -2,12 +2,19 @@ import importlib
|
||||
import json
|
||||
import re
|
||||
from time import monotonic
|
||||
from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_registry = {}
|
||||
@@ -32,12 +39,25 @@ class LLM:
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.ensure_nltk()
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
nltk.download("punkt", download_dir=settings.CACHE_DIR)
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR)
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name=None):
|
||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
@@ -50,7 +70,39 @@ class LLM:
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
return cls._registry[name]()
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get the currently set model name
|
||||
"""
|
||||
return self._get_model_name()
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
pass
|
||||
|
||||
def set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Update the model name with the provided model name
|
||||
"""
|
||||
return self._set_model_name(model_name)
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
"""
|
||||
Return the LLM Prompt template
|
||||
"""
|
||||
return """
|
||||
### Human:
|
||||
{instruct}
|
||||
|
||||
{text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -73,21 +125,39 @@ class LLM:
|
||||
async def _warmup(self, logger: reflector_logger):
|
||||
pass
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer instance used by LLM
|
||||
"""
|
||||
return self._get_tokenizer()
|
||||
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
schema: dict | None = None,
|
||||
gen_schema: dict | None = None,
|
||||
gen_cfg: GenerationConfig | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
|
||||
if gen_cfg:
|
||||
gen_cfg = gen_cfg.to_dict()
|
||||
self.m_generate_call.inc()
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._generate)(
|
||||
prompt=prompt, schema=schema, **kwargs
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
@@ -100,7 +170,60 @@ class LLM:
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
def ensure_casing(self, title: str) -> str:
|
||||
"""
|
||||
LLM takes care of word casing, but in rare cases this
|
||||
can falter. This is a fallback to ensure the casing of
|
||||
topics is in a proper format.
|
||||
|
||||
We select nouns, verbs and adjectives and check if camel
|
||||
casing is present and fix it, if not. Will not perform
|
||||
any other changes.
|
||||
"""
|
||||
tokens = nltk.word_tokenize(title)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
camel_cased = []
|
||||
|
||||
whitelisted_pos_tags = [
|
||||
"NN",
|
||||
"NNS",
|
||||
"NNP",
|
||||
"NNPS", # Noun POS
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ", # Verb POS
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS", # Adjective POS
|
||||
]
|
||||
|
||||
# If at all there is an exception, do not block other reflector
|
||||
# processes. Return the LLM generated title, at the least.
|
||||
try:
|
||||
for word, pos in pos_tags:
|
||||
if pos in whitelisted_pos_tags and word[0].islower():
|
||||
camel_cased.append(word[0].upper() + word[1:])
|
||||
else:
|
||||
camel_cased.append(word)
|
||||
modified_title = " ".join(camel_cased)
|
||||
|
||||
# The result can have words in braces with additional space.
|
||||
# Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc.
|
||||
pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])"
|
||||
title = re.sub(pattern, "", modified_title)
|
||||
except Exception as e:
|
||||
reflector_logger.info(
|
||||
f"Failed to ensure casing on {title=} " f"with exception : {str(e)}"
|
||||
)
|
||||
|
||||
return title
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
@@ -122,3 +245,62 @@ class LLM:
|
||||
result = result[:-3]
|
||||
|
||||
return json.loads(result.strip())
|
||||
|
||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
||||
"""
|
||||
Choose the token size to set as the threshold to pack the LLM calls
|
||||
"""
|
||||
buffer_token_size = 25
|
||||
default_output_tokens = 1000
|
||||
context_window = self.tokenizer.model_max_length
|
||||
tokens = self.tokenizer.tokenize(
|
||||
self.create_prompt(instruct=task_params.instruct, text="")
|
||||
)
|
||||
threshold = context_window - len(tokens) - buffer_token_size
|
||||
if task_params.gen_cfg:
|
||||
threshold -= task_params.gen_cfg.max_new_tokens
|
||||
else:
|
||||
threshold -= default_output_tokens
|
||||
return threshold
|
||||
|
||||
def split_corpus(
|
||||
self,
|
||||
corpus: str,
|
||||
task_params: TaskParams,
|
||||
token_threshold: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
||||
restrictions.
|
||||
|
||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
||||
"""
|
||||
if not token_threshold:
|
||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
||||
|
||||
accumulated_tokens = []
|
||||
accumulated_sentences = []
|
||||
accumulated_token_count = 0
|
||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
||||
|
||||
for sentence in corpus_sentences:
|
||||
tokens = self.tokenizer.tokenize(sentence)
|
||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
||||
accumulated_token_count += len(tokens)
|
||||
accumulated_tokens.extend(tokens)
|
||||
accumulated_sentences.append(sentence)
|
||||
else:
|
||||
yield "".join(accumulated_sentences)
|
||||
accumulated_token_count = len(tokens)
|
||||
accumulated_tokens = tokens
|
||||
accumulated_sentences = [sentence]
|
||||
|
||||
if accumulated_tokens:
|
||||
yield " ".join(accumulated_sentences)
|
||||
|
||||
def create_prompt(self, instruct: str, text: str) -> str:
|
||||
"""
|
||||
Create a consumable prompt based on the prompt template
|
||||
"""
|
||||
return self.template.format(instruct=instruct, text=text)
|
||||
|
||||
Reference in New Issue
Block a user