This commit is contained in:
prestonrasmussen 2025-09-07 23:36:22 -04:00
parent ef8bd41e1a
commit 13fc9cf1e4
9 changed files with 114 additions and 89 deletions

View file

@ -217,15 +217,7 @@ class GraphDriver(ABC):
client.indices.delete(index=index_name) client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices: pass
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
from opensearchpy import helpers
def save_to_aoss(self, name: str, data: list[dict]) -> int: def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices: for index in aoss_indices:

View file

@ -18,15 +18,25 @@ import logging
from collections.abc import Coroutine from collections.abc import Coroutine
from typing import Any from typing import Any
import boto3
from neo4j import AsyncGraphDatabase, EagerResult from neo4j import AsyncGraphDatabase, EagerResult
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import boto3
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
_HAS_OPENSEARCH = True
except ImportError:
boto3 = None
OpenSearch = None
Urllib3AWSV4SignerAuth = None
Urllib3HttpConnection = None
_HAS_OPENSEARCH = False
class Neo4jDriver(GraphDriver): class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J provider = GraphProvider.NEO4J
@ -49,17 +59,21 @@ class Neo4jDriver(GraphDriver):
self.aoss_client = None self.aoss_client = None
if aoss_host and aoss_port: if aoss_host and aoss_port:
session = boto3.Session() try:
self.aoss_client = OpenSearch( session = boto3.Session()
hosts=[{'host': aoss_host, 'port': aoss_port}], self.aoss_client = OpenSearch(
http_auth=Urllib3AWSV4SignerAuth( hosts=[{'host': aoss_host, 'port': aoss_port}],
session.get_credentials(), session.region_name, 'aoss' http_auth=Urllib3AWSV4SignerAuth(
), session.get_credentials(), session.region_name, 'aoss'
use_ssl=True, ),
verify_certs=True, use_ssl=True,
connection_class=Urllib3HttpConnection, verify_certs=True,
pool_maxsize=20, connection_class=Urllib3HttpConnection,
) pool_maxsize=20,
)
except Exception as e:
logger.warning(f'Failed to initialize OpenSearch client: {e}')
self.aoss_client = None
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs. # Check if database_ is provided in kwargs.
@ -86,7 +100,7 @@ class Neo4jDriver(GraphDriver):
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
if self.aoss_client: if self.aoss_client:
self.delete_all_indexes_impl() return self.delete_aoss_indices()
return self.client.execute_query( return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name', 'CALL db.indexes() YIELD name DROP INDEX name',
) )

View file

@ -307,7 +307,7 @@ class EntityEdge(Edge):
if driver.provider == GraphProvider.KUZU: if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes) edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
**edge_data, **edge_data,
) )
else: else:

View file

@ -482,7 +482,7 @@ class EntityNode(Node):
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels), get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
entity_data=entity_data, entity_data=entity_data,
) )

View file

@ -237,7 +237,7 @@ def edge_search_filter_query_constructor(
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}] filters = [{'terms': {'group_id': group_ids}}]
if search_filters.node_labels: if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}}) filters.append({'terms': {'node_labels': search_filters.node_labels}})
@ -246,7 +246,7 @@ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters)
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}] filters = [{'terms': {'group_id': group_ids}}]
if search_filters.edge_types: if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}}) filters.append({'terms': {'edge_types': search_filters.edge_types}})

View file

@ -209,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -265,7 +265,8 @@ async def edge_fulltext_search(
# Get edges # Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else: else:
return [] return []
else: else:
@ -342,8 +343,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -423,7 +424,8 @@ async def edge_similarity_search(
# Get edges # Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else: else:
query = ( query = (
@ -618,11 +620,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -669,7 +671,8 @@ async def node_fulltext_search(
# Get nodes # Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entities
else: else:
return [] return []
else: else:
@ -728,8 +731,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -758,11 +761,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -801,13 +804,14 @@ async def node_similarity_search(
input_uuids[r['_source']['uuid']] = r['_score'] input_uuids[r['_source']['uuid']] = r['_score']
# Get edges # Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -985,7 +989,7 @@ async def episode_fulltext_search(
_source=['uuid'], _source=['uuid'],
query={ query={
'bool': { 'bool': {
'filter': [{'term': {'group_id': group_ids}}], 'filter': [{'terms': {'group_id': group_ids}}],
'must': [ 'must': [
{ {
'multi_match': { 'multi_match': {
@ -1161,8 +1165,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1221,8 +1225,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1364,9 +1368,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1411,9 +1415,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1502,9 +1506,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1574,9 +1578,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1612,9 +1616,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1687,10 +1691,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1760,10 +1764,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1799,10 +1803,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -187,12 +187,20 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(episodic_edge_query, **edge.model_dump()) await tx.run(episodic_edge_query, **edge.model_dump())
else: else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
has_aoss=bool(driver.aoss_client),
)
await tx.run( await tx.run(
get_episodic_edge_save_bulk_query(driver.provider), get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges], episodic_edges=[edge.model_dump() for edge in episodic_edges],
) )
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
has_aoss=bool(driver.aoss_client),
)
if driver.aoss_client: if driver.aoss_client:
driver.save_to_aoss('episodes', episodes) driver.save_to_aoss('episodes', episodes)

View file

@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"] kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"] falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"] voyageai = ["voyageai>=0.2.3"]
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
sentence-transformers = ["sentence-transformers>=3.2.1"] sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [ dev = [

10
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.21.0" version = "0.21.0rc1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },
@ -835,6 +835,10 @@ groq = [
kuzu = [ kuzu = [
{ name = "kuzu" }, { name = "kuzu" },
] ]
neo4j-opensearch = [
{ name = "boto3" },
{ name = "opensearch-py" },
]
neptune = [ neptune = [
{ name = "boto3" }, { name = "boto3" },
{ name = "langchain-aws" }, { name = "langchain-aws" },
@ -851,6 +855,7 @@ voyageai = [
requires-dist = [ requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
{ name = "diskcache", specifier = ">=5.6.3" }, { name = "diskcache", specifier = ">=5.6.3" },
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" }, { name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
@ -872,6 +877,7 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" }, { name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" }, { name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" }, { name = "openai", specifier = ">=1.91.0" },
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
{ name = "posthog", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" },
{ name = "pydantic", specifier = ">=2.11.5" }, { name = "pydantic", specifier = ">=2.11.5" },
@ -888,7 +894,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
] ]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"] provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
[[package]] [[package]]
name = "groq" name = "groq"