diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index c037b45e0..50dde8e89 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,7 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer from transformers import AutoTokenizer @@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): # If model also contains provider information, extract only model information model = self.model.split("/")[-1] - if "openai" in self.provider.lower() or "gpt" in self.model: + if "openai" in self.provider.lower(): tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) + elif "gemini" in self.provider.lower(): + tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens) else: tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 62335ea41..cb72a46f4 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class EmbeddingConfig(BaseSettings): - embedding_model: Optional[str] = "text-embedding-3-large" + embedding_provider: Optional[str] = "openai" + embedding_model: Optional[str] = "openai/text-embedding-3-large" embedding_dimensions: Optional[int] = 3072 embedding_endpoint: Optional[str] = None embedding_api_key: Optional[str] = None diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index e894da892..d3011f059 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine: return LiteLLMEmbeddingEngine( # If OpenAI API is used for embeddings, litellm needs only the api_key. + provider=config.embedding_provider, api_key=config.embedding_api_key or llm_config.llm_api_key, endpoint=config.embedding_endpoint, api_version=config.embedding_api_version, diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py new file mode 100644 index 000000000..4ed4ad4d5 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py @@ -0,0 +1 @@ +from .adapter import GeminiTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py new file mode 100644 index 000000000..697bc9577 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -0,0 +1,44 @@ +from typing import List, Any + +from ..tokenizer_interface import TokenizerInterface + + +class GeminiTokenizer(TokenizerInterface): + def __init__( + self, + model: str, + max_tokens: int = float("inf"), + ): + self.model = model + self.max_tokens = max_tokens + + # Get LLM API key from config + from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config + from cognee.infrastructure.llm.config import get_llm_config + + config = get_embedding_config() + llm_config = get_llm_config() + + import google.generativeai as genai + + genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key) + + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + def num_tokens_from_text(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + import google.generativeai as genai + + return len(genai.embed_content(model=f"models/{self.model}", content=text)) + + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError