feat: Use litellm max token size as default for model, if model exists in litellm

This commit is contained in:
Igor Ilic 2025-01-28 17:00:47 +01:00
parent 710ca78d6e
commit a8644e0bd7
4 changed files with 39 additions and 6 deletions

View file

@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini"
LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="128000"
LLM_MAX_TOKENS="16384"
GRAPHISTRY_USERNAME=
GRAPHISTRY_PASSWORD=

View file

@ -11,7 +11,7 @@ class LLMConfig(BaseSettings):
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
llm_max_tokens: int = 128000
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1"
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -20,6 +20,16 @@ def get_llm_client():
provider = LLMProvider(llm_config.llm_provider)
# Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration
from cognee.infrastructure.llm.utils import get_model_max_tokens
max_tokens = (
get_model_max_tokens(llm_config.llm_model)
if get_model_max_tokens(llm_config.llm_model)
else llm_config.llm_max_tokens
)
if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None:
raise InvalidValueError(message="LLM API key is not set.")
@ -32,7 +42,7 @@ def get_llm_client():
api_version=llm_config.llm_api_version,
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_tokens=llm_config.llm_max_tokens,
max_tokens=max_tokens,
streaming=llm_config.llm_streaming,
)
@ -47,13 +57,13 @@ def get_llm_client():
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_tokens=llm_config.llm_max_tokens,
max_tokens=max_tokens,
)
elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model)
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None:
@ -66,7 +76,7 @@ def get_llm_client():
llm_config.llm_api_key,
llm_config.llm_model,
"Custom",
max_tokens=llm_config.llm_max_tokens,
max_tokens=max_tokens,
)
else:

View file

@ -1,6 +1,11 @@
import logging
import litellm
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
logger = logging.getLogger(__name__)
def get_max_chunk_tokens():
# Calculate max chunk size based on the following formula
@ -13,3 +18,21 @@ def get_max_chunk_tokens():
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_tokens(model_name: str):
"""
Args:
model_name: name of LLM or embedding model
Returns: Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens