From fb93c55993963461ee02db7eeefe4fd0dca6ee0e Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 13 Sep 2023 11:40:39 +0200 Subject: [PATCH] server: fix nltk download --- server/reflector/llm/base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index c7704a1f..63cc1c50 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -6,17 +6,17 @@ from typing import TypeVar import nltk from prometheus_client import Counter, Histogram -from transformers import GenerationConfig - from reflector.llm.llm_params import TaskParams from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry +from transformers import GenerationConfig T = TypeVar("T", bound="LLM") class LLM: + _nltk_downloaded = False _registry = {} m_generate = Histogram( "llm_generate", @@ -39,18 +39,17 @@ class LLM: ["backend"], ) - def __enter__(self): - self.ensure_nltk() - @classmethod def ensure_nltk(cls): """ Make sure NLTK package is installed. Searches in the cache and downloads only if needed. """ - nltk.download("punkt") - # For POS tagging - nltk.download("averaged_perceptron_tagger") + if not cls._nltk_downloaded: + nltk.download("punkt") + # For POS tagging + nltk.download("averaged_perceptron_tagger") + cls._nltk_downloaded = True @classmethod def register(cls, name, klass): @@ -70,6 +69,7 @@ class LLM: if name not in cls._registry: module_name = f"reflector.llm.llm_{name}" importlib.import_module(module_name) + cls.ensure_nltk() return cls._registry[name](model_name) def get_model_name(self) -> str: