feat: Add gemini tokenizer to cognee
This commit is contained in:
parent
294ed1d960
commit
b686376c54
5 changed files with 52 additions and 2 deletions
|
|
@ -6,6 +6,7 @@ import litellm
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||||
|
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
|
||||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
# If model also contains provider information, extract only model information
|
# If model also contains provider information, extract only model information
|
||||||
model = self.model.split("/")[-1]
|
model = self.model.split("/")[-1]
|
||||||
|
|
||||||
if "openai" in self.provider.lower() or "gpt" in self.model:
|
if "openai" in self.provider.lower():
|
||||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||||
|
elif "gemini" in self.provider.lower():
|
||||||
|
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
||||||
else:
|
else:
|
||||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig(BaseSettings):
|
class EmbeddingConfig(BaseSettings):
|
||||||
embedding_model: Optional[str] = "text-embedding-3-large"
|
embedding_provider: Optional[str] = "openai"
|
||||||
|
embedding_model: Optional[str] = "openai/text-embedding-3-large"
|
||||||
embedding_dimensions: Optional[int] = 3072
|
embedding_dimensions: Optional[int] = 3072
|
||||||
embedding_endpoint: Optional[str] = None
|
embedding_endpoint: Optional[str] = None
|
||||||
embedding_api_key: Optional[str] = None
|
embedding_api_key: Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
||||||
|
|
||||||
return LiteLLMEmbeddingEngine(
|
return LiteLLMEmbeddingEngine(
|
||||||
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||||
|
provider=config.embedding_provider,
|
||||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
||||||
endpoint=config.embedding_endpoint,
|
endpoint=config.embedding_endpoint,
|
||||||
api_version=config.embedding_api_version,
|
api_version=config.embedding_api_version,
|
||||||
|
|
|
||||||
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/Gemini/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .adapter import GeminiTokenizer
|
||||||
44
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
44
cognee/infrastructure/llm/tokenizer/Gemini/adapter.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
from ..tokenizer_interface import TokenizerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiTokenizer(TokenizerInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int = float("inf"),
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# Get LLM API key from config
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
config = get_embedding_config()
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
|
||||||
|
|
||||||
|
def extract_tokens(self, text: str) -> List[Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def num_tokens_from_text(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of tokens in the given text.
|
||||||
|
Args:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
number of tokens in the given text
|
||||||
|
|
||||||
|
"""
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
return len(genai.embed_content(model=f"models/{self.model}", content=text))
|
||||||
|
|
||||||
|
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
Loading…
Add table
Reference in a new issue