diff --git a/.env.template b/.env.template index b91f89540..b1e7057e2 100644 --- a/.env.template +++ b/.env.template @@ -30,6 +30,9 @@ EMBEDDING_DIMENSIONS=3072 EMBEDDING_MAX_TOKENS=8191 # If embedding key is not provided same key set for LLM_API_KEY will be used #EMBEDDING_API_KEY="your_api_key" +# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch, +# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models +#EMBEDDING_BATCH_SIZE=2048 # If using BAML structured output these env variables will be used BAML_LLM_PROVIDER=openai @@ -52,18 +55,18 @@ BAML_LLM_API_VERSION="" ################################################################################ # Configure storage backend (local filesystem or S3) # STORAGE_BACKEND="local" # Default: uses local filesystem -# +# # -- To switch to S3 storage, uncomment and fill these: --------------------- # STORAGE_BACKEND="s3" # STORAGE_BUCKET_NAME="your-bucket-name" # AWS_REGION="us-east-1" # AWS_ACCESS_KEY_ID="your-access-key" # AWS_SECRET_ACCESS_KEY="your-secret-key" -# +# # -- S3 Root Directories (optional) ----------------------------------------- # DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data" # SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system" -# +# # -- Cache Directory (auto-configured for S3) ------------------------------- # When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache # To override the automatic S3 cache location, uncomment: diff --git a/.github/actions/cognee_setup/action.yml b/.github/actions/cognee_setup/action.yml index e46a42edb..1326f2d81 100644 --- a/.github/actions/cognee_setup/action.yml +++ b/.github/actions/cognee_setup/action.yml @@ -41,4 +41,4 @@ runs: EXTRA_ARGS="$EXTRA_ARGS --extra $extra" done fi - uv sync --extra api --extra docs --extra evals --extra gemini --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS + uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS diff --git a/.github/workflows/test_llms.yml b/.github/workflows/test_llms.yml index 5a0f947c9..6b0221309 100644 --- a/.github/workflows/test_llms.yml +++ b/.github/workflows/test_llms.yml @@ -27,7 +27,7 @@ jobs: env: LLM_PROVIDER: "gemini" LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }} - LLM_MODEL: "gemini/gemini-1.5-flash" + LLM_MODEL: "gemini/gemini-2.0-flash" EMBEDDING_PROVIDER: "gemini" EMBEDDING_API_KEY: ${{ secrets.GEMINI_API_KEY }} EMBEDDING_MODEL: "gemini/text-embedding-004" @@ -83,4 +83,4 @@ jobs: EMBEDDING_MODEL: "openai/text-embedding-3-large" EMBEDDING_DIMENSIONS: "3072" EMBEDDING_MAX_TOKENS: "8191" - run: uv run python ./examples/python/simple_example.py \ No newline at end of file + run: uv run python ./examples/python/simple_example.py diff --git a/Dockerfile b/Dockerfile index be29f359a..b1b77513d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,7 +31,7 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./ # Install the project's dependencies using the lockfile and settings RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable + uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable # Copy Alembic configuration COPY alembic.ini /app/alembic.ini @@ -42,7 +42,7 @@ COPY alembic/ /app/alembic COPY ./cognee /app/cognee COPY ./distributed /app/distributed RUN --mount=type=cache,target=/root/.cache/uv \ -uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable +uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable FROM python:3.12-slim-bookworm diff --git a/cognee-mcp/pyproject.toml b/cognee-mcp/pyproject.toml index e1a3a092c..e4fcab909 100644 --- a/cognee-mcp/pyproject.toml +++ b/cognee-mcp/pyproject.toml @@ -8,6 +8,7 @@ requires-python = ">=3.10" dependencies = [ # For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes. #"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee", + # TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4 "cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.3.4", "fastmcp>=2.10.0,<3.0.0", "mcp>=1.12.0,<2.0.0", diff --git a/cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py index afa180c67..c410d99dc 100644 --- a/cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py @@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol): - int: An integer representing the number of dimensions in the embedding vector. """ raise NotImplementedError() + + def get_batch_size(self) -> int: + """ + Return the desired batch size for embedding calls + + Returns: + + """ + raise NotImplementedError() diff --git a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py index acb041e76..8f922d934 100644 --- a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py @@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine): model: Optional[str] = "openai/text-embedding-3-large", dimensions: Optional[int] = 3072, max_completion_tokens: int = 512, + batch_size: int = 100, ): self.model = model self.dimensions = dimensions self.max_completion_tokens = max_completion_tokens self.tokenizer = self.get_tokenizer() + self.batch_size = batch_size # self.retry_count = 0 self.embedding_model = TextEmbedding(model_name=model) @@ -101,6 +103,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine): """ return self.dimensions + def get_batch_size(self) -> int: + """ + Return the desired batch size for embedding calls + + Returns: + + """ + return self.batch_size + def get_tokenizer(self): """ Instantiate and return the tokenizer used for preparing text for embedding. diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 0cb8286fc..d68941d25 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): endpoint: str = None, api_version: str = None, max_completion_tokens: int = 512, + batch_size: int = 100, ): self.api_key = api_key self.endpoint = endpoint @@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.max_completion_tokens = max_completion_tokens self.tokenizer = self.get_tokenizer() self.retry_count = 0 + self.batch_size = batch_size enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): @@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): """ return self.dimensions + def get_batch_size(self) -> int: + """ + Return the desired batch size for embedding calls + + Returns: + + """ + return self.batch_size + def get_tokenizer(self): """ Load and return the appropriate tokenizer for the specified model based on the provider. @@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): model=model, max_completion_tokens=self.max_completion_tokens ) elif "gemini" in self.provider.lower(): - tokenizer = GeminiTokenizer( - model=model, max_completion_tokens=self.max_completion_tokens + # Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to + # count tokens as we calculate tokens word by word + tokenizer = TikTokenTokenizer( + model=None, max_completion_tokens=self.max_completion_tokens ) + # Note: Gemini Tokenizer expects an LLM model as input and not the embedding model + # tokenizer = GeminiTokenizer( + # llm_model=llm_model, max_completion_tokens=self.max_completion_tokens + # ) elif "mistral" in self.provider.lower(): tokenizer = MistralTokenizer( model=model, max_completion_tokens=self.max_completion_tokens diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index e6e590597..e79ba3f6a 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine): max_completion_tokens: int = 512, endpoint: Optional[str] = "http://localhost:11434/api/embeddings", huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral", + batch_size: int = 100, ): self.model = model self.dimensions = dimensions self.max_completion_tokens = max_completion_tokens self.endpoint = endpoint self.huggingface_tokenizer_name = huggingface_tokenizer + self.batch_size = batch_size self.tokenizer = self.get_tokenizer() enable_mocking = os.getenv("MOCK_EMBEDDING", "false") @@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine): """ return self.dimensions + def get_batch_size(self) -> int: + """ + Return the desired batch size for embedding calls + + Returns: + + """ + return self.batch_size + def get_tokenizer(self): """ Load and return a HuggingFace tokenizer for the embedding engine. diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 04a1f18f2..24f724151 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings): embedding_api_key: Optional[str] = None embedding_api_version: Optional[str] = None embedding_max_completion_tokens: Optional[int] = 8191 + embedding_batch_size: Optional[int] = None huggingface_tokenizer: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") + def model_post_init(self, __context) -> None: + # If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models + if not self.embedding_batch_size and self.embedding_provider.lower() == "openai": + self.embedding_batch_size = 2048 + elif not self.embedding_batch_size: + self.embedding_batch_size = 100 + def to_dict(self) -> dict: """ Serialize all embedding configuration settings to a dictionary. diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index e7fcf4e94..de56c8971 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine: config.embedding_endpoint, config.embedding_api_key, config.embedding_api_version, + config.embedding_batch_size, config.huggingface_tokenizer, llm_config.llm_api_key, llm_config.llm_provider, @@ -46,6 +47,7 @@ def create_embedding_engine( embedding_endpoint, embedding_api_key, embedding_api_version, + embedding_batch_size, huggingface_tokenizer, llm_api_key, llm_provider, @@ -84,6 +86,7 @@ def create_embedding_engine( model=embedding_model, dimensions=embedding_dimensions, max_completion_tokens=embedding_max_completion_tokens, + batch_size=embedding_batch_size, ) if embedding_provider == "ollama": @@ -95,6 +98,7 @@ def create_embedding_engine( max_completion_tokens=embedding_max_completion_tokens, endpoint=embedding_endpoint, huggingface_tokenizer=huggingface_tokenizer, + batch_size=embedding_batch_size, ) from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine @@ -108,4 +112,5 @@ def create_embedding_engine( model=embedding_model, dimensions=embedding_dimensions, max_completion_tokens=embedding_max_completion_tokens, + batch_size=embedding_batch_size, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 2c7b73eef..510d29ce8 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -1,115 +1,155 @@ -import litellm -from pydantic import BaseModel -from typing import Type, Optional -from litellm import acompletion, JSONSchemaValidationError +"""Adapter for Generic API LLM provider API""" -from cognee.shared.logging_utils import get_logger -from cognee.modules.observability.get_observe import get_observe -from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError +import litellm +import instructor +from typing import Type +from pydantic import BaseModel +from openai import ContentFilterFinishReasonError +from litellm.exceptions import ContentPolicyViolationError +from instructor.core import InstructorRetryException + +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) -from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( rate_limit_async, sleep_and_retry_async, ) -logger = get_logger() -observe = get_observe() - class GeminiAdapter(LLMInterface): """ - Handles interactions with a language model API. + Adapter for Gemini API LLM provider. - Public methods include: - - acreate_structured_output - - show_prompt + This class initializes the API adapter with necessary credentials and configurations for + interacting with the gemini LLM models. It provides methods for creating structured outputs + based on user input and system prompts. + + Public methods: + - acreate_structured_output(text_input: str, system_prompt: str, response_model: + Type[BaseModel]) -> BaseModel """ - MAX_RETRIES = 5 + name: str + model: str + api_key: str def __init__( self, + endpoint, api_key: str, model: str, + api_version: str, max_completion_tokens: int, - endpoint: Optional[str] = None, - api_version: Optional[str] = None, - streaming: bool = False, - ) -> None: - self.api_key = api_key + fallback_model: str = None, + fallback_api_key: str = None, + fallback_endpoint: str = None, + ): self.model = model + self.api_key = api_key self.endpoint = endpoint self.api_version = api_version - self.streaming = streaming self.max_completion_tokens = max_completion_tokens - @observe(as_type="generation") + self.fallback_model = fallback_model + self.fallback_api_key = fallback_api_key + self.fallback_endpoint = fallback_endpoint + + self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + @sleep_and_retry_async() @rate_limit_async async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> BaseModel: """ - Generate structured output from the language model based on the provided input and - system prompt. + Generate a response from a user query. - This method handles retries and raises a ValueError if the request fails or the response - does not conform to the expected schema, logging errors accordingly. + This asynchronous method sends a user query and a system prompt to a language model and + retrieves the generated response. It handles API communication and retries up to a + specified limit in case of request failures. Parameters: ----------- - - text_input (str): The user input text to generate a response for. - - system_prompt (str): The system's prompt or context to influence the language - model's generation. - - response_model (Type[BaseModel]): A model type indicating the expected format of - the response. + - text_input (str): The input text from the user to generate a response for. + - system_prompt (str): A prompt that provides context or instructions for the + response generation. + - response_model (Type[BaseModel]): A Pydantic model that defines the structure of + the expected response. Returns: -------- - - BaseModel: Returns the generated response as an instance of the specified response - model. + - BaseModel: An instance of the specified response model containing the structured + output from the language model. """ + try: - if response_model is str: - response_schema = {"type": "string"} - else: - response_schema = response_model + return await self.aclient.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": f"""{text_input}""", + }, + { + "role": "system", + "content": system_prompt, + }, + ], + api_key=self.api_key, + max_retries=5, + api_base=self.endpoint, + api_version=self.api_version, + response_model=response_model, + ) + except ( + ContentFilterFinishReasonError, + ContentPolicyViolationError, + InstructorRetryException, + ) as error: + if ( + isinstance(error, InstructorRetryException) + and "content management policy" not in str(error).lower() + ): + raise error - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": text_input}, - ] - - try: - response = await acompletion( - model=f"{self.model}", - messages=messages, - api_key=self.api_key, - max_completion_tokens=self.max_completion_tokens, - temperature=0.1, - response_format=response_schema, - timeout=100, - num_retries=self.MAX_RETRIES, + if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint): + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" ) - if response.choices and response.choices[0].message.content: - content = response.choices[0].message.content - if response_model is str: - return content - return response_model.model_validate_json(content) - - except litellm.exceptions.BadRequestError as e: - logger.error(f"Bad request error: {str(e)}") - raise ValueError(f"Invalid request: {str(e)}") - - raise ValueError("Failed to get valid response after retries") - - except JSONSchemaValidationError as e: - logger.error(f"Schema validation failed: {str(e)}") - logger.debug(f"Raw response: {e.raw_response}") - raise ValueError(f"Response failed schema validation: {str(e)}") + try: + return await self.aclient.chat.completions.create( + model=self.fallback_model, + messages=[ + { + "role": "user", + "content": f"""{text_input}""", + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + api_key=self.fallback_api_key, + api_base=self.fallback_endpoint, + response_model=response_model, + ) + except ( + ContentFilterFinishReasonError, + ContentPolicyViolationError, + InstructorRetryException, + ) as error: + if ( + isinstance(error, InstructorRetryException) + and "content management policy" not in str(error).lower() + ): + raise error + else: + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 86adac25a..8c15a5804 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -6,7 +6,7 @@ from typing import Type from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError -from instructor.exceptions import InstructorRetryException +from instructor.core import InstructorRetryException from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( @@ -56,9 +56,7 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key - ) + self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) @sleep_and_retry_async() @rate_limit_async @@ -102,6 +100,7 @@ class GenericAPIAdapter(LLMInterface): }, ], max_retries=5, + api_key=self.api_key, api_base=self.endpoint, response_model=response_model, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 049b27e04..0ae621428 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -143,7 +143,6 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, - streaming=llm_config.llm_streaming, ) else: diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 7273de976..989b240ac 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -5,15 +5,13 @@ from typing import Type from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError -from instructor.exceptions import InstructorRetryException +from instructor.core import InstructorRetryException -from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) from cognee.infrastructure.llm.exceptions import ( ContentPolicyFilterError, - MissingSystemPromptPathError, ) from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py index a57cff3f7..d595fba2d 100644 --- a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -3,6 +3,7 @@ from typing import List, Any from ..tokenizer_interface import TokenizerInterface +# NOTE: DEPRECATED as to count tokens you need to send an API request to Google it is too slow to use with Cognee class GeminiTokenizer(TokenizerInterface): """ Implements a tokenizer interface for the Gemini model, managing token extraction and @@ -16,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface): def __init__( self, - model: str, + llm_model: str, max_completion_tokens: int = 3072, ): - self.model = model + self.llm_model = llm_model self.max_completion_tokens = max_completion_tokens # Get LLM API key from config @@ -28,12 +29,11 @@ class GeminiTokenizer(TokenizerInterface): get_llm_config, ) - config = get_embedding_config() llm_config = get_llm_config() - import google.generativeai as genai + from google import genai - genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key) + self.client = genai.Client(api_key=llm_config.llm_api_key) def extract_tokens(self, text: str) -> List[Any]: """ @@ -77,6 +77,7 @@ class GeminiTokenizer(TokenizerInterface): - int: The number of tokens in the given text. """ - import google.generativeai as genai - return len(genai.embed_content(model=f"models/{self.model}", content=text)) + tokens_response = self.client.models.count_tokens(model=self.llm_model, contents=text) + + return tokens_response.total_tokens diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 9c363c04c..362412657 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -39,7 +39,7 @@ async def index_data_points(data_points: list[DataPoint]): field_name = index_name_and_field[first_occurence + 1 :] try: # In case the amount of indexable points is too large we need to send them in batches - batch_size = 100 + batch_size = vector_engine.embedding_engine.get_batch_size() for i in range(0, len(indexable_points), batch_size): batch = indexable_points[i : i + batch_size] await vector_engine.index_data_points(index_name, field_name, batch) diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 2233ab99f..f4e32f8c8 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -9,7 +9,7 @@ from cognee.modules.graph.models.EdgeType import EdgeType logger = get_logger(level=ERROR) -async def index_graph_edges(batch_size: int = 1024): +async def index_graph_edges(): """ Indexes graph edges by creating and managing vector indexes for relationship types. @@ -72,6 +72,8 @@ async def index_graph_edges(batch_size: int = 1024): for index_name, indexable_points in index_points.items(): index_name, field_name = index_name.split(".") + # Get maximum batch size for embedding model + batch_size = vector_engine.embedding_engine.get_batch_size() # We save the data in batches of {batch_size} to not put a lot of pressure on the database for start in range(0, len(indexable_points), batch_size): batch = indexable_points[start : start + batch_size] diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py index 9cd96f5b9..48bbc53e3 100644 --- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, patch, MagicMock from cognee.tasks.storage.index_graph_edges import index_graph_edges @@ -16,6 +16,7 @@ async def test_index_graph_edges_success(): ], ) mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) # Patch the globals of the function so that when it does: # vector_engine = get_vector_engine() diff --git a/poetry.lock b/poetry.lock index f873ca5a0..85dc8c879 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -878,7 +878,7 @@ description = "Extensible memoizing collections and decorators" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"deepeval\" or extra == \"chromadb\" or extra == \"docs\"" files = [ {file = "cachetools-6.2.0-py3-none-any.whl", hash = "sha256:1c76a8960c0041fcc21097e357f882197c79da0dbff766e7317890a65d7d8ba6"}, {file = "cachetools-6.2.0.tar.gz", hash = "sha256:38b328c0889450f05f5e120f56ab68c8abaf424e1275522b138ffc93253f7e32"}, @@ -2749,28 +2749,6 @@ files = [ {file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"}, ] -[[package]] -name = "google-ai-generativelanguage" -version = "0.6.15" -description = "Google Ai Generativelanguage API client library" -optional = true -python-versions = ">=3.7" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c"}, - {file = "google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3"}, -] - -[package.dependencies] -google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} -google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" -proto-plus = [ - {version = ">=1.22.3,<2.0.0dev"}, - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, -] -protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" - [[package]] name = "google-api-core" version = "2.25.1" @@ -2778,7 +2756,7 @@ description = "Google API client core library" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\"" +markers = "extra == \"docs\"" files = [ {file = "google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7"}, {file = "google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8"}, @@ -2808,26 +2786,6 @@ grpc = ["grpcio (>=1.33.2,<2.0.0)", "grpcio (>=1.49.1,<2.0.0) ; python_version > grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"] -[[package]] -name = "google-api-python-client" -version = "2.183.0" -description = "Google API Client Library for Python" -optional = true -python-versions = ">=3.7" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "google_api_python_client-2.183.0-py3-none-any.whl", hash = "sha256:2005b6e86c27be1db1a43f43e047a0f8e004159f3cceddecb08cf1624bddba31"}, - {file = "google_api_python_client-2.183.0.tar.gz", hash = "sha256:abae37e04fecf719388e5c02f707ed9cdf952f10b217c79a3e76c636762e3ea9"}, -] - -[package.dependencies] -google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0" -google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0" -google-auth-httplib2 = ">=0.2.0,<1.0.0" -httplib2 = ">=0.19.0,<1.0.0" -uritemplate = ">=3.0.1,<5" - [[package]] name = "google-auth" version = "2.41.0" @@ -2835,7 +2793,7 @@ description = "Google Authentication Library" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"deepeval\" or extra == \"chromadb\" or extra == \"docs\"" files = [ {file = "google_auth-2.41.0-py2.py3-none-any.whl", hash = "sha256:d8bed9b53ab63b7b0374656b8e1bef051f95bb14ecc0cf21ba49de7911d62e09"}, {file = "google_auth-2.41.0.tar.gz", hash = "sha256:c9d7b534ea4a5d9813c552846797fafb080312263cd4994d6622dd50992ae101"}, @@ -2856,23 +2814,6 @@ requests = ["requests (>=2.20.0,<3.0.0)"] testing = ["aiohttp (<3.10.0)", "aiohttp (>=3.6.2,<4.0.0)", "aioresponses", "cryptography (<39.0.0) ; python_version < \"3.8\"", "cryptography (<39.0.0) ; python_version < \"3.8\"", "cryptography (>=38.0.3)", "cryptography (>=38.0.3)", "flask", "freezegun", "grpcio", "mock", "oauth2client", "packaging", "pyjwt (>=2.0)", "pyopenssl (<24.3.0)", "pyopenssl (>=20.0.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-localserver", "pyu2f (>=0.1.5)", "requests (>=2.20.0,<3.0.0)", "responses", "urllib3"] urllib3 = ["packaging", "urllib3"] -[[package]] -name = "google-auth-httplib2" -version = "0.2.0" -description = "Google Authentication Library: httplib2 transport" -optional = true -python-versions = "*" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, - {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, -] - -[package.dependencies] -google-auth = "*" -httplib2 = ">=0.19.0" - [[package]] name = "google-cloud-vision" version = "3.10.2" @@ -2922,31 +2863,6 @@ websockets = ">=13.0.0,<15.1.0" aiohttp = ["aiohttp (<4.0.0)"] local-tokenizer = ["protobuf", "sentencepiece (>=0.2.0)"] -[[package]] -name = "google-generativeai" -version = "0.8.5" -description = "Google Generative AI High level API client library and tools." -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "google_generativeai-0.8.5-py3-none-any.whl", hash = "sha256:22b420817fb263f8ed520b33285f45976d5b21e904da32b80d4fd20c055123a2"}, -] - -[package.dependencies] -google-ai-generativelanguage = "0.6.15" -google-api-core = "*" -google-api-python-client = "*" -google-auth = ">=2.15.0" -protobuf = "*" -pydantic = "*" -tqdm = "*" -typing-extensions = "*" - -[package.extras] -dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] - [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -2954,7 +2870,7 @@ description = "Common protobufs used in Google APIs" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" files = [ {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"}, {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"}, @@ -3104,7 +3020,7 @@ description = "HTTP/2-based RPC framework" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" files = [ {file = "grpcio-1.75.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:1712b5890b22547dd29f3215c5788d8fc759ce6dd0b85a6ba6e2731f2d04c088"}, {file = "grpcio-1.75.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8d04e101bba4b55cea9954e4aa71c24153ba6182481b487ff376da28d4ba46cf"}, @@ -3182,7 +3098,7 @@ description = "Status proto mapping for gRPC" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\"" +markers = "extra == \"docs\"" files = [ {file = "grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3"}, {file = "grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50"}, @@ -3374,22 +3290,6 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] -[[package]] -name = "httplib2" -version = "0.31.0" -description = "A comprehensive HTTP client library." -optional = true -python-versions = ">=3.6" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "httplib2-0.31.0-py3-none-any.whl", hash = "sha256:b9cd78abea9b4e43a7714c6e0f8b6b8561a6fc1e95d5dbd367f5bf0ef35f5d24"}, - {file = "httplib2-0.31.0.tar.gz", hash = "sha256:ac7ab497c50975147d4f7b1ade44becc7df2f8954d42b38b3d69c515f531135c"}, -] - -[package.dependencies] -pyparsing = ">=3.0.4,<4" - [[package]] name = "httptools" version = "0.6.4" @@ -4071,8 +3971,6 @@ groups = ["main"] markers = "extra == \"dlt\"" files = [ {file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"}, - {file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"}, - {file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"}, ] [package.dependencies] @@ -8184,7 +8082,7 @@ description = "Beautiful, Pythonic protocol buffers" optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\"" +markers = "extra == \"docs\"" files = [ {file = "proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66"}, {file = "proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012"}, @@ -8256,7 +8154,6 @@ files = [ {file = "psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4"}, {file = "psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067"}, {file = "psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e"}, - {file = "psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2"}, {file = "psycopg2-2.9.10-cp39-cp39-win32.whl", hash = "sha256:9d5b3b94b79a844a986d029eee38998232451119ad653aea42bb9220a8c5066b"}, {file = "psycopg2-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:88138c8dedcbfa96408023ea2b0c369eda40fe5d75002c0964c78f46f11fa442"}, {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, @@ -8318,7 +8215,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -8528,7 +8424,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"deepeval\" or extra == \"chromadb\" or extra == \"docs\"" files = [ {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, @@ -8541,7 +8437,7 @@ description = "A collection of ASN.1-based protocols modules" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"deepeval\" or extra == \"chromadb\" or extra == \"docs\"" files = [ {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"}, {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"}, @@ -9430,13 +9326,6 @@ optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, - {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, - {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, - {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, @@ -10249,7 +10138,7 @@ description = "Pure-Python RSA implementation" optional = true python-versions = "<4,>=3.6" groups = ["main"] -markers = "extra == \"gemini\" or extra == \"docs\" or extra == \"deepeval\" or extra == \"chromadb\"" +markers = "extra == \"deepeval\" or extra == \"chromadb\" or extra == \"docs\"" files = [ {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"}, {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"}, @@ -12003,19 +11892,6 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] -[[package]] -name = "uritemplate" -version = "4.2.0" -description = "Implementation of RFC 6570 URI Templates" -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"gemini\"" -files = [ - {file = "uritemplate-4.2.0-py3-none-any.whl", hash = "sha256:962201ba1c4edcab02e60f9a0d3821e82dfc5d2d6662a21abd533879bdb8a686"}, - {file = "uritemplate-4.2.0.tar.gz", hash = "sha256:480c2ed180878955863323eea31b0ede668795de182617fef9c6ca09e6ec9d0e"}, -] - [[package]] name = "urllib3" version = "2.3.0" @@ -12882,7 +12758,6 @@ dlt = ["dlt"] docs = ["unstructured"] evals = ["gdown", "matplotlib", "pandas", "plotly", "scikit-learn"] falkordb = ["falkordb"] -gemini = ["google-generativeai"] graphiti = ["graphiti-core"] groq = ["groq"] huggingface = ["transformers"] @@ -12901,4 +12776,4 @@ posthog = ["posthog"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<=3.13" -content-hash = "8c5ebda99705d8bcaf9d14e244a556d2627106d58d0fdce25318b4c1a647197a" +content-hash = "c76267fe685339b5b5665342c81850a3e891cadaf760178bf3b04058f35b1014" diff --git a/pyproject.toml b/pyproject.toml index e2c523a2d..572e9287d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,6 @@ langchain = [ "langchain_text_splitters>=0.3.2,<1.0.0", ] llama-index = ["llama-index-core>=0.12.11,<0.13"] -gemini = ["google-generativeai>=0.8.4,<0.9"] huggingface = ["transformers>=4.46.3,<5"] ollama = ["transformers>=4.46.3,<5"] mistral = ["mistral-common>=1.5.2,<2"] diff --git a/uv.lock b/uv.lock index 4e46e1d56..a36416ce5 100644 --- a/uv.lock +++ b/uv.lock @@ -962,9 +962,6 @@ evals = [ falkordb = [ { name = "falkordb" }, ] -gemini = [ - { name = "google-generativeai" }, -] graphiti = [ { name = "graphiti-core" }, ] @@ -1038,7 +1035,6 @@ requires-dist = [ { name = "filetype", specifier = ">=1.2.0,<2.0.0" }, { name = "gdown", marker = "extra == 'evals'", specifier = ">=5.2.0,<6" }, { name = "gitpython", marker = "extra == 'dev'", specifier = ">=3.1.43,<4" }, - { name = "google-generativeai", marker = "extra == 'gemini'", specifier = ">=0.8.4,<0.9" }, { name = "graphiti-core", marker = "extra == 'graphiti'", specifier = ">=0.7.0,<0.8" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.8.0,<1.0.0" }, { name = "gunicorn", specifier = ">=20.1.0,<24" }, @@ -1108,7 +1104,7 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.34.0,<1.0.0" }, { name = "websockets", specifier = ">=15.0.1,<16.0.0" }, ] -provides-extras = ["api", "distributed", "neo4j", "neptune", "postgres", "postgres-binary", "notebook", "langchain", "llama-index", "gemini", "huggingface", "ollama", "mistral", "anthropic", "deepeval", "posthog", "falkordb", "groq", "chromadb", "docs", "codegraph", "evals", "graphiti", "aws", "dlt", "baml", "dev", "debug", "monitoring"] +provides-extras = ["api", "distributed", "neo4j", "neptune", "postgres", "postgres-binary", "notebook", "langchain", "llama-index", "huggingface", "ollama", "mistral", "anthropic", "deepeval", "posthog", "falkordb", "groq", "chromadb", "docs", "codegraph", "evals", "graphiti", "aws", "dlt", "baml", "dev", "debug", "monitoring"] [[package]] name = "colorama" @@ -2176,21 +2172,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/94/c6ff3388b8e3225a014e55aed957188639aa0966443e0408d38f0c9614a7/giturlparse-0.12.0-py2.py3-none-any.whl", hash = "sha256:412b74f2855f1da2fefa89fd8dde62df48476077a72fc19b62039554d27360eb", size = 15752, upload-time = "2023-09-24T07:22:35.465Z" }, ] -[[package]] -name = "google-ai-generativelanguage" -version = "0.6.15" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "proto-plus" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/11/d1/48fe5d7a43d278e9f6b5ada810b0a3530bbeac7ed7fcbcd366f932f05316/google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3", size = 1375443, upload-time = "2025-01-13T21:50:47.459Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/a3/67b8a6ff5001a1d8864922f2d6488dc2a14367ceb651bc3f09a947f2f306/google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c", size = 1327356, upload-time = "2025-01-13T21:50:44.174Z" }, -] - [[package]] name = "google-api-core" version = "2.25.1" @@ -2213,22 +2194,6 @@ grpc = [ { name = "grpcio-status" }, ] -[[package]] -name = "google-api-python-client" -version = "2.183.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-auth-httplib2" }, - { name = "httplib2" }, - { name = "uritemplate" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fa/1f/49a2c83fc6dcd8b127cc9efbecf7d5fc36109c2028ba22ed6cb4d072fca4/google_api_python_client-2.183.0.tar.gz", hash = "sha256:abae37e04fecf719388e5c02f707ed9cdf952f10b217c79a3e76c636762e3ea9", size = 13645623, upload-time = "2025-09-23T22:27:00.854Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/06/1974f937172854bc7622eff5c2390f33542ceb843f305922922c8f5f7f17/google_api_python_client-2.183.0-py3-none-any.whl", hash = "sha256:2005b6e86c27be1db1a43f43e047a0f8e004159f3cceddecb08cf1624bddba31", size = 14214837, upload-time = "2025-09-23T22:26:57.758Z" }, -] - [[package]] name = "google-auth" version = "2.41.0" @@ -2243,19 +2208,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/ff/a1c426fc9bea7268230bf92340da7d112fae18cf946cafe13ab17d14e6ee/google_auth-2.41.0-py2.py3-none-any.whl", hash = "sha256:d8bed9b53ab63b7b0374656b8e1bef051f95bb14ecc0cf21ba49de7911d62e09", size = 221168, upload-time = "2025-09-29T21:36:33.925Z" }, ] -[[package]] -name = "google-auth-httplib2" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "httplib2" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/56/be/217a598a818567b28e859ff087f347475c807a5649296fb5a817c58dacef/google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05", size = 10842, upload-time = "2023-12-12T17:40:30.722Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/8a/fe34d2f3f9470a27b01c9e76226965863f153d5fbe276f83608562e49c04/google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d", size = 9253, upload-time = "2023-12-12T17:40:13.055Z" }, -] - [[package]] name = "google-cloud-vision" version = "3.10.2" @@ -2290,24 +2242,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/c3/12c1f386184d2fcd694b73adeabc3714a5ed65c01cc183b4e3727a26b9d1/google_genai-1.39.1-py3-none-any.whl", hash = "sha256:6ca36c7e40db6fcba7049dfdd102c86da326804f34403bd7d90fa613a45e5a78", size = 244681, upload-time = "2025-09-26T20:56:17.527Z" }, ] -[[package]] -name = "google-generativeai" -version = "0.8.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-ai-generativelanguage" }, - { name = "google-api-core" }, - { name = "google-api-python-client" }, - { name = "google-auth" }, - { name = "protobuf" }, - { name = "pydantic" }, - { name = "tqdm" }, - { name = "typing-extensions" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/6e/40/c42ff9ded9f09ec9392879a8e6538a00b2dc185e834a3392917626255419/google_generativeai-0.8.5-py3-none-any.whl", hash = "sha256:22b420817fb263f8ed520b33285f45976d5b21e904da32b80d4fd20c055123a2", size = 155427, upload-time = "2025-04-17T00:40:00.67Z" }, -] - [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -2591,18 +2525,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, ] -[[package]] -name = "httplib2" -version = "0.31.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyparsing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/77/6653db69c1f7ecfe5e3f9726fdadc981794656fcd7d98c4209fecfea9993/httplib2-0.31.0.tar.gz", hash = "sha256:ac7ab497c50975147d4f7b1ade44becc7df2f8954d42b38b3d69c515f531135c", size = 250759, upload-time = "2025-09-11T12:16:03.403Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl", hash = "sha256:b9cd78abea9b4e43a7714c6e0f8b6b8561a6fc1e95d5dbd367f5bf0ef35f5d24", size = 91148, upload-time = "2025-09-11T12:16:01.803Z" }, -] - [[package]] name = "httptools" version = "0.6.4" @@ -4869,7 +4791,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten'" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -4880,7 +4802,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -4907,9 +4829,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'emscripten'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten'" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -4920,7 +4842,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -8280,7 +8202,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform != 'emscripten'" }, + { name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, @@ -8565,15 +8487,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/00/3fca040d7cf8a32776d3d81a00c8ee7457e00f80c649f1e4a863c8321ae9/uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363", size = 11140, upload-time = "2023-06-21T01:49:03.467Z" }, ] -[[package]] -name = "uritemplate" -version = "4.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/98/60/f174043244c5306c9988380d2cb10009f91563fc4b31293d27e17201af56/uritemplate-4.2.0.tar.gz", hash = "sha256:480c2ed180878955863323eea31b0ede668795de182617fef9c6ca09e6ec9d0e", size = 33267, upload-time = "2025-06-02T15:12:06.318Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl", hash = "sha256:962201ba1c4edcab02e60f9a0d3821e82dfc5d2d6662a21abd533879bdb8a686", size = 11488, upload-time = "2025-06-02T15:12:03.405Z" }, -] - [[package]] name = "urllib3" version = "2.3.0"