update kuze and neptune with build indices function
This commit is contained in:
parent
c144ff5995
commit
5d1ba3cf57
6 changed files with 93 additions and 30 deletions
|
|
@ -15,11 +15,13 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, LiteralString
|
||||||
|
|
||||||
import kuzu
|
import kuzu
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -174,3 +176,37 @@ class KuzuDriverSession(GraphDriverSession):
|
||||||
else:
|
else:
|
||||||
await self.driver.execute_query(query, **kwargs)
|
await self.driver.execute_query(query, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||||
|
|
||||||
|
# Skip creating fulltext indices if they already exist. Need to do this manually
|
||||||
|
# until Kuzu supports `IF NOT EXISTS` for indices.
|
||||||
|
result, _, _ = await self.execute_query('CALL SHOW_INDEXES() RETURN *;')
|
||||||
|
if len(result) > 0:
|
||||||
|
fulltext_indices = []
|
||||||
|
|
||||||
|
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
|
||||||
|
result, _, _ = await self.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
|
||||||
|
if len(result) == 0:
|
||||||
|
fulltext_indices.insert(
|
||||||
|
0,
|
||||||
|
"""
|
||||||
|
INSTALL fts;
|
||||||
|
LOAD fts;
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.execute_query(
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ class Neo4jDriver(GraphDriver):
|
||||||
for query in index_queries
|
for query in index_queries
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def health_check(self) -> None:
|
async def health_check(self) -> None:
|
||||||
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,15 @@ import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from typing import Any
|
from typing import Any, LiteralString
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
@ -267,6 +269,23 @@ class NeptuneDriver(GraphDriver):
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||||
|
|
||||||
|
index_queries: list[LiteralString] = range_indices
|
||||||
|
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.execute_query(
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneDriverSession(GraphDriverSession):
|
class NeptuneDriverSession(GraphDriverSession):
|
||||||
provider = GraphProvider.NEPTUNE
|
provider = GraphProvider.NEPTUNE
|
||||||
|
|
|
||||||
|
|
@ -166,13 +166,17 @@ class BaseOpenAIClient(LLMClient):
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
raise RateLimitError from e
|
raise RateLimitError from e
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
logger.error(f'OpenAI Authentication Error: {e}. Please verify your API key is correct.')
|
logger.error(
|
||||||
|
f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Provide more context for connection errors
|
# Provide more context for connection errors
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
if 'Connection error' in error_msg or 'connection' in error_msg.lower():
|
if 'Connection error' in error_msg or 'connection' in error_msg.lower():
|
||||||
logger.error(f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}')
|
logger.error(
|
||||||
|
f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,9 @@ class OpenAIClient(BaseOpenAIClient):
|
||||||
):
|
):
|
||||||
"""Create a structured completion using OpenAI's beta parse API."""
|
"""Create a structured completion using OpenAI's beta parse API."""
|
||||||
# Reasoning models (gpt-5 family) don't support temperature
|
# Reasoning models (gpt-5 family) don't support temperature
|
||||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
is_reasoning_model = (
|
||||||
|
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||||
|
)
|
||||||
|
|
||||||
response = await self.client.responses.parse(
|
response = await self.client.responses.parse(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -100,7 +102,9 @@ class OpenAIClient(BaseOpenAIClient):
|
||||||
):
|
):
|
||||||
"""Create a regular completion with JSON format."""
|
"""Create a regular completion with JSON format."""
|
||||||
# Reasoning models (gpt-5 family) don't support temperature
|
# Reasoning models (gpt-5 family) don't support temperature
|
||||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
is_reasoning_model = (
|
||||||
|
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||||
|
)
|
||||||
|
|
||||||
return await self.client.chat.completions.create(
|
return await self.client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
||||||
|
|
@ -245,35 +245,35 @@ class GraphitiService:
|
||||||
db_provider = self.config.database.provider
|
db_provider = self.config.database.provider
|
||||||
if db_provider.lower() == 'falkordb':
|
if db_provider.lower() == 'falkordb':
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"\n{'='*70}\n"
|
f'\n{"=" * 70}\n'
|
||||||
f"Database Connection Error: FalkorDB is not running\n"
|
f'Database Connection Error: FalkorDB is not running\n'
|
||||||
f"{'='*70}\n\n"
|
f'{"=" * 70}\n\n'
|
||||||
f"FalkorDB at {db_config['host']}:{db_config['port']} is not accessible.\n\n"
|
f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
|
||||||
f"To start FalkorDB:\n"
|
f'To start FalkorDB:\n'
|
||||||
f" - Using Docker Compose: cd mcp_server && docker compose up\n"
|
f' - Using Docker Compose: cd mcp_server && docker compose up\n'
|
||||||
f" - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n"
|
f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
|
||||||
f"{'='*70}\n"
|
f'{"=" * 70}\n'
|
||||||
) from db_error
|
) from db_error
|
||||||
elif db_provider.lower() == 'neo4j':
|
elif db_provider.lower() == 'neo4j':
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"\n{'='*70}\n"
|
f'\n{"=" * 70}\n'
|
||||||
f"Database Connection Error: Neo4j is not running\n"
|
f'Database Connection Error: Neo4j is not running\n'
|
||||||
f"{'='*70}\n\n"
|
f'{"=" * 70}\n\n'
|
||||||
f"Neo4j at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
|
f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
||||||
f"To start Neo4j:\n"
|
f'To start Neo4j:\n'
|
||||||
f" - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n"
|
f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
|
||||||
f" - Or install Neo4j Desktop from: https://neo4j.com/download/\n"
|
f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
|
||||||
f" - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n"
|
f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
|
||||||
f"{'='*70}\n"
|
f'{"=" * 70}\n'
|
||||||
) from db_error
|
) from db_error
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"\n{'='*70}\n"
|
f'\n{"=" * 70}\n'
|
||||||
f"Database Connection Error: {db_provider} is not running\n"
|
f'Database Connection Error: {db_provider} is not running\n'
|
||||||
f"{'='*70}\n\n"
|
f'{"=" * 70}\n\n'
|
||||||
f"{db_provider} at {db_config.get('uri', 'unknown')} is not accessible.\n\n"
|
f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
||||||
f"Please ensure {db_provider} is running and accessible.\n\n"
|
f'Please ensure {db_provider} is running and accessible.\n\n'
|
||||||
f"{'='*70}\n"
|
f'{"=" * 70}\n'
|
||||||
) from db_error
|
) from db_error
|
||||||
# Re-raise other errors
|
# Re-raise other errors
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue