fix: add deadlock retry for neo4j and content policy filter allback for openai models

This commit is contained in:
Boris Arzentar 2025-04-24 17:21:23 +02:00
parent 46cc2a128c
commit 36507502ef
10 changed files with 301 additions and 152 deletions

View file

@ -26,6 +26,8 @@ from .neo4j_metrics_utils import (
get_size_of_connected_components,
count_self_loops,
)
from .deadlock_retry import deadlock_retry
logger = get_logger("Neo4jAdapter", level=ERROR)
@ -49,19 +51,16 @@ class Neo4jAdapter(GraphDBInterface):
async with self.driver.session() as session:
yield session
@deadlock_retry
async def query(
self,
query: str,
params: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
try:
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
return data
async def has_node(self, node_id: str) -> bool:
results = self.query(

View file

@ -0,0 +1,64 @@
import asyncio
from functools import wraps
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("deadlock_retry")
def deadlock_retry(max_retries=5):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.
This decorator implements an exponential backoff strategy with jitter
to handle rate limit errors efficiently.
Args:
max_retries: Maximum number of retry attempts.
initial_backoff: Initial backoff time in seconds.
backoff_factor: Multiplier for exponential backoff.
jitter: Jitter factor to avoid the thundering herd problem.
Returns:
The decorated async function.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
from neo4j.exceptions import Neo4jError, DatabaseUnavailable
attempt = 0
async def wait():
backoff_time = calculate_backoff(attempt)
logger.warning(
f"Neo4j rate limit hit, retrying in {backoff_time:.2f}s "
f"Attempt {attempt}/{max_retries}"
)
await asyncio.sleep(backoff_time)
while attempt <= max_retries:
try:
attempt += 1
return await func(*args, **kwargs)
except Neo4jError as error:
if attempt > max_retries:
raise # Re-raise the original error
error_str = str(error)
if "DeadlockDetected" in error_str or "Neo.TransientError" in error_str:
await wait()
else:
raise # Re-raise the original error
except DatabaseUnavailable as error:
if attempt >= max_retries:
raise # Re-raise the original error
await wait()
return wrapper
return decorator

View file

@ -23,6 +23,10 @@ class LLMConfig(BaseSettings):
embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
fallback_api_key: str = None
fallback_endpoint: str = None
fallback_model: str = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@model_validator(mode="after")
@ -97,6 +101,9 @@ class LLMConfig(BaseSettings):
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
"fallback_api_key": self.fallback_api_key,
"fallback_endpoint": self.fallback_endpoint,
"fallback_model": self.fallback_model,
}

View file

@ -0,0 +1,5 @@
from cognee.exceptions.exceptions import CriticalError
class ContentPolicyFilterError(CriticalError):
pass

View file

@ -1,13 +1,13 @@
"""Adapter for Generic API LLM provider API"""
from typing import Type
from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
class GenericAPIAdapter(LLMInterface):
@ -17,13 +17,27 @@ class GenericAPIAdapter(LLMInterface):
model: str
api_key: str
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_tokens: int,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
@ -35,20 +49,50 @@ class GenericAPIAdapter(LLMInterface):
) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_base=self.endpoint,
response_model=response_model,
)
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_base=self.endpoint,
response_model=response_model,
)
except litellm.exceptions.ContentPolicyViolationError:
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except litellm.exceptions.ContentPolicyViolationError:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -45,6 +45,9 @@ def get_llm_client():
transcription_model=llm_config.transcription_model,
max_tokens=max_tokens,
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
)
elif provider == LLMProvider.OLLAMA:
@ -78,6 +81,9 @@ def get_llm_client():
llm_config.llm_model,
"Custom",
max_tokens=max_tokens,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
)
elif provider == LLMProvider.GEMINI:

View file

@ -1,17 +1,18 @@
import os
import base64
from pathlib import Path
from typing import Type
import litellm
import instructor
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.shared.data_models import MonitoringTool
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
@ -45,6 +46,9 @@ class OpenAIAdapter(LLMInterface):
transcription_model: str,
max_tokens: int,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
@ -56,6 +60,10 @@ class OpenAIAdapter(LLMInterface):
self.max_tokens = max_tokens
self.streaming = streaming
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
@ -64,25 +72,55 @@ class OpenAIAdapter(LLMInterface):
) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
except ContentFilterFinishReasonError:
if not (self.fallback_model and self.fallback_api_key):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.fallback_api_key,
# api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
except ContentFilterFinishReasonError:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
@observe
@sleep_and_retry_sync()

View file

@ -0,0 +1,30 @@
import random
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
def calculate_backoff(
attempt,
initial_backoff=DEFAULT_INITIAL_BACKOFF,
backoff_factor=DEFAULT_BACKOFF_FACTOR,
jitter=DEFAULT_JITTER,
):
"""
Calculate the backoff time for a retry attempt with jitter.
Args:
attempt: The current retry attempt (0-based).
initial_backoff: The initial backoff time in seconds.
backoff_factor: The multiplier for exponential backoff.
jitter: The jitter factor to avoid thundering herd.
Returns:
float: The backoff time in seconds.
"""
backoff = initial_backoff * (backoff_factor**attempt)
jitter_amount = backoff * jitter
return backoff + random.uniform(-jitter_amount, jitter_amount)

