219 lines
7.8 KiB
Python
219 lines
7.8 KiB
Python
import os
|
|
from typing import Optional, ClassVar
|
|
from functools import lru_cache
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
from pydantic import model_validator
|
|
|
|
try:
|
|
from baml_py import ClientRegistry
|
|
except ImportError:
|
|
ClientRegistry = None
|
|
|
|
|
|
class LLMConfig(BaseSettings):
|
|
"""
|
|
Configuration settings for the LLM (Large Language Model) provider and related options.
|
|
|
|
Public instance variables include:
|
|
- llm_provider
|
|
- llm_model
|
|
- llm_endpoint
|
|
- llm_api_key
|
|
- llm_api_version
|
|
- llm_temperature
|
|
- llm_streaming
|
|
- llm_max_completion_tokens
|
|
- transcription_model
|
|
- graph_prompt_path
|
|
- llm_rate_limit_enabled
|
|
- llm_rate_limit_requests
|
|
- llm_rate_limit_interval
|
|
- embedding_rate_limit_enabled
|
|
- embedding_rate_limit_requests
|
|
- embedding_rate_limit_interval
|
|
|
|
Public methods include:
|
|
- ensure_env_vars_for_ollama
|
|
- to_dict
|
|
"""
|
|
|
|
structured_output_framework: str = "instructor"
|
|
llm_provider: str = "openai"
|
|
llm_model: str = "gpt-5-mini"
|
|
llm_endpoint: str = ""
|
|
llm_api_key: Optional[str] = None
|
|
llm_api_version: Optional[str] = None
|
|
llm_temperature: float = 0.0
|
|
llm_streaming: bool = False
|
|
llm_max_completion_tokens: int = 16384
|
|
|
|
baml_llm_provider: str = "openai"
|
|
baml_llm_model: str = "gpt-5-mini"
|
|
baml_llm_endpoint: str = ""
|
|
baml_llm_api_key: Optional[str] = None
|
|
baml_llm_temperature: float = 0.0
|
|
baml_llm_api_version: str = ""
|
|
|
|
transcription_model: str = "whisper-1"
|
|
graph_prompt_path: str = "generate_graph_prompt.txt"
|
|
temporal_graph_prompt_path: str = "generate_event_graph_prompt.txt"
|
|
event_entity_prompt_path: str = "generate_event_entity_prompt.txt"
|
|
llm_rate_limit_enabled: bool = False
|
|
llm_rate_limit_requests: int = 60
|
|
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
|
embedding_rate_limit_enabled: bool = False
|
|
embedding_rate_limit_requests: int = 60
|
|
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
|
|
|
fallback_api_key: str = ""
|
|
fallback_endpoint: str = ""
|
|
fallback_model: str = ""
|
|
|
|
baml_registry: ClassVar = None
|
|
|
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
|
|
def model_post_init(self, __context) -> None:
|
|
"""Initialize the BAML registry after the model is created."""
|
|
# Check if BAML is selected as structured output framework but not available
|
|
if self.structured_output_framework == "baml" and ClientRegistry is None:
|
|
raise ImportError(
|
|
"BAML is selected as structured output framework but not available. "
|
|
"Please install with 'pip install cognee[baml]' to use BAML extraction features."
|
|
)
|
|
|
|
if ClientRegistry is not None:
|
|
self.baml_registry = ClientRegistry()
|
|
self.baml_registry.add_llm_client(
|
|
name=self.baml_llm_provider,
|
|
provider=self.baml_llm_provider,
|
|
options={
|
|
"model": self.baml_llm_model,
|
|
"temperature": self.baml_llm_temperature,
|
|
"api_key": self.baml_llm_api_key,
|
|
"base_url": self.baml_llm_endpoint,
|
|
"api_version": self.baml_llm_api_version,
|
|
},
|
|
)
|
|
# Sets the primary client
|
|
self.baml_registry.set_primary(self.baml_llm_provider)
|
|
|
|
@model_validator(mode="after")
|
|
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
|
|
"""
|
|
Validate required environment variables for the 'ollama' LLM provider.
|
|
|
|
Raises ValueError if some required environment variables are set without the others.
|
|
Only checks are performed when 'llm_provider' is set to 'ollama'.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- 'LLMConfig': The instance of LLMConfig after validation.
|
|
"""
|
|
|
|
if self.llm_provider != "ollama":
|
|
# Skip checks unless provider is "ollama"
|
|
return self
|
|
|
|
def is_env_set(var_name: str) -> bool:
|
|
"""
|
|
Check if a given environment variable is set and non-empty.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- var_name (str): The name of the environment variable to check.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: True if the environment variable exists and is not empty, otherwise False.
|
|
"""
|
|
val = os.environ.get(var_name)
|
|
return val is not None and val.strip() != ""
|
|
|
|
#
|
|
# 1. Check LLM environment variables
|
|
#
|
|
llm_env_vars = {
|
|
"LLM_MODEL": is_env_set("LLM_MODEL"),
|
|
"LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"),
|
|
"LLM_API_KEY": is_env_set("LLM_API_KEY"),
|
|
}
|
|
if any(llm_env_vars.values()) and not all(llm_env_vars.values()):
|
|
missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set]
|
|
raise ValueError(
|
|
"You have set some but not all of the required environment variables "
|
|
f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}"
|
|
)
|
|
|
|
#
|
|
# 2. Check embedding environment variables
|
|
#
|
|
embedding_env_vars = {
|
|
"EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"),
|
|
"EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"),
|
|
"EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"),
|
|
"HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"),
|
|
}
|
|
if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()):
|
|
missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set]
|
|
raise ValueError(
|
|
"You have set some but not all of the required environment variables "
|
|
"for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, "
|
|
"EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: "
|
|
f"{missing_embed}"
|
|
)
|
|
|
|
return self
|
|
|
|
def to_dict(self) -> dict:
|
|
"""
|
|
Convert the LLMConfig instance into a dictionary representation.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- dict: A dictionary containing the configuration settings of the LLMConfig
|
|
instance.
|
|
"""
|
|
return {
|
|
"provider": self.llm_provider,
|
|
"model": self.llm_model,
|
|
"endpoint": self.llm_endpoint,
|
|
"api_key": self.llm_api_key,
|
|
"api_version": self.llm_api_version,
|
|
"temperature": self.llm_temperature,
|
|
"streaming": self.llm_streaming,
|
|
"max_completion_tokens": self.llm_max_completion_tokens,
|
|
"transcription_model": self.transcription_model,
|
|
"graph_prompt_path": self.graph_prompt_path,
|
|
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
|
"rate_limit_requests": self.llm_rate_limit_requests,
|
|
"rate_limit_interval": self.llm_rate_limit_interval,
|
|
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
|
|
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
|
|
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
|
|
"fallback_api_key": self.fallback_api_key,
|
|
"fallback_endpoint": self.fallback_endpoint,
|
|
"fallback_model": self.fallback_model,
|
|
}
|
|
|
|
|
|
@lru_cache
|
|
def get_llm_config():
|
|
"""
|
|
Retrieve and cache the LLM configuration.
|
|
|
|
This function returns an instance of the LLMConfig class. It leverages
|
|
caching to ensure that repeated calls do not create new instances,
|
|
but instead return the already created configuration object.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- LLMConfig: An instance of the LLMConfig class containing the configuration for the
|
|
LLM.
|
|
"""
|
|
return LLMConfig()
|