fix: add deadlock retry for neo4j and content policy filter allback for openai models

This commit is contained in:
Boris Arzentar 2025-04-24 17:21:23 +02:00
parent 46cc2a128c
commit 36507502ef
10 changed files with 301 additions and 152 deletions

View file

@ -26,6 +26,8 @@ from .neo4j_metrics_utils import (
get_size_of_connected_components, get_size_of_connected_components,
count_self_loops, count_self_loops,
) )
from .deadlock_retry import deadlock_retry
logger = get_logger("Neo4jAdapter", level=ERROR) logger = get_logger("Neo4jAdapter", level=ERROR)
@ -49,19 +51,16 @@ class Neo4jAdapter(GraphDBInterface):
async with self.driver.session() as session: async with self.driver.session() as session:
yield session yield session
@deadlock_retry
async def query( async def query(
self, self,
query: str, query: str,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
try: async with self.get_session() as session:
async with self.get_session() as session: result = await session.run(query, parameters=params)
result = await session.run(query, parameters=params) data = await result.data()
data = await result.data() return data
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
results = self.query( results = self.query(

View file

@ -0,0 +1,64 @@
import asyncio
from functools import wraps
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("deadlock_retry")
def deadlock_retry(max_retries=5):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.
This decorator implements an exponential backoff strategy with jitter
to handle rate limit errors efficiently.
Args:
max_retries: Maximum number of retry attempts.
initial_backoff: Initial backoff time in seconds.
backoff_factor: Multiplier for exponential backoff.
jitter: Jitter factor to avoid the thundering herd problem.
Returns:
The decorated async function.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
from neo4j.exceptions import Neo4jError, DatabaseUnavailable
attempt = 0
async def wait():
backoff_time = calculate_backoff(attempt)
logger.warning(
f"Neo4j rate limit hit, retrying in {backoff_time:.2f}s "
f"Attempt {attempt}/{max_retries}"
)
await asyncio.sleep(backoff_time)
while attempt <= max_retries:
try:
attempt += 1
return await func(*args, **kwargs)
except Neo4jError as error:
if attempt > max_retries:
raise # Re-raise the original error
error_str = str(error)
if "DeadlockDetected" in error_str or "Neo.TransientError" in error_str:
await wait()
else:
raise # Re-raise the original error
except DatabaseUnavailable as error:
if attempt >= max_retries:
raise # Re-raise the original error
await wait()
return wrapper
return decorator

View file

@ -23,6 +23,10 @@ class LLMConfig(BaseSettings):
embedding_rate_limit_requests: int = 60 embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
fallback_api_key: str = None
fallback_endpoint: str = None
fallback_model: str = None
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@model_validator(mode="after") @model_validator(mode="after")
@ -97,6 +101,9 @@ class LLMConfig(BaseSettings):
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled, "embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests, "embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval, "embedding_rate_limit_interval": self.embedding_rate_limit_interval,
"fallback_api_key": self.fallback_api_key,
"fallback_endpoint": self.fallback_endpoint,
"fallback_model": self.fallback_model,
} }

View file

@ -0,0 +1,5 @@
from cognee.exceptions.exceptions import CriticalError
class ContentPolicyFilterError(CriticalError):
pass

View file

@ -1,13 +1,13 @@
"""Adapter for Generic API LLM provider API""" """Adapter for Generic API LLM provider API"""
from typing import Type
from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
import litellm import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
@ -17,13 +17,27 @@ class GenericAPIAdapter(LLMInterface):
model: str model: str
api_key: str api_key: str
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_tokens: int,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.name = name self.name = name
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm( self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
) )
@ -35,20 +49,50 @@ class GenericAPIAdapter(LLMInterface):
) -> BaseModel: ) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return await self.aclient.chat.completions.create( try:
model=self.model, return await self.aclient.chat.completions.create(
messages=[ model=self.model,
{ messages=[
"role": "user", {
"content": f"""Use the given format to "role": "user",
extract information from the following input: {text_input}. """, "content": f"""Use the given format to
}, extract information from the following input: {text_input}. """,
{ },
"role": "system", {
"content": system_prompt, "role": "system",
}, "content": system_prompt,
], },
max_retries=5, ],
api_base=self.endpoint, max_retries=5,
response_model=response_model, api_base=self.endpoint,
) response_model=response_model,
)
except litellm.exceptions.ContentPolicyViolationError:
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except litellm.exceptions.ContentPolicyViolationError:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -45,6 +45,9 @@ def get_llm_client():
transcription_model=llm_config.transcription_model, transcription_model=llm_config.transcription_model,
max_tokens=max_tokens, max_tokens=max_tokens,
streaming=llm_config.llm_streaming, streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
) )
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
@ -78,6 +81,9 @@ def get_llm_client():
llm_config.llm_model, llm_config.llm_model,
"Custom", "Custom",
max_tokens=max_tokens, max_tokens=max_tokens,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
) )
elif provider == LLMProvider.GEMINI: elif provider == LLMProvider.GEMINI:

View file

@ -1,17 +1,18 @@
import os import os
import base64 import base64
from pathlib import Path
from typing import Type from typing import Type
import litellm import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.shared.data_models import MonitoringTool
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.infrastructure.llm.rate_limiter import ( from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async, rate_limit_async,
rate_limit_sync, rate_limit_sync,
@ -45,6 +46,9 @@ class OpenAIAdapter(LLMInterface):
transcription_model: str, transcription_model: str,
max_tokens: int, max_tokens: int,
streaming: bool = False, streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
): ):
self.aclient = instructor.from_litellm(litellm.acompletion) self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion) self.client = instructor.from_litellm(litellm.completion)
@ -56,6 +60,10 @@ class OpenAIAdapter(LLMInterface):
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.streaming = streaming self.streaming = streaming
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation") @observe(as_type="generation")
@sleep_and_retry_async() @sleep_and_retry_async()
@rate_limit_async @rate_limit_async
@ -64,25 +72,55 @@ class OpenAIAdapter(LLMInterface):
) -> BaseModel: ) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return await self.aclient.chat.completions.create( try:
model=self.model, return await self.aclient.chat.completions.create(
messages=[ model=self.model,
{ messages=[
"role": "user", {
"content": f"""Use the given format to "role": "user",
extract information from the following input: {text_input}. """, "content": f"""Use the given format to
}, extract information from the following input: {text_input}. """,
{ },
"role": "system", {
"content": system_prompt, "role": "system",
}, "content": system_prompt,
], },
api_key=self.api_key, ],
api_base=self.endpoint, api_key=self.api_key,
api_version=self.api_version, api_base=self.endpoint,
response_model=response_model, api_version=self.api_version,
max_retries=self.MAX_RETRIES, response_model=response_model,
) max_retries=self.MAX_RETRIES,
)
except ContentFilterFinishReasonError:
if not (self.fallback_model and self.fallback_api_key):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.fallback_api_key,
# api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
except ContentFilterFinishReasonError:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
@observe @observe
@sleep_and_retry_sync() @sleep_and_retry_sync()

View file

@ -0,0 +1,30 @@
import random
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
def calculate_backoff(
attempt,
initial_backoff=DEFAULT_INITIAL_BACKOFF,
backoff_factor=DEFAULT_BACKOFF_FACTOR,
jitter=DEFAULT_JITTER,
):
"""
Calculate the backoff time for a retry attempt with jitter.
Args:
attempt: The current retry attempt (0-based).
initial_backoff: The initial backoff time in seconds.
backoff_factor: The multiplier for exponential backoff.
jitter: The jitter factor to avoid thundering herd.
Returns:
float: The backoff time in seconds.
"""
backoff = initial_backoff * (backoff_factor**attempt)
jitter_amount = backoff * jitter
return backoff + random.uniform(-jitter_amount, jitter_amount)

View file

@ -0,0 +1,52 @@
import pytest
import asyncio
from unittest.mock import MagicMock
from neo4j.exceptions import Neo4jError
from cognee.infrastructure.databases.graph.neo4j_driver.deadlock_retry import deadlock_retry
async def test_deadlock_retry_errored():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
)
wrapped_function = deadlock_retry(max_retries=1)(mock_function)
with pytest.raises(Neo4jError):
await wrapped_function()
async def test_deadlock_retry():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(side_effect=[Neo4jError("DeadlockDetected"), mock_return])
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
result = await wrapped_function()
assert result == True, "Function should have succeded on second time"
async def test_deadlock_retry_exhaustive():
mock_return = asyncio.Future()
mock_return.set_result(True)
mock_function = MagicMock(
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
)
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
result = await wrapped_function()
assert result == True, "Function should have succeded on second time"
if __name__ == "__main__":
async def main():
await test_deadlock_retry()
await test_deadlock_retry_errored()
await test_deadlock_retry_exhaustive()
asyncio.run(main())

