feat: Add gemini tokenizer to cognee
This commit is contained in:
parent
294ed1d960
commit
b686376c54
5 changed files with 52 additions and 2 deletions
|
|
@ -6,6 +6,7 @@ import litellm
|
|||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
|
@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
# If model also contains provider information, extract only model information
|
||||
model = self.model.split("/")[-1]
|
||||
|
||||
if "openai" in self.provider.lower() or "gpt" in self.model:
|
||||
if "openai" in self.provider.lower():
|
||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||
|
||||
|
||||
class EmbeddingConfig(BaseSettings):
|
||||
embedding_model: Optional[str] = "text-embedding-3-large"
|
||||
embedding_provider: Optional[str] = "openai"
|
||||
embedding_model: Optional[str] = "openai/text-embedding-3-large"
|
||||
embedding_dimensions: Optional[int] = 3072
|
||||
embedding_endpoint: Optional[str] = None
|
||||
embedding_api_key: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
|
||||
return LiteLLMEmbeddingEngine(
|
||||
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||
provider=config.embedding_provider,
|
||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
||||
endpoint=config.embedding_endpoint,
|
||||
api_version=config.embedding_api_version,
|
||||
|
|
|
|||
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .adapter import GeminiTokenizer
|
||||
44
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
44
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from typing import List, Any
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class GeminiTokenizer(TokenizerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = float("inf"),
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Get LLM API key from config
|
||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
import google.generativeai as genai
|
||||
|
||||
return len(genai.embed_content(model=f"models/{self.model}", content=text))
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
raise NotImplementedError
|
||||
Loading…
Add table
Reference in a new issue