From 604e3199a35ea4b96aba0ef3dbd0ff8d50cd71fa Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:34:37 -0400 Subject: [PATCH] add search and graph operations interfaces (#984) * add search and graph operations interfaces * update * update * update * update * update * update --- graphiti_core/driver/driver.py | 12 +- .../driver/graph_operations/__init__.py | 0 .../graph_operations/graph_operations.py | 195 +++++++++++ .../driver/search_interface/__init__.py | 0 .../search_interface/search_interface.py | 89 +++++ graphiti_core/edges.py | 45 +-- graphiti_core/nodes.py | 119 ++----- graphiti_core/search/search_filters.py | 38 --- graphiti_core/search/search_utils.py | 304 +++++------------- graphiti_core/utils/bulk_utils.py | 44 +-- .../maintenance/graph_data_operations.py | 15 +- uv.lock | 2 +- 12 files changed, 430 insertions(+), 433 deletions(-) create mode 100644 graphiti_core/driver/graph_operations/__init__.py create mode 100644 graphiti_core/driver/graph_operations/graph_operations.py create mode 100644 graphiti_core/driver/search_interface/__init__.py create mode 100644 graphiti_core/driver/search_interface/search_interface.py diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 4605bf45..5b4e0fc3 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -24,6 +24,9 @@ from typing import Any from dotenv import load_dotenv +from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface +from graphiti_core.driver.search_interface.search_interface import SearchInterface + logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 @@ -73,7 +76,8 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str - aoss_client: Any # type: ignore + search_interface: SearchInterface | None = None + graph_operations_interface: GraphOperationsInterface | None = None @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -109,9 +113,3 @@ class GraphDriver(ABC): Only implemented by providers that need custom fulltext query building. """ raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}') - - async def save_to_aoss(self, name: str, data: list[dict]) -> int: - return 0 - - async def clear_aoss_indices(self): - return 1 diff --git a/graphiti_core/driver/graph_operations/__init__.py b/graphiti_core/driver/graph_operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphiti_core/driver/graph_operations/graph_operations.py b/graphiti_core/driver/graph_operations/graph_operations.py new file mode 100644 index 00000000..e4887923 --- /dev/null +++ b/graphiti_core/driver/graph_operations/graph_operations.py @@ -0,0 +1,195 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any + +from pydantic import BaseModel + + +class GraphOperationsInterface(BaseModel): + """ + Interface for updating graph mutation behavior. + """ + + # ----------------- + # Node: Save/Delete + # ----------------- + + async def node_save(self, node: Any, driver: Any) -> None: + """Persist (create or update) a single node.""" + raise NotImplementedError + + async def node_delete(self, node: Any, driver: Any) -> None: + raise NotImplementedError + + async def node_save_bulk( + self, + _cls: Any, # kept for parity; callers won't pass it + driver: Any, + transaction: Any, + nodes: list[Any], + batch_size: int = 100, + ) -> None: + """Persist (create or update) many nodes in batches.""" + raise NotImplementedError + + async def node_delete_by_group_id( + self, + _cls: Any, + driver: Any, + group_id: str, + batch_size: int = 100, + ) -> None: + raise NotImplementedError + + async def node_delete_by_uuids( + self, + _cls: Any, + driver: Any, + uuids: list[str], + group_id: str | None = None, + batch_size: int = 100, + ) -> None: + raise NotImplementedError + + # -------------------------- + # Node: Embeddings (load) + # -------------------------- + + async def node_load_embeddings(self, node: Any, driver: Any) -> None: + """ + Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar). + """ + raise NotImplementedError + + async def node_load_embeddings_bulk( + self, + _cls: Any, + driver: Any, + transaction: Any, + nodes: list[Any], + batch_size: int = 100, + ) -> None: + """ + Load embedding vectors for many nodes in batches. Mutates the provided node instances. + """ + raise NotImplementedError + + # -------------------------- + # EpisodicNode: Save/Delete + # -------------------------- + + async def episodic_node_save(self, node: Any, driver: Any) -> None: + """Persist (create or update) a single episodic node.""" + raise NotImplementedError + + async def episodic_node_delete(self, node: Any, driver: Any) -> None: + raise NotImplementedError + + async def episodic_node_save_bulk( + self, + _cls: Any, + driver: Any, + transaction: Any, + nodes: list[Any], + batch_size: int = 100, + ) -> None: + """Persist (create or update) many episodic nodes in batches.""" + raise NotImplementedError + + async def episodic_edge_save_bulk( + self, + _cls: Any, + driver: Any, + transaction: Any, + episodic_edges: list[Any], + batch_size: int = 100, + ) -> None: + """Persist (create or update) many episodic edges in batches.""" + raise NotImplementedError + + async def episodic_node_delete_by_group_id( + self, + _cls: Any, + driver: Any, + group_id: str, + batch_size: int = 100, + ) -> None: + raise NotImplementedError + + async def episodic_node_delete_by_uuids( + self, + _cls: Any, + driver: Any, + uuids: list[str], + group_id: str | None = None, + batch_size: int = 100, + ) -> None: + raise NotImplementedError + + # ----------------- + # Edge: Save/Delete + # ----------------- + + async def edge_save(self, edge: Any, driver: Any) -> None: + """Persist (create or update) a single edge.""" + raise NotImplementedError + + async def edge_delete(self, edge: Any, driver: Any) -> None: + raise NotImplementedError + + async def edge_save_bulk( + self, + _cls: Any, + driver: Any, + transaction: Any, + edges: list[Any], + batch_size: int = 100, + ) -> None: + """Persist (create or update) many edges in batches.""" + raise NotImplementedError + + async def edge_delete_by_uuids( + self, + _cls: Any, + driver: Any, + uuids: list[str], + group_id: str | None = None, + ) -> None: + raise NotImplementedError + + # ----------------- + # Edge: Embeddings (load) + # ----------------- + + async def edge_load_embeddings(self, edge: Any, driver: Any) -> None: + """ + Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar). + """ + raise NotImplementedError + + async def edge_load_embeddings_bulk( + self, + _cls: Any, + driver: Any, + transaction: Any, + edges: list[Any], + batch_size: int = 100, + ) -> None: + """ + Load embedding vectors for many edges in batches. Mutates the provided edge instances. + """ + raise NotImplementedError diff --git a/graphiti_core/driver/search_interface/__init__.py b/graphiti_core/driver/search_interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphiti_core/driver/search_interface/search_interface.py b/graphiti_core/driver/search_interface/search_interface.py new file mode 100644 index 00000000..0abf024d --- /dev/null +++ b/graphiti_core/driver/search_interface/search_interface.py @@ -0,0 +1,89 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any + +from pydantic import BaseModel + + +class SearchInterface(BaseModel): + """ + This is an interface for implementing custom search logic + """ + + async def edge_fulltext_search( + self, + driver: Any, + query: str, + search_filter: Any, + group_ids: list[str] | None = None, + limit: int = 100, + ) -> list[Any]: + raise NotImplementedError + + async def edge_similarity_search( + self, + driver: Any, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + search_filter: Any, + group_ids: list[str] | None = None, + limit: int = 100, + min_score: float = 0.7, + ) -> list[Any]: + raise NotImplementedError + + async def node_fulltext_search( + self, + driver: Any, + query: str, + search_filter: Any, + group_ids: list[str] | None = None, + limit: int = 100, + ) -> list[Any]: + raise NotImplementedError + + async def node_similarity_search( + self, + driver: Any, + search_vector: list[float], + search_filter: Any, + group_ids: list[str] | None = None, + limit: int = 100, + min_score: float = 0.7, + ) -> list[Any]: + raise NotImplementedError + + async def episode_fulltext_search( + self, + driver: Any, + query: str, + search_filter: Any, # kept for parity even if unused in your impl + group_ids: list[str] | None = None, + limit: int = 100, + ) -> list[Any]: + raise NotImplementedError + + # ---------- SEARCH FILTERS (sync) ---------- + def build_node_search_filters(self, search_filters: Any) -> Any: + raise NotImplementedError + + def build_edge_search_filters(self, search_filters: Any) -> Any: + raise NotImplementedError + + class Config: + arbitrary_types_allowed = True diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 88d2a472..82066e73 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -25,7 +25,7 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider +from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date @@ -53,6 +53,9 @@ class Edge(BaseModel, ABC): async def save(self, driver: GraphDriver): ... async def delete(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.edge_delete(self, driver) + if driver.provider == GraphProvider.KUZU: await driver.execute_query( """ @@ -77,17 +80,13 @@ class Edge(BaseModel, ABC): uuid=self.uuid, ) - if driver.aoss_client: - await driver.aoss_client.delete( - index=ENTITY_EDGE_INDEX_NAME, - id=self.uuid, - params={'routing': self.group_id}, - ) - logger.debug(f'Deleted Edge: {self.uuid}') @classmethod async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids) + if driver.provider == GraphProvider.KUZU: await driver.execute_query( """ @@ -115,12 +114,6 @@ class Edge(BaseModel, ABC): uuids=uuids, ) - if driver.aoss_client: - await driver.aoss_client.delete_by_query( - index=ENTITY_EDGE_INDEX_NAME, - body={'query': {'terms': {'uuid': uuids}}}, - ) - logger.debug(f'Deleted Edges: {uuids}') def __hash__(self): @@ -258,6 +251,9 @@ class EntityEdge(Edge): return self.fact_embedding async def load_fact_embedding(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.edge_load_embeddings(self, driver) + query = """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN e.fact_embedding AS fact_embedding @@ -268,21 +264,6 @@ class EntityEdge(Edge): MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding """ - elif driver.aoss_client: - resp = await driver.aoss_client.search( - body={ - 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, - 'size': 1, - }, - index=ENTITY_EDGE_INDEX_NAME, - params={'routing': self.group_id}, - ) - - if resp['hits']['hits']: - self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding'] - return - else: - raise EdgeNotFoundError(self.uuid) if driver.provider == GraphProvider.KUZU: query = """ @@ -320,15 +301,11 @@ class EntityEdge(Edge): if driver.provider == GraphProvider.KUZU: edge_data['attributes'] = json.dumps(self.attributes) result = await driver.execute_query( - get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)), + get_entity_edge_save_query(driver.provider), **edge_data, ) else: edge_data.update(self.attributes or {}) - - if driver.aoss_client: - await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue - result = await driver.execute_query( get_entity_edge_save_query(driver.provider), edge_data=edge_data, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 4105c88e..df764ba9 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -27,10 +27,6 @@ from pydantic import BaseModel, Field from typing_extensions import LiteralString from graphiti_core.driver.driver import ( - COMMUNITY_INDEX_NAME, - ENTITY_EDGE_INDEX_NAME, - ENTITY_INDEX_NAME, - EPISODE_INDEX_NAME, GraphDriver, GraphProvider, ) @@ -99,6 +95,9 @@ class Node(BaseModel, ABC): async def save(self, driver: GraphDriver): ... async def delete(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_delete(self, driver) + match driver.provider: case GraphProvider.NEO4J: records, _, _ = await driver.execute_query( @@ -113,27 +112,6 @@ class Node(BaseModel, ABC): uuid=self.uuid, ) - edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else [] - - if driver.aoss_client: - # Delete the node from OpenSearch indices - for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): - await driver.aoss_client.delete( - index=index, - id=self.uuid, - params={'routing': self.group_id}, - ) - - # Bulk delete the detached edges - if edge_uuids: - actions = [] - for eid in edge_uuids: - actions.append( - {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} - ) - - await driver.aoss_client.bulk(body=actions) - case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -181,6 +159,11 @@ class Node(BaseModel, ABC): @classmethod async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_delete_by_group_id( + cls, driver, group_id, batch_size + ) + match driver.provider: case GraphProvider.NEO4J: async with driver.session() as session: @@ -196,31 +179,6 @@ class Node(BaseModel, ABC): batch_size=batch_size, ) - if driver.aoss_client: - await driver.aoss_client.delete_by_query( - index=EPISODE_INDEX_NAME, - body={'query': {'term': {'group_id': group_id}}}, - params={'routing': group_id}, - ) - - await driver.aoss_client.delete_by_query( - index=ENTITY_INDEX_NAME, - body={'query': {'term': {'group_id': group_id}}}, - params={'routing': group_id}, - ) - - await driver.aoss_client.delete_by_query( - index=COMMUNITY_INDEX_NAME, - body={'query': {'term': {'group_id': group_id}}}, - params={'routing': group_id}, - ) - - await driver.aoss_client.delete_by_query( - index=ENTITY_EDGE_INDEX_NAME, - body={'query': {'term': {'group_id': group_id}}}, - params={'routing': group_id}, - ) - case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -258,6 +216,11 @@ class Node(BaseModel, ABC): @classmethod async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_delete_by_uuids( + cls, driver, uuids, group_id=None, batch_size=batch_size + ) + match driver.provider: case GraphProvider.FALKORDB: for label in ['Entity', 'Episodic', 'Community']: @@ -300,7 +263,7 @@ class Node(BaseModel, ABC): case _: # Neo4J, Neptune async with driver.session() as session: # Collect all edge UUIDs before deleting nodes - result = await session.run( + await session.run( """ MATCH (n:Entity|Episodic|Community) WHERE n.uuid IN $uuids @@ -310,11 +273,6 @@ class Node(BaseModel, ABC): uuids=uuids, ) - record = await result.single() - edge_uuids: list[str] = ( - record['edge_uuids'] if record and record['edge_uuids'] else [] - ) - # Now delete the nodes in batches await session.run( """ @@ -329,20 +287,6 @@ class Node(BaseModel, ABC): batch_size=batch_size, ) - if driver.aoss_client: - for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): - await driver.aoss_client.delete_by_query( - index=index, - body={'query': {'terms': {'uuid': uuids}}}, - ) - - if edge_uuids: - actions = [ - {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} - for eid in edge_uuids - ] - await driver.aoss_client.bulk(body=actions) - @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @@ -363,6 +307,9 @@ class EpisodicNode(Node): ) async def save(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.episodic_node_save(self, driver) + episode_args = { 'uuid': self.uuid, 'name': self.name, @@ -375,12 +322,6 @@ class EpisodicNode(Node): 'source': self.source.value, } - if driver.aoss_client: - await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue - 'episodes', - [episode_args], - ) - result = await driver.execute_query( get_episode_node_save_query(driver.provider), **episode_args ) @@ -510,26 +451,14 @@ class EntityNode(Node): return self.name_embedding async def load_name_embedding(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_load_embeddings(self, driver) + if driver.provider == GraphProvider.NEPTUNE: query: LiteralString = """ MATCH (n:Entity {uuid: $uuid}) RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding """ - elif driver.aoss_client: - resp = await driver.aoss_client.search( - body={ - 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, - 'size': 1, - }, - index=ENTITY_INDEX_NAME, - params={'routing': self.group_id}, - ) - - if resp['hits']['hits']: - self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding'] - return - else: - raise NodeNotFoundError(self.uuid) else: query: LiteralString = """ @@ -548,6 +477,9 @@ class EntityNode(Node): self.name_embedding = records[0]['name_embedding'] async def save(self, driver: GraphDriver): + if driver.graph_operations_interface: + return await driver.graph_operations_interface.node_save(self, driver) + entity_data: dict[str, Any] = { 'uuid': self.uuid, 'name': self.name, @@ -568,11 +500,8 @@ class EntityNode(Node): entity_data.update(self.attributes or {}) labels = ':'.join(self.labels + ['Entity']) - if driver.aoss_client: - await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue - result = await driver.execute_query( - get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), + get_entity_node_save_query(driver.provider, labels), entity_data=entity_data, ) diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 37cf7e82..1534b926 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -249,41 +249,3 @@ def edge_search_filter_query_constructor( filter_queries.append(expired_at_filter) return filter_queries, filter_params - - -def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: - filters = [{'terms': {'group_id': group_ids}}] - - if search_filters.node_labels: - filters.append({'terms': {'node_labels': search_filters.node_labels}}) - - return filters - - -def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: - filters: list[dict] = [{'terms': {'group_id': group_ids}}] - - if search_filters.edge_types: - filters.append({'terms': {'edge_types': search_filters.edge_types}}) - - if search_filters.edge_uuids: - filters.append({'terms': {'uuid': search_filters.edge_uuids}}) - - for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: - ranges = getattr(search_filters, field) - if ranges: - # OR of ANDs - should_clauses = [] - for and_group in ranges: - and_filters = [] - for df in and_group: # df is a DateFilter - range_query = { - 'range': { - field: {cypher_to_opensearch_operator(df.comparison_operator): df.date} - } - } - and_filters.append(range_query) - should_clauses.append({'bool': {'filter': and_filters}}) - filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}}) - - return filters diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 6d70c99c..104aede6 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -24,9 +24,6 @@ from numpy._typing import NDArray from typing_extensions import LiteralString from graphiti_core.driver.driver import ( - ENTITY_EDGE_INDEX_NAME, - ENTITY_INDEX_NAME, - EPISODE_INDEX_NAME, GraphDriver, GraphProvider, ) @@ -57,8 +54,6 @@ from graphiti_core.nodes import ( ) from graphiti_core.search.search_filters import ( SearchFilters, - build_aoss_edge_filters, - build_aoss_node_filters, edge_search_filter_query_constructor, node_search_filter_query_constructor, ) @@ -179,6 +174,11 @@ async def edge_fulltext_search( group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: + if driver.search_interface: + return await driver.search_interface.edge_fulltext_search( + driver, query, search_filter, group_ids, limit + ) + # fulltext search over facts fuzzy_query = fulltext_query(query, group_ids, driver) @@ -217,11 +217,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -253,35 +253,6 @@ async def edge_fulltext_search( ) else: return [] - elif driver.aoss_client: - route = group_ids[0] if group_ids else None - filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = await driver.aoss_client.search( - index=ENTITY_EDGE_INDEX_NAME, - params={'routing': route}, - body={ - 'size': limit, - '_source': ['uuid'], - 'query': { - 'bool': { - 'filter': filters, - 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], - } - }, - }, - ) - - if res['hits']['total']['value'] > 0: - input_uuids = {} - for r in res['hits']['hits']: - input_uuids[r['_source']['uuid']] = r['_score'] - - # Get edges - entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) - entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) - return entity_edges - else: - return [] else: query = ( get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider) @@ -321,6 +292,18 @@ async def edge_similarity_search( limit: int = RELEVANT_SCHEMA_LIMIT, min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityEdge]: + if driver.search_interface: + return await driver.search_interface.edge_similarity_search( + driver, + search_vector, + source_node_uuid, + target_node_uuid, + search_filter, + group_ids, + limit, + min_score, + ) + match_query = """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) """ @@ -356,8 +339,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -415,38 +398,6 @@ async def edge_similarity_search( ) else: return [] - elif driver.aoss_client: - route = group_ids[0] if group_ids else None - filters = build_aoss_edge_filters(group_ids or [], search_filter) - res = await driver.aoss_client.search( - index=ENTITY_EDGE_INDEX_NAME, - params={'routing': route}, - body={ - 'size': limit, - '_source': ['uuid'], - 'query': { - 'knn': { - 'fact_embedding': { - 'vector': list(map(float, search_vector)), - 'k': limit, - 'filter': {'bool': {'filter': filters}}, - } - } - }, - }, - ) - - if res['hits']['total']['value'] > 0: - input_uuids = {} - for r in res['hits']['hits']: - input_uuids[r['_source']['uuid']] = r['_score'] - - # Get edges - entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) - entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) - return entity_edges - return [] - else: query = ( match_query @@ -609,6 +560,11 @@ async def node_fulltext_search( group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: + if driver.search_interface: + return await driver.search_interface.node_fulltext_search( + driver, query, search_filter, group_ids, limit + ) + # BM25 search to get top nodes fuzzy_query = fulltext_query(query, group_ids, driver) if fuzzy_query == '': @@ -640,11 +596,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -661,43 +617,6 @@ async def node_fulltext_search( ) else: return [] - elif driver.aoss_client: - route = group_ids[0] if group_ids else None - filters = build_aoss_node_filters(group_ids or [], search_filter) - res = await driver.aoss_client.search( - index=ENTITY_INDEX_NAME, - params={'routing': route}, - body={ - '_source': ['uuid'], - 'size': limit, - 'query': { - 'bool': { - 'filter': filters, - 'must': [ - { - 'multi_match': { - 'query': query, - 'fields': ['name', 'summary'], - 'operator': 'or', - } - } - ], - } - }, - }, - ) - - if res['hits']['total']['value'] > 0: - input_uuids = {} - for r in res['hits']['hits']: - input_uuids[r['_source']['uuid']] = r['_score'] - - # Get nodes - entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) - entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) - return entities - else: - return [] else: query = ( get_nodes_query( @@ -735,6 +654,11 @@ async def node_similarity_search( limit=RELEVANT_SCHEMA_LIMIT, min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: + if driver.search_interface: + return await driver.search_interface.node_similarity_search( + driver, search_vector, search_filter, group_ids, limit, min_score + ) + filter_queries, filter_params = node_search_filter_query_constructor( search_filter, driver.provider ) @@ -754,8 +678,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -784,11 +708,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -806,42 +730,11 @@ async def node_similarity_search( ) else: return [] - elif driver.aoss_client: - route = group_ids[0] if group_ids else None - filters = build_aoss_node_filters(group_ids or [], search_filter) - res = await driver.aoss_client.search( - index=ENTITY_INDEX_NAME, - params={'routing': route}, - body={ - 'size': limit, - '_source': ['uuid'], - 'query': { - 'knn': { - 'name_embedding': { - 'vector': list(map(float, search_vector)), - 'k': limit, - 'filter': {'bool': {'filter': filters}}, - } - } - }, - }, - ) - - if res['hits']['total']['value'] > 0: - input_uuids = {} - for r in res['hits']['hits']: - input_uuids[r['_source']['uuid']] = r['_score'] - - # Get edges - entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) - entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) - return entity_nodes - return [] else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -966,6 +859,11 @@ async def episode_fulltext_search( group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EpisodicNode]: + if driver.search_interface: + return await driver.search_interface.episode_fulltext_search( + driver, query, _search_filter, group_ids, limit + ) + # BM25 search to get top episodes fuzzy_query = fulltext_query(query, group_ids, driver) if fuzzy_query == '': @@ -1012,40 +910,6 @@ async def episode_fulltext_search( ) else: return [] - elif driver.aoss_client: - route = group_ids[0] if group_ids else None - res = await driver.aoss_client.search( - index=EPISODE_INDEX_NAME, - params={'routing': route}, - body={ - 'size': limit, - '_source': ['uuid'], - 'bool': { - 'filter': {'terms': group_ids}, - 'must': [ - { - 'multi_match': { - 'query': query, - 'field': ['name', 'content'], - 'operator': 'or', - } - } - ], - }, - }, - ) - - if res['hits']['total']['value'] > 0: - input_uuids = {} - for r in res['hits']['hits']: - input_uuids[r['_source']['uuid']] = r['_score'] - - # Get nodes - episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys())) - episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) - return episodes - else: - return [] else: query = ( get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider) @@ -1173,8 +1037,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1233,8 +1097,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1376,9 +1240,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1423,9 +1287,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1514,9 +1378,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1586,9 +1450,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1624,9 +1488,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1699,10 +1563,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1772,10 +1636,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1811,10 +1675,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index dfbcb109..049aa53e 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -24,9 +24,6 @@ from pydantic import BaseModel, Field from typing_extensions import Any from graphiti_core.driver.driver import ( - ENTITY_EDGE_INDEX_NAME, - ENTITY_INDEX_NAME, - EPISODE_INDEX_NAME, GraphDriver, GraphDriverSession, GraphProvider, @@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx( 'group_id': node.group_id, 'summary': node.summary, 'created_at': node.created_at, + 'name_embedding': node.name_embedding, + 'labels': list(set(node.labels + ['Entity'])), } - if not bool(driver.aoss_client): - entity_data['name_embedding'] = node.name_embedding - - entity_data['labels'] = list(set(node.labels + ['Entity'])) if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} entity_data['attributes'] = json.dumps(attributes) @@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx( 'expired_at': edge.expired_at, 'valid_at': edge.valid_at, 'invalid_at': edge.invalid_at, + 'fact_embedding': edge.fact_embedding, } - if not bool(driver.aoss_client): - edge_data['fact_embedding'] = edge.fact_embedding - if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) @@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx( edges.append(edge_data) - if driver.provider == GraphProvider.KUZU: + if driver.graph_operations_interface: + await driver.graph_operations_interface.episodic_node_save_bulk( + None, driver, tx, episodic_nodes + ) + await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes) + await driver.graph_operations_interface.episodic_edge_save_bulk( + None, driver, tx, episodic_edges + ) + await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges) + + elif driver.provider == GraphProvider.KUZU: # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now. episode_query = get_episode_node_save_bulk_query(driver.provider) for episode in episodes: @@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx( else: await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run( - get_entity_node_save_bulk_query( - driver.provider, nodes, has_aoss=bool(driver.aoss_client) - ), + get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes, ) await tx.run( @@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx( episodic_edges=[edge.model_dump() for edge in episodic_edges], ) await tx.run( - get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)), + get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges, ) - if bool(driver.aoss_client): - for node_data, entity_node in zip(nodes, entity_nodes, strict=True): - if node_data.get('uuid') == entity_node.uuid: - node_data['name_embedding'] = entity_node.name_embedding - - for edge_data, entity_edge in zip(edges, entity_edges, strict=True): - if edge_data.get('uuid') == entity_edge.uuid: - edge_data['fact_embedding'] = entity_edge.fact_embedding - - await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) - await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) - await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) - async def extract_nodes_and_edges_bulk( clients: GraphitiClients, diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index f62fc6a2..e9aa3c8b 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -34,9 +34,6 @@ logger = logging.getLogger(__name__) async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): - if driver.aoss_client: - await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] - return if delete_existing: records, _, _ = await driver.execute_query( """ @@ -56,8 +53,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo range_indices: list[LiteralString] = get_range_indices(driver.provider) - # Don't create fulltext indices if OpenSearch is being used - if not driver.aoss_client: + # Don't create fulltext indices if search_interface is being used + if not driver.search_interface: fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) if driver.provider == GraphProvider.KUZU: @@ -95,8 +92,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async def delete_all(tx): await tx.run('MATCH (n) DETACH DELETE n') - if driver.aoss_client: - await driver.clear_aoss_indices() async def delete_group_ids(tx): labels = ['Entity', 'Episodic', 'Community'] @@ -153,9 +148,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + query_filter + """ RETURN diff --git a/uv.lock b/uv.lock index 7ba9bde7..c72b0c11 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <4" resolution-markers = [ "python_full_version >= '3.14'",