renamed max tokens
This commit is contained in:
parent
c4ec6799a6
commit
1bd40f1401
20 changed files with 92 additions and 73 deletions
|
|
@ -91,7 +91,7 @@ async def cognify(
|
||||||
- LangchainChunker: Recursive character splitting with overlap
|
- LangchainChunker: Recursive character splitting with overlap
|
||||||
Determines how documents are segmented for processing.
|
Determines how documents are segmented for processing.
|
||||||
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
|
chunk_size: Maximum tokens per chunk. Auto-calculated based on LLM if None.
|
||||||
Formula: min(embedding_max_tokens, llm_max_tokens // 2)
|
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
||||||
Default limits: ~512-8192 tokens depending on models.
|
Default limits: ~512-8192 tokens depending on models.
|
||||||
Smaller chunks = more granular but potentially fragmented knowledge.
|
Smaller chunks = more granular but potentially fragmented knowledge.
|
||||||
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class ResponseRequest(InDTO):
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
max_tokens: Optional[int] = None
|
max_completion_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCallOutput(BaseModel):
|
class ToolCallOutput(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,11 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = "openai/text-embedding-3-large",
|
model: Optional[str] = "openai/text-embedding-3-large",
|
||||||
dimensions: Optional[int] = 3072,
|
dimensions: Optional[int] = 3072,
|
||||||
max_tokens: int = 512,
|
max_completion_tokens: int = 512,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.tokenizer = self.get_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
# self.retry_count = 0
|
# self.retry_count = 0
|
||||||
self.embedding_model = TextEmbedding(model_name=model)
|
self.embedding_model = TextEmbedding(model_name=model)
|
||||||
|
|
@ -112,7 +112,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
||||||
"""
|
"""
|
||||||
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
|
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
|
||||||
|
|
||||||
tokenizer = TikTokenTokenizer(model="gpt-4o", max_tokens=self.max_tokens)
|
tokenizer = TikTokenTokenizer(
|
||||||
|
model="gpt-4o", max_completion_tokens=self.max_completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
|
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
endpoint: str = None,
|
endpoint: str = None,
|
||||||
api_version: str = None,
|
api_version: str = None,
|
||||||
max_tokens: int = 512,
|
max_completion_tokens: int = 512,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
|
|
@ -65,7 +65,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.tokenizer = self.get_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
||||||
|
|
@ -179,20 +179,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
model = self.model.split("/")[-1]
|
model = self.model.split("/")[-1]
|
||||||
|
|
||||||
if "openai" in self.provider.lower():
|
if "openai" in self.provider.lower():
|
||||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
tokenizer = TikTokenTokenizer(
|
||||||
|
model=model, max_completion_tokens=self.max_completion_tokens
|
||||||
|
)
|
||||||
elif "gemini" in self.provider.lower():
|
elif "gemini" in self.provider.lower():
|
||||||
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
tokenizer = GeminiTokenizer(
|
||||||
|
model=model, max_completion_tokens=self.max_completion_tokens
|
||||||
|
)
|
||||||
elif "mistral" in self.provider.lower():
|
elif "mistral" in self.provider.lower():
|
||||||
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
|
tokenizer = MistralTokenizer(
|
||||||
|
model=model, max_completion_tokens=self.max_completion_tokens
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tokenizer = HuggingFaceTokenizer(
|
tokenizer = HuggingFaceTokenizer(
|
||||||
model=self.model.replace("hosted_vllm/", ""), max_tokens=self.max_tokens
|
model=self.model.replace("hosted_vllm/", ""),
|
||||||
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
||||||
logger.info("Switching to TikToken default tokenizer.")
|
logger.info("Switching to TikToken default tokenizer.")
|
||||||
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
|
tokenizer = TikTokenTokenizer(
|
||||||
|
model=None, max_completion_tokens=self.max_completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
Instance variables:
|
Instance variables:
|
||||||
- model
|
- model
|
||||||
- dimensions
|
- dimensions
|
||||||
- max_tokens
|
- max_completion_tokens
|
||||||
- endpoint
|
- endpoint
|
||||||
- mock
|
- mock
|
||||||
- huggingface_tokenizer_name
|
- huggingface_tokenizer_name
|
||||||
|
|
@ -39,7 +39,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
dimensions: int
|
dimensions: int
|
||||||
max_tokens: int
|
max_completion_tokens: int
|
||||||
endpoint: str
|
endpoint: str
|
||||||
mock: bool
|
mock: bool
|
||||||
huggingface_tokenizer_name: str
|
huggingface_tokenizer_name: str
|
||||||
|
|
@ -50,13 +50,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
||||||
dimensions: Optional[int] = 1024,
|
dimensions: Optional[int] = 1024,
|
||||||
max_tokens: int = 512,
|
max_completion_tokens: int = 512,
|
||||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.huggingface_tokenizer_name = huggingface_tokenizer
|
self.huggingface_tokenizer_name = huggingface_tokenizer
|
||||||
self.tokenizer = self.get_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
|
|
@ -132,7 +132,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
"""
|
"""
|
||||||
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
|
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
|
||||||
tokenizer = HuggingFaceTokenizer(
|
tokenizer = HuggingFaceTokenizer(
|
||||||
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
|
model=self.huggingface_tokenizer_name, max_completion_tokens=self.max_completion_tokens
|
||||||
)
|
)
|
||||||
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
|
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class EmbeddingConfig(BaseSettings):
|
||||||
embedding_endpoint: Optional[str] = None
|
embedding_endpoint: Optional[str] = None
|
||||||
embedding_api_key: Optional[str] = None
|
embedding_api_key: Optional[str] = None
|
||||||
embedding_api_version: Optional[str] = None
|
embedding_api_version: Optional[str] = None
|
||||||
embedding_max_tokens: Optional[int] = 8191
|
embedding_max_completion_tokens: Optional[int] = 8191
|
||||||
huggingface_tokenizer: Optional[str] = None
|
huggingface_tokenizer: Optional[str] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ class EmbeddingConfig(BaseSettings):
|
||||||
"embedding_endpoint": self.embedding_endpoint,
|
"embedding_endpoint": self.embedding_endpoint,
|
||||||
"embedding_api_key": self.embedding_api_key,
|
"embedding_api_key": self.embedding_api_key,
|
||||||
"embedding_api_version": self.embedding_api_version,
|
"embedding_api_version": self.embedding_api_version,
|
||||||
"embedding_max_tokens": self.embedding_max_tokens,
|
"embedding_max_completion_tokens": self.embedding_max_completion_tokens,
|
||||||
"huggingface_tokenizer": self.huggingface_tokenizer,
|
"huggingface_tokenizer": self.huggingface_tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
||||||
config.embedding_provider,
|
config.embedding_provider,
|
||||||
config.embedding_model,
|
config.embedding_model,
|
||||||
config.embedding_dimensions,
|
config.embedding_dimensions,
|
||||||
config.embedding_max_tokens,
|
config.embedding_max_completion_tokens,
|
||||||
config.embedding_endpoint,
|
config.embedding_endpoint,
|
||||||
config.embedding_api_key,
|
config.embedding_api_key,
|
||||||
config.embedding_api_version,
|
config.embedding_api_version,
|
||||||
|
|
@ -41,7 +41,7 @@ def create_embedding_engine(
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
embedding_dimensions,
|
embedding_dimensions,
|
||||||
embedding_max_tokens,
|
embedding_max_completion_tokens,
|
||||||
embedding_endpoint,
|
embedding_endpoint,
|
||||||
embedding_api_key,
|
embedding_api_key,
|
||||||
embedding_api_version,
|
embedding_api_version,
|
||||||
|
|
@ -58,7 +58,7 @@ def create_embedding_engine(
|
||||||
'ollama', or another supported provider.
|
'ollama', or another supported provider.
|
||||||
- embedding_model: The model to be used for the embedding engine.
|
- embedding_model: The model to be used for the embedding engine.
|
||||||
- embedding_dimensions: The number of dimensions for the embeddings.
|
- embedding_dimensions: The number of dimensions for the embeddings.
|
||||||
- embedding_max_tokens: The maximum number of tokens for the embeddings.
|
- embedding_max_completion_tokens: The maximum number of tokens for the embeddings.
|
||||||
- embedding_endpoint: The endpoint for the embedding service, relevant for certain
|
- embedding_endpoint: The endpoint for the embedding service, relevant for certain
|
||||||
providers.
|
providers.
|
||||||
- embedding_api_key: API key to authenticate with the embedding service, if
|
- embedding_api_key: API key to authenticate with the embedding service, if
|
||||||
|
|
@ -81,7 +81,7 @@ def create_embedding_engine(
|
||||||
return FastembedEmbeddingEngine(
|
return FastembedEmbeddingEngine(
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
dimensions=embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=embedding_max_tokens,
|
max_completion_tokens=embedding_max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if embedding_provider == "ollama":
|
if embedding_provider == "ollama":
|
||||||
|
|
@ -90,7 +90,7 @@ def create_embedding_engine(
|
||||||
return OllamaEmbeddingEngine(
|
return OllamaEmbeddingEngine(
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
dimensions=embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=embedding_max_tokens,
|
max_completion_tokens=embedding_max_completion_tokens,
|
||||||
endpoint=embedding_endpoint,
|
endpoint=embedding_endpoint,
|
||||||
huggingface_tokenizer=huggingface_tokenizer,
|
huggingface_tokenizer=huggingface_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
@ -104,5 +104,5 @@ def create_embedding_engine(
|
||||||
api_version=embedding_api_version,
|
api_version=embedding_api_version,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
dimensions=embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=embedding_max_tokens,
|
max_completion_tokens=embedding_max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class LLMConfig(BaseSettings):
|
||||||
- llm_api_version
|
- llm_api_version
|
||||||
- llm_temperature
|
- llm_temperature
|
||||||
- llm_streaming
|
- llm_streaming
|
||||||
- llm_max_tokens
|
- llm_max_completion_tokens
|
||||||
- transcription_model
|
- transcription_model
|
||||||
- graph_prompt_path
|
- graph_prompt_path
|
||||||
- llm_rate_limit_enabled
|
- llm_rate_limit_enabled
|
||||||
|
|
@ -41,7 +41,7 @@ class LLMConfig(BaseSettings):
|
||||||
llm_api_version: Optional[str] = None
|
llm_api_version: Optional[str] = None
|
||||||
llm_temperature: float = 0.0
|
llm_temperature: float = 0.0
|
||||||
llm_streaming: bool = False
|
llm_streaming: bool = False
|
||||||
llm_max_tokens: int = 16384
|
llm_max_completion_tokens: int = 16384
|
||||||
|
|
||||||
baml_llm_provider: str = "openai"
|
baml_llm_provider: str = "openai"
|
||||||
baml_llm_model: str = "gpt-5-mini"
|
baml_llm_model: str = "gpt-5-mini"
|
||||||
|
|
@ -171,7 +171,7 @@ class LLMConfig(BaseSettings):
|
||||||
"api_version": self.llm_api_version,
|
"api_version": self.llm_api_version,
|
||||||
"temperature": self.llm_temperature,
|
"temperature": self.llm_temperature,
|
||||||
"streaming": self.llm_streaming,
|
"streaming": self.llm_streaming,
|
||||||
"max_tokens": self.llm_max_tokens,
|
"max_completion_tokens": self.llm_max_completion_tokens,
|
||||||
"transcription_model": self.transcription_model,
|
"transcription_model": self.transcription_model,
|
||||||
"graph_prompt_path": self.graph_prompt_path,
|
"graph_prompt_path": self.graph_prompt_path,
|
||||||
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
name = "Anthropic"
|
name = "Anthropic"
|
||||||
model: str
|
model: str
|
||||||
|
|
||||||
def __init__(self, max_tokens: int, model: str = None):
|
def __init__(self, max_completion_tokens: int, model: str = None):
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
|
|
@ -31,7 +31,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@sleep_and_retry_async()
|
||||||
@rate_limit_async
|
@rate_limit_async
|
||||||
|
|
@ -57,7 +57,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
|
|
||||||
return await self.aclient(
|
return await self.aclient(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_tokens=4096,
|
max_completion_tokens=4096,
|
||||||
max_retries=5,
|
max_retries=5,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_completion_tokens: int,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
|
|
@ -44,7 +44,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@sleep_and_retry_async()
|
@sleep_and_retry_async()
|
||||||
|
|
@ -90,7 +90,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
model=f"{self.model}",
|
model=f"{self.model}",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
max_tokens=self.max_tokens,
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
response_format=response_schema,
|
response_format=response_schema,
|
||||||
timeout=100,
|
timeout=100,
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
name: str,
|
name: str,
|
||||||
max_tokens: int,
|
max_completion_tokens: int,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
|
|
@ -50,7 +50,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
self.fallback_model = fallback_model
|
self.fallback_model = fallback_model
|
||||||
self.fallback_api_key = fallback_api_key
|
self.fallback_api_key = fallback_api_key
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,15 @@ def get_llm_client():
|
||||||
# Check if max_token value is defined in liteLLM for given model
|
# Check if max_token value is defined in liteLLM for given model
|
||||||
# if not use value from cognee configuration
|
# if not use value from cognee configuration
|
||||||
from cognee.infrastructure.llm.utils import (
|
from cognee.infrastructure.llm.utils import (
|
||||||
get_model_max_tokens,
|
get_model_max_completion_tokens,
|
||||||
) # imported here to avoid circular imports
|
) # imported here to avoid circular imports
|
||||||
|
|
||||||
model_max_tokens = get_model_max_tokens(llm_config.llm_model)
|
model_max_completion_tokens = get_model_max_completion_tokens(llm_config.llm_model)
|
||||||
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens
|
max_completion_tokens = (
|
||||||
|
model_max_completion_tokens
|
||||||
|
if model_max_completion_tokens
|
||||||
|
else llm_config.llm_max_completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
|
|
@ -74,7 +78,7 @@ def get_llm_client():
|
||||||
api_version=llm_config.llm_api_version,
|
api_version=llm_config.llm_api_version,
|
||||||
model=llm_config.llm_model,
|
model=llm_config.llm_model,
|
||||||
transcription_model=llm_config.transcription_model,
|
transcription_model=llm_config.transcription_model,
|
||||||
max_tokens=max_tokens,
|
max_completion_tokens=max_completion_tokens,
|
||||||
streaming=llm_config.llm_streaming,
|
streaming=llm_config.llm_streaming,
|
||||||
fallback_api_key=llm_config.fallback_api_key,
|
fallback_api_key=llm_config.fallback_api_key,
|
||||||
fallback_endpoint=llm_config.fallback_endpoint,
|
fallback_endpoint=llm_config.fallback_endpoint,
|
||||||
|
|
@ -94,7 +98,7 @@ def get_llm_client():
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Ollama",
|
"Ollama",
|
||||||
max_tokens=max_tokens,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
|
|
@ -102,7 +106,9 @@ def get_llm_client():
|
||||||
AnthropicAdapter,
|
AnthropicAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
|
return AnthropicAdapter(
|
||||||
|
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
|
||||||
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
|
|
@ -117,7 +123,7 @@ def get_llm_client():
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Custom",
|
"Custom",
|
||||||
max_tokens=max_tokens,
|
max_completion_tokens=max_completion_tokens,
|
||||||
fallback_api_key=llm_config.fallback_api_key,
|
fallback_api_key=llm_config.fallback_api_key,
|
||||||
fallback_endpoint=llm_config.fallback_endpoint,
|
fallback_endpoint=llm_config.fallback_endpoint,
|
||||||
fallback_model=llm_config.fallback_model,
|
fallback_model=llm_config.fallback_model,
|
||||||
|
|
@ -134,7 +140,7 @@ def get_llm_client():
|
||||||
return GeminiAdapter(
|
return GeminiAdapter(
|
||||||
api_key=llm_config.llm_api_key,
|
api_key=llm_config.llm_api_key,
|
||||||
model=llm_config.llm_model,
|
model=llm_config.llm_model,
|
||||||
max_tokens=max_tokens,
|
max_completion_tokens=max_completion_tokens,
|
||||||
endpoint=llm_config.llm_endpoint,
|
endpoint=llm_config.llm_endpoint,
|
||||||
api_version=llm_config.llm_api_version,
|
api_version=llm_config.llm_api_version,
|
||||||
streaming=llm_config.llm_streaming,
|
streaming=llm_config.llm_streaming,
|
||||||
|
|
|
||||||
|
|
@ -30,16 +30,18 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
- model
|
- model
|
||||||
- api_key
|
- api_key
|
||||||
- endpoint
|
- endpoint
|
||||||
- max_tokens
|
- max_completion_tokens
|
||||||
- aclient
|
- aclient
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
|
def __init__(
|
||||||
|
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
|
||||||
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
self.aclient = instructor.from_openai(
|
self.aclient = instructor.from_openai(
|
||||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||||
|
|
@ -159,7 +161,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
max_tokens=300,
|
max_completion_tokens=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure response is valid before accessing .choices[0].message.content
|
# Ensure response is valid before accessing .choices[0].message.content
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_version: str,
|
api_version: str,
|
||||||
model: str,
|
model: str,
|
||||||
transcription_model: str,
|
transcription_model: str,
|
||||||
max_tokens: int,
|
max_completion_tokens: int,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
|
|
@ -77,7 +77,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
|
||||||
self.fallback_model = fallback_model
|
self.fallback_model = fallback_model
|
||||||
|
|
@ -301,7 +301,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
api_base=self.endpoint,
|
api_base=self.endpoint,
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
max_tokens=300,
|
max_completion_tokens=300,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 3072,
|
max_completion_tokens: int = 3072,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
# Get LLM API key from config
|
# Get LLM API key from config
|
||||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||||
|
|
|
||||||
|
|
@ -14,17 +14,17 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
||||||
|
|
||||||
Instance variables include:
|
Instance variables include:
|
||||||
- model: str
|
- model: str
|
||||||
- max_tokens: int
|
- max_completion_tokens: int
|
||||||
- tokenizer: AutoTokenizer
|
- tokenizer: AutoTokenizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 512,
|
max_completion_tokens: int = 512,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
# Import here to make it an optional dependency
|
# Import here to make it an optional dependency
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
|
||||||
|
|
@ -16,17 +16,17 @@ class MistralTokenizer(TokenizerInterface):
|
||||||
|
|
||||||
Instance variables include:
|
Instance variables include:
|
||||||
- model: str
|
- model: str
|
||||||
- max_tokens: int
|
- max_completion_tokens: int
|
||||||
- tokenizer: MistralTokenizer
|
- tokenizer: MistralTokenizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 3072,
|
max_completion_tokens: int = 3072,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
# Import here to make it an optional dependency
|
# Import here to make it an optional dependency
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,10 @@ class TikTokenTokenizer(TokenizerInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
max_tokens: int = 8191,
|
max_completion_tokens: int = 8191,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
# Initialize TikToken for GPT based on model
|
# Initialize TikToken for GPT based on model
|
||||||
if model:
|
if model:
|
||||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||||
|
|
@ -93,9 +93,9 @@ class TikTokenTokenizer(TokenizerInterface):
|
||||||
num_tokens = len(self.tokenizer.encode(text))
|
num_tokens = len(self.tokenizer.encode(text))
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
def trim_text_to_max_completion_tokens(self, text: str) -> str:
|
||||||
"""
|
"""
|
||||||
Trim the text so that the number of tokens does not exceed max_tokens.
|
Trim the text so that the number of tokens does not exceed max_completion_tokens.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
@ -111,13 +111,13 @@ class TikTokenTokenizer(TokenizerInterface):
|
||||||
num_tokens = self.count_tokens(text)
|
num_tokens = self.count_tokens(text)
|
||||||
|
|
||||||
# If the number of tokens is within the limit, return the text as is
|
# If the number of tokens is within the limit, return the text as is
|
||||||
if num_tokens <= self.max_tokens:
|
if num_tokens <= self.max_completion_tokens:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# If the number exceeds the limit, trim the text
|
# If the number exceeds the limit, trim the text
|
||||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||||
encoded_text = self.tokenizer.encode(text)
|
encoded_text = self.tokenizer.encode(text)
|
||||||
trimmed_encoded_text = encoded_text[: self.max_tokens]
|
trimmed_encoded_text = encoded_text[: self.max_completion_tokens]
|
||||||
# Decoding the trimmed text
|
# Decoding the trimmed text
|
||||||
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
||||||
return trimmed_text
|
return trimmed_text
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,13 @@ def get_max_chunk_tokens():
|
||||||
|
|
||||||
# We need to make sure chunk size won't take more than half of LLM max context token size
|
# We need to make sure chunk size won't take more than half of LLM max context token size
|
||||||
# but it also can't be bigger than the embedding engine max token size
|
# but it also can't be bigger than the embedding engine max token size
|
||||||
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
|
llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
|
||||||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
|
||||||
|
|
||||||
return max_chunk_tokens
|
return max_chunk_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_model_max_tokens(model_name: str):
|
def get_model_max_completion_tokens(model_name: str):
|
||||||
"""
|
"""
|
||||||
Retrieve the maximum token limit for a specified model name if it exists.
|
Retrieve the maximum token limit for a specified model name if it exists.
|
||||||
|
|
||||||
|
|
@ -56,15 +56,15 @@ def get_model_max_tokens(model_name: str):
|
||||||
|
|
||||||
Number of max tokens of model, or None if model is unknown
|
Number of max tokens of model, or None if model is unknown
|
||||||
"""
|
"""
|
||||||
max_tokens = None
|
max_completion_tokens = None
|
||||||
|
|
||||||
if model_name in litellm.model_cost:
|
if model_name in litellm.model_cost:
|
||||||
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||||
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
|
||||||
else:
|
else:
|
||||||
logger.info("Model not found in LiteLLM's model_cost.")
|
logger.info("Model not found in LiteLLM's model_cost.")
|
||||||
|
|
||||||
return max_tokens
|
return max_completion_tokens
|
||||||
|
|
||||||
|
|
||||||
async def test_llm_connection():
|
async def test_llm_connection():
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class QABenchmarkGraphiti(QABenchmarkRAG):
|
||||||
|
|
||||||
async def initialize_rag(self) -> Any:
|
async def initialize_rag(self) -> Any:
|
||||||
"""Initialize Graphiti and LLM."""
|
"""Initialize Graphiti and LLM."""
|
||||||
llm_config = LLMConfig(model=self.config.model_name, max_tokens=65536)
|
llm_config = LLMConfig(model=self.config.model_name, max_completion_tokens=65536)
|
||||||
llm_client = OpenAIClient(config=llm_config)
|
llm_client = OpenAIClient(config=llm_config)
|
||||||
graphiti = Graphiti(
|
graphiti = Graphiti(
|
||||||
self.config.db_url,
|
self.config.db_url,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue