diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 48c94423e..b8ce29ccb 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -1,6 +1,8 @@ from typing import Optional from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import model_validator, Field +import os class LLMConfig(BaseSettings): @@ -16,6 +18,60 @@ class LLMConfig(BaseSettings): model_config = SettingsConfigDict(env_file=".env", extra="allow") + @model_validator(mode="after") + def ensure_env_vars_for_ollama(self) -> "LLMConfig": + """ + Only if llm_provider is 'ollama': + - If any of (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY) is set, all must be set. + - If any of (EMBEDDING_PROVIDER, EMBEDDING_MODEL, EMBEDDING_DIMENSIONS, + HUGGINGFACE_TOKENIZER) is set, all must be set. + Otherwise, skip these checks. + """ + + if self.llm_provider != "ollama": + # Skip checks unless provider is "ollama" + return self + + def is_env_set(var_name: str) -> bool: + """Return True if environment variable is present and non-empty.""" + val = os.environ.get(var_name) + return val is not None and val.strip() != "" + + # + # 1. Check LLM environment variables + # + llm_env_vars = { + "LLM_MODEL": is_env_set("LLM_MODEL"), + "LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"), + "LLM_API_KEY": is_env_set("LLM_API_KEY"), + } + if any(llm_env_vars.values()) and not all(llm_env_vars.values()): + missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set] + raise ValueError( + "You have set some but not all of the required environment variables " + f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}" + ) + + # + # 2. Check embedding environment variables + # + embedding_env_vars = { + "EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"), + "EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"), + "EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"), + "HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"), + } + if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()): + missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set] + raise ValueError( + "You have set some but not all of the required environment variables " + "for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, " + "EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: " + f"{missing_embed}" + ) + + return self + def to_dict(self) -> dict: return { "provider": self.llm_provider,