feat: Add gemini tokenizer to cognee

This commit is contained in:
Igor Ilic 2025-01-23 17:55:04 +01:00
parent 294ed1d960
commit b686376c54
5 changed files with 52 additions and 2 deletions

View file

@ -6,6 +6,7 @@ import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from transformers import AutoTokenizer
@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower() or "gpt" in self.model:
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)

View file

@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings):
embedding_model: Optional[str] = "text-embedding-3-large"
embedding_provider: Optional[str] = "openai"
embedding_model: Optional[str] = "openai/text-embedding-3-large"
embedding_dimensions: Optional[int] = 3072
embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None

View file

@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine:
return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint,
api_version=config.embedding_api_version,

View file

@ -0,0 +1 @@
from .adapter import GeminiTokenizer

View file

@ -0,0 +1,44 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
# Get LLM API key from config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
config = get_embedding_config()
llm_config = get_llm_config()
import google.generativeai as genai
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
def num_tokens_from_text(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
import google.generativeai as genai
return len(genai.embed_content(model=f"models/{self.model}", content=text))
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError