renamed max tokens

This commit is contained in:
vasilije 2025-08-17 12:39:12 +02:00
parent c4ec6799a6
commit 1bd40f1401
20 changed files with 92 additions and 73 deletions

View file

@ -91,7 +91,7 @@ async def cognify(
- LangchainChunker: Recursive character splitting with overlap - LangchainChunker: Recursive character splitting with overlap
Determines how documents are segmented for processing. Determines how documents are segmented for processing.
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None. chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
Formula: min(embedding_max_tokens, llm_max_tokens // 2) Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
Default limits: ~512-8192 tokens depending on models. Default limits: ~512-8192 tokens depending on models.
Smaller chunks = more granular but potentially fragmented knowledge. Smaller chunks = more granular but potentially fragmented knowledge.
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types. ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.

View file

@ -70,7 +70,7 @@ class ResponseRequest(InDTO):
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
user: Optional[str] = None user: Optional[str] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
max_tokens: Optional[int] = None max_completion_tokens: Optional[int] = None
class ToolCallOutput(BaseModel): class ToolCallOutput(BaseModel):

View file

@ -41,11 +41,11 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
self, self,
model: Optional[str] = "openai/text-embedding-3-large", model: Optional[str] = "openai/text-embedding-3-large",
dimensions: Optional[int] = 3072, dimensions: Optional[int] = 3072,
max_tokens: int = 512, max_completion_tokens: int = 512,
): ):
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer() self.tokenizer = self.get_tokenizer()
# self.retry_count = 0 # self.retry_count = 0
self.embedding_model = TextEmbedding(model_name=model) self.embedding_model = TextEmbedding(model_name=model)
@ -112,7 +112,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
""" """
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...") logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
tokenizer = TikTokenTokenizer(model="gpt-4o", max_tokens=self.max_tokens) tokenizer = TikTokenTokenizer(
model="gpt-4o", max_completion_tokens=self.max_completion_tokens
)
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine") logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
return tokenizer return tokenizer

View file

@ -57,7 +57,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str = None, api_key: str = None,
endpoint: str = None, endpoint: str = None,
api_version: str = None, api_version: str = None,
max_tokens: int = 512, max_completion_tokens: int = 512,
): ):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
@ -65,7 +65,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.provider = provider self.provider = provider
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer() self.tokenizer = self.get_tokenizer()
self.retry_count = 0 self.retry_count = 0
@ -179,20 +179,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
model = self.model.split("/")[-1] model = self.model.split("/")[-1]
if "openai" in self.provider.lower(): if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) tokenizer = TikTokenTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower(): elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens) tokenizer = GeminiTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "mistral" in self.provider.lower(): elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens) tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
else: else:
try: try:
tokenizer = HuggingFaceTokenizer( tokenizer = HuggingFaceTokenizer(
model=self.model.replace("hosted_vllm/", ""), max_tokens=self.max_tokens model=self.model.replace("hosted_vllm/", ""),
max_completion_tokens=self.max_completion_tokens,
) )
except Exception as e: except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}") logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.") logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens) tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
logger.debug(f"Tokenizer loaded for model: {self.model}") logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer return tokenizer

View file

@ -30,7 +30,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
Instance variables: Instance variables:
- model - model
- dimensions - dimensions
- max_tokens - max_completion_tokens
- endpoint - endpoint
- mock - mock
- huggingface_tokenizer_name - huggingface_tokenizer_name
@ -39,7 +39,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
model: str model: str
dimensions: int dimensions: int
max_tokens: int max_completion_tokens: int
endpoint: str endpoint: str
mock: bool mock: bool
huggingface_tokenizer_name: str huggingface_tokenizer_name: str
@ -50,13 +50,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
self, self,
model: Optional[str] = "avr/sfr-embedding-mistral:latest", model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024, dimensions: Optional[int] = 1024,
max_tokens: int = 512, max_completion_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings", endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral", huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
): ):
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.endpoint = endpoint self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer self.huggingface_tokenizer_name = huggingface_tokenizer
self.tokenizer = self.get_tokenizer() self.tokenizer = self.get_tokenizer()
@ -132,7 +132,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
""" """
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...") logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
tokenizer = HuggingFaceTokenizer( tokenizer = HuggingFaceTokenizer(
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens model=self.huggingface_tokenizer_name, max_completion_tokens=self.max_completion_tokens
) )
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine") logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
return tokenizer return tokenizer

View file

@ -18,7 +18,7 @@ class EmbeddingConfig(BaseSettings):
embedding_endpoint: Optional[str] = None embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = 8191 embedding_max_completion_tokens: Optional[int] = 8191
huggingface_tokenizer: Optional[str] = None huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -38,7 +38,7 @@ class EmbeddingConfig(BaseSettings):
"embedding_endpoint": self.embedding_endpoint, "embedding_endpoint": self.embedding_endpoint,
"embedding_api_key": self.embedding_api_key, "embedding_api_key": self.embedding_api_key,
"embedding_api_version": self.embedding_api_version, "embedding_api_version": self.embedding_api_version,
"embedding_max_tokens": self.embedding_max_tokens, "embedding_max_completion_tokens": self.embedding_max_completion_tokens,
"huggingface_tokenizer": self.huggingface_tokenizer, "huggingface_tokenizer": self.huggingface_tokenizer,
} }

View file

@ -27,7 +27,7 @@ def get_embedding_engine() -> EmbeddingEngine:
config.embedding_provider, config.embedding_provider,
config.embedding_model, config.embedding_model,
config.embedding_dimensions, config.embedding_dimensions,
config.embedding_max_tokens, config.embedding_max_completion_tokens,
config.embedding_endpoint, config.embedding_endpoint,
config.embedding_api_key, config.embedding_api_key,
config.embedding_api_version, config.embedding_api_version,
@ -41,7 +41,7 @@ def create_embedding_engine(
embedding_provider, embedding_provider,
embedding_model, embedding_model,
embedding_dimensions, embedding_dimensions,
embedding_max_tokens, embedding_max_completion_tokens,
embedding_endpoint, embedding_endpoint,
embedding_api_key, embedding_api_key,
embedding_api_version, embedding_api_version,
@ -58,7 +58,7 @@ def create_embedding_engine(
'ollama', or another supported provider. 'ollama', or another supported provider.
- embedding_model: The model to be used for the embedding engine. - embedding_model: The model to be used for the embedding engine.
- embedding_dimensions: The number of dimensions for the embeddings. - embedding_dimensions: The number of dimensions for the embeddings.
- embedding_max_tokens: The maximum number of tokens for the embeddings. - embedding_max_completion_tokens: The maximum number of tokens for the embeddings.
- embedding_endpoint: The endpoint for the embedding service, relevant for certain - embedding_endpoint: The endpoint for the embedding service, relevant for certain
providers. providers.
- embedding_api_key: API key to authenticate with the embedding service, if - embedding_api_key: API key to authenticate with the embedding service, if
@ -81,7 +81,7 @@ def create_embedding_engine(
return FastembedEmbeddingEngine( return FastembedEmbeddingEngine(
model=embedding_model, model=embedding_model,
dimensions=embedding_dimensions, dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens, max_completion_tokens=embedding_max_completion_tokens,
) )
if embedding_provider == "ollama": if embedding_provider == "ollama":
@ -90,7 +90,7 @@ def create_embedding_engine(
return OllamaEmbeddingEngine( return OllamaEmbeddingEngine(
model=embedding_model, model=embedding_model,
dimensions=embedding_dimensions, dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens, max_completion_tokens=embedding_max_completion_tokens,
endpoint=embedding_endpoint, endpoint=embedding_endpoint,
huggingface_tokenizer=huggingface_tokenizer, huggingface_tokenizer=huggingface_tokenizer,
) )
@ -104,5 +104,5 @@ def create_embedding_engine(
api_version=embedding_api_version, api_version=embedding_api_version,
model=embedding_model, model=embedding_model,
dimensions=embedding_dimensions, dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens, max_completion_tokens=embedding_max_completion_tokens,
) )

View file

@ -18,7 +18,7 @@ class LLMConfig(BaseSettings):
- llm_api_version - llm_api_version
- llm_temperature - llm_temperature
- llm_streaming - llm_streaming
- llm_max_tokens - llm_max_completion_tokens
- transcription_model - transcription_model
- graph_prompt_path - graph_prompt_path
- llm_rate_limit_enabled - llm_rate_limit_enabled
@ -41,7 +41,7 @@ class LLMConfig(BaseSettings):
llm_api_version: Optional[str] = None llm_api_version: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False
llm_max_tokens: int = 16384 llm_max_completion_tokens: int = 16384
baml_llm_provider: str = "openai" baml_llm_provider: str = "openai"
baml_llm_model: str = "gpt-5-mini" baml_llm_model: str = "gpt-5-mini"
@ -171,7 +171,7 @@ class LLMConfig(BaseSettings):
"api_version": self.llm_api_version, "api_version": self.llm_api_version,
"temperature": self.llm_temperature, "temperature": self.llm_temperature,
"streaming": self.llm_streaming, "streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens, "max_completion_tokens": self.llm_max_completion_tokens,
"transcription_model": self.transcription_model, "transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path, "graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled, "rate_limit_enabled": self.llm_rate_limit_enabled,

View file

@ -23,7 +23,7 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic" name = "Anthropic"
model: str model: str
def __init__(self, max_tokens: int, model: str = None): def __init__(self, max_completion_tokens: int, model: str = None):
import anthropic import anthropic
self.aclient = instructor.patch( self.aclient = instructor.patch(
@ -31,7 +31,7 @@ class AnthropicAdapter(LLMInterface):
) )
self.model = model self.model = model
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
@sleep_and_retry_async() @sleep_and_retry_async()
@rate_limit_async @rate_limit_async
@ -57,7 +57,7 @@ class AnthropicAdapter(LLMInterface):
return await self.aclient( return await self.aclient(
model=self.model, model=self.model,
max_tokens=4096, max_completion_tokens=4096,
max_retries=5, max_retries=5,
messages=[ messages=[
{ {

View file

@ -34,7 +34,7 @@ class GeminiAdapter(LLMInterface):
self, self,
api_key: str, api_key: str,
model: str, model: str,
max_tokens: int, max_completion_tokens: int,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
streaming: bool = False, streaming: bool = False,
@ -44,7 +44,7 @@ class GeminiAdapter(LLMInterface):
self.endpoint = endpoint self.endpoint = endpoint
self.api_version = api_version self.api_version = api_version
self.streaming = streaming self.streaming = streaming
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation") @observe(as_type="generation")
@sleep_and_retry_async() @sleep_and_retry_async()
@ -90,7 +90,7 @@ class GeminiAdapter(LLMInterface):
model=f"{self.model}", model=f"{self.model}",
messages=messages, messages=messages,
api_key=self.api_key, api_key=self.api_key,
max_tokens=self.max_tokens, max_completion_tokens=self.max_completion_tokens,
temperature=0.1, temperature=0.1,
response_format=response_schema, response_format=response_schema,
timeout=100, timeout=100,

View file

@ -41,7 +41,7 @@ class GenericAPIAdapter(LLMInterface):
api_key: str, api_key: str,
model: str, model: str,
name: str, name: str,
max_tokens: int, max_completion_tokens: int,
fallback_model: str = None, fallback_model: str = None,
fallback_api_key: str = None, fallback_api_key: str = None,
fallback_endpoint: str = None, fallback_endpoint: str = None,
@ -50,7 +50,7 @@ class GenericAPIAdapter(LLMInterface):
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.fallback_model = fallback_model self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key self.fallback_api_key = fallback_api_key

View file

@ -54,11 +54,15 @@ def get_llm_client():
# Check if max_token value is defined in liteLLM for given model # Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration # if not use value from cognee configuration
from cognee.infrastructure.llm.utils import ( from cognee.infrastructure.llm.utils import (
get_model_max_tokens, get_model_max_completion_tokens,
) # imported here to avoid circular imports ) # imported here to avoid circular imports
model_max_tokens = get_model_max_tokens(llm_config.llm_model) model_max_completion_tokens = get_model_max_completion_tokens(llm_config.llm_model)
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens max_completion_tokens = (
model_max_completion_tokens
if model_max_completion_tokens
else llm_config.llm_max_completion_tokens
)
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
@ -74,7 +78,7 @@ def get_llm_client():
api_version=llm_config.llm_api_version, api_version=llm_config.llm_api_version,
model=llm_config.llm_model, model=llm_config.llm_model,
transcription_model=llm_config.transcription_model, transcription_model=llm_config.transcription_model,
max_tokens=max_tokens, max_completion_tokens=max_completion_tokens,
streaming=llm_config.llm_streaming, streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key, fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint, fallback_endpoint=llm_config.fallback_endpoint,
@ -94,7 +98,7 @@ def get_llm_client():
llm_config.llm_api_key, llm_config.llm_api_key,
llm_config.llm_model, llm_config.llm_model,
"Ollama", "Ollama",
max_tokens=max_tokens, max_completion_tokens=max_completion_tokens,
) )
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
@ -102,7 +106,9 @@ def get_llm_client():
AnthropicAdapter, AnthropicAdapter,
) )
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model) return AnthropicAdapter(
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
@ -117,7 +123,7 @@ def get_llm_client():
llm_config.llm_api_key, llm_config.llm_api_key,
llm_config.llm_model, llm_config.llm_model,
"Custom", "Custom",
max_tokens=max_tokens, max_completion_tokens=max_completion_tokens,
fallback_api_key=llm_config.fallback_api_key, fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint, fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model, fallback_model=llm_config.fallback_model,
@ -134,7 +140,7 @@ def get_llm_client():
return GeminiAdapter( return GeminiAdapter(
api_key=llm_config.llm_api_key, api_key=llm_config.llm_api_key,
model=llm_config.llm_model, model=llm_config.llm_model,
max_tokens=max_tokens, max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint, endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version, api_version=llm_config.llm_api_version,
streaming=llm_config.llm_streaming, streaming=llm_config.llm_streaming,

View file

@ -30,16 +30,18 @@ class OllamaAPIAdapter(LLMInterface):
- model - model
- api_key - api_key
- endpoint - endpoint
- max_tokens - max_completion_tokens
- aclient - aclient
""" """
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int): def __init__(
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
):
self.name = name self.name = name
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.aclient = instructor.from_openai( self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
@ -159,7 +161,7 @@ class OllamaAPIAdapter(LLMInterface):
], ],
} }
], ],
max_tokens=300, max_completion_tokens=300,
) )
# Ensure response is valid before accessing .choices[0].message.content # Ensure response is valid before accessing .choices[0].message.content

View file

@ -64,7 +64,7 @@ class OpenAIAdapter(LLMInterface):
api_version: str, api_version: str,
model: str, model: str,
transcription_model: str, transcription_model: str,
max_tokens: int, max_completion_tokens: int,
streaming: bool = False, streaming: bool = False,
fallback_model: str = None, fallback_model: str = None,
fallback_api_key: str = None, fallback_api_key: str = None,
@ -77,7 +77,7 @@ class OpenAIAdapter(LLMInterface):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.api_version = api_version self.api_version = api_version
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
self.streaming = streaming self.streaming = streaming
self.fallback_model = fallback_model self.fallback_model = fallback_model
@ -301,7 +301,7 @@ class OpenAIAdapter(LLMInterface):
api_key=self.api_key, api_key=self.api_key,
api_base=self.endpoint, api_base=self.endpoint,
api_version=self.api_version, api_version=self.api_version,
max_tokens=300, max_completion_tokens=300,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
) )

View file

@ -17,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
def __init__( def __init__(
self, self,
model: str, model: str,
max_tokens: int = 3072, max_completion_tokens: int = 3072,
): ):
self.model = model self.model = model
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
# Get LLM API key from config # Get LLM API key from config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config

View file

@ -14,17 +14,17 @@ class HuggingFaceTokenizer(TokenizerInterface):
Instance variables include: Instance variables include:
- model: str - model: str
- max_tokens: int - max_completion_tokens: int
- tokenizer: AutoTokenizer - tokenizer: AutoTokenizer
""" """
def __init__( def __init__(
self, self,
model: str, model: str,
max_tokens: int = 512, max_completion_tokens: int = 512,
): ):
self.model = model self.model = model
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
# Import here to make it an optional dependency # Import here to make it an optional dependency
from transformers import AutoTokenizer from transformers import AutoTokenizer

View file

@ -16,17 +16,17 @@ class MistralTokenizer(TokenizerInterface):
Instance variables include: Instance variables include:
- model: str - model: str
- max_tokens: int - max_completion_tokens: int
- tokenizer: MistralTokenizer - tokenizer: MistralTokenizer
""" """
def __init__( def __init__(
self, self,
model: str, model: str,
max_tokens: int = 3072, max_completion_tokens: int = 3072,
): ):
self.model = model self.model = model
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
# Import here to make it an optional dependency # Import here to make it an optional dependency
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

View file

@ -13,10 +13,10 @@ class TikTokenTokenizer(TokenizerInterface):
def __init__( def __init__(
self, self,
model: Optional[str] = None, model: Optional[str] = None,
max_tokens: int = 8191, max_completion_tokens: int = 8191,
): ):
self.model = model self.model = model
self.max_tokens = max_tokens self.max_completion_tokens = max_completion_tokens
# Initialize TikToken for GPT based on model # Initialize TikToken for GPT based on model
if model: if model:
self.tokenizer = tiktoken.encoding_for_model(self.model) self.tokenizer = tiktoken.encoding_for_model(self.model)
@ -93,9 +93,9 @@ class TikTokenTokenizer(TokenizerInterface):
num_tokens = len(self.tokenizer.encode(text)) num_tokens = len(self.tokenizer.encode(text))
return num_tokens return num_tokens
def trim_text_to_max_tokens(self, text: str) -> str: def trim_text_to_max_completion_tokens(self, text: str) -> str:
""" """
Trim the text so that the number of tokens does not exceed max_tokens. Trim the text so that the number of tokens does not exceed max_completion_tokens.
Parameters: Parameters:
----------- -----------
@ -111,13 +111,13 @@ class TikTokenTokenizer(TokenizerInterface):
num_tokens = self.count_tokens(text) num_tokens = self.count_tokens(text)
# If the number of tokens is within the limit, return the text as is # If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens: if num_tokens <= self.max_completion_tokens:
return text return text
# If the number exceeds the limit, trim the text # If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut # This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = self.tokenizer.encode(text) encoded_text = self.tokenizer.encode(text)
trimmed_encoded_text = encoded_text[: self.max_tokens] trimmed_encoded_text = encoded_text[: self.max_completion_tokens]
# Decoding the trimmed text # Decoding the trimmed text
trimmed_text = self.tokenizer.decode(trimmed_encoded_text) trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
return trimmed_text return trimmed_text

