diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index ccc76cbcf..4623c1951 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -26,6 +26,8 @@ from .neo4j_metrics_utils import ( get_size_of_connected_components, count_self_loops, ) +from .deadlock_retry import deadlock_retry + logger = get_logger("Neo4jAdapter", level=ERROR) @@ -49,19 +51,16 @@ class Neo4jAdapter(GraphDBInterface): async with self.driver.session() as session: yield session + @deadlock_retry async def query( self, query: str, params: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - try: - async with self.get_session() as session: - result = await session.run(query, parameters=params) - data = await result.data() - return data - except Neo4jError as error: - logger.error("Neo4j query error: %s", error, exc_info=True) - raise error + async with self.get_session() as session: + result = await session.run(query, parameters=params) + data = await result.data() + return data async def has_node(self, node_id: str) -> bool: results = self.query( diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py b/cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py new file mode 100644 index 000000000..c1591ed92 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py @@ -0,0 +1,64 @@ +import asyncio +from functools import wraps + +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.utils.calculate_backoff import calculate_backoff + + +logger = get_logger("deadlock_retry") + + +def deadlock_retry(max_retries=5): + """ + Decorator that automatically retries an asynchronous function when rate limit errors occur. + + This decorator implements an exponential backoff strategy with jitter + to handle rate limit errors efficiently. + + Args: + max_retries: Maximum number of retry attempts. + initial_backoff: Initial backoff time in seconds. + backoff_factor: Multiplier for exponential backoff. + jitter: Jitter factor to avoid the thundering herd problem. + + Returns: + The decorated async function. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + from neo4j.exceptions import Neo4jError, DatabaseUnavailable + + attempt = 0 + + async def wait(): + backoff_time = calculate_backoff(attempt) + logger.warning( + f"Neo4j rate limit hit, retrying in {backoff_time:.2f}s " + f"Attempt {attempt}/{max_retries}" + ) + await asyncio.sleep(backoff_time) + + while attempt <= max_retries: + try: + attempt += 1 + return await func(*args, **kwargs) + except Neo4jError as error: + if attempt > max_retries: + raise # Re-raise the original error + + error_str = str(error) + if "DeadlockDetected" in error_str or "Neo.TransientError" in error_str: + await wait() + else: + raise # Re-raise the original error + except DatabaseUnavailable as error: + if attempt >= max_retries: + raise # Re-raise the original error + + await wait() + + return wrapper + + return decorator diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index f31e18308..155179570 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -23,6 +23,10 @@ class LLMConfig(BaseSettings): embedding_rate_limit_requests: int = 60 embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) + fallback_api_key: str = None + fallback_endpoint: str = None + fallback_model: str = None + model_config = SettingsConfigDict(env_file=".env", extra="allow") @model_validator(mode="after") @@ -97,6 +101,9 @@ class LLMConfig(BaseSettings): "embedding_rate_limit_enabled": self.embedding_rate_limit_enabled, "embedding_rate_limit_requests": self.embedding_rate_limit_requests, "embedding_rate_limit_interval": self.embedding_rate_limit_interval, + "fallback_api_key": self.fallback_api_key, + "fallback_endpoint": self.fallback_endpoint, + "fallback_model": self.fallback_model, } diff --git a/cognee/infrastructure/llm/exceptions.py b/cognee/infrastructure/llm/exceptions.py new file mode 100644 index 000000000..af3aa5832 --- /dev/null +++ b/cognee/infrastructure/llm/exceptions.py @@ -0,0 +1,5 @@ +from cognee.exceptions.exceptions import CriticalError + + +class ContentPolicyFilterError(CriticalError): + pass diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index 07085c076..5268626c7 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -1,13 +1,13 @@ """Adapter for Generic API LLM provider API""" -from typing import Type - -from pydantic import BaseModel -import instructor -from cognee.infrastructure.llm.llm_interface import LLMInterface -from cognee.infrastructure.llm.config import get_llm_config -from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async import litellm +import instructor +from typing import Type +from pydantic import BaseModel + +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.llm_interface import LLMInterface +from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async class GenericAPIAdapter(LLMInterface): @@ -17,13 +17,27 @@ class GenericAPIAdapter(LLMInterface): model: str api_key: str - def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): + def __init__( + self, + endpoint, + api_key: str, + model: str, + name: str, + max_tokens: int, + fallback_model: str = None, + fallback_api_key: str = None, + fallback_endpoint: str = None, + ): self.name = name self.model = model self.api_key = api_key self.endpoint = endpoint self.max_tokens = max_tokens + self.fallback_model = fallback_model + self.fallback_api_key = fallback_api_key + self.fallback_endpoint = fallback_endpoint + self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key ) @@ -35,20 +49,50 @@ class GenericAPIAdapter(LLMInterface): ) -> BaseModel: """Generate a response from a user query.""" - return await self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - { - "role": "system", - "content": system_prompt, - }, - ], - max_retries=5, - api_base=self.endpoint, - response_model=response_model, - ) + try: + return await self.aclient.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + api_base=self.endpoint, + response_model=response_model, + ) + except litellm.exceptions.ContentPolicyViolationError: + if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint): + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) + + try: + return await self.aclient.chat.completions.create( + model=self.fallback_model, + messages=[ + { + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + api_key=self.fallback_api_key, + api_base=self.fallback_endpoint, + response_model=response_model, + ) + except litellm.exceptions.ContentPolicyViolationError: + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 4a095d179..1183143e7 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -45,6 +45,9 @@ def get_llm_client(): transcription_model=llm_config.transcription_model, max_tokens=max_tokens, streaming=llm_config.llm_streaming, + fallback_api_key=llm_config.fallback_api_key, + fallback_endpoint=llm_config.fallback_endpoint, + fallback_model=llm_config.fallback_model, ) elif provider == LLMProvider.OLLAMA: @@ -78,6 +81,9 @@ def get_llm_client(): llm_config.llm_model, "Custom", max_tokens=max_tokens, + fallback_api_key=llm_config.fallback_api_key, + fallback_endpoint=llm_config.fallback_endpoint, + fallback_model=llm_config.fallback_model, ) elif provider == LLMProvider.GEMINI: diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index 417af85df..4b59595b7 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -1,17 +1,18 @@ import os import base64 -from pathlib import Path from typing import Type import litellm import instructor from pydantic import BaseModel +from openai import ContentFilterFinishReasonError -from cognee.modules.data.processing.document_types.open_data_file import open_data_file -from cognee.shared.data_models import MonitoringTool from cognee.exceptions import InvalidValueError -from cognee.infrastructure.llm.llm_interface import LLMInterface +from cognee.shared.data_models import MonitoringTool from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.infrastructure.llm.llm_interface import LLMInterface +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.modules.data.processing.document_types.open_data_file import open_data_file from cognee.infrastructure.llm.rate_limiter import ( rate_limit_async, rate_limit_sync, @@ -45,6 +46,9 @@ class OpenAIAdapter(LLMInterface): transcription_model: str, max_tokens: int, streaming: bool = False, + fallback_model: str = None, + fallback_api_key: str = None, + fallback_endpoint: str = None, ): self.aclient = instructor.from_litellm(litellm.acompletion) self.client = instructor.from_litellm(litellm.completion) @@ -56,6 +60,10 @@ class OpenAIAdapter(LLMInterface): self.max_tokens = max_tokens self.streaming = streaming + self.fallback_model = fallback_model + self.fallback_api_key = fallback_api_key + self.fallback_endpoint = fallback_endpoint + @observe(as_type="generation") @sleep_and_retry_async() @rate_limit_async @@ -64,25 +72,55 @@ class OpenAIAdapter(LLMInterface): ) -> BaseModel: """Generate a response from a user query.""" - return await self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - { - "role": "system", - "content": system_prompt, - }, - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - response_model=response_model, - max_retries=self.MAX_RETRIES, - ) + try: + return await self.aclient.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, + { + "role": "system", + "content": system_prompt, + }, + ], + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + response_model=response_model, + max_retries=self.MAX_RETRIES, + ) + except ContentFilterFinishReasonError: + if not (self.fallback_model and self.fallback_api_key): + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) + + try: + return await self.aclient.chat.completions.create( + model=self.fallback_model, + messages=[ + { + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, + { + "role": "system", + "content": system_prompt, + }, + ], + api_key=self.fallback_api_key, + # api_base=self.fallback_endpoint, + response_model=response_model, + max_retries=self.MAX_RETRIES, + ) + except ContentFilterFinishReasonError: + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) @observe @sleep_and_retry_sync() diff --git a/cognee/infrastructure/utils/calculate_backoff.py b/cognee/infrastructure/utils/calculate_backoff.py new file mode 100644 index 000000000..bc6372923 --- /dev/null +++ b/cognee/infrastructure/utils/calculate_backoff.py @@ -0,0 +1,30 @@ +import random + +# Default retry settings +DEFAULT_MAX_RETRIES = 5 +DEFAULT_INITIAL_BACKOFF = 1.0 # seconds +DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier +DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd + + +def calculate_backoff( + attempt, + initial_backoff=DEFAULT_INITIAL_BACKOFF, + backoff_factor=DEFAULT_BACKOFF_FACTOR, + jitter=DEFAULT_JITTER, +): + """ + Calculate the backoff time for a retry attempt with jitter. + + Args: + attempt: The current retry attempt (0-based). + initial_backoff: The initial backoff time in seconds. + backoff_factor: The multiplier for exponential backoff. + jitter: The jitter factor to avoid thundering herd. + + Returns: + float: The backoff time in seconds. + """ + backoff = initial_backoff * (backoff_factor**attempt) + jitter_amount = backoff * jitter + return backoff + random.uniform(-jitter_amount, jitter_amount) diff --git a/cognee/tests/unit/infrastructure/databases/graph/neo4j_deadlock_test.py b/cognee/tests/unit/infrastructure/databases/graph/neo4j_deadlock_test.py new file mode 100644 index 000000000..f73b7b60b --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/graph/neo4j_deadlock_test.py @@ -0,0 +1,52 @@ +import pytest +import asyncio +from unittest.mock import MagicMock +from neo4j.exceptions import Neo4jError +from cognee.infrastructure.databases.graph.neo4j_driver.deadlock_retry import deadlock_retry + + +async def test_deadlock_retry_errored(): + mock_return = asyncio.Future() + mock_return.set_result(True) + mock_function = MagicMock( + side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return] + ) + + wrapped_function = deadlock_retry(max_retries=1)(mock_function) + + with pytest.raises(Neo4jError): + await wrapped_function() + + +async def test_deadlock_retry(): + mock_return = asyncio.Future() + mock_return.set_result(True) + mock_function = MagicMock(side_effect=[Neo4jError("DeadlockDetected"), mock_return]) + + wrapped_function = deadlock_retry(max_retries=2)(mock_function) + + result = await wrapped_function() + assert result == True, "Function should have succeded on second time" + + +async def test_deadlock_retry_exhaustive(): + mock_return = asyncio.Future() + mock_return.set_result(True) + mock_function = MagicMock( + side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return] + ) + + wrapped_function = deadlock_retry(max_retries=2)(mock_function) + + result = await wrapped_function() + assert result == True, "Function should have succeded on second time" + + +if __name__ == "__main__": + + async def main(): + await test_deadlock_retry() + await test_deadlock_retry_errored() + await test_deadlock_retry_exhaustive() + + asyncio.run(main()) diff --git a/modal_deployment.py b/modal_deployment.py deleted file mode 100644 index 3c9d0e05a..000000000 --- a/modal_deployment.py +++ /dev/null @@ -1,96 +0,0 @@ -import modal -import os -from cognee.shared.logging_utils import get_logger -import asyncio -import cognee -import signal - - -from cognee.modules.search.types import SearchType - -app = modal.App("cognee-runner") - -image = ( - modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False) - .add_local_file("pyproject.toml", remote_path="/root/pyproject.toml", copy=True) - .add_local_file("poetry.lock", remote_path="/root/poetry.lock", copy=True) - .env({"ENV": os.getenv("ENV"), "LLM_API_KEY": os.getenv("LLM_API_KEY")}) - .poetry_install_from_file(poetry_pyproject_toml="pyproject.toml") - .pip_install("protobuf", "h2") -) - - -@app.function(image=image, max_containers=4) -async def entry(text: str, query: str): - logger = get_logger() - logger.info("Initializing Cognee") - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await cognee.add(text) - await cognee.cognify() - search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=query) - - return { - "text": text, - "query": query, - "answer": search_results[0] if search_results else None, - } - - -@app.local_entrypoint() -async def main(): - logger = get_logger() - text_queries = [ - { - "text": "NASA's Artemis program aims to return humans to the Moon by 2026, focusing on sustainable exploration and preparing for future Mars missions.", - "query": "When does NASA plan to return humans to the Moon under the Artemis program?", - }, - { - "text": "According to a 2022 UN report, global food waste amounts to approximately 931 million tons annually, with households contributing 61% of the total.", - "query": "How much food waste do households contribute annually according to the 2022 UN report?", - }, - { - "text": "The 2021 census data revealed that Tokyo's population reached 14 million, reflecting a 2.1% increase compared to the previous census conducted in 2015.", - "query": "What was Tokyo's population according to the 2021 census data?", - }, - { - "text": "A recent study published in the Journal of Nutrition found that consuming 30 grams of almonds daily can lower LDL cholesterol levels by 7% over a 12-week period.", - "query": "How much can daily almond consumption lower LDL cholesterol according to the study?", - }, - { - "text": "Amazon's Prime membership grew to 200 million subscribers in 2023, marking a 10% increase from the previous year, driven by exclusive content and faster delivery options.", - "query": "How many Prime members did Amazon have in 2023?", - }, - { - "text": "A new report by the International Energy Agency states that global renewable energy capacity increased by 295 gigawatts in 2022, primarily driven by solar and wind power expansion.", - "query": "By how much did global renewable energy capacity increase in 2022 according to the report?", - }, - { - "text": "The World Health Organization reported in 2023 that the global life expectancy has risen to 73.4 years, an increase of 5.5 years since the year 2000.", - "query": "What is the current global life expectancy according to the WHO's 2023 report?", - }, - { - "text": "The FIFA World Cup 2022 held in Qatar attracted a record-breaking audience of 5 billion people across various digital and traditional broadcasting platforms.", - "query": "How many people watched the FIFA World Cup 2022?", - }, - { - "text": "The European Space Agency's JUICE mission, launched in 2023, aims to explore Jupiter's icy moons, including Ganymede, Europa, and Callisto, over the next decade.", - "query": "Which moons is the JUICE mission set to explore?", - }, - { - "text": "According to a report by the International Labour Organization, the global unemployment rate in 2023 was estimated at 5.4%, reflecting a slight decrease compared to the previous year.", - "query": "What was the global unemployment rate in 2023 according to the ILO?", - }, - ] - - tasks = [entry.remote.aio(item["text"], item["query"]) for item in text_queries] - - results = await asyncio.gather(*tasks) - - logger.info("Final Results:") - - for result in results: - logger.info(result) - logger.info("----") - - os.kill(os.getpid(), signal.SIGTERM)