feat: Add gemini tokenizer to cognee

This commit is contained in:
Igor Ilic 2025-01-23 17:55:04 +01:00
parent 294ed1d960
commit b686376c54
5 changed files with 52 additions and 2 deletions

View file

@ -6,6 +6,7 @@ import litellm
import os import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
# If model also contains provider information, extract only model information # If model also contains provider information, extract only model information
model = self.model.split("/")[-1] model = self.model.split("/")[-1]
if "openai" in self.provider.lower() or "gpt" in self.model: if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
else: else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)

View file

@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings): class EmbeddingConfig(BaseSettings):
embedding_model: Optional[str] = "text-embedding-3-large" embedding_provider: Optional[str] = "openai"
embedding_model: Optional[str] = "openai/text-embedding-3-large"
embedding_dimensions: Optional[int] = 3072 embedding_dimensions: Optional[int] = 3072
embedding_endpoint: Optional[str] = None embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None embedding_api_key: Optional[str] = None

View file

@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine:
return LiteLLMEmbeddingEngine( return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key. # If OpenAI API is used for embeddings, litellm needs only the api_key.
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key, api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint, endpoint=config.embedding_endpoint,
api_version=config.embedding_api_version, api_version=config.embedding_api_version,

View file

@ -0,0 +1 @@
from .adapter import GeminiTokenizer

View file

@ -0,0 +1,44 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
# Get LLM API key from config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
config = get_embedding_config()
llm_config = get_llm_config()
import google.generativeai as genai
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
def num_tokens_from_text(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
import google.generativeai as genai
return len(genai.embed_content(model=f"models/{self.model}", content=text))
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError