renamed max tokens

This commit is contained in:
vasilije 2025-08-17 12:39:12 +02:00
parent c4ec6799a6
commit 1bd40f1401
20 changed files with 92 additions and 73 deletions

View file

@ -91,7 +91,7 @@ async def cognify(
- LangchainChunker: Recursive character splitting with overlap
Determines how documents are segmented for processing.
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
Formula: min(embedding_max_tokens, llm_max_tokens // 2)
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
Default limits: ~512-8192 tokens depending on models.
Smaller chunks = more granular but potentially fragmented knowledge.
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.

View file

@ -70,7 +70,7 @@ class ResponseRequest(InDTO):
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
user: Optional[str] = None
temperature: Optional[float] = 1.0
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
class ToolCallOutput(BaseModel):

View file

@ -41,11 +41,11 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
self,
model: Optional[str] = "openai/text-embedding-3-large",
dimensions: Optional[int] = 3072,
max_tokens: int = 512,
max_completion_tokens: int = 512,
):
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
# self.retry_count = 0
self.embedding_model = TextEmbedding(model_name=model)
@ -112,7 +112,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
"""
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
tokenizer = TikTokenTokenizer(model="gpt-4o", max_tokens=self.max_tokens)
tokenizer = TikTokenTokenizer(
model="gpt-4o", max_completion_tokens=self.max_completion_tokens
)
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
return tokenizer

View file

@ -57,7 +57,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str = None,
endpoint: str = None,
api_version: str = None,
max_tokens: int = 512,
max_completion_tokens: int = 512,
):
self.api_key = api_key
self.endpoint = endpoint
@ -65,7 +65,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
@ -179,20 +179,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
model = self.model.split("/")[-1]
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
tokenizer = TikTokenTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
tokenizer = GeminiTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
else:
try:
tokenizer = HuggingFaceTokenizer(
model=self.model.replace("hosted_vllm/", ""), max_tokens=self.max_tokens
model=self.model.replace("hosted_vllm/", ""),
max_completion_tokens=self.max_completion_tokens,
)
except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -30,7 +30,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
Instance variables:
- model
- dimensions
- max_tokens
- max_completion_tokens
- endpoint
- mock
- huggingface_tokenizer_name
@ -39,7 +39,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
model: str
dimensions: int
max_tokens: int
max_completion_tokens: int
endpoint: str
mock: bool
huggingface_tokenizer_name: str
@ -50,13 +50,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
self,
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024,
max_tokens: int = 512,
max_completion_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
):
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.tokenizer = self.get_tokenizer()
@ -132,7 +132,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
"""
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
tokenizer = HuggingFaceTokenizer(
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
model=self.huggingface_tokenizer_name, max_completion_tokens=self.max_completion_tokens
)
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
return tokenizer

View file

@ -18,7 +18,7 @@ class EmbeddingConfig(BaseSettings):
embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = 8191
embedding_max_completion_tokens: Optional[int] = 8191
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -38,7 +38,7 @@ class EmbeddingConfig(BaseSettings):
"embedding_endpoint": self.embedding_endpoint,
"embedding_api_key": self.embedding_api_key,
"embedding_api_version": self.embedding_api_version,
"embedding_max_tokens": self.embedding_max_tokens,
"embedding_max_completion_tokens": self.embedding_max_completion_tokens,
"huggingface_tokenizer": self.huggingface_tokenizer,
}

View file

@ -27,7 +27,7 @@ def get_embedding_engine() -> EmbeddingEngine:
config.embedding_provider,
config.embedding_model,
config.embedding_dimensions,
config.embedding_max_tokens,
config.embedding_max_completion_tokens,
config.embedding_endpoint,
config.embedding_api_key,
config.embedding_api_version,
@ -41,7 +41,7 @@ def create_embedding_engine(
embedding_provider,
embedding_model,
embedding_dimensions,
embedding_max_tokens,
embedding_max_completion_tokens,
embedding_endpoint,
embedding_api_key,
embedding_api_version,
@ -58,7 +58,7 @@ def create_embedding_engine(
'ollama', or another supported provider.
- embedding_model: The model to be used for the embedding engine.
- embedding_dimensions: The number of dimensions for the embeddings.
- embedding_max_tokens: The maximum number of tokens for the embeddings.
- embedding_max_completion_tokens: The maximum number of tokens for the embeddings.
- embedding_endpoint: The endpoint for the embedding service, relevant for certain
providers.
- embedding_api_key: API key to authenticate with the embedding service, if
@ -81,7 +81,7 @@ def create_embedding_engine(
return FastembedEmbeddingEngine(
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
max_completion_tokens=embedding_max_completion_tokens,
)
if embedding_provider == "ollama":
@ -90,7 +90,7 @@ def create_embedding_engine(
return OllamaEmbeddingEngine(
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
max_completion_tokens=embedding_max_completion_tokens,
endpoint=embedding_endpoint,
huggingface_tokenizer=huggingface_tokenizer,
)
@ -104,5 +104,5 @@ def create_embedding_engine(
api_version=embedding_api_version,
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
max_completion_tokens=embedding_max_completion_tokens,
)

View file

@ -18,7 +18,7 @@ class LLMConfig(BaseSettings):
- llm_api_version
- llm_temperature
- llm_streaming
- llm_max_tokens
- llm_max_completion_tokens
- transcription_model
- graph_prompt_path
- llm_rate_limit_enabled
@ -41,7 +41,7 @@ class LLMConfig(BaseSettings):
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
llm_max_tokens: int = 16384
llm_max_completion_tokens: int = 16384
baml_llm_provider: str = "openai"
baml_llm_model: str = "gpt-5-mini"
@ -171,7 +171,7 @@ class LLMConfig(BaseSettings):
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"max_completion_tokens": self.llm_max_completion_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,

View file

@ -23,7 +23,7 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic"
model: str
def __init__(self, max_tokens: int, model: str = None):
def __init__(self, max_completion_tokens: int, model: str = None):
import anthropic
self.aclient = instructor.patch(
@ -31,7 +31,7 @@ class AnthropicAdapter(LLMInterface):
)
self.model = model
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
@sleep_and_retry_async()
@rate_limit_async
@ -57,7 +57,7 @@ class AnthropicAdapter(LLMInterface):
return await self.aclient(
model=self.model,
max_tokens=4096,
max_completion_tokens=4096,
max_retries=5,
messages=[
{

View file

@ -34,7 +34,7 @@ class GeminiAdapter(LLMInterface):
self,
api_key: str,
model: str,
max_tokens: int,
max_completion_tokens: int,
endpoint: Optional[str] = None,
api_version: Optional[str] = None,
streaming: bool = False,
@ -44,7 +44,7 @@ class GeminiAdapter(LLMInterface):
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation")
@sleep_and_retry_async()
@ -90,7 +90,7 @@ class GeminiAdapter(LLMInterface):
model=f"{self.model}",
messages=messages,
api_key=self.api_key,
max_tokens=self.max_tokens,
max_completion_tokens=self.max_completion_tokens,
temperature=0.1,
response_format=response_schema,
timeout=100,

View file

@ -41,7 +41,7 @@ class GenericAPIAdapter(LLMInterface):
api_key: str,
model: str,
name: str,
max_tokens: int,
max_completion_tokens: int,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -50,7 +50,7 @@ class GenericAPIAdapter(LLMInterface):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key

View file

@ -54,11 +54,15 @@ def get_llm_client():
# Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration
from cognee.infrastructure.llm.utils import (
get_model_max_tokens,
get_model_max_completion_tokens,
) # imported here to avoid circular imports
model_max_tokens = get_model_max_tokens(llm_config.llm_model)
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens
model_max_completion_tokens = get_model_max_completion_tokens(llm_config.llm_model)
max_completion_tokens = (
model_max_completion_tokens
if model_max_completion_tokens
else llm_config.llm_max_completion_tokens
)
if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None:
@ -74,7 +78,7 @@ def get_llm_client():
api_version=llm_config.llm_api_version,
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@ -94,7 +98,7 @@ def get_llm_client():
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
)
elif provider == LLMProvider.ANTHROPIC:
@ -102,7 +106,9 @@ def get_llm_client():
AnthropicAdapter,
)
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
)
elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None:
@ -117,7 +123,7 @@ def get_llm_client():
llm_config.llm_api_key,
llm_config.llm_model,
"Custom",
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@ -134,7 +140,7 @@ def get_llm_client():
return GeminiAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
streaming=llm_config.llm_streaming,

View file

@ -30,16 +30,18 @@ class OllamaAPIAdapter(LLMInterface):
- model
- api_key
- endpoint
- max_tokens
- max_completion_tokens
- aclient
"""
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
def __init__(
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
@ -159,7 +161,7 @@ class OllamaAPIAdapter(LLMInterface):
],
}
],
max_tokens=300,
max_completion_tokens=300,
)
# Ensure response is valid before accessing .choices[0].message.content

View file

@ -64,7 +64,7 @@ class OpenAIAdapter(LLMInterface):
api_version: str,
model: str,
transcription_model: str,
max_tokens: int,
max_completion_tokens: int,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
@ -77,7 +77,7 @@ class OpenAIAdapter(LLMInterface):
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.streaming = streaming
self.fallback_model = fallback_model
@ -301,7 +301,7 @@ class OpenAIAdapter(LLMInterface):
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens=300,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)

View file

@ -17,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = 3072,
max_completion_tokens: int = 3072,
):
self.model = model
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
# Get LLM API key from config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config

View file

@ -14,17 +14,17 @@ class HuggingFaceTokenizer(TokenizerInterface):
Instance variables include:
- model: str
- max_tokens: int
- max_completion_tokens: int
- tokenizer: AutoTokenizer
"""
def __init__(
self,
model: str,
max_tokens: int = 512,
max_completion_tokens: int = 512,
):
self.model = model
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
# Import here to make it an optional dependency
from transformers import AutoTokenizer

View file

@ -16,17 +16,17 @@ class MistralTokenizer(TokenizerInterface):
Instance variables include:
- model: str
- max_tokens: int
- max_completion_tokens: int
- tokenizer: MistralTokenizer
"""
def __init__(
self,
model: str,
max_tokens: int = 3072,
max_completion_tokens: int = 3072,
):
self.model = model
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
# Import here to make it an optional dependency
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

View file

@ -13,10 +13,10 @@ class TikTokenTokenizer(TokenizerInterface):
def __init__(
self,
model: Optional[str] = None,
max_tokens: int = 8191,
max_completion_tokens: int = 8191,
):
self.model = model
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
# Initialize TikToken for GPT based on model
if model:
self.tokenizer = tiktoken.encoding_for_model(self.model)
@ -93,9 +93,9 @@ class TikTokenTokenizer(TokenizerInterface):
num_tokens = len(self.tokenizer.encode(text))
return num_tokens
def trim_text_to_max_tokens(self, text: str) -> str:
def trim_text_to_max_completion_tokens(self, text: str) -> str:
"""
Trim the text so that the number of tokens does not exceed max_tokens.
Trim the text so that the number of tokens does not exceed max_completion_tokens.
Parameters:
-----------
@ -111,13 +111,13 @@ class TikTokenTokenizer(TokenizerInterface):
num_tokens = self.count_tokens(text)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens:
if num_tokens <= self.max_completion_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = self.tokenizer.encode(text)
trimmed_encoded_text = encoded_text[: self.max_tokens]
trimmed_encoded_text = encoded_text[: self.max_completion_tokens]
# Decoding the trimmed text
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
return trimmed_text

View file

@ -32,13 +32,13 @@ def get_max_chunk_tokens():
# We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_tokens(model_name: str):
def get_model_max_completion_tokens(model_name: str):
"""
Retrieve the maximum token limit for a specified model name if it exists.
@ -56,15 +56,15 @@ def get_model_max_tokens(model_name: str):
Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
max_completion_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens
return max_completion_tokens
async def test_llm_connection():

View file

@ -43,7 +43,7 @@ class QABenchmarkGraphiti(QABenchmarkRAG):
async def initialize_rag(self) -> Any:
"""Initialize Graphiti and LLM."""
llm_config = LLMConfig(model=self.config.model_name, max_tokens=65536)
llm_config = LLMConfig(model=self.config.model_name, max_completion_tokens=65536)
llm_client = OpenAIClient(config=llm_config)
graphiti = Graphiti(
self.config.db_url,