diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 670a7426..40ce2540 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -14,15 +14,30 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import copy import logging from abc import ABC, abstractmethod from collections.abc import Coroutine +from datetime import datetime from enum import Enum from typing import Any +from graphiti_core.embedder.client import EMBEDDING_DIM + +try: + from opensearchpy import OpenSearch, helpers + + _HAS_OPENSEARCH = True +except ImportError: + OpenSearch = None + helpers = None + _HAS_OPENSEARCH = False + logger = logging.getLogger(__name__) +DEFAULT_SIZE = 10 + class GraphProvider(Enum): NEO4J = 'neo4j' @@ -31,6 +46,93 @@ class GraphProvider(Enum): NEPTUNE = 'neptune' +aoss_indices = [ + { + 'index_name': 'entities', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'summary': {'type': 'text'}, + 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'name_embedding': { + 'type': 'knn_vector', + 'dims': EMBEDDING_DIM, + 'index': True, + 'similarity': 'cosine', + 'method': { + 'engine': 'faiss', + 'space_type': 'cosinesimil', + 'name': 'hnsw', + 'parameters': {'ef_construction': 128, 'm': 16}, + }, + }, + } + } + }, + }, + { + 'index_name': 'communities', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'group_id': {'type': 'text'}, + } + } + }, + }, + { + 'index_name': 'episodes', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'content': {'type': 'text'}, + 'source': {'type': 'text'}, + 'source_description': {'type': 'text'}, + 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + } + } + }, + }, + { + 'index_name': 'entity_edges', + 'body': { + 'mappings': { + 'properties': { + 'uuid': {'type': 'keyword'}, + 'name': {'type': 'text'}, + 'fact': {'type': 'text'}, + 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'fact_embedding': { + 'type': 'knn_vector', + 'dims': EMBEDDING_DIM, + 'index': True, + 'similarity': 'cosine', + 'method': { + 'engine': 'faiss', + 'space_type': 'cosinesimil', + 'name': 'hnsw', + 'parameters': {'ef_construction': 128, 'm': 16}, + }, + }, + } + } + }, + }, +] + + class GraphDriverSession(ABC): provider: GraphProvider @@ -61,6 +163,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str + aoss_client: OpenSearch | None # type: ignore @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -87,3 +190,70 @@ class GraphDriver(ABC): cloned._database = database return cloned + + async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]: + # No matter what happens above, always return True + return self.delete_aoss_indices() + + async def create_aoss_indices(self): + client = self.aoss_client + if not client: + logger.warning('No OpenSearch client found') + return + + for index in aoss_indices: + alias_name = index['index_name'] + + # If alias already exists, skip (idempotent behavior) + if client.indices.exists_alias(name=alias_name): + continue + + # Build a physical index name with timestamp + ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S') + physical_index_name = f'{alias_name}_{ts_suffix}' + + # Create the index + client.indices.create(index=physical_index_name, body=index['body']) + + # Point alias to it + client.indices.put_alias(index=physical_index_name, name=alias_name) + + # Allow some time for index creation + await asyncio.sleep(60) + + async def delete_aoss_indices(self): + for index in aoss_indices: + index_name = index['index_name'] + client = self.aoss_client + + if not client: + logger.warning('No OpenSearch client found') + return + + if client.indices.exists(index=index_name): + client.indices.delete(index=index_name) + + def save_to_aoss(self, name: str, data: list[dict]) -> int: + client = self.aoss_client + if not client or not helpers: + logger.warning('No OpenSearch client found') + return 0 + + for index in aoss_indices: + if name.lower() == index['index_name']: + to_index = [] + for d in data: + item = { + '_index': name, + '_routing': d.get('group_id'), # shard routing + } + for p in index['body']['mappings']['properties']: + if p in d: # protect against missing fields + item[p] = d[p] + to_index.append(item) + + success, failed = helpers.bulk(client, to_index, stats_only=True) + + return success if failed == 0 else success + + return 0 diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 00c39342..7c619897 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): provider = GraphProvider.FALKORDB + aoss_client: None = None def __init__( self, diff --git a/graphiti_core/driver/kuzu_driver.py b/graphiti_core/driver/kuzu_driver.py index af371b2f..8a04a4ac 100644 --- a/graphiti_core/driver/kuzu_driver.py +++ b/graphiti_core/driver/kuzu_driver.py @@ -92,6 +92,7 @@ SCHEMA_QUERIES = """ class KuzuDriver(GraphDriver): provider: GraphProvider = GraphProvider.KUZU + aoss_client: None = None def __init__( self, diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 7ac9a5a8..f10d176b 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -22,14 +22,35 @@ from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.helpers import semaphore_gather logger = logging.getLogger(__name__) +try: + import boto3 + from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection + + _HAS_OPENSEARCH = True +except ImportError: + boto3 = None + OpenSearch = None + Urllib3AWSV4SignerAuth = None + Urllib3HttpConnection = None + _HAS_OPENSEARCH = False + class Neo4jDriver(GraphDriver): provider = GraphProvider.NEO4J - def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): + def __init__( + self, + uri: str, + user: str | None, + password: str | None, + database: str = 'neo4j', + aoss_host: str | None = None, + aoss_port: int | None = None, + ): super().__init__() self.client = AsyncGraphDatabase.driver( uri=uri, @@ -37,6 +58,24 @@ class Neo4jDriver(GraphDriver): ) self._database = database + self.aoss_client = None + if aoss_host and aoss_port and boto3 is not None: + try: + session = boto3.Session() + self.aoss_client = OpenSearch( # type: ignore + hosts=[{'host': aoss_host, 'port': aoss_port}], + http_auth=Urllib3AWSV4SignerAuth( # type: ignore + session.get_credentials(), session.region_name, 'aoss' + ), + use_ssl=True, + verify_certs=True, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) # type: ignore + except Exception as e: + logger.warning(f'Failed to initialize OpenSearch client: {e}') + self.aoss_client = None + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: # Check if database_ is provided in kwargs. # If not populated, set the value to retain backwards compatibility @@ -60,7 +99,14 @@ class Neo4jDriver(GraphDriver): async def close(self) -> None: return await self.client.close() - def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: + def delete_all_indexes(self) -> Coroutine: + if self.aoss_client: + return semaphore_gather( + self.client.execute_query( + 'CALL db.indexes() YIELD name DROP INDEX name', + ), + self.delete_aoss_indices(), + ) return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index 25aa12c3..cb343163 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -22,16 +22,21 @@ from typing import Any import boto3 from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph -from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers +from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.driver.driver import ( + DEFAULT_SIZE, + GraphDriver, + GraphDriverSession, + GraphProvider, +) logger = logging.getLogger(__name__) -DEFAULT_SIZE = 10 -aoss_indices = [ +neptune_aoss_indices = [ { 'index_name': 'node_name_and_summary', + 'alias_name': 'entities', 'body': { 'mappings': { 'properties': { @@ -49,6 +54,7 @@ aoss_indices = [ }, { 'index_name': 'community_name', + 'alias_name': 'communities', 'body': { 'mappings': { 'properties': { @@ -65,6 +71,7 @@ aoss_indices = [ }, { 'index_name': 'episode_content', + 'alias_name': 'episodes', 'body': { 'mappings': { 'properties': { @@ -88,6 +95,7 @@ aoss_indices = [ }, { 'index_name': 'edge_name_and_fact', + 'alias_name': 'facts', 'body': { 'mappings': { 'properties': { @@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver): async def _delete_all_data(self) -> Any: return await self.execute_query('MATCH (n) DETACH DELETE n') - def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: - return self.delete_all_indexes_impl() - - async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]: - # No matter what happens above, always return True - return self.delete_aoss_indices() - async def create_aoss_indices(self): - for index in aoss_indices: + for index in neptune_aoss_indices: index_name = index['index_name'] client = self.aoss_client + if not client: + raise ValueError( + 'You must provide an AOSS endpoint to create an OpenSearch driver.' + ) if not client.indices.exists(index=index_name): client.indices.create(index=index_name, body=index['body']) + + alias_name = index.get('alias_name', index_name) + + if not client.indices.exists_alias(name=alias_name, index=index_name): + client.indices.put_alias(index=index_name, name=alias_name) + # Sleep for 1 minute to let the index creation complete await asyncio.sleep(60) - async def delete_aoss_indices(self): - for index in aoss_indices: - index_name = index['index_name'] - client = self.aoss_client - if client.indices.exists(index=index_name): - client.indices.delete(index=index_name) - - def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: - for index in aoss_indices: - if name.lower() == index['index_name']: - index['query']['query']['multi_match']['query'] = query_text - query = {'size': limit, 'query': index['query']} - resp = self.aoss_client.search(body=query['query'], index=index['index_name']) - return resp - return {} - - def save_to_aoss(self, name: str, data: list[dict]) -> int: - for index in aoss_indices: - if name.lower() == index['index_name']: - to_index = [] - for d in data: - item = {'_index': name} - for p in index['body']['mappings']['properties']: - item[p] = d[p] - to_index.append(item) - success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) - if failed > 0: - return success - else: - return 0 - - return 0 + def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: + return self.delete_all_indexes_impl() class NeptuneDriverSession(GraphDriverSession): diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index a427d65e..e1c8e44e 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -255,6 +255,21 @@ class EntityEdge(Edge): MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding """ + elif driver.aoss_client: + resp = driver.aoss_client.search( + body={ + 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, + 'size': 1, + }, + index='entity_edges', + routing=self.group_id, + ) + + if resp['hits']['hits']: + self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding'] + return + else: + raise EdgeNotFoundError(self.uuid) if driver.provider == GraphProvider.KUZU: query = """ @@ -292,14 +307,14 @@ class EntityEdge(Edge): if driver.provider == GraphProvider.KUZU: edge_data['attributes'] = json.dumps(self.attributes) result = await driver.execute_query( - get_entity_edge_save_query(driver.provider), + get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)), **edge_data, ) else: edge_data.update(self.attributes or {}) - if driver.provider == GraphProvider.NEPTUNE: - driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue + if driver.aoss_client: + driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_edge_save_query(driver.provider), diff --git a/graphiti_core/embedder/client.py b/graphiti_core/embedder/client.py index 9ffc0653..9e05a088 100644 --- a/graphiti_core/embedder/client.py +++ b/graphiti_core/embedder/client.py @@ -14,12 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os from abc import ABC, abstractmethod from collections.abc import Iterable from pydantic import BaseModel, Field -EMBEDDING_DIM = 1024 +EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024)) class EmbedderConfig(BaseModel): diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 5b8d5402..92e294fb 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """ """ -def get_entity_edge_save_query(provider: GraphProvider) -> str: +def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str: match provider: case GraphProvider.FALKORDB: return """ @@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str: RETURN e.uuid AS uuid """ case _: # Neo4j - return """ - MATCH (source:Entity {uuid: $edge_data.source_uuid}) - MATCH (target:Entity {uuid: $edge_data.target_uuid}) - MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) - SET e = $edge_data - WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) + save_embedding_query = ( + """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)""" + if not has_aoss + else '' + ) + return ( + ( + """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + """ + + save_embedding_query + ) + + """ RETURN e.uuid AS uuid - """ + """ + ) -def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: +def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str: match provider: case GraphProvider.FALKORDB: return """ @@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: RETURN e.uuid AS uuid """ case _: - return """ - UNWIND $entity_edges AS edge - MATCH (source:Entity {uuid: edge.source_node_uuid}) - MATCH (target:Entity {uuid: edge.target_node_uuid}) - MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) - SET e = edge - WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) + save_embedding_query = ( + 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)' + if not has_aoss + else '' + ) + return ( + """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) + SET e = edge + """ + + save_embedding_query + + """ RETURN edge.uuid AS uuid """ + ) def get_entity_edge_return_query(provider: GraphProvider) -> str: diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index e588a58d..0972dac4 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """ """ -def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: +def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str: match provider: case GraphProvider.FALKORDB: return f""" @@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: RETURN n.uuid AS uuid """ case _: - return f""" + save_embedding_query = ( + 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)' + if not has_aoss + else '' + ) + return ( + f""" MERGE (n:Entity {{uuid: $entity_data.uuid}}) SET n:{labels} SET n = $entity_data - WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) + """ + + save_embedding_query + + """ RETURN n.uuid AS uuid """ + ) -def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: +def get_entity_node_save_bulk_query( + provider: GraphProvider, nodes: list[dict], has_aoss: bool = False +) -> str | Any: match provider: case GraphProvider.FALKORDB: queries = [] @@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) RETURN n.uuid AS uuid """ case _: # Neo4j - return """ - UNWIND $nodes AS node - MERGE (n:Entity {uuid: node.uuid}) - SET n:$(node.labels) - SET n = node - WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) + save_embedding_query = ( + 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)' + if not has_aoss + else '' + ) + return ( + """ + UNWIND $nodes AS node + MERGE (n:Entity {uuid: node.uuid}) + SET n:$(node.labels) + SET n = node + """ + + save_embedding_query + + """ RETURN n.uuid AS uuid """ + ) def get_entity_node_return_query(provider: GraphProvider) -> str: diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 4c2bbf36..3adef419 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -273,20 +273,6 @@ class EpisodicNode(Node): ) async def save(self, driver: GraphDriver): - if driver.provider == GraphProvider.NEPTUNE: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'episode_content', - [ - { - 'uuid': self.uuid, - 'group_id': self.group_id, - 'source': self.source.value, - 'content': self.content, - 'source_description': self.source_description, - } - ], - ) - episode_args = { 'uuid': self.uuid, 'name': self.name, @@ -299,6 +285,12 @@ class EpisodicNode(Node): 'source': self.source.value, } + if driver.aoss_client: + driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + 'episodes', + [episode_args], + ) + result = await driver.execute_query( get_episode_node_save_query(driver.provider), **episode_args ) @@ -433,6 +425,22 @@ class EntityNode(Node): MATCH (n:Entity {uuid: $uuid}) RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding """ + elif driver.aoss_client: + resp = driver.aoss_client.search( + body={ + 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, + 'size': 1, + }, + index='entities', + routing=self.group_id, + ) + + if resp['hits']['hits']: + self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding'] + return + else: + raise NodeNotFoundError(self.uuid) + else: query: LiteralString = """ MATCH (n:Entity {uuid: $uuid}) @@ -470,11 +478,11 @@ class EntityNode(Node): entity_data.update(self.attributes or {}) labels = ':'.join(self.labels + ['Entity']) - if driver.provider == GraphProvider.NEPTUNE: - driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue + if driver.aoss_client: + driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( - get_entity_node_save_query(driver.provider, labels), + get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), entity_data=entity_data, ) @@ -570,7 +578,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): if driver.provider == GraphProvider.NEPTUNE: driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'community_name', + 'communities', [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], ) result = await driver.execute_query( diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 93cab5ba..f5f2252c 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -54,6 +54,16 @@ class SearchFilters(BaseModel): expired_at: list[list[DateFilter]] | None = Field(default=None) +def cypher_to_opensearch_operator(op: ComparisonOperator) -> str: + mapping = { + ComparisonOperator.greater_than: 'gt', + ComparisonOperator.less_than: 'lt', + ComparisonOperator.greater_than_equal: 'gte', + ComparisonOperator.less_than_equal: 'lte', + } + return mapping.get(op, op.value) + + def node_search_filter_query_constructor( filters: SearchFilters, provider: GraphProvider, @@ -234,3 +244,38 @@ def edge_search_filter_query_constructor( filter_queries.append(expired_at_filter) return filter_queries, filter_params + + +def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: + filters = [{'terms': {'group_id': group_ids}}] + + if search_filters.node_labels: + filters.append({'terms': {'node_labels': search_filters.node_labels}}) + + return filters + + +def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: + filters: list[dict] = [{'terms': {'group_id': group_ids}}] + + if search_filters.edge_types: + filters.append({'terms': {'edge_types': search_filters.edge_types}}) + + for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: + ranges = getattr(search_filters, field) + if ranges: + # OR of ANDs + should_clauses = [] + for and_group in ranges: + and_filters = [] + for df in and_group: # df is a DateFilter + range_query = { + 'range': { + field: {cypher_to_opensearch_operator(df.comparison_operator): df.date} + } + } + and_filters.append(range_query) + should_clauses.append({'bool': {'filter': and_filters}}) + filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}}) + + return filters diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 379662d5..40eeaa7e 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -51,6 +51,8 @@ from graphiti_core.nodes import ( ) from graphiti_core.search.search_filters import ( SearchFilters, + build_aoss_edge_filters, + build_aoss_node_filters, edge_search_filter_query_constructor, node_search_filter_query_constructor, ) @@ -200,7 +202,6 @@ async def edge_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -208,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -244,6 +245,31 @@ async def edge_fulltext_search( ) else: return [] + elif driver.aoss_client: + route = group_ids[0] if group_ids else None + filters = build_aoss_edge_filters(group_ids or [], search_filter) + res = driver.aoss_client.search( + index='entity_edges', + routing=route, + _source=['uuid'], + query={ + 'bool': { + 'filter': filters, + 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], + } + }, + ) + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get edges + entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) + entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_edges + else: + return [] else: query = ( get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider) @@ -318,8 +344,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -377,6 +403,32 @@ async def edge_similarity_search( ) else: return [] + elif driver.aoss_client: + route = group_ids[0] if group_ids else None + filters = build_aoss_edge_filters(group_ids or [], search_filter) + res = driver.aoss_client.search( + index='entity_edges', + routing=route, + _source=['uuid'], + knn={ + 'field': 'fact_embedding', + 'query_vector': search_vector, + 'k': limit, + 'num_candidates': 1000, + }, + query={'bool': {'filter': filters}}, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get edges + entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) + entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_edges + else: query = ( match_query @@ -563,7 +615,6 @@ async def node_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -571,11 +622,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -592,6 +643,41 @@ async def node_fulltext_search( ) else: return [] + elif driver.aoss_client: + route = group_ids[0] if group_ids else None + filters = build_aoss_node_filters(group_ids or [], search_filter) + res = driver.aoss_client.search( + 'entities', + routing=route, + _source=['uuid'], + query={ + 'bool': { + 'filter': filters, + 'must': [ + { + 'multi_match': { + 'query': query, + 'field': ['name', 'summary'], + 'operator': 'or', + } + } + ], + } + }, + limit=limit, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get nodes + entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) + entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entities + else: + return [] else: query = ( get_nodes_query( @@ -648,8 +734,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -678,11 +764,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -700,11 +786,36 @@ async def node_similarity_search( ) else: return [] + elif driver.aoss_client: + route = group_ids[0] if group_ids else None + filters = build_aoss_node_filters(group_ids or [], search_filter) + res = driver.aoss_client.search( + index='entities', + routing=route, + _source=['uuid'], + knn={ + 'field': 'fact_embedding', + 'query_vector': search_vector, + 'k': limit, + 'num_candidates': 1000, + }, + query={'bool': {'filter': filters}}, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get edges + entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) + entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return entity_nodes else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -843,7 +954,6 @@ async def episode_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -852,7 +962,7 @@ async def episode_fulltext_search( query = """ UNWIND $ids as i MATCH (e:Episodic) - WHERE e.uuid=i.id + WHERE e.uuid=i.uuid RETURN e.content AS content, e.created_at AS created_at, @@ -876,6 +986,40 @@ async def episode_fulltext_search( ) else: return [] + elif driver.aoss_client: + route = group_ids[0] if group_ids else None + res = driver.aoss_client.search( + 'episodes', + routing=route, + _source=['uuid'], + query={ + 'bool': { + 'filter': {'terms': group_ids}, + 'must': [ + { + 'multi_match': { + 'query': query, + 'field': ['name', 'content'], + 'operator': 'or', + } + } + ], + } + }, + limit=limit, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get nodes + episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys())) + episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return episodes + else: + return [] else: query = ( get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider) @@ -1003,8 +1147,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1063,8 +1207,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1206,9 +1350,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1253,9 +1397,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1344,9 +1488,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1416,9 +1560,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1454,9 +1598,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 76494800..ff76a930 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx( await tx.run(episodic_edge_query, **edge.model_dump()) else: await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) - await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) + await tx.run( + get_entity_node_save_bulk_query(driver.provider, nodes), + nodes=nodes, + has_aoss=bool(driver.aoss_client), + ) await tx.run( get_episodic_edge_save_bulk_query(driver.provider), episodic_edges=[edge.model_dump() for edge in episodic_edges], ) - await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) + await tx.run( + get_entity_edge_save_bulk_query(driver.provider), + entity_edges=edges, + has_aoss=bool(driver.aoss_client), + ) + + if driver.aoss_client: + driver.save_to_aoss('episodes', episodes) + driver.save_to_aoss('entities', nodes) + driver.save_to_aoss('entity_edges', edges) async def extract_nodes_and_edges_bulk( diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 801f816e..e8bbc86c 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): - if driver.provider == GraphProvider.NEPTUNE: + if driver.aoss_client: await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] return if delete_existing: @@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo range_indices: list[LiteralString] = get_range_indices(driver.provider) - fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) + # Don't create fulltext indices if OpenSearch is being used + if not driver.aoss_client: + fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) if driver.provider == GraphProvider.KUZU: # Skip creating fulltext indices if they already exist. Need to do this manually @@ -149,9 +151,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + query_filter + """ RETURN diff --git a/pyproject.toml b/pyproject.toml index 04061eec..f4c8821b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.20.4" +version = "0.21.0pre1" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, @@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"] kuzu = ["kuzu>=0.11.2"] falkordb = ["falkordb>=1.1.2,<2.0.0"] voyageai = ["voyageai>=0.2.3"] +neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"] sentence-transformers = ["sentence-transformers>=3.2.1"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] dev = [ diff --git a/uv.lock b/uv.lock index e2a3b855..8228d4a8 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.20.4" +version = "0.21.0rc1" source = { editable = "." } dependencies = [ { name = "diskcache" }, @@ -835,6 +835,10 @@ groq = [ kuzu = [ { name = "kuzu" }, ] +neo4j-opensearch = [ + { name = "boto3" }, + { name = "opensearch-py" }, +] neptune = [ { name = "boto3" }, { name = "langchain-aws" }, @@ -851,6 +855,7 @@ voyageai = [ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, + { name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "diskcache", specifier = ">=5.6.3" }, { name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" }, @@ -872,6 +877,7 @@ requires-dist = [ { name = "neo4j", specifier = ">=5.26.0" }, { name = "numpy", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.91.0" }, + { name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" }, { name = "pydantic", specifier = ">=2.11.5" }, @@ -888,7 +894,7 @@ requires-dist = [ { name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" }, ] -provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"] +provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"] [[package]] name = "groq"