From 3efe085a9246793383994b5524b6c4341bfb3451 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:43:37 -0400 Subject: [PATCH] OpenSearch updates (#906) * updates * add uuid filter functionality * update * updates * bump-version * update * fix typo * use async function * update unit tests * update delete * update deletion * async update * update * update * update * update --- graphiti_core/driver/driver.py | 140 +++++++---- graphiti_core/driver/neo4j_driver.py | 25 +- graphiti_core/driver/neptune_driver.py | 4 +- graphiti_core/edges.py | 52 +++- graphiti_core/graphiti.py | 26 +- graphiti_core/nodes.py | 108 +++++++- graphiti_core/search/search_filters.py | 8 + graphiti_core/search/search_utils.py | 236 ++++++++++-------- graphiti_core/utils/bulk_utils.py | 15 +- .../utils/maintenance/edge_operations.py | 44 +++- .../maintenance/graph_data_operations.py | 8 +- pyproject.toml | 2 +- uv.lock | 2 +- 13 files changed, 479 insertions(+), 191 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 40ce2540..df748d4e 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -17,16 +17,19 @@ limitations under the License. import asyncio import copy import logging +import os from abc import ABC, abstractmethod from collections.abc import Coroutine from datetime import datetime from enum import Enum from typing import Any +from dotenv import load_dotenv + from graphiti_core.embedder.client import EMBEDDING_DIM try: - from opensearchpy import OpenSearch, helpers + from opensearchpy import AsyncOpenSearch, helpers _HAS_OPENSEARCH = True except ImportError: @@ -38,6 +41,13 @@ logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 +load_dotenv() + +ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities') +EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes') +COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') +ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges') + class GraphProvider(Enum): NEO4J = 'neo4j' @@ -48,20 +58,19 @@ class GraphProvider(Enum): aoss_indices = [ { - 'index_name': 'entities', + 'index_name': ENTITY_INDEX_NAME, 'body': { + 'settings': {'index': {'knn': True}}, 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'summary': {'type': 'text'}, - 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'group_id': {'type': 'keyword'}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'name_embedding': { 'type': 'knn_vector', - 'dims': EMBEDDING_DIM, - 'index': True, - 'similarity': 'cosine', + 'dimension': EMBEDDING_DIM, 'method': { 'engine': 'faiss', 'space_type': 'cosinesimil', @@ -70,23 +79,23 @@ aoss_indices = [ }, }, } - } + }, }, }, { - 'index_name': 'communities', + 'index_name': COMMUNITY_INDEX_NAME, 'body': { 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, - 'group_id': {'type': 'text'}, + 'group_id': {'type': 'keyword'}, } } }, }, { - 'index_name': 'episodes', + 'index_name': EPISODE_INDEX_NAME, 'body': { 'mappings': { 'properties': { @@ -94,31 +103,30 @@ aoss_indices = [ 'content': {'type': 'text'}, 'source': {'type': 'text'}, 'source_description': {'type': 'text'}, - 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'group_id': {'type': 'keyword'}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, } } }, }, { - 'index_name': 'entity_edges', + 'index_name': ENTITY_EDGE_INDEX_NAME, 'body': { + 'settings': {'index': {'knn': True}}, 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'fact': {'type': 'text'}, - 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'group_id': {'type': 'keyword'}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'fact_embedding': { 'type': 'knn_vector', - 'dims': EMBEDDING_DIM, - 'index': True, - 'similarity': 'cosine', + 'dimension': EMBEDDING_DIM, 'method': { 'engine': 'faiss', 'space_type': 'cosinesimil', @@ -127,7 +135,7 @@ aoss_indices = [ }, }, } - } + }, }, }, ] @@ -163,7 +171,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str - aoss_client: OpenSearch | None # type: ignore + aoss_client: AsyncOpenSearch | None # type: ignore @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -205,7 +213,7 @@ class GraphDriver(ABC): alias_name = index['index_name'] # If alias already exists, skip (idempotent behavior) - if client.indices.exists_alias(name=alias_name): + if await client.indices.exists_alias(name=alias_name): continue # Build a physical index name with timestamp @@ -213,27 +221,67 @@ class GraphDriver(ABC): physical_index_name = f'{alias_name}_{ts_suffix}' # Create the index - client.indices.create(index=physical_index_name, body=index['body']) + await client.indices.create(index=physical_index_name, body=index['body']) # Point alias to it - client.indices.put_alias(index=physical_index_name, name=alias_name) + await client.indices.put_alias(index=physical_index_name, name=alias_name) # Allow some time for index creation - await asyncio.sleep(60) + await asyncio.sleep(1) async def delete_aoss_indices(self): + client = self.aoss_client + + if not client: + logger.warning('No OpenSearch client found') + return + + for entry in aoss_indices: + alias_name = entry['index_name'] + + try: + # Resolve alias → indices + alias_info = await client.indices.get_alias(name=alias_name) + indices = list(alias_info.keys()) + + if not indices: + logger.info(f"No indices found for alias '{alias_name}'") + continue + + for index in indices: + if await client.indices.exists(index=index): + await client.indices.delete(index=index) + logger.info(f"Deleted index '{index}' (alias: {alias_name})") + else: + logger.warning(f"Index '{index}' not found for alias '{alias_name}'") + + except Exception as e: + logger.error(f"Error deleting indices for alias '{alias_name}': {e}") + + async def clear_aoss_indices(self): + client = self.aoss_client + + if not client: + logger.warning('No OpenSearch client found') + return + for index in aoss_indices: index_name = index['index_name'] - client = self.aoss_client - if not client: - logger.warning('No OpenSearch client found') - return + if await client.indices.exists(index=index_name): + try: + # Delete all documents but keep the index + response = await client.delete_by_query( + index=index_name, + body={'query': {'match_all': {}}}, + ) + logger.info(f"Cleared index '{index_name}': {response}") + except Exception as e: + logger.error(f"Error clearing index '{index_name}': {e}") + else: + logger.warning(f"Index '{index_name}' does not exist") - if client.indices.exists(index=index_name): - client.indices.delete(index=index_name) - - def save_to_aoss(self, name: str, data: list[dict]) -> int: + async def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client if not client or not helpers: logger.warning('No OpenSearch client found') @@ -243,16 +291,22 @@ class GraphDriver(ABC): if name.lower() == index['index_name']: to_index = [] for d in data: - item = { - '_index': name, - '_routing': d.get('group_id'), # shard routing - } + doc = {} for p in index['body']['mappings']['properties']: if p in d: # protect against missing fields - item[p] = d[p] + doc[p] = d[p] + + item = { + '_index': name, + '_id': d['uuid'], + '_routing': d.get('group_id'), + '_source': doc, + } to_index.append(item) - success, failed = helpers.bulk(client, to_index, stats_only=True) + success, failed = await helpers.async_bulk( + client, to_index, stats_only=True, request_timeout=60 + ) return success if failed == 0 else success diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index f10d176b..e88e1c8c 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -28,7 +28,13 @@ logger = logging.getLogger(__name__) try: import boto3 - from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection + from opensearchpy import ( + AIOHttpConnection, + AsyncOpenSearch, + AWSV4SignerAuth, + Urllib3AWSV4SignerAuth, + Urllib3HttpConnection, + ) _HAS_OPENSEARCH = True except ImportError: @@ -50,6 +56,9 @@ class Neo4jDriver(GraphDriver): database: str = 'neo4j', aoss_host: str | None = None, aoss_port: int | None = None, + aws_profile_name: str | None = None, + aws_region: str | None = None, + aws_service: str | None = None, ): super().__init__() self.client = AsyncGraphDatabase.driver( @@ -61,15 +70,17 @@ class Neo4jDriver(GraphDriver): self.aoss_client = None if aoss_host and aoss_port and boto3 is not None: try: - session = boto3.Session() - self.aoss_client = OpenSearch( # type: ignore + region = aws_region + service = aws_service + credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() + auth = AWSV4SignerAuth(credentials, region or '', service or '') + + self.aoss_client = AsyncOpenSearch( hosts=[{'host': aoss_host, 'port': aoss_port}], - http_auth=Urllib3AWSV4SignerAuth( # type: ignore - session.get_credentials(), session.region_name, 'aoss' - ), + auth=auth, use_ssl=True, verify_certs=True, - connection_class=Urllib3HttpConnection, + connection_class=AIOHttpConnection, pool_maxsize=20, ) # type: ignore except Exception as e: diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index cb343163..355fded1 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver): 'You must provide an AOSS endpoint to create an OpenSearch driver.' ) if not client.indices.exists(index=index_name): - client.indices.create(index=index_name, body=index['body']) + await client.indices.create(index=index_name, body=index['body']) alias_name = index.get('alias_name', index_name) if not client.indices.exists_alias(name=alias_name, index=index_name): - client.indices.put_alias(index=index_name, name=alias_name) + await client.indices.put_alias(index=index_name, name=alias_name) # Sleep for 1 minute to let the index creation complete await asyncio.sleep(60) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index e1c8e44e..165dee53 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -25,7 +25,7 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date @@ -77,6 +77,13 @@ class Edge(BaseModel, ABC): uuid=self.uuid, ) + if driver.aoss_client: + await driver.aoss_client.delete( + index=ENTITY_EDGE_INDEX_NAME, + id=self.uuid, + params={'routing': self.group_id}, + ) + logger.debug(f'Deleted Edge: {self.uuid}') @classmethod @@ -108,6 +115,12 @@ class Edge(BaseModel, ABC): uuids=uuids, ) + if driver.aoss_client: + await driver.aoss_client.delete_by_query( + index=ENTITY_EDGE_INDEX_NAME, + body={'query': {'terms': {'uuid': uuids}}}, + ) + logger.debug(f'Deleted Edges: {uuids}') def __hash__(self): @@ -256,13 +269,13 @@ class EntityEdge(Edge): RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding """ elif driver.aoss_client: - resp = driver.aoss_client.search( + resp = await driver.aoss_client.search( body={ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, }, - index='entity_edges', - routing=self.group_id, + index=ENTITY_EDGE_INDEX_NAME, + params={'routing': self.group_id}, ) if resp['hits']['hits']: @@ -314,7 +327,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) if driver.aoss_client: - driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_edge_save_query(driver.provider), @@ -351,6 +364,35 @@ class EntityEdge(Edge): raise EdgeNotFoundError(uuid) return edges[0] + @classmethod + async def get_between_nodes( + cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str + ): + match_query = """ + MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid}) + """ + if driver.provider == GraphProvider.KUZU: + match_query = """ + MATCH (n:Entity {uuid: $source_node_uuid}) + -[:RELATES_TO]->(e:RelatesToNode_) + -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid}) + """ + + records, _, _ = await driver.execute_query( + match_query + + """ + RETURN + """ + + get_entity_edge_return_query(driver.provider), + source_node_uuid=source_node_uuid, + target_node_uuid=target_node_uuid, + routing_='r', + ) + + edges = [get_entity_edge_from_record(record, driver.provider) for record in records] + + return edges + @classmethod async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): if len(uuids) == 0: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index d217d924..bce0c326 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import ( from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, - get_edge_invalidation_candidates, get_mentioned_nodes, - get_relevant_edges, ) from graphiti_core.telemetry import capture_event from graphiti_core.utils.bulk_utils import ( @@ -1037,10 +1035,28 @@ class Graphiti: updated_edge = resolve_edge_pointers([edge], uuid_map)[0] - related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0] + valid_edges = await EntityEdge.get_between_nodes( + self.driver, edge.source_node_uuid, edge.target_node_uuid + ) + + related_edges = ( + await search( + self.clients, + updated_edge.fact, + group_ids=[updated_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]), + ) + ).edges existing_edges = ( - await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters()) - )[0] + await search( + self.clients, + updated_edge.fact, + group_ids=[updated_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(), + ) + ).edges resolved_edge, invalidated_edges, _ = await resolve_extracted_edge( self.llm_client, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 3adef419..7fafbe4f 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -26,7 +26,14 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ( + COMMUNITY_INDEX_NAME, + ENTITY_EDGE_INDEX_NAME, + ENTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphProvider, +) from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError from graphiti_core.helpers import parse_db_date @@ -94,13 +101,39 @@ class Node(BaseModel, ABC): async def delete(self, driver: GraphDriver): match driver.provider: case GraphProvider.NEO4J: - await driver.execute_query( + records, _, _ = await driver.execute_query( """ - MATCH (n:Entity|Episodic|Community {uuid: $uuid}) + MATCH (n {uuid: $uuid}) + WHERE n:Entity OR n:Episodic OR n:Community + OPTIONAL MATCH (n)-[r]-() + WITH collect(r.uuid) AS edge_uuids, n DETACH DELETE n + RETURN edge_uuids """, uuid=self.uuid, ) + + edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else [] + + if driver.aoss_client: + # Delete the node from OpenSearch indices + for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): + await driver.aoss_client.delete( + index=index, + id=self.uuid, + params={'routing': self.group_id}, + ) + + # Bulk delete the detached edges + if edge_uuids: + actions = [] + for eid in edge_uuids: + actions.append( + {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} + ) + + await driver.aoss_client.bulk(body=actions) + case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -162,6 +195,32 @@ class Node(BaseModel, ABC): group_id=group_id, batch_size=batch_size, ) + + if driver.aoss_client: + await driver.aoss_client.delete_by_query( + index=EPISODE_INDEX_NAME, + body={'query': {'term': {'group_id': group_id}}}, + params={'routing': group_id}, + ) + + await driver.aoss_client.delete_by_query( + index=ENTITY_INDEX_NAME, + body={'query': {'term': {'group_id': group_id}}}, + params={'routing': group_id}, + ) + + await driver.aoss_client.delete_by_query( + index=COMMUNITY_INDEX_NAME, + body={'query': {'term': {'group_id': group_id}}}, + params={'routing': group_id}, + ) + + await driver.aoss_client.delete_by_query( + index=ENTITY_EDGE_INDEX_NAME, + body={'query': {'term': {'group_id': group_id}}}, + params={'routing': group_id}, + ) + case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -240,6 +299,23 @@ class Node(BaseModel, ABC): ) case _: # Neo4J, Neptune async with driver.session() as session: + # Collect all edge UUIDs before deleting nodes + result = await session.run( + """ + MATCH (n:Entity|Episodic|Community) + WHERE n.uuid IN $uuids + MATCH (n)-[r]-() + RETURN collect(r.uuid) AS edge_uuids + """, + uuids=uuids, + ) + + record = await result.single() + edge_uuids: list[str] = ( + record['edge_uuids'] if record and record['edge_uuids'] else [] + ) + + # Now delete the nodes in batches await session.run( """ MATCH (n:Entity|Episodic|Community) @@ -253,6 +329,20 @@ class Node(BaseModel, ABC): batch_size=batch_size, ) + if driver.aoss_client: + for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): + await driver.aoss_client.delete_by_query( + index=index, + body={'query': {'terms': {'uuid': uuids}}}, + ) + + if edge_uuids: + actions = [ + {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} + for eid in edge_uuids + ] + await driver.aoss_client.bulk(body=actions) + @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @@ -286,7 +376,7 @@ class EpisodicNode(Node): } if driver.aoss_client: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue 'episodes', [episode_args], ) @@ -426,13 +516,13 @@ class EntityNode(Node): RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding """ elif driver.aoss_client: - resp = driver.aoss_client.search( + resp = await driver.aoss_client.search( body={ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, }, - index='entities', - routing=self.group_id, + index=ENTITY_INDEX_NAME, + params={'routing': self.group_id}, ) if resp['hits']['hits']: @@ -479,7 +569,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), @@ -577,7 +667,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): if driver.provider == GraphProvider.NEPTUNE: - driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue + await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue 'communities', [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], ) diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index f5f2252c..37cf7e82 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -52,6 +52,7 @@ class SearchFilters(BaseModel): invalid_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None) + edge_uuids: list[str] | None = Field(default=None) def cypher_to_opensearch_operator(op: ComparisonOperator) -> str: @@ -108,6 +109,10 @@ def edge_search_filter_query_constructor( filter_queries.append('e.name in $edge_types') filter_params['edge_types'] = edge_types + if filters.edge_uuids is not None: + filter_queries.append('e.uuid in $edge_uuids') + filter_params['edge_uuids'] = filters.edge_uuids + if filters.node_labels is not None: if provider == GraphProvider.KUZU: node_label_filter = ( @@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) if search_filters.edge_types: filters.append({'terms': {'edge_types': search_filters.edge_types}}) + if search_filters.edge_uuids: + filters.append({'terms': {'uuid': search_filters.edge_uuids}}) + for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: ranges = getattr(search_filters, field) if ranges: diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 40eeaa7e..7aeb5b46 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -23,7 +23,13 @@ import numpy as np from numpy._typing import NDArray from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ( + ENTITY_EDGE_INDEX_NAME, + ENTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphProvider, +) from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.graph_queries import ( get_nodes_query, @@ -209,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -248,17 +254,21 @@ async def edge_fulltext_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( - index='entity_edges', - routing=route, - _source=['uuid'], - query={ - 'bool': { - 'filter': filters, - 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], - } + res = await driver.aoss_client.search( + index=ENTITY_EDGE_INDEX_NAME, + params={'routing': route}, + body={ + 'size': limit, + '_source': ['uuid'], + 'query': { + 'bool': { + 'filter': filters, + 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], + } + }, }, ) + if res['hits']['total']['value'] > 0: input_uuids = {} for r in res['hits']['hits']: @@ -344,8 +354,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -406,17 +416,22 @@ async def edge_similarity_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( - index='entity_edges', - routing=route, - _source=['uuid'], - knn={ - 'field': 'fact_embedding', - 'query_vector': search_vector, - 'k': limit, - 'num_candidates': 1000, + res = await driver.aoss_client.search( + index=ENTITY_EDGE_INDEX_NAME, + params={'routing': route}, + body={ + 'size': limit, + '_source': ['uuid'], + 'query': { + 'knn': { + 'fact_embedding': { + 'vector': list(map(float, search_vector)), + 'k': limit, + 'filter': {'bool': {'filter': filters}}, + } + } + }, }, - query={'bool': {'filter': filters}}, ) if res['hits']['total']['value'] > 0: @@ -428,6 +443,7 @@ async def edge_similarity_search( entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) return entity_edges + return [] else: query = ( @@ -622,11 +638,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -646,25 +662,27 @@ async def node_fulltext_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( - 'entities', - routing=route, - _source=['uuid'], - query={ - 'bool': { - 'filter': filters, - 'must': [ - { - 'multi_match': { - 'query': query, - 'field': ['name', 'summary'], - 'operator': 'or', + res = await driver.aoss_client.search( + index=ENTITY_INDEX_NAME, + params={'routing': route}, + body={ + '_source': ['uuid'], + 'size': limit, + 'query': { + 'bool': { + 'filter': filters, + 'must': [ + { + 'multi_match': { + 'query': query, + 'fields': ['name', 'summary'], + 'operator': 'or', + } } - } - ], - } + ], + } + }, }, - limit=limit, ) if res['hits']['total']['value'] > 0: @@ -734,8 +752,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -764,11 +782,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -789,17 +807,22 @@ async def node_similarity_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) - res = driver.aoss_client.search( - index='entities', - routing=route, - _source=['uuid'], - knn={ - 'field': 'fact_embedding', - 'query_vector': search_vector, - 'k': limit, - 'num_candidates': 1000, + res = await driver.aoss_client.search( + index=ENTITY_INDEX_NAME, + params={'routing': route}, + body={ + 'size': limit, + '_source': ['uuid'], + 'query': { + 'knn': { + 'name_embedding': { + 'vector': list(map(float, search_vector)), + 'k': limit, + 'filter': {'bool': {'filter': filters}}, + } + } + }, }, - query={'bool': {'filter': filters}}, ) if res['hits']['total']['value'] > 0: @@ -811,11 +834,12 @@ async def node_similarity_search( entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) return entity_nodes + return [] else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -988,11 +1012,12 @@ async def episode_fulltext_search( return [] elif driver.aoss_client: route = group_ids[0] if group_ids else None - res = driver.aoss_client.search( - 'episodes', - routing=route, - _source=['uuid'], - query={ + res = await driver.aoss_client.search( + index=EPISODE_INDEX_NAME, + params={'routing': route}, + body={ + 'size': limit, + '_source': ['uuid'], 'bool': { 'filter': {'terms': group_ids}, 'must': [ @@ -1004,9 +1029,8 @@ async def episode_fulltext_search( } } ], - } + }, }, - limit=limit, ) if res['hits']['total']['value'] > 0: @@ -1147,8 +1171,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1207,8 +1231,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1350,9 +1374,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1397,9 +1421,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1488,9 +1512,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1560,9 +1584,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1598,9 +1622,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1673,10 +1697,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1746,10 +1770,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1785,10 +1809,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ff76a930..78397e87 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -23,7 +23,14 @@ import numpy as np from pydantic import BaseModel, Field from typing_extensions import Any -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.driver.driver import ( + ENTITY_EDGE_INDEX_NAME, + ENTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphDriverSession, + GraphProvider, +) from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.embedder import EmbedderClient from graphiti_core.graphiti_types import GraphitiClients @@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx( ) if driver.aoss_client: - driver.save_to_aoss('episodes', episodes) - driver.save_to_aoss('entities', nodes) - driver.save_to_aoss('entity_edges', edges) + await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) + await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) + await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) async def extract_nodes_and_edges_bulk( diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 55cea243..5de8d22a 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_edges import EdgeDuplicate from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts +from graphiti_core.search.search import search +from graphiti_core.search.search_config import SearchResults +from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters -from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges from graphiti_core.utils.datetime_utils import ensure_utc, utc_now logger = logging.getLogger(__name__) @@ -258,12 +260,44 @@ async def resolve_extracted_edges( embedder = clients.embedder await create_entity_edge_embeddings(embedder, extracted_edges) - search_results = await semaphore_gather( - get_relevant_edges(driver, extracted_edges, SearchFilters()), - get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2), + valid_edges_list: list[list[EntityEdge]] = await semaphore_gather( + *[ + EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) + for edge in extracted_edges + ] ) - related_edges_lists, edge_invalidation_candidates = search_results + related_edges_results: list[SearchResults] = await semaphore_gather( + *[ + search( + clients, + extracted_edge.fact, + group_ids=[extracted_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]), + ) + for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True) + ] + ) + + related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results] + + edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather( + *[ + search( + clients, + extracted_edge.fact, + group_ids=[extracted_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(), + ) + for extracted_edge in extracted_edges + ] + ) + + edge_invalidation_candidates: list[list[EntityEdge]] = [ + result.edges for result in edge_invalidation_candidate_results + ] logger.debug( f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}' diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index e8bbc86c..f62fc6a2 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async def delete_all(tx): await tx.run('MATCH (n) DETACH DELETE n') + if driver.aoss_client: + await driver.clear_aoss_indices() async def delete_group_ids(tx): labels = ['Entity', 'Episodic', 'Community'] @@ -151,9 +153,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + query_filter + """ RETURN diff --git a/pyproject.toml b/pyproject.toml index f4c8821b..c86f90e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.21.0pre1" +version = "0.21.0pre2" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index 8228d4a8..bad253b8 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.21.0rc1" +version = "0.21.0rc2" source = { editable = "." } dependencies = [ { name = "diskcache" },