diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 8dfc35f9..6fb822a8 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -217,15 +217,7 @@ class GraphDriver(ABC): client.indices.delete(index=index_name) def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: - for index in aoss_indices: - if name.lower() == index['index_name']: - index['query']['query']['multi_match']['query'] = query_text - query = {'size': limit, 'query': index['query']} - resp = self.aoss_client.search(body=query['query'], index=index['index_name']) - return resp - return {} - - from opensearchpy import helpers + pass def save_to_aoss(self, name: str, data: list[dict]) -> int: for index in aoss_indices: diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index c59eefb3..066866bf 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -18,15 +18,25 @@ import logging from collections.abc import Coroutine from typing import Any -import boto3 from neo4j import AsyncGraphDatabase, EagerResult -from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) +try: + import boto3 + from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection + + _HAS_OPENSEARCH = True +except ImportError: + boto3 = None + OpenSearch = None + Urllib3AWSV4SignerAuth = None + Urllib3HttpConnection = None + _HAS_OPENSEARCH = False + class Neo4jDriver(GraphDriver): provider = GraphProvider.NEO4J @@ -49,17 +59,21 @@ class Neo4jDriver(GraphDriver): self.aoss_client = None if aoss_host and aoss_port: - session = boto3.Session() - self.aoss_client = OpenSearch( - hosts=[{'host': aoss_host, 'port': aoss_port}], - http_auth=Urllib3AWSV4SignerAuth( - session.get_credentials(), session.region_name, 'aoss' - ), - use_ssl=True, - verify_certs=True, - connection_class=Urllib3HttpConnection, - pool_maxsize=20, - ) + try: + session = boto3.Session() + self.aoss_client = OpenSearch( + hosts=[{'host': aoss_host, 'port': aoss_port}], + http_auth=Urllib3AWSV4SignerAuth( + session.get_credentials(), session.region_name, 'aoss' + ), + use_ssl=True, + verify_certs=True, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) + except Exception as e: + logger.warning(f'Failed to initialize OpenSearch client: {e}') + self.aoss_client = None async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: # Check if database_ is provided in kwargs. @@ -86,7 +100,7 @@ class Neo4jDriver(GraphDriver): def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: if self.aoss_client: - self.delete_all_indexes_impl() + return self.delete_aoss_indices() return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index fede9b27..e1c8e44e 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -307,7 +307,7 @@ class EntityEdge(Edge): if driver.provider == GraphProvider.KUZU: edge_data['attributes'] = json.dumps(self.attributes) result = await driver.execute_query( - get_entity_edge_save_query(driver.provider), + get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)), **edge_data, ) else: diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 6e83acd0..3adef419 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -482,7 +482,7 @@ class EntityNode(Node): driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( - get_entity_node_save_query(driver.provider, labels), + get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), entity_data=entity_data, ) diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index e354a1a3..05892810 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -237,7 +237,7 @@ def edge_search_filter_query_constructor( def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: - filters = [{'term': {'group_id': group_ids}}] + filters = [{'terms': {'group_id': group_ids}}] if search_filters.node_labels: filters.append({'terms': {'node_labels': search_filters.node_labels}}) @@ -246,7 +246,7 @@ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: - filters = [{'term': {'group_id': group_ids}}] + filters = [{'terms': {'group_id': group_ids}}] if search_filters.edge_types: filters.append({'terms': {'edge_types': search_filters.edge_types}}) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index b2a0c5f2..07c13dde 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -209,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -265,7 +265,8 @@ async def edge_fulltext_search( # Get edges entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) - return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) + entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_edges else: return [] else: @@ -342,8 +343,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -423,7 +424,8 @@ async def edge_similarity_search( # Get edges entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) - return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) + entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_edges else: query = ( @@ -618,11 +620,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -669,7 +671,8 @@ async def node_fulltext_search( # Get nodes entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) - return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) + entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entities else: return [] else: @@ -728,8 +731,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -758,11 +761,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -801,13 +804,14 @@ async def node_similarity_search( input_uuids[r['_source']['uuid']] = r['_score'] # Get edges - entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) - return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) + entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) + entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_nodes else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -985,7 +989,7 @@ async def episode_fulltext_search( _source=['uuid'], query={ 'bool': { - 'filter': [{'term': {'group_id': group_ids}}], + 'filter': [{'terms': {'group_id': group_ids}}], 'must': [ { 'multi_match': { @@ -1161,8 +1165,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1221,8 +1225,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1364,9 +1368,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1411,9 +1415,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1502,9 +1506,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1574,9 +1578,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1612,9 +1616,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1687,10 +1691,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1760,10 +1764,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1799,10 +1803,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index bc771c43..ff76a930 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -187,12 +187,20 @@ async def add_nodes_and_edges_bulk_tx( await tx.run(episodic_edge_query, **edge.model_dump()) else: await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) - await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) + await tx.run( + get_entity_node_save_bulk_query(driver.provider, nodes), + nodes=nodes, + has_aoss=bool(driver.aoss_client), + ) await tx.run( get_episodic_edge_save_bulk_query(driver.provider), episodic_edges=[edge.model_dump() for edge in episodic_edges], ) - await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) + await tx.run( + get_entity_edge_save_bulk_query(driver.provider), + entity_edges=edges, + has_aoss=bool(driver.aoss_client), + ) if driver.aoss_client: driver.save_to_aoss('episodes', episodes) diff --git a/pyproject.toml b/pyproject.toml index 2c887082..f4c8821b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"] kuzu = ["kuzu>=0.11.2"] falkordb = ["falkordb>=1.1.2,<2.0.0"] voyageai = ["voyageai>=0.2.3"] +neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"] sentence-transformers = ["sentence-transformers>=3.2.1"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] dev = [ diff --git a/uv.lock b/uv.lock index 6abdaaf9..8228d4a8 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.21.0" +version = "0.21.0rc1" source = { editable = "." } dependencies = [ { name = "diskcache" }, @@ -835,6 +835,10 @@ groq = [ kuzu = [ { name = "kuzu" }, ] +neo4j-opensearch = [ + { name = "boto3" }, + { name = "opensearch-py" }, +] neptune = [ { name = "boto3" }, { name = "langchain-aws" }, @@ -851,6 +855,7 @@ voyageai = [ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, + { name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "diskcache", specifier = ">=5.6.3" }, { name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" }, @@ -872,6 +877,7 @@ requires-dist = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "numpy", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.91.0" }, + { name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" }, { name = "pydantic", specifier = ">=2.11.5" }, @@ -888,7 +894,7 @@ requires-dist = [ { name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" }, ] -provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"] +provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"] [[package]] name = "groq"