feat: Use litellm max token size as default for model, if model exists in litellm
This commit is contained in:
parent
710ca78d6e
commit
a8644e0bd7
4 changed files with 39 additions and 6 deletions
|
|
@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini"
|
|||
LLM_PROVIDER="openai"
|
||||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
LLM_MAX_TOKENS="128000"
|
||||
LLM_MAX_TOKENS="16384"
|
||||
|
||||
GRAPHISTRY_USERNAME=
|
||||
GRAPHISTRY_PASSWORD=
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_api_version: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
llm_max_tokens: int = 128000
|
||||
llm_max_tokens: int = 16384
|
||||
transcription_model: str = "whisper-1"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,16 @@ def get_llm_client():
|
|||
|
||||
provider = LLMProvider(llm_config.llm_provider)
|
||||
|
||||
# Check if max_token value is defined in liteLLM for given model
|
||||
# if not use value from cognee configuration
|
||||
from cognee.infrastructure.llm.utils import get_model_max_tokens
|
||||
|
||||
max_tokens = (
|
||||
get_model_max_tokens(llm_config.llm_model)
|
||||
if get_model_max_tokens(llm_config.llm_model)
|
||||
else llm_config.llm_max_tokens
|
||||
)
|
||||
|
||||
if provider == LLMProvider.OPENAI:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise InvalidValueError(message="LLM API key is not set.")
|
||||
|
|
@ -32,7 +42,7 @@ def get_llm_client():
|
|||
api_version=llm_config.llm_api_version,
|
||||
model=llm_config.llm_model,
|
||||
transcription_model=llm_config.transcription_model,
|
||||
max_tokens=llm_config.llm_max_tokens,
|
||||
max_tokens=max_tokens,
|
||||
streaming=llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
|
|
@ -47,13 +57,13 @@ def get_llm_client():
|
|||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_tokens=llm_config.llm_max_tokens,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
from .anthropic.adapter import AnthropicAdapter
|
||||
|
||||
return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model)
|
||||
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
if llm_config.llm_api_key is None:
|
||||
|
|
@ -66,7 +76,7 @@ def get_llm_client():
|
|||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_tokens=llm_config.llm_max_tokens,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import logging
|
||||
import litellm
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_max_chunk_tokens():
|
||||
# Calculate max chunk size based on the following formula
|
||||
|
|
@ -13,3 +18,21 @@ def get_max_chunk_tokens():
|
|||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||
|
||||
return max_chunk_tokens
|
||||
|
||||
|
||||
def get_model_max_tokens(model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model_name: name of LLM or embedding model
|
||||
|
||||
Returns: Number of max tokens of model, or None if model is unknown
|
||||
"""
|
||||
max_tokens = None
|
||||
|
||||
if model_name in litellm.model_cost:
|
||||
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
||||
else:
|
||||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue