feat: Use litellm max token size as default for model, if model exists in litellm
This commit is contained in:
parent
710ca78d6e
commit
a8644e0bd7
4 changed files with 39 additions and 6 deletions
|
|
@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini"
|
||||||
LLM_PROVIDER="openai"
|
LLM_PROVIDER="openai"
|
||||||
LLM_ENDPOINT=""
|
LLM_ENDPOINT=""
|
||||||
LLM_API_VERSION=""
|
LLM_API_VERSION=""
|
||||||
LLM_MAX_TOKENS="128000"
|
LLM_MAX_TOKENS="16384"
|
||||||
|
|
||||||
GRAPHISTRY_USERNAME=
|
GRAPHISTRY_USERNAME=
|
||||||
GRAPHISTRY_PASSWORD=
|
GRAPHISTRY_PASSWORD=
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class LLMConfig(BaseSettings):
|
||||||
llm_api_version: Optional[str] = None
|
llm_api_version: Optional[str] = None
|
||||||
llm_temperature: float = 0.0
|
llm_temperature: float = 0.0
|
||||||
llm_streaming: bool = False
|
llm_streaming: bool = False
|
||||||
llm_max_tokens: int = 128000
|
llm_max_tokens: int = 16384
|
||||||
transcription_model: str = "whisper-1"
|
transcription_model: str = "whisper-1"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,16 @@ def get_llm_client():
|
||||||
|
|
||||||
provider = LLMProvider(llm_config.llm_provider)
|
provider = LLMProvider(llm_config.llm_provider)
|
||||||
|
|
||||||
|
# Check if max_token value is defined in liteLLM for given model
|
||||||
|
# if not use value from cognee configuration
|
||||||
|
from cognee.infrastructure.llm.utils import get_model_max_tokens
|
||||||
|
|
||||||
|
max_tokens = (
|
||||||
|
get_model_max_tokens(llm_config.llm_model)
|
||||||
|
if get_model_max_tokens(llm_config.llm_model)
|
||||||
|
else llm_config.llm_max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise InvalidValueError(message="LLM API key is not set.")
|
raise InvalidValueError(message="LLM API key is not set.")
|
||||||
|
|
@ -32,7 +42,7 @@ def get_llm_client():
|
||||||
api_version=llm_config.llm_api_version,
|
api_version=llm_config.llm_api_version,
|
||||||
model=llm_config.llm_model,
|
model=llm_config.llm_model,
|
||||||
transcription_model=llm_config.transcription_model,
|
transcription_model=llm_config.transcription_model,
|
||||||
max_tokens=llm_config.llm_max_tokens,
|
max_tokens=max_tokens,
|
||||||
streaming=llm_config.llm_streaming,
|
streaming=llm_config.llm_streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -47,13 +57,13 @@ def get_llm_client():
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Ollama",
|
"Ollama",
|
||||||
max_tokens=llm_config.llm_max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
|
|
||||||
return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model)
|
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
|
|
@ -66,7 +76,7 @@ def get_llm_client():
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Custom",
|
"Custom",
|
||||||
max_tokens=llm_config.llm_max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
|
import logging
|
||||||
|
import litellm
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_max_chunk_tokens():
|
def get_max_chunk_tokens():
|
||||||
# Calculate max chunk size based on the following formula
|
# Calculate max chunk size based on the following formula
|
||||||
|
|
@ -13,3 +18,21 @@ def get_max_chunk_tokens():
|
||||||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||||
|
|
||||||
return max_chunk_tokens
|
return max_chunk_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_max_tokens(model_name: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_name: name of LLM or embedding model
|
||||||
|
|
||||||
|
Returns: Number of max tokens of model, or None if model is unknown
|
||||||
|
"""
|
||||||
|
max_tokens = None
|
||||||
|
|
||||||
|
if model_name in litellm.model_cost:
|
||||||
|
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||||
|
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
||||||
|
else:
|
||||||
|
logger.info("Model not found in LiteLLM's model_cost.")
|
||||||
|
|
||||||
|
return max_tokens
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue