server: fix nltk download

This commit is contained in:
2023-09-13 11:40:39 +02:00
parent 0ac4ee4490
commit fb93c55993

View File

@@ -6,17 +6,17 @@ from typing import TypeVar
import nltk import nltk
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
from transformers import GenerationConfig
from reflector.llm.llm_params import TaskParams from reflector.llm.llm_params import TaskParams
from reflector.logger import logger as reflector_logger from reflector.logger import logger as reflector_logger
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
from transformers import GenerationConfig
T = TypeVar("T", bound="LLM") T = TypeVar("T", bound="LLM")
class LLM: class LLM:
_nltk_downloaded = False
_registry = {} _registry = {}
m_generate = Histogram( m_generate = Histogram(
"llm_generate", "llm_generate",
@@ -39,18 +39,17 @@ class LLM:
["backend"], ["backend"],
) )
def __enter__(self):
self.ensure_nltk()
@classmethod @classmethod
def ensure_nltk(cls): def ensure_nltk(cls):
""" """
Make sure NLTK package is installed. Searches in the cache and Make sure NLTK package is installed. Searches in the cache and
downloads only if needed. downloads only if needed.
""" """
nltk.download("punkt") if not cls._nltk_downloaded:
# For POS tagging nltk.download("punkt")
nltk.download("averaged_perceptron_tagger") # For POS tagging
nltk.download("averaged_perceptron_tagger")
cls._nltk_downloaded = True
@classmethod @classmethod
def register(cls, name, klass): def register(cls, name, klass):
@@ -70,6 +69,7 @@ class LLM:
if name not in cls._registry: if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}" module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name) importlib.import_module(module_name)
cls.ensure_nltk()
return cls._registry[name](model_name) return cls._registry[name](model_name)
def get_model_name(self) -> str: def get_model_name(self) -> str: