cognee/cognee/infrastructure/llm/utils.py
2025-09-27 00:06:33 +02:00

107 lines
3.8 KiB
Python

import litellm
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.shared.logging_utils import get_logger
logger = get_logger()
def get_max_chunk_tokens():
"""
Calculate the maximum number of tokens allowed in a chunk.
The function determines the maximum chunk size based on the maximum token limit of the
embedding engine and half of the LLM maximum context token size. It ensures that the
chunk size does not exceed these constraints.
Returns:
--------
- int: The maximum number of tokens that can be included in a chunk, determined by
the smaller value of the embedding engine's max tokens and half of the LLM's
maximum tokens.
"""
# NOTE: Import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
# Calculate max chunk size based on the following formula
embedding_engine = get_vector_engine().embedding_engine
llm_client = get_llm_client(raise_api_key_error=False)
# We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_completion_tokens(model_name: str):
"""
Retrieve the maximum token limit for a specified model name if it exists.
Checks if the provided model name is present in the predefined model cost dictionary. If
found, it logs the maximum token count for that model and returns it. If the model name
is not recognized, it logs an informational message and returns None.
Parameters:
-----------
- model_name (str): Name of LLM or embedding model
Returns:
--------
Number of max tokens of model, or None if model is unknown
"""
max_completion_tokens = None
if model_name in litellm.model_cost:
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
else:
logger.debug("Model not found in LiteLLM's model_cost.")
return max_completion_tokens
async def test_llm_connection():
"""
Establish a connection to the LLM and create a structured output.
Attempt to connect to the LLM client and uses the adapter to create a structured output
with a predefined text input and system prompt. Log any exceptions encountered during
the connection attempt and re-raise the exception for further handling.
"""
try:
await LLMGateway.acreate_structured_output(
text_input="test",
system_prompt='Respond to me with the following string: "test"',
response_model=str,
)
except Exception as e:
logger.error(e)
logger.error("Connection to LLM could not be established.")
raise e
async def test_embedding_connection():
"""
Test the connection to the embedding engine by embedding a sample text.
Handles exceptions that may occur during the operation, logs the error, and re-raises
the exception if the connection to the embedding handler cannot be established.
"""
try:
# NOTE: Vector engine import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
await get_vector_engine().embedding_engine.embed_text("test")
except Exception as e:
logger.error(e)
logger.error("Connection to Embedding handler could not be established.")
raise e