fix: add deadlock retry for neo4j and content policy filter allback for openai models
This commit is contained in:
parent
46cc2a128c
commit
36507502ef
10 changed files with 301 additions and 152 deletions
|
|
@ -26,6 +26,8 @@ from .neo4j_metrics_utils import (
|
|||
get_size_of_connected_components,
|
||||
count_self_loops,
|
||||
)
|
||||
from .deadlock_retry import deadlock_retry
|
||||
|
||||
|
||||
logger = get_logger("Neo4jAdapter", level=ERROR)
|
||||
|
||||
|
|
@ -49,19 +51,16 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async with self.driver.session() as session:
|
||||
yield session
|
||||
|
||||
@deadlock_retry
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
data = await result.data()
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
data = await result.data()
|
||||
return data
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
results = self.query(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
||||
|
||||
|
||||
logger = get_logger("deadlock_retry")
|
||||
|
||||
|
||||
def deadlock_retry(max_retries=5):
|
||||
"""
|
||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||
|
||||
This decorator implements an exponential backoff strategy with jitter
|
||||
to handle rate limit errors efficiently.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
initial_backoff: Initial backoff time in seconds.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
jitter: Jitter factor to avoid the thundering herd problem.
|
||||
|
||||
Returns:
|
||||
The decorated async function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
from neo4j.exceptions import Neo4jError, DatabaseUnavailable
|
||||
|
||||
attempt = 0
|
||||
|
||||
async def wait():
|
||||
backoff_time = calculate_backoff(attempt)
|
||||
logger.warning(
|
||||
f"Neo4j rate limit hit, retrying in {backoff_time:.2f}s "
|
||||
f"Attempt {attempt}/{max_retries}"
|
||||
)
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
while attempt <= max_retries:
|
||||
try:
|
||||
attempt += 1
|
||||
return await func(*args, **kwargs)
|
||||
except Neo4jError as error:
|
||||
if attempt > max_retries:
|
||||
raise # Re-raise the original error
|
||||
|
||||
error_str = str(error)
|
||||
if "DeadlockDetected" in error_str or "Neo.TransientError" in error_str:
|
||||
await wait()
|
||||
else:
|
||||
raise # Re-raise the original error
|
||||
except DatabaseUnavailable as error:
|
||||
if attempt >= max_retries:
|
||||
raise # Re-raise the original error
|
||||
|
||||
await wait()
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -23,6 +23,10 @@ class LLMConfig(BaseSettings):
|
|||
embedding_rate_limit_requests: int = 60
|
||||
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
||||
|
||||
fallback_api_key: str = None
|
||||
fallback_endpoint: str = None
|
||||
fallback_model: str = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
|
@ -97,6 +101,9 @@ class LLMConfig(BaseSettings):
|
|||
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
|
||||
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
|
||||
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
|
||||
"fallback_api_key": self.fallback_api_key,
|
||||
"fallback_endpoint": self.fallback_endpoint,
|
||||
"fallback_model": self.fallback_model,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
5
cognee/infrastructure/llm/exceptions.py
Normal file
5
cognee/infrastructure/llm/exceptions.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from cognee.exceptions.exceptions import CriticalError
|
||||
|
||||
|
||||
class ContentPolicyFilterError(CriticalError):
|
||||
pass
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
|
||||
|
||||
|
||||
class GenericAPIAdapter(LLMInterface):
|
||||
|
|
@ -17,13 +17,27 @@ class GenericAPIAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
|
||||
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
api_key: str,
|
||||
model: str,
|
||||
name: str,
|
||||
max_tokens: int,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
|
||||
)
|
||||
|
|
@ -35,20 +49,50 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except litellm.exceptions.ContentPolicyViolationError:
|
||||
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except litellm.exceptions.ContentPolicyViolationError:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ def get_llm_client():
|
|||
transcription_model=llm_config.transcription_model,
|
||||
max_tokens=max_tokens,
|
||||
streaming=llm_config.llm_streaming,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
fallback_model=llm_config.fallback_model,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
|
|
@ -78,6 +81,9 @@ def get_llm_client():
|
|||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_tokens=max_tokens,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
fallback_model=llm_config.fallback_model,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.GEMINI:
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
|
|
@ -45,6 +46,9 @@ class OpenAIAdapter(LLMInterface):
|
|||
transcription_model: str,
|
||||
max_tokens: int,
|
||||
streaming: bool = False,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
|
|
@ -56,6 +60,10 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.max_tokens = max_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
@observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
|
|
@ -64,25 +72,55 @@ class OpenAIAdapter(LLMInterface):
|
|||
) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except ContentFilterFinishReasonError:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except ContentFilterFinishReasonError:
|
||||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
)
|
||||
|
||||
@observe
|
||||
@sleep_and_retry_sync()
|
||||
|
|
|
|||
30
cognee/infrastructure/utils/calculate_backoff.py
Normal file
30
cognee/infrastructure/utils/calculate_backoff.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import random
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
def calculate_backoff(
|
||||
attempt,
|
||||
initial_backoff=DEFAULT_INITIAL_BACKOFF,
|
||||
backoff_factor=DEFAULT_BACKOFF_FACTOR,
|
||||
jitter=DEFAULT_JITTER,
|
||||
):
|
||||
"""
|
||||
Calculate the backoff time for a retry attempt with jitter.
|
||||
|
||||
Args:
|
||||
attempt: The current retry attempt (0-based).
|
||||
initial_backoff: The initial backoff time in seconds.
|
||||
backoff_factor: The multiplier for exponential backoff.
|
||||
jitter: The jitter factor to avoid thundering herd.
|
||||
|
||||
Returns:
|
||||
float: The backoff time in seconds.
|
||||
"""
|
||||
backoff = initial_backoff * (backoff_factor**attempt)
|
||||
jitter_amount = backoff * jitter
|
||||
return backoff + random.uniform(-jitter_amount, jitter_amount)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from cognee.infrastructure.databases.graph.neo4j_driver.deadlock_retry import deadlock_retry
|
||||
|
||||
|
||||
async def test_deadlock_retry_errored():
|
||||
mock_return = asyncio.Future()
|
||||
mock_return.set_result(True)
|
||||
mock_function = MagicMock(
|
||||
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
|
||||
)
|
||||
|
||||
wrapped_function = deadlock_retry(max_retries=1)(mock_function)
|
||||
|
||||
with pytest.raises(Neo4jError):
|
||||
await wrapped_function()
|
||||
|
||||
|
||||
async def test_deadlock_retry():
|
||||
mock_return = asyncio.Future()
|
||||
mock_return.set_result(True)
|
||||
mock_function = MagicMock(side_effect=[Neo4jError("DeadlockDetected"), mock_return])
|
||||
|
||||
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
|
||||
|
||||
result = await wrapped_function()
|
||||
assert result == True, "Function should have succeded on second time"
|
||||
|
||||
|
||||
async def test_deadlock_retry_exhaustive():
|
||||
mock_return = asyncio.Future()
|
||||
mock_return.set_result(True)
|
||||
mock_function = MagicMock(
|
||||
side_effect=[Neo4jError("DeadlockDetected"), Neo4jError("DeadlockDetected"), mock_return]
|
||||
)
|
||||
|
||||
wrapped_function = deadlock_retry(max_retries=2)(mock_function)
|
||||
|
||||
result = await wrapped_function()
|
||||
assert result == True, "Function should have succeded on second time"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
await test_deadlock_retry()
|
||||
await test_deadlock_retry_errored()
|
||||
await test_deadlock_retry_exhaustive()
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
import modal
|
||||
import os
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import asyncio
|
||||
import cognee
|
||||
import signal
|
||||
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
app = modal.App("cognee-runner")
|
||||
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.add_local_file("pyproject.toml", remote_path="/root/pyproject.toml", copy=True)
|
||||
.add_local_file("poetry.lock", remote_path="/root/poetry.lock", copy=True)
|
||||
.env({"ENV": os.getenv("ENV"), "LLM_API_KEY": os.getenv("LLM_API_KEY")})
|
||||
.poetry_install_from_file(poetry_pyproject_toml="pyproject.toml")
|
||||
.pip_install("protobuf", "h2")
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, max_containers=4)
|
||||
async def entry(text: str, query: str):
|
||||
logger = get_logger()
|
||||
logger.info("Initializing Cognee")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await cognee.add(text)
|
||||
await cognee.cognify()
|
||||
search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=query)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"query": query,
|
||||
"answer": search_results[0] if search_results else None,
|
||||
}
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
async def main():
|
||||
logger = get_logger()
|
||||
text_queries = [
|
||||
{
|
||||
"text": "NASA's Artemis program aims to return humans to the Moon by 2026, focusing on sustainable exploration and preparing for future Mars missions.",
|
||||
"query": "When does NASA plan to return humans to the Moon under the Artemis program?",
|
||||
},
|
||||
{
|
||||
"text": "According to a 2022 UN report, global food waste amounts to approximately 931 million tons annually, with households contributing 61% of the total.",
|
||||
"query": "How much food waste do households contribute annually according to the 2022 UN report?",
|
||||
},
|
||||
{
|
||||
"text": "The 2021 census data revealed that Tokyo's population reached 14 million, reflecting a 2.1% increase compared to the previous census conducted in 2015.",
|
||||
"query": "What was Tokyo's population according to the 2021 census data?",
|
||||
},
|
||||
{
|
||||
"text": "A recent study published in the Journal of Nutrition found that consuming 30 grams of almonds daily can lower LDL cholesterol levels by 7% over a 12-week period.",
|
||||
"query": "How much can daily almond consumption lower LDL cholesterol according to the study?",
|
||||
},
|
||||
{
|
||||
"text": "Amazon's Prime membership grew to 200 million subscribers in 2023, marking a 10% increase from the previous year, driven by exclusive content and faster delivery options.",
|
||||
"query": "How many Prime members did Amazon have in 2023?",
|
||||
},
|
||||
{
|
||||
"text": "A new report by the International Energy Agency states that global renewable energy capacity increased by 295 gigawatts in 2022, primarily driven by solar and wind power expansion.",
|
||||
"query": "By how much did global renewable energy capacity increase in 2022 according to the report?",
|
||||
},
|
||||
{
|
||||
"text": "The World Health Organization reported in 2023 that the global life expectancy has risen to 73.4 years, an increase of 5.5 years since the year 2000.",
|
||||
"query": "What is the current global life expectancy according to the WHO's 2023 report?",
|
||||
},
|
||||
{
|
||||
"text": "The FIFA World Cup 2022 held in Qatar attracted a record-breaking audience of 5 billion people across various digital and traditional broadcasting platforms.",
|
||||
"query": "How many people watched the FIFA World Cup 2022?",
|
||||
},
|
||||
{
|
||||
"text": "The European Space Agency's JUICE mission, launched in 2023, aims to explore Jupiter's icy moons, including Ganymede, Europa, and Callisto, over the next decade.",
|
||||
"query": "Which moons is the JUICE mission set to explore?",
|
||||
},
|
||||
{
|
||||
"text": "According to a report by the International Labour Organization, the global unemployment rate in 2023 was estimated at 5.4%, reflecting a slight decrease compared to the previous year.",
|
||||
"query": "What was the global unemployment rate in 2023 according to the ILO?",
|
||||
},
|
||||
]
|
||||
|
||||
tasks = [entry.remote.aio(item["text"], item["query"]) for item in text_queries]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
logger.info("Final Results:")
|
||||
|
||||
for result in results:
|
||||
logger.info(result)
|
||||
logger.info("----")
|
||||
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
Loading…
Add table
Reference in a new issue