feat: Use litellm max token size as default for model, if model exists in litellm

This commit is contained in:
Igor Ilic 2025-01-28 17:00:47 +01:00
parent 710ca78d6e
commit a8644e0bd7
4 changed files with 39 additions and 6 deletions

View file

@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini"
LLM_PROVIDER="openai" LLM_PROVIDER="openai"
LLM_ENDPOINT="" LLM_ENDPOINT=""
LLM_API_VERSION="" LLM_API_VERSION=""
LLM_MAX_TOKENS="128000" LLM_MAX_TOKENS="16384"
GRAPHISTRY_USERNAME= GRAPHISTRY_USERNAME=
GRAPHISTRY_PASSWORD= GRAPHISTRY_PASSWORD=

View file

@ -11,7 +11,7 @@ class LLMConfig(BaseSettings):
llm_api_version: Optional[str] = None llm_api_version: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False
llm_max_tokens: int = 128000 llm_max_tokens: int = 16384
transcription_model: str = "whisper-1" transcription_model: str = "whisper-1"
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -20,6 +20,16 @@ def get_llm_client():
provider = LLMProvider(llm_config.llm_provider) provider = LLMProvider(llm_config.llm_provider)
# Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration
from cognee.infrastructure.llm.utils import get_model_max_tokens
max_tokens = (
get_model_max_tokens(llm_config.llm_model)
if get_model_max_tokens(llm_config.llm_model)
else llm_config.llm_max_tokens
)
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise InvalidValueError(message="LLM API key is not set.") raise InvalidValueError(message="LLM API key is not set.")
@ -32,7 +42,7 @@ def get_llm_client():
api_version=llm_config.llm_api_version, api_version=llm_config.llm_api_version,
model=llm_config.llm_model, model=llm_config.llm_model,
transcription_model=llm_config.transcription_model, transcription_model=llm_config.transcription_model,
max_tokens=llm_config.llm_max_tokens, max_tokens=max_tokens,
streaming=llm_config.llm_streaming, streaming=llm_config.llm_streaming,
) )
@ -47,13 +57,13 @@ def get_llm_client():
llm_config.llm_api_key, llm_config.llm_api_key,
llm_config.llm_model, llm_config.llm_model,
"Ollama", "Ollama",
max_tokens=llm_config.llm_max_tokens, max_tokens=max_tokens,
) )
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model) return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
@ -66,7 +76,7 @@ def get_llm_client():
llm_config.llm_api_key, llm_config.llm_api_key,
llm_config.llm_model, llm_config.llm_model,
"Custom", "Custom",
max_tokens=llm_config.llm_max_tokens, max_tokens=max_tokens,
) )
else: else:

View file

@ -1,6 +1,11 @@
import logging
import litellm
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
logger = logging.getLogger(__name__)
def get_max_chunk_tokens(): def get_max_chunk_tokens():
# Calculate max chunk size based on the following formula # Calculate max chunk size based on the following formula
@ -13,3 +18,21 @@ def get_max_chunk_tokens():
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
return max_chunk_tokens return max_chunk_tokens
def get_model_max_tokens(model_name: str):
"""
Args:
model_name: name of LLM or embedding model
Returns: Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens