diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 950a1a07..c93d6099 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -51,6 +51,7 @@ class LLM: nltk.download("punkt", download_dir=settings.CACHE_DIR) # For POS tagging nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR) + nltk.data.path.append(settings.CACHE_DIR) @classmethod def register(cls, name, klass): diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 3fc45819..249661db 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -95,7 +95,7 @@ class Settings(BaseSettings): DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5" # Cache directory for all model storage - CACHE_DIR: str = "data" + CACHE_DIR: str = "./data" settings = Settings()