This commit is contained in:
prestonrasmussen 2025-09-08 10:14:25 -04:00
parent b036d45329
commit 6b57979869
6 changed files with 84 additions and 66 deletions

View file

@ -21,7 +21,7 @@ from abc import ABC, abstractmethod
from collections.abc import Coroutine
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any
try:
from opensearchpy import OpenSearch, helpers
@ -32,6 +32,9 @@ except ImportError:
helpers = None
_HAS_OPENSEARCH = False
if TYPE_CHECKING:
from opensearchpy import OpenSearch, helpers
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10

View file

@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class FalkorDriverSession(GraphDriverSession):
provider = GraphProvider.FALKORDB
aoss_client: None
aoss_client: None = None
def __init__(self, graph: FalkorGraph):
self.graph = graph

View file

@ -92,7 +92,7 @@ SCHEMA_QUERIES = """
class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU
aoss_client: None
aoss_client: None = None
def __init__(
self,

View file

@ -22,6 +22,7 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
@ -98,9 +99,14 @@ class Neo4jDriver(GraphDriver):
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
def delete_all_indexes(self) -> Coroutine:
if self.aoss_client:
return self.delete_aoss_indices()
return semaphore_gather(
self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
),
self.delete_aoss_indices(),
)
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)

View file

@ -232,6 +232,10 @@ class NeptuneDriver(GraphDriver):
for index in neptune_aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client:
raise ValueError(
'You must provide an AOSS endpoint to create an OpenSearch driver.'
)
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])

View file

@ -209,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -246,10 +246,11 @@ async def edge_fulltext_search(
else:
return []
elif driver.aoss_client:
filters = build_aoss_edge_filters(group_ids, search_filter)
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids[0],
routing=route,
_source=['uuid'],
query={
'bool': {
@ -343,8 +344,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -403,10 +404,11 @@ async def edge_similarity_search(
else:
return []
elif driver.aoss_client:
filters = build_aoss_edge_filters(group_ids, search_filter)
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids[0],
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
@ -620,11 +622,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -642,10 +644,11 @@ async def node_fulltext_search(
else:
return []
elif driver.aoss_client:
filters = build_aoss_node_filters(group_ids, search_filter)
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
'entities',
routing=group_ids[0],
routing=route,
_source=['uuid'],
query={
'bool': {
@ -731,8 +734,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -761,11 +764,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -784,10 +787,11 @@ async def node_similarity_search(
else:
return []
elif driver.aoss_client:
filters = build_aoss_node_filters(group_ids, search_filter)
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entities',
routing=group_ids[0],
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
@ -810,8 +814,8 @@ async def node_similarity_search(
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -983,9 +987,10 @@ async def episode_fulltext_search(
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
res = driver.aoss_client.search(
'episodes',
routing=group_ids[0],
routing=route,
_source=['uuid'],
query={
'bool': {
@ -1142,8 +1147,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1202,8 +1207,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1345,9 +1350,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1392,9 +1397,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1483,9 +1488,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1555,9 +1560,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1593,9 +1598,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1668,10 +1673,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1741,10 +1746,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1780,10 +1785,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """