This commit is contained in:
prestonrasmussen 2025-09-07 23:36:22 -04:00
parent ef8bd41e1a
commit 13fc9cf1e4
9 changed files with 114 additions and 89 deletions

View file

@ -217,15 +217,7 @@ class GraphDriver(ABC):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
from opensearchpy import helpers
pass
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:

View file

@ -18,15 +18,25 @@ import logging
from collections.abc import Coroutine
from typing import Any
import boto3
from neo4j import AsyncGraphDatabase, EagerResult
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
try:
import boto3
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
_HAS_OPENSEARCH = True
except ImportError:
boto3 = None
OpenSearch = None
Urllib3AWSV4SignerAuth = None
Urllib3HttpConnection = None
_HAS_OPENSEARCH = False
class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
@ -49,17 +59,21 @@ class Neo4jDriver(GraphDriver):
self.aoss_client = None
if aoss_host and aoss_port:
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
try:
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
except Exception as e:
logger.warning(f'Failed to initialize OpenSearch client: {e}')
self.aoss_client = None
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs.
@ -86,7 +100,7 @@ class Neo4jDriver(GraphDriver):
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
if self.aoss_client:
self.delete_all_indexes_impl()
return self.delete_aoss_indices()
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)

View file

@ -307,7 +307,7 @@ class EntityEdge(Edge):
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
**edge_data,
)
else:

View file

@ -482,7 +482,7 @@ class EntityNode(Node):
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
entity_data=entity_data,
)

View file

@ -237,7 +237,7 @@ def edge_search_filter_query_constructor(
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
filters = [{'terms': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
@ -246,7 +246,7 @@ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters)
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
filters = [{'terms': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})

View file

@ -209,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -265,7 +265,8 @@ async def edge_fulltext_search(
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
return []
else:
@ -342,8 +343,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -423,7 +424,8 @@ async def edge_similarity_search(
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
query = (
@ -618,11 +620,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -669,7 +671,8 @@ async def node_fulltext_search(
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entities
else:
return []
else:
@ -728,8 +731,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -758,11 +761,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -801,13 +804,14 @@ async def node_similarity_search(
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -985,7 +989,7 @@ async def episode_fulltext_search(
_source=['uuid'],
query={
'bool': {
'filter': [{'term': {'group_id': group_ids}}],
'filter': [{'terms': {'group_id': group_ids}}],
'must': [
{
'multi_match': {
@ -1161,8 +1165,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1221,8 +1225,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1364,9 +1368,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1411,9 +1415,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1502,9 +1506,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1574,9 +1578,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1612,9 +1616,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1687,10 +1691,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1760,10 +1764,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1799,10 +1803,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -187,12 +187,20 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
has_aoss=bool(driver.aoss_client),
)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
has_aoss=bool(driver.aoss_client),
)
if driver.aoss_client:
driver.save_to_aoss('episodes', episodes)

View file

@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"]
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [

10
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.21.0"
version = "0.21.0rc1"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },
@ -835,6 +835,10 @@ groq = [
kuzu = [
{ name = "kuzu" },
]
neo4j-opensearch = [
{ name = "boto3" },
{ name = "opensearch-py" },
]
neptune = [
{ name = "boto3" },
{ name = "langchain-aws" },
@ -851,6 +855,7 @@ voyageai = [
requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
{ name = "diskcache", specifier = ">=5.6.3" },
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
@ -872,6 +877,7 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" },
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
{ name = "posthog", specifier = ">=3.0.0" },
{ name = "pydantic", specifier = ">=2.11.5" },
@ -888,7 +894,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
[[package]]
name = "groq"