diff --git a/.env.template b/.env.template index ec6d01596..df8408518 100644 --- a/.env.template +++ b/.env.template @@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini" LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" -LLM_MAX_TOKENS="128000" +LLM_MAX_TOKENS="16384" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 00dff82b9..48c94423e 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -11,7 +11,7 @@ class LLMConfig(BaseSettings): llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False - llm_max_tokens: int = 128000 + llm_max_tokens: int = 16384 transcription_model: str = "whisper-1" model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index f601f48b2..383355fd2 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -20,6 +20,16 @@ def get_llm_client(): provider = LLMProvider(llm_config.llm_provider) + # Check if max_token value is defined in liteLLM for given model + # if not use value from cognee configuration + from cognee.infrastructure.llm.utils import get_model_max_tokens + + max_tokens = ( + get_model_max_tokens(llm_config.llm_model) + if get_model_max_tokens(llm_config.llm_model) + else llm_config.llm_max_tokens + ) + if provider == LLMProvider.OPENAI: if llm_config.llm_api_key is None: raise InvalidValueError(message="LLM API key is not set.") @@ -32,7 +42,7 @@ def get_llm_client(): api_version=llm_config.llm_api_version, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, streaming=llm_config.llm_streaming, ) @@ -47,13 +57,13 @@ def get_llm_client(): llm_config.llm_api_key, llm_config.llm_model, "Ollama", - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, ) elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter - return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model) + return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model) elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: @@ -66,7 +76,7 @@ def get_llm_client(): llm_config.llm_api_key, llm_config.llm_model, "Custom", - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, ) else: diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py index 816eaf285..e0aa8945a 100644 --- a/cognee/infrastructure/llm/utils.py +++ b/cognee/infrastructure/llm/utils.py @@ -1,6 +1,11 @@ +import logging +import litellm + from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.llm.get_llm_client import get_llm_client +logger = logging.getLogger(__name__) + def get_max_chunk_tokens(): # Calculate max chunk size based on the following formula @@ -13,3 +18,21 @@ def get_max_chunk_tokens(): max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) return max_chunk_tokens + + +def get_model_max_tokens(model_name: str): + """ + Args: + model_name: name of LLM or embedding model + + Returns: Number of max tokens of model, or None if model is unknown + """ + max_tokens = None + + if model_name in litellm.model_cost: + max_tokens = litellm.model_cost[model_name]["max_tokens"] + logger.debug(f"Max input tokens for {model_name}: {max_tokens}") + else: + logger.info("Model not found in LiteLLM's model_cost.") + + return max_tokens