View file

@ -0,0 +1,52 @@
import pytest
import asyncio
from unittest.mock import MagicMock
from neo4j.exceptions import Neo4jError
from cognee.infrastructure.databases.graph.neo4j_driver.deadlock_retry import deadlock_retry
async def test_deadlock_retry_errored():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
)
wrapped_function = deadlock_retry(max_retries=1)(mock_function)
with pytest.raises(Neo4jError):
await wrapped_function()
async def test_deadlock_retry():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(side_effect=[Neo4jError("DeadlockDetected"), mock_return])
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
result = await wrapped_function()
assert result == True, "Function should have succeded on second time"
async def test_deadlock_retry_exhaustive():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
)
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
result = await wrapped_function()
assert result == True, "Function should have succeded on second time"
if __name__ == "__main__":
async def main():
await test_deadlock_retry()
await test_deadlock_retry_errored()
await test_deadlock_retry_exhaustive()
asyncio.run(main())

View file

@ -1,96 +0,0 @@
import modal
import os
from cognee.shared.logging_utils import get_logger
import asyncio
import cognee
import signal
from cognee.modules.search.types import SearchType
app = modal.App("cognee-runner")
image = (
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
.add_local_file("pyproject.toml", remote_path="/root/pyproject.toml", copy=True)
.add_local_file("poetry.lock", remote_path="/root/poetry.lock", copy=True)
.env({"ENV": os.getenv("ENV"), "LLM_API_KEY": os.getenv("LLM_API_KEY")})
.poetry_install_from_file(poetry_pyproject_toml="pyproject.toml")
.pip_install("protobuf", "h2")
)
@app.function(image=image, max_containers=4)
async def entry(text: str, query: str):
logger = get_logger()
logger.info("Initializing Cognee")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text)
await cognee.cognify()
search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=query)
return {
"text": text,
"query": query,
"answer": search_results[0] if search_results else None,
}
@app.local_entrypoint()
async def main():
logger = get_logger()
text_queries = [
{
"text": "NASA's Artemis program aims to return humans to the Moon by 2026, focusing on sustainable exploration and preparing for future Mars missions.",
"query": "When does NASA plan to return humans to the Moon under the Artemis program?",
},
{
"text": "According to a 2022 UN report, global food waste amounts to approximately 931 million tons annually, with households contributing 61% of the total.",
"query": "How much food waste do households contribute annually according to the 2022 UN report?",
},
{
"text": "The 2021 census data revealed that Tokyo's population reached 14 million, reflecting a 2.1% increase compared to the previous census conducted in 2015.",
"query": "What was Tokyo's population according to the 2021 census data?",
},
{
"text": "A recent study published in the Journal of Nutrition found that consuming 30 grams of almonds daily can lower LDL cholesterol levels by 7% over a 12-week period.",
"query": "How much can daily almond consumption lower LDL cholesterol according to the study?",
},
{
"text": "Amazon's Prime membership grew to 200 million subscribers in 2023, marking a 10% increase from the previous year, driven by exclusive content and faster delivery options.",
"query": "How many Prime members did Amazon have in 2023?",
},
{
"text": "A new report by the International Energy Agency states that global renewable energy capacity increased by 295 gigawatts in 2022, primarily driven by solar and wind power expansion.",
"query": "By how much did global renewable energy capacity increase in 2022 according to the report?",
},
{
"text": "The World Health Organization reported in 2023 that the global life expectancy has risen to 73.4 years, an increase of 5.5 years since the year 2000.",
"query": "What is the current global life expectancy according to the WHO's 2023 report?",
},
{
"text": "The FIFA World Cup 2022 held in Qatar attracted a record-breaking audience of 5 billion people across various digital and traditional broadcasting platforms.",
"query": "How many people watched the FIFA World Cup 2022?",
},
{
"text": "The European Space Agency's JUICE mission, launched in 2023, aims to explore Jupiter's icy moons, including Ganymede, Europa, and Callisto, over the next decade.",
"query": "Which moons is the JUICE mission set to explore?",
},
{
"text": "According to a report by the International Labour Organization, the global unemployment rate in 2023 was estimated at 5.4%, reflecting a slight decrease compared to the previous year.",
"query": "What was the global unemployment rate in 2023 according to the ILO?",
},
]
tasks = [entry.remote.aio(item["text"], item["query"]) for item in text_queries]
results = await asyncio.gather(*tasks)
logger.info("Final Results:")
for result in results:
logger.info(result)
logger.info("----")
os.kill(os.getpid(), signal.SIGTERM)