This commit is contained in:
prestonrasmussen 2025-11-05 11:42:02 -05:00
parent 5d1ba3cf57
commit da1a417059
2 changed files with 35 additions and 36 deletions

View file

@ -20,7 +20,7 @@ from typing import Any, LiteralString
import kuzu
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.graph_queries import get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
@ -147,6 +147,39 @@ class KuzuDriver(GraphDriver):
conn.execute(SCHEMA_QUERIES)
conn.close()
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await self.execute_query('CALL SHOW_INDEXES() RETURN *;')
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await self.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)
class KuzuDriverSession(GraphDriverSession):
provider = GraphProvider.KUZU
@ -176,37 +209,3 @@ class KuzuDriverSession(GraphDriverSession):
else:
await self.driver.execute_query(query, **kwargs)
return None
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await self.execute_query('CALL SHOW_INDEXES() RETURN *;')
if len(result) > 0:
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await self.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)

View file

@ -25,7 +25,7 @@ from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.graph_queries import get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)