View file

@ -32,13 +32,13 @@ def get_max_chunk_tokens():
# We need to make sure chunk size won't take more than half of LLM max context token size # We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size # but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
return max_chunk_tokens return max_chunk_tokens
def get_model_max_tokens(model_name: str): def get_model_max_completion_tokens(model_name: str):
""" """
Retrieve the maximum token limit for a specified model name if it exists. Retrieve the maximum token limit for a specified model name if it exists.
@ -56,15 +56,15 @@ def get_model_max_tokens(model_name: str):
Number of max tokens of model, or None if model is unknown Number of max tokens of model, or None if model is unknown
""" """
max_tokens = None max_completion_tokens = None
if model_name in litellm.model_cost: if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"] max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}") logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
else: else:
logger.info("Model not found in LiteLLM's model_cost.") logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens return max_completion_tokens
async def test_llm_connection(): async def test_llm_connection():

View file

@ -43,7 +43,7 @@ class QABenchmarkGraphiti(QABenchmarkRAG):
async def initialize_rag(self) -> Any: async def initialize_rag(self) -> Any:
"""Initialize Graphiti and LLM.""" """Initialize Graphiti and LLM."""
llm_config = LLMConfig(model=self.config.model_name, max_tokens=65536) llm_config = LLMConfig(model=self.config.model_name, max_completion_tokens=65536)
llm_client = OpenAIClient(config=llm_config) llm_client = OpenAIClient(config=llm_config)
graphiti = Graphiti( graphiti = Graphiti(
self.config.db_url, self.config.db_url,