From 6441be9934ece994ce1ee2b764faaca1f92c424a Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sun, 7 Sep 2025 22:34:15 -0400 Subject: [PATCH] add search filters and similarity search --- graphiti_core/driver/driver.py | 45 ++-- graphiti_core/search/search_filters.py | 31 +++ graphiti_core/search/search_utils.py | 281 +++++++++++++++---------- 3 files changed, 211 insertions(+), 146 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index d5f61913..161cc7a5 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -52,20 +52,16 @@ aoss_indices = [ 'dims': 1024, 'index': True, 'similarity': 'cosine', + 'method': { + 'engine': 'faiss', + 'space_type': 'cosinesimil', + 'name': 'hnsw', + 'parameters': {'ef_construction': 128, 'm': 16}, + }, }, } } }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}}, - 'size': DEFAULT_SIZE, - 'knn': { - 'field': 'name_embedding', - 'query_vector': [], - 'k': DEFAULT_SIZE, - 'num_candidates': 100, - }, - }, }, { 'index_name': 'communities', @@ -78,10 +74,6 @@ aoss_indices = [ } } }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}}, - 'size': DEFAULT_SIZE, - }, }, { 'index_name': 'episodes', @@ -98,15 +90,6 @@ aoss_indices = [ } } }, - 'query': { - 'query': { - 'multi_match': { - 'query': '', - 'fields': ['content', 'source', 'source_description', 'group_id'], - } - }, - 'size': DEFAULT_SIZE, - }, }, { 'index_name': 'entity_edges', @@ -126,20 +109,16 @@ aoss_indices = [ 'dims': 1024, 'index': True, 'similarity': 'cosine', + 'method': { + 'engine': 'faiss', + 'space_type': 'cosinesimil', + 'name': 'hnsw', + 'parameters': {'ef_construction': 128, 'm': 16}, + }, }, } } }, - 'query': { - 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}}, - 'size': DEFAULT_SIZE, - 'knn': { - 'field': 'fact_embedding', - 'query_vector': [], # supply vector at runtime - 'k': DEFAULT_SIZE, - 'num_candidates': 100, - }, - }, }, ] diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 93cab5ba..e354a1a3 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -234,3 +234,34 @@ def edge_search_filter_query_constructor( filter_queries.append(expired_at_filter) return filter_queries, filter_params + + +def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: + filters = [{'term': {'group_id': group_ids}}] + + if search_filters.node_labels: + filters.append({'terms': {'node_labels': search_filters.node_labels}}) + + return filters + + +def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]: + filters = [{'term': {'group_id': group_ids}}] + + if search_filters.edge_types: + filters.append({'terms': {'edge_types': search_filters.edge_types}}) + + for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: + ranges = getattr(search_filters, field) + if ranges: + # OR of ANDs + should_clauses = [] + for and_group in ranges: + and_filters = [] + for df in and_group: # df is a DateFilter + range_query = {'range': {field: {df.op: df.value}}} + and_filters.append(range_query) + should_clauses.append({'bool': {'filter': and_filters}}) + filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}}) + + return filters diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 611fb6d4..b2a0c5f2 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -51,6 +51,8 @@ from graphiti_core.nodes import ( ) from graphiti_core.search.search_filters import ( SearchFilters, + build_aoss_edge_filters, + build_aoss_node_filters, edge_search_filter_query_constructor, node_search_filter_query_constructor, ) @@ -200,7 +202,6 @@ async def edge_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -208,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -245,38 +246,26 @@ async def edge_fulltext_search( else: return [] elif driver.aoss_client: - res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue + filters = build_aoss_edge_filters(group_ids, search_filter) + res = driver.aoss_client.search( + index='entity_edges', + routing=group_ids, + _source=['uuid'], + query={ + 'bool': { + 'filter': filters, + 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], + } + }, + ) if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids - input_uuids = [] + input_uuids = {} for r in res['hits']['hits']: - input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + input_uuids[r['_source']['uuid']] = r['_score'] - # Match the edge ids and return the values - query = ( - """ - UNWIND $uuids as uuid - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND e.uuid=uuid - """ - + filter_query - + """ - AND e.uuid=uuid - WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m""" - + get_entity_edge_return_query(driver.provider) - + """ORDER BY score DESC LIMIT $limit - """ - ) - - records, _, _ = await driver.execute_query( - query, - query=fuzzy_query, - uuids=input_uuids, - limit=limit, - routing_='r', - **filter_params, - ) + # Get edges + entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) + return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) else: return [] else: @@ -353,8 +342,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -412,6 +401,30 @@ async def edge_similarity_search( ) else: return [] + elif driver.aoss_client: + filters = build_aoss_edge_filters(group_ids, search_filter) + res = driver.aoss_client.search( + index='entity_edges', + routing=group_ids, + _source=['uuid'], + knn={ + 'field': 'fact_embedding', + 'query_vector': search_vector, + 'k': limit, + 'num_candidates': 1000, + }, + query={'bool': {'filter': filters}}, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get edges + entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) + return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) + else: query = ( match_query @@ -598,7 +611,6 @@ async def node_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -606,11 +618,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -628,35 +640,36 @@ async def node_fulltext_search( else: return [] elif driver.aoss_client: - res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue - if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids - input_uuids = [] - for r in res['hits']['hits']: - input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + filters = build_aoss_node_filters(group_ids, search_filter) + res = driver.aoss_client.search( + 'entities', + routing=group_ids, + _source=['uuid'], + query={ + 'bool': { + 'filter': filters, + 'must': [ + { + 'multi_match': { + 'query': query, + 'field': ['name', 'summary'], + 'operator': 'or', + } + } + ], + } + }, + limit=limit, + ) - # Match the edge ides and return the values - query = ( - """ - UNWIND $uuids as i - MATCH (n:Entity) - WHERE n.uuid=i.uuid - RETURN - """ - + get_entity_node_return_query(driver.provider) - + """ - ORDER BY i.score DESC - LIMIT $limit - """ - ) - records, _, _ = await driver.execute_query( - query, - uuids=input_uuids, - query=fuzzy_query, - limit=limit, - routing_='r', - **filter_params, - ) + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get nodes + entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) + return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) else: return [] else: @@ -715,8 +728,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -745,11 +758,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -767,11 +780,34 @@ async def node_similarity_search( ) else: return [] + elif driver.aoss_client: + filters = build_aoss_node_filters(group_ids, search_filter) + res = driver.aoss_client.search( + index='entities', + routing=group_ids, + _source=['uuid'], + knn={ + 'field': 'fact_embedding', + 'query_vector': search_vector, + 'k': limit, + 'num_candidates': 1000, + }, + query={'bool': {'filter': filters}}, + ) + + if res['hits']['total']['value'] > 0: + input_uuids = {} + for r in res['hits']['hits']: + input_uuids[r['_source']['uuid']] = r['_score'] + + # Get edges + entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) + return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True) else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -910,7 +946,6 @@ async def episode_fulltext_search( if driver.provider == GraphProvider.NEPTUNE: res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids input_ids = [] for r in res['hits']['hits']: input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) @@ -944,7 +979,27 @@ async def episode_fulltext_search( else: return [] elif driver.aoss_client: - res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue + res = driver.aoss_client.search( + 'episodes', + routing=group_ids, + _source=['uuid'], + query={ + 'bool': { + 'filter': [{'term': {'group_id': group_ids}}], + 'must': [ + { + 'multi_match': { + 'query': query, + 'field': ['name', 'content'], + 'operator': 'or', + } + } + ], + } + }, + limit=limit, + ) + if res['hits']['total']['value'] > 0: # Calculate Cosine similarity then return the edge ids input_uuids = [] @@ -1106,8 +1161,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1166,8 +1221,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1309,9 +1364,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1356,9 +1411,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1447,9 +1502,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1519,9 +1574,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1557,9 +1612,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1632,10 +1687,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1705,10 +1760,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1744,10 +1799,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """