server: fix nltk download

This commit is contained in:
2023-09-13 11:40:39 +02:00
parent 0ac4ee4490
commit fb93c55993

View File

@@ -6,17 +6,17 @@ from typing import TypeVar
import nltk
from prometheus_client import Counter, Histogram
from transformers import GenerationConfig
from reflector.llm.llm_params import TaskParams
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
from transformers import GenerationConfig
T = TypeVar("T", bound="LLM")
class LLM:
_nltk_downloaded = False
_registry = {}
m_generate = Histogram(
"llm_generate",
@@ -39,18 +39,17 @@ class LLM:
["backend"],
)
def __enter__(self):
self.ensure_nltk()
@classmethod
def ensure_nltk(cls):
"""
Make sure NLTK package is installed. Searches in the cache and
downloads only if needed.
"""
nltk.download("punkt")
# For POS tagging
nltk.download("averaged_perceptron_tagger")
if not cls._nltk_downloaded:
nltk.download("punkt")
# For POS tagging
nltk.download("averaged_perceptron_tagger")
cls._nltk_downloaded = True
@classmethod
def register(cls, name, klass):
@@ -70,6 +69,7 @@ class LLM:
if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name)
cls.ensure_nltk()
return cls._registry[name](model_name)
def get_model_name(self) -> str: