import importlib import json import re from time import monotonic from typing import TypeVar import nltk from prometheus_client import Counter, Histogram from transformers import GenerationConfig from reflector.llm.llm_params import TaskParams from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry T = TypeVar("T", bound="LLM") class LLM: _registry = {} m_generate = Histogram( "llm_generate", "Time spent in LLM.generate", ["backend"], ) m_generate_call = Counter( "llm_generate_call", "Number of calls to LLM.generate", ["backend"], ) m_generate_success = Counter( "llm_generate_success", "Number of successful calls to LLM.generate", ["backend"], ) m_generate_failure = Counter( "llm_generate_failure", "Number of failed calls to LLM.generate", ["backend"], ) def __enter__(self): self.ensure_nltk() @classmethod def ensure_nltk(cls): """ Make sure NLTK package is installed. Searches in the cache and downloads only if needed. """ nltk.download("punkt", download_dir=settings.CACHE_DIR) # For POS tagging nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR) nltk.data.path.append(settings.CACHE_DIR) @classmethod def register(cls, name, klass): cls._registry[name] = klass @classmethod def get_instance(cls, model_name: str | None = None, name: str = None) -> T: """ Return an instance depending on the settings. Settings used: - `LLM_BACKEND`: key of the backend, defaults to `oobabooga` - `LLM_URL`: url of the backend """ if name is None: name = settings.LLM_BACKEND if name not in cls._registry: module_name = f"reflector.llm.llm_{name}" importlib.import_module(module_name) return cls._registry[name](model_name) def get_model_name(self) -> str: """ Get the currently set model name """ return self._get_model_name() def _get_model_name(self) -> str: pass def set_model_name(self, model_name: str) -> bool: """ Update the model name with the provided model name """ return self._set_model_name(model_name) def _set_model_name(self, model_name: str) -> bool: raise NotImplementedError @property def template(self) -> str: """ Return the LLM Prompt template """ return """ ### Human: {instruct} {text} ### Assistant: """ def __init__(self): name = self.__class__.__name__ self.m_generate = self.m_generate.labels(name) self.m_generate_call = self.m_generate_call.labels(name) self.m_generate_success = self.m_generate_success.labels(name) self.m_generate_failure = self.m_generate_failure.labels(name) async def warmup(self, logger: reflector_logger): start = monotonic() name = self.__class__.__name__ logger.info(f"LLM[{name}] warming up...") try: await self._warmup(logger=logger) duration = monotonic() - start logger.info(f"LLM[{name}] warmup took {duration:.2f} seconds") except Exception: logger.exception(f"LLM[{name}] warmup failed, ignoring") async def _warmup(self, logger: reflector_logger): pass @property def tokenizer(self): """ Return the tokenizer instance used by LLM """ return self._get_tokenizer() def _get_tokenizer(self): pass async def generate( self, prompt: str, logger: reflector_logger, gen_schema: dict | None = None, gen_cfg: GenerationConfig | None = None, **kwargs, ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) if gen_cfg: gen_cfg = gen_cfg.to_dict() self.m_generate_call.inc() try: with self.m_generate.time(): result = await retry(self._generate)( prompt=prompt, gen_schema=gen_schema, gen_cfg=gen_cfg, **kwargs, ) self.m_generate_success.inc() except Exception: logger.exception("Failed to call llm after retrying") self.m_generate_failure.inc() raise logger.debug("LLM result [raw]", result=repr(result)) if isinstance(result, str): result = self._parse_json(result) logger.debug("LLM result [parsed]", result=repr(result)) return result def ensure_casing(self, title: str) -> str: """ LLM takes care of word casing, but in rare cases this can falter. This is a fallback to ensure the casing of topics is in a proper format. We select nouns, verbs and adjectives and check if camel casing is present and fix it, if not. Will not perform any other changes. """ tokens = nltk.word_tokenize(title) pos_tags = nltk.pos_tag(tokens) camel_cased = [] whitelisted_pos_tags = [ "NN", "NNS", "NNP", "NNPS", # Noun POS "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", # Verb POS "JJ", "JJR", "JJS", # Adjective POS ] # If at all there is an exception, do not block other reflector # processes. Return the LLM generated title, at the least. try: for word, pos in pos_tags: if pos in whitelisted_pos_tags and word[0].islower(): camel_cased.append(word[0].upper() + word[1:]) else: camel_cased.append(word) modified_title = " ".join(camel_cased) # The result can have words in braces with additional space. # Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc. pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])" title = re.sub(pattern, "", modified_title) except Exception as e: reflector_logger.info( f"Failed to ensure casing on {title=} " f"with exception : {str(e)}" ) return title async def _generate( self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs ) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: result = result.strip() # try detecting code block if exist # starts with ```json\n, ends with ``` # or starts with ```\n, ends with ``` # or starts with \n```javascript\n, ends with ``` regex = r"```(json|javascript|)?(.*)```" matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL) if matches: result = matches[0][1] else: # maybe the prompt has been started with ```json # so if text ends with ```, just remove it and use it as json if result.endswith("```"): result = result[:-3] return json.loads(result.strip()) def text_token_threshold(self, task_params: TaskParams | None) -> int: """ Choose the token size to set as the threshold to pack the LLM calls """ buffer_token_size = 25 default_output_tokens = 1000 context_window = self.tokenizer.model_max_length tokens = self.tokenizer.tokenize( self.create_prompt(instruct=task_params.instruct, text="") ) threshold = context_window - len(tokens) - buffer_token_size if task_params.gen_cfg: threshold -= task_params.gen_cfg.max_new_tokens else: threshold -= default_output_tokens return threshold def split_corpus( self, corpus: str, task_params: TaskParams, token_threshold: int | None = None, ) -> list[str]: """ Split the input to the LLM due to CUDA memory limitations and LLM context window restrictions. Accumulate tokens from full sentences till threshold and yield accumulated tokens. Reset accumulation when threshold is reached and repeat process. """ if not token_threshold: token_threshold = self.text_token_threshold(task_params=task_params) accumulated_tokens = [] accumulated_sentences = [] accumulated_token_count = 0 corpus_sentences = nltk.sent_tokenize(corpus) for sentence in corpus_sentences: tokens = self.tokenizer.tokenize(sentence) if accumulated_token_count + len(tokens) <= token_threshold: accumulated_token_count += len(tokens) accumulated_tokens.extend(tokens) accumulated_sentences.append(sentence) else: yield "".join(accumulated_sentences) accumulated_token_count = len(tokens) accumulated_tokens = tokens accumulated_sentences = [sentence] if accumulated_tokens: yield " ".join(accumulated_sentences) def create_prompt(self, instruct: str, text: str) -> str: """ Create a consumable prompt based on the prompt template """ return self.template.format(instruct=instruct, text=text)