update kuze and neptune with build indices function

This commit is contained in:
prestonrasmussen 2025-11-05 11:35:41 -05:00
parent c144ff5995
commit 5d1ba3cf57
6 changed files with 93 additions and 30 deletions

View file

@ -15,11 +15,13 @@ limitations under the License.
""" """
import logging import logging
from typing import Any from typing import Any, LiteralString
import kuzu import kuzu
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -174,3 +176,37 @@ class KuzuDriverSession(GraphDriverSession):
else: else:
await self.driver.execute_query(query, **kwargs) await self.driver.execute_query(query, **kwargs)
return None return None
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await self.execute_query('CALL SHOW_INDEXES() RETURN *;')
if len(result) > 0:
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await self.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)

View file

@ -106,7 +106,7 @@ class Neo4jDriver(GraphDriver):
for query in index_queries for query in index_queries
] ]
) )
async def health_check(self) -> None: async def health_check(self) -> None:
"""Check Neo4j connectivity by running the driver's verify_connectivity method.""" """Check Neo4j connectivity by running the driver's verify_connectivity method."""
try: try:

View file

@ -18,13 +18,15 @@ import asyncio
import datetime import datetime
import logging import logging
from collections.abc import Coroutine from collections.abc import Coroutine
from typing import Any from typing import Any, LiteralString
import boto3 import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
@ -267,6 +269,23 @@ class NeptuneDriver(GraphDriver):
return 0 return 0
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
index_queries: list[LiteralString] = range_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)
class NeptuneDriverSession(GraphDriverSession): class NeptuneDriverSession(GraphDriverSession):
provider = GraphProvider.NEPTUNE provider = GraphProvider.NEPTUNE

View file

@ -166,13 +166,17 @@ class BaseOpenAIClient(LLMClient):
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise RateLimitError from e raise RateLimitError from e
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
logger.error(f'OpenAI Authentication Error: {e}. Please verify your API key is correct.') logger.error(
f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
)
raise raise
except Exception as e: except Exception as e:
# Provide more context for connection errors # Provide more context for connection errors
error_msg = str(e) error_msg = str(e)
if 'Connection error' in error_msg or 'connection' in error_msg.lower(): if 'Connection error' in error_msg or 'connection' in error_msg.lower():
logger.error(f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}') logger.error(
f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
)
else: else:
logger.error(f'Error in generating LLM response: {e}') logger.error(f'Error in generating LLM response: {e}')
raise raise

View file

@ -74,7 +74,9 @@ class OpenAIClient(BaseOpenAIClient):
): ):
"""Create a structured completion using OpenAI's beta parse API.""" """Create a structured completion using OpenAI's beta parse API."""
# Reasoning models (gpt-5 family) don't support temperature # Reasoning models (gpt-5 family) don't support temperature
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3') is_reasoning_model = (
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
)
response = await self.client.responses.parse( response = await self.client.responses.parse(
model=model, model=model,
@ -100,7 +102,9 @@ class OpenAIClient(BaseOpenAIClient):
): ):
"""Create a regular completion with JSON format.""" """Create a regular completion with JSON format."""
# Reasoning models (gpt-5 family) don't support temperature # Reasoning models (gpt-5 family) don't support temperature
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3') is_reasoning_model = (
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
)
return await self.client.chat.completions.create( return await self.client.chat.completions.create(
model=model, model=model,

View file

@ -245,35 +245,35 @@ class GraphitiService:
db_provider = self.config.database.provider db_provider = self.config.database.provider
if db_provider.lower() == 'falkordb': if db_provider.lower() == 'falkordb':
raise RuntimeError( raise RuntimeError(
f"\n{'='*70}\n" f'\n{"=" * 70}\n'
f"Database Connection Error: FalkorDB is not running\n" f'Database Connection Error: FalkorDB is not running\n'
f"{'='*70}\n\n" f'{"=" * 70}\n\n'
f"FalkorDB at {db_config['host']}:{db_config['port']} is not accessible.\n\n" f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
f"To start FalkorDB:\n" f'To start FalkorDB:\n'
f" - Using Docker Compose: cd mcp_server && docker compose up\n" f' - Using Docker Compose: cd mcp_server && docker compose up\n'
f" - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n" f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
f"{'='*70}\n" f'{"=" * 70}\n'
) from db_error ) from db_error
elif db_provider.lower() == 'neo4j': elif db_provider.lower() == 'neo4j':
raise RuntimeError( raise RuntimeError(
f"\n{'='*70}\n" f'\n{"=" * 70}\n'
f"Database Connection Error: Neo4j is not running\n" f'Database Connection Error: Neo4j is not running\n'
f"{'='*70}\n\n" f'{"=" * 70}\n\n'
f"Neo4j at {db_config.get('uri', 'unknown')} is not accessible.\n\n" f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f"To start Neo4j:\n" f'To start Neo4j:\n'
f" - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n" f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
f" - Or install Neo4j Desktop from: https://neo4j.com/download/\n" f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
f" - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n" f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
f"{'='*70}\n" f'{"=" * 70}\n'
) from db_error ) from db_error
else: else:
raise RuntimeError( raise RuntimeError(
f"\n{'='*70}\n" f'\n{"=" * 70}\n'
f"Database Connection Error: {db_provider} is not running\n" f'Database Connection Error: {db_provider} is not running\n'
f"{'='*70}\n\n" f'{"=" * 70}\n\n'
f"{db_provider} at {db_config.get('uri', 'unknown')} is not accessible.\n\n" f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f"Please ensure {db_provider} is running and accessible.\n\n" f'Please ensure {db_provider} is running and accessible.\n\n'
f"{'='*70}\n" f'{"=" * 70}\n'
) from db_error ) from db_error
# Re-raise other errors # Re-raise other errors
raise raise