View file

@ -1,96 +0,0 @@
import modal
import os
from cognee.shared.logging_utils import get_logger
import asyncio
import cognee
import signal
from cognee.modules.search.types import SearchType
app = modal.App("cognee-runner")
image = (
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
.add_local_file("pyproject.toml", remote_path="/root/pyproject.toml", copy=True)
.add_local_file("poetry.lock", remote_path="/root/poetry.lock", copy=True)
.env({"ENV": os.getenv("ENV"), "LLM_API_KEY": os.getenv("LLM_API_KEY")})
.poetry_install_from_file(poetry_pyproject_toml="pyproject.toml")
.pip_install("protobuf", "h2")
)
@app.function(image=image, max_containers=4)
async def entry(text: str, query: str):
logger = get_logger()
logger.info("Initializing Cognee")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(text)
await cognee.cognify()
search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=query)
return {
"text": text,
"query": query,
"answer": search_results[0] if search_results else None,
}
@app.local_entrypoint()
async def main():
logger = get_logger()
text_queries = [
{
"text": "NASA's Artemis program aims to return humans to the Moon by 2026, focusing on sustainable exploration and preparing for future Mars missions.",
"query": "When does NASA plan to return humans to the Moon under the Artemis program?",
},
{
"text": "According to a 2022 UN report, global food waste amounts to approximately 931 million tons annually, with households contributing 61% of the total.",
"query": "How much food waste do households contribute annually according to the 2022 UN report?",
},
{
"text": "The 2021 census data revealed that Tokyo's population reached 14 million, reflecting a 2.1% increase compared to the previous census conducted in 2015.",
"query": "What was Tokyo's population according to the 2021 census data?",
},
{
"text": "A recent study published in the Journal of Nutrition found that consuming 30 grams of almonds daily can lower LDL cholesterol levels by 7% over a 12-week period.",
"query": "How much can daily almond consumption lower LDL cholesterol according to the study?",
},
{
"text": "Amazon's Prime membership grew to 200 million subscribers in 2023, marking a 10% increase from the previous year, driven by exclusive content and faster delivery options.",
"query": "How many Prime members did Amazon have in 2023?",
},
{
"text": "A new report by the International Energy Agency states that global renewable energy capacity increased by 295 gigawatts in 2022, primarily driven by solar and wind power expansion.",
"query": "By how much did global renewable energy capacity increase in 2022 according to the report?",
},
{
"text": "The World Health Organization reported in 2023 that the global life expectancy has risen to 73.4 years, an increase of 5.5 years since the year 2000.",
"query": "What is the current global life expectancy according to the WHO's 2023 report?",
},
{
"text": "The FIFA World Cup 2022 held in Qatar attracted a record-breaking audience of 5 billion people across various digital and traditional broadcasting platforms.",
"query": "How many people watched the FIFA World Cup 2022?",
},
{
"text": "The European Space Agency's JUICE mission, launched in 2023, aims to explore Jupiter's icy moons, including Ganymede, Europa, and Callisto, over the next decade.",
"query": "Which moons is the JUICE mission set to explore?",
},
{
"text": "According to a report by the International Labour Organization, the global unemployment rate in 2023 was estimated at 5.4%, reflecting a slight decrease compared to the previous year.",
"query": "What was the global unemployment rate in 2023 according to the ILO?",
},
]
tasks = [entry.remote.aio(item["text"], item["query"]) for item in text_queries]
results = await asyncio.gather(*tasks)
logger.info("Final Results:")
for result in results:
logger.info(result)
logger.info("----")
os.kill(os.getpid(), signal.SIGTERM)