mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: fix nltk download
This commit is contained in:
@@ -6,17 +6,17 @@ from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from transformers import GenerationConfig
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_nltk_downloaded = False
|
||||
_registry = {}
|
||||
m_generate = Histogram(
|
||||
"llm_generate",
|
||||
@@ -39,18 +39,17 @@ class LLM:
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.ensure_nltk()
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
nltk.download("punkt")
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger")
|
||||
if not cls._nltk_downloaded:
|
||||
nltk.download("punkt")
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger")
|
||||
cls._nltk_downloaded = True
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
@@ -70,6 +69,7 @@ class LLM:
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
cls.ensure_nltk()
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user