From 4b777cf214e63b8954e01d3970085fe601c4d71b Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:44:45 -0800 Subject: [PATCH] feat: add validation to llm env variables (#558) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … needed ## Description ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **New Features** - Implemented enhanced configuration validation for environment-based settings. Now, if any configuration parameter is provided via the environment, all required settings must be present. This improvement helps catch misconfigurations early, reducing potential errors and ensuring a smoother, more reliable user experience. These proactive measures significantly enhance overall system stability and performance. --------- Co-authored-by: Boris --- cognee/infrastructure/llm/config.py | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 48c94423e..b8ce29ccb 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -1,6 +1,8 @@ from typing import Optional from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import model_validator, Field +import os class LLMConfig(BaseSettings): @@ -16,6 +18,60 @@ class LLMConfig(BaseSettings): model_config = SettingsConfigDict(env_file=".env", extra="allow") + @model_validator(mode="after") + def ensure_env_vars_for_ollama(self) -> "LLMConfig": + """ + Only if llm_provider is 'ollama': + - If any of (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY) is set, all must be set. + - If any of (EMBEDDING_PROVIDER, EMBEDDING_MODEL, EMBEDDING_DIMENSIONS, + HUGGINGFACE_TOKENIZER) is set, all must be set. + Otherwise, skip these checks. + """ + + if self.llm_provider != "ollama": + # Skip checks unless provider is "ollama" + return self + + def is_env_set(var_name: str) -> bool: + """Return True if environment variable is present and non-empty.""" + val = os.environ.get(var_name) + return val is not None and val.strip() != "" + + # + # 1. Check LLM environment variables + # + llm_env_vars = { + "LLM_MODEL": is_env_set("LLM_MODEL"), + "LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"), + "LLM_API_KEY": is_env_set("LLM_API_KEY"), + } + if any(llm_env_vars.values()) and not all(llm_env_vars.values()): + missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set] + raise ValueError( + "You have set some but not all of the required environment variables " + f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}" + ) + + # + # 2. Check embedding environment variables + # + embedding_env_vars = { + "EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"), + "EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"), + "EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"), + "HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"), + } + if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()): + missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set] + raise ValueError( + "You have set some but not all of the required environment variables " + "for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, " + "EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: " + f"{missing_embed}" + ) + + return self + def to_dict(self) -> dict: return { "provider": self.llm_provider,