renamed max tokens
This commit is contained in:
parent
c4ec6799a6
commit
1bd40f1401
20 changed files with 92 additions and 73 deletions
|
|
@ -91,7 +91,7 @@ async def cognify(
|
|||
- LangchainChunker: Recursive character splitting with overlap
|
||||
Determines how documents are segmented for processing.
|
||||
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
|
||||
Formula: min(embedding_max_tokens, llm_max_tokens // 2)
|
||||
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
||||
Default limits: ~512-8192 tokens depending on models.
|
||||
Smaller chunks = more granular but potentially fragmented knowledge.
|
||||
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class ResponseRequest(InDTO):
|
|||
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||
user: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
max_tokens: Optional[int] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
self,
|
||||
model: Optional[str] = "openai/text-embedding-3-large",
|
||||
dimensions: Optional[int] = 3072,
|
||||
max_tokens: int = 512,
|
||||
max_completion_tokens: int = 512,
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
# self.retry_count = 0
|
||||
self.embedding_model = TextEmbedding(model_name=model)
|
||||
|
|
@ -112,7 +112,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
|
||||
|
||||
tokenizer = TikTokenTokenizer(model="gpt-4o", max_tokens=self.max_tokens)
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model="gpt-4o", max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
|
||||
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_key: str = None,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
max_tokens: int = 512,
|
||||
max_completion_tokens: int = 512,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -65,7 +65,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.provider = provider
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
|
||||
|
|
@ -179,20 +179,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
model = self.model.split("/")[-1]
|
||||
|
||||
if "openai" in self.provider.lower():
|
||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
tokenizer = GeminiTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
tokenizer = MistralTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
else:
|
||||
try:
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
model=self.model.replace("hosted_vllm/", ""), max_tokens=self.max_tokens
|
||||
model=self.model.replace("hosted_vllm/", ""),
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
||||
logger.info("Switching to TikToken default tokenizer.")
|
||||
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=None, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
Instance variables:
|
||||
- model
|
||||
- dimensions
|
||||
- max_tokens
|
||||
- max_completion_tokens
|
||||
- endpoint
|
||||
- mock
|
||||
- huggingface_tokenizer_name
|
||||
|
|
@ -39,7 +39,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
model: str
|
||||
dimensions: int
|
||||
max_tokens: int
|
||||
max_completion_tokens: int
|
||||
endpoint: str
|
||||
mock: bool
|
||||
huggingface_tokenizer_name: str
|
||||
|
|
@ -50,13 +50,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
self,
|
||||
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
||||
dimensions: Optional[int] = 1024,
|
||||
max_tokens: int = 512,
|
||||
max_completion_tokens: int = 512,
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.endpoint = endpoint
|
||||
self.huggingface_tokenizer_name = huggingface_tokenizer
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
|
|
@ -132,7 +132,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
|
||||
model=self.huggingface_tokenizer_name, max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class EmbeddingConfig(BaseSettings):
|
|||
embedding_endpoint: Optional[str] = None
|
||||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
embedding_max_tokens: Optional[int] = 8191
|
||||
embedding_max_completion_tokens: Optional[int] = 8191
|
||||
huggingface_tokenizer: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class EmbeddingConfig(BaseSettings):
|
|||
"embedding_endpoint": self.embedding_endpoint,
|
||||
"embedding_api_key": self.embedding_api_key,
|
||||
"embedding_api_version": self.embedding_api_version,
|
||||
"embedding_max_tokens": self.embedding_max_tokens,
|
||||
"embedding_max_completion_tokens": self.embedding_max_completion_tokens,
|
||||
"huggingface_tokenizer": self.huggingface_tokenizer,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
config.embedding_provider,
|
||||
config.embedding_model,
|
||||
config.embedding_dimensions,
|
||||
config.embedding_max_tokens,
|
||||
config.embedding_max_completion_tokens,
|
||||
config.embedding_endpoint,
|
||||
config.embedding_api_key,
|
||||
config.embedding_api_version,
|
||||
|
|
@ -41,7 +41,7 @@ def create_embedding_engine(
|
|||
embedding_provider,
|
||||
embedding_model,
|
||||
embedding_dimensions,
|
||||
embedding_max_tokens,
|
||||
embedding_max_completion_tokens,
|
||||
embedding_endpoint,
|
||||
embedding_api_key,
|
||||
embedding_api_version,
|
||||
|
|
@ -58,7 +58,7 @@ def create_embedding_engine(
|
|||
'ollama', or another supported provider.
|
||||
- embedding_model: The model to be used for the embedding engine.
|
||||
- embedding_dimensions: The number of dimensions for the embeddings.
|
||||
- embedding_max_tokens: The maximum number of tokens for the embeddings.
|
||||
- embedding_max_completion_tokens: The maximum number of tokens for the embeddings.
|
||||
- embedding_endpoint: The endpoint for the embedding service, relevant for certain
|
||||
providers.
|
||||
- embedding_api_key: API key to authenticate with the embedding service, if
|
||||
|
|
@ -81,7 +81,7 @@ def create_embedding_engine(
|
|||
return FastembedEmbeddingEngine(
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
)
|
||||
|
||||
if embedding_provider == "ollama":
|
||||
|
|
@ -90,7 +90,7 @@ def create_embedding_engine(
|
|||
return OllamaEmbeddingEngine(
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
endpoint=embedding_endpoint,
|
||||
huggingface_tokenizer=huggingface_tokenizer,
|
||||
)
|
||||
|
|
@ -104,5 +104,5 @@ def create_embedding_engine(
|
|||
api_version=embedding_api_version,
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
max_completion_tokens=embedding_max_completion_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class LLMConfig(BaseSettings):
|
|||
- llm_api_version
|
||||
- llm_temperature
|
||||
- llm_streaming
|
||||
- llm_max_tokens
|
||||
- llm_max_completion_tokens
|
||||
- transcription_model
|
||||
- graph_prompt_path
|
||||
- llm_rate_limit_enabled
|
||||
|
|
@ -41,7 +41,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_api_version: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
llm_max_tokens: int = 16384
|
||||
llm_max_completion_tokens: int = 16384
|
||||
|
||||
baml_llm_provider: str = "openai"
|
||||
baml_llm_model: str = "gpt-5-mini"
|
||||
|
|
@ -171,7 +171,7 @@ class LLMConfig(BaseSettings):
|
|||
"api_version": self.llm_api_version,
|
||||
"temperature": self.llm_temperature,
|
||||
"streaming": self.llm_streaming,
|
||||
"max_tokens": self.llm_max_tokens,
|
||||
"max_completion_tokens": self.llm_max_completion_tokens,
|
||||
"transcription_model": self.transcription_model,
|
||||
"graph_prompt_path": self.graph_prompt_path,
|
||||
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
name = "Anthropic"
|
||||
model: str
|
||||
|
||||
def __init__(self, max_tokens: int, model: str = None):
|
||||
def __init__(self, max_completion_tokens: int, model: str = None):
|
||||
import anthropic
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
|
|
@ -31,7 +31,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
|
|
@ -57,7 +57,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_completion_tokens=4096,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class GeminiAdapter(LLMInterface):
|
|||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
max_completion_tokens: int,
|
||||
endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
|
|
@ -44,7 +44,7 @@ class GeminiAdapter(LLMInterface):
|
|||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
|
|
@ -90,7 +90,7 @@ class GeminiAdapter(LLMInterface):
|
|||
model=f"{self.model}",
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
max_tokens=self.max_tokens,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
temperature=0.1,
|
||||
response_format=response_schema,
|
||||
timeout=100,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
api_key: str,
|
||||
model: str,
|
||||
name: str,
|
||||
max_tokens: int,
|
||||
max_completion_tokens: int,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
|
|
@ -50,7 +50,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
|
|
|
|||
|
|
@ -54,11 +54,15 @@ def get_llm_client():
|
|||
# Check if max_token value is defined in liteLLM for given model
|
||||
# if not use value from cognee configuration
|
||||
from cognee.infrastructure.llm.utils import (
|
||||
get_model_max_tokens,
|
||||
get_model_max_completion_tokens,
|
||||
) # imported here to avoid circular imports
|
||||
|
||||
model_max_tokens = get_model_max_tokens(llm_config.llm_model)
|
||||
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens
|
||||
model_max_completion_tokens = get_model_max_completion_tokens(llm_config.llm_model)
|
||||
max_completion_tokens = (
|
||||
model_max_completion_tokens
|
||||
if model_max_completion_tokens
|
||||
else llm_config.llm_max_completion_tokens
|
||||
)
|
||||
|
||||
if provider == LLMProvider.OPENAI:
|
||||
if llm_config.llm_api_key is None:
|
||||
|
|
@ -74,7 +78,7 @@ def get_llm_client():
|
|||
api_version=llm_config.llm_api_version,
|
||||
model=llm_config.llm_model,
|
||||
transcription_model=llm_config.transcription_model,
|
||||
max_tokens=max_tokens,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
streaming=llm_config.llm_streaming,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
|
|
@ -94,7 +98,7 @@ def get_llm_client():
|
|||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_tokens=max_tokens,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
|
|
@ -102,7 +106,9 @@ def get_llm_client():
|
|||
AnthropicAdapter,
|
||||
)
|
||||
|
||||
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
|
||||
return AnthropicAdapter(
|
||||
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
if llm_config.llm_api_key is None:
|
||||
|
|
@ -117,7 +123,7 @@ def get_llm_client():
|
|||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_tokens=max_tokens,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
fallback_model=llm_config.fallback_model,
|
||||
|
|
@ -134,7 +140,7 @@ def get_llm_client():
|
|||
return GeminiAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_tokens=max_tokens,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
streaming=llm_config.llm_streaming,
|
||||
|
|
|
|||
|
|
@ -30,16 +30,18 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
- model
|
||||
- api_key
|
||||
- endpoint
|
||||
- max_tokens
|
||||
- max_completion_tokens
|
||||
- aclient
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
|
||||
def __init__(
|
||||
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
|
||||
):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.aclient = instructor.from_openai(
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||
|
|
@ -159,7 +161,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
max_completion_tokens=300,
|
||||
)
|
||||
|
||||
# Ensure response is valid before accessing .choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_version: str,
|
||||
model: str,
|
||||
transcription_model: str,
|
||||
max_tokens: int,
|
||||
max_completion_tokens: int,
|
||||
streaming: bool = False,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
|
|
@ -77,7 +77,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
|
|
@ -301,7 +301,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_tokens=300,
|
||||
max_completion_tokens=300,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 3072,
|
||||
max_completion_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
# Get LLM API key from config
|
||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
|
|
|
|||
|
|
@ -14,17 +14,17 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
|||
|
||||
Instance variables include:
|
||||
- model: str
|
||||
- max_tokens: int
|
||||
- max_completion_tokens: int
|
||||
- tokenizer: AutoTokenizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 512,
|
||||
max_completion_tokens: int = 512,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
# Import here to make it an optional dependency
|
||||
from transformers import AutoTokenizer
|
||||
|
|
|
|||
|
|
@ -16,17 +16,17 @@ class MistralTokenizer(TokenizerInterface):
|
|||
|
||||
Instance variables include:
|
||||
- model: str
|
||||
- max_tokens: int
|
||||
- max_completion_tokens: int
|
||||
- tokenizer: MistralTokenizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 3072,
|
||||
max_completion_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
# Import here to make it an optional dependency
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 8191,
|
||||
max_completion_tokens: int = 8191,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
# Initialize TikToken for GPT based on model
|
||||
if model:
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
|
|
@ -93,9 +93,9 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
num_tokens = len(self.tokenizer.encode(text))
|
||||
return num_tokens
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
def trim_text_to_max_completion_tokens(self, text: str) -> str:
|
||||
"""
|
||||
Trim the text so that the number of tokens does not exceed max_tokens.
|
||||
Trim the text so that the number of tokens does not exceed max_completion_tokens.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
|
@ -111,13 +111,13 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
num_tokens = self.count_tokens(text)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= self.max_tokens:
|
||||
if num_tokens <= self.max_completion_tokens:
|
||||
return text
|
||||
|
||||
# If the number exceeds the limit, trim the text
|
||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||
encoded_text = self.tokenizer.encode(text)
|
||||
trimmed_encoded_text = encoded_text[: self.max_tokens]
|
||||
trimmed_encoded_text = encoded_text[: self.max_completion_tokens]
|
||||
# Decoding the trimmed text
|
||||
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
||||
return trimmed_text
|
||||
|
|
|
|||
|
|
@ -32,13 +32,13 @@ def get_max_chunk_tokens():
|
|||
|
||||
# We need to make sure chunk size won't take more than half of LLM max context token size
|
||||
# but it also can't be bigger than the embedding engine max token size
|
||||
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
|
||||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||
llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
|
||||
max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
|
||||
|
||||
return max_chunk_tokens
|
||||
|
||||
|
||||
def get_model_max_tokens(model_name: str):
|
||||
def get_model_max_completion_tokens(model_name: str):
|
||||
"""
|
||||
Retrieve the maximum token limit for a specified model name if it exists.
|
||||
|
||||
|
|
@ -56,15 +56,15 @@ def get_model_max_tokens(model_name: str):
|
|||
|
||||
Number of max tokens of model, or None if model is unknown
|
||||
"""
|
||||
max_tokens = None
|
||||
max_completion_tokens = None
|
||||
|
||||
if model_name in litellm.model_cost:
|
||||
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
||||
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
|
||||
else:
|
||||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
return max_completion_tokens
|
||||
|
||||
|
||||
async def test_llm_connection():
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class QABenchmarkGraphiti(QABenchmarkRAG):
|
|||
|
||||
async def initialize_rag(self) -> Any:
|
||||
"""Initialize Graphiti and LLM."""
|
||||
llm_config = LLMConfig(model=self.config.model_name, max_tokens=65536)
|
||||
llm_config = LLMConfig(model=self.config.model_name, max_completion_tokens=65536)
|
||||
llm_client = OpenAIClient(config=llm_config)
|
||||
graphiti = Graphiti(
|
||||
self.config.db_url,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue