From 7c15b729a97cdd04cf16c7392b2d0e0bfaea5873 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:54:33 -0400 Subject: [PATCH] fix duplicate search results bug (#190) * fix bugs * format * syntax --- graphiti_core/nodes.py | 6 +++ graphiti_core/search/search.py | 54 +++++++++---------- graphiti_core/search/search_config_recipes.py | 5 +- graphiti_core/search/search_utils.py | 2 +- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 8bd203f7..687bb465 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -258,6 +258,9 @@ class EntityNode(Node): nodes = [get_entity_node_from_record(record) for record in records] + if len(nodes) == 0: + raise NodeNotFoundError(uuid) + return nodes[0] @classmethod @@ -351,6 +354,9 @@ class CommunityNode(Node): nodes = [get_community_node_from_record(record) for record in records] + if len(nodes) == 0: + raise NodeNotFoundError(uuid) + return nodes[0] @classmethod diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index e6e81b18..3b9a47cf 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -53,12 +53,12 @@ logger = logging.getLogger(__name__) async def search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder: EmbedderClient, + query: str, + group_ids: list[str] | None, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() query = query.replace('\n', ' ') @@ -107,13 +107,13 @@ async def search( async def edge_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: EdgeSearchConfig | None, - center_node_uuid: str | None = None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + embedder: EmbedderClient, + query: str, + group_ids: list[str] | None, + config: EdgeSearchConfig | None, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: if config is None: return [] @@ -160,7 +160,7 @@ async def edge_search( for edge in sorted_results: source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid) - source_uuids = [edge.source_node_uuid for edge in sorted_results] + source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) @@ -176,13 +176,13 @@ async def edge_search( async def node_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: NodeSearchConfig | None, - center_node_uuid: str | None = None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + embedder: EmbedderClient, + query: str, + group_ids: list[str] | None, + config: NodeSearchConfig | None, + center_node_uuid: str | None = None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: if config is None: return [] @@ -230,12 +230,12 @@ async def node_search( async def community_search( - driver: AsyncDriver, - embedder: EmbedderClient, - query: str, - group_ids: list[str] | None, - config: CommunitySearchConfig | None, - limit=DEFAULT_SEARCH_LIMIT, + driver: AsyncDriver, + embedder: EmbedderClient, + query: str, + group_ids: list[str] | None, + config: CommunitySearchConfig | None, + limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: if config is None: return [] diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py index 2264cde0..abba4c29 100644 --- a/graphiti_core/search/search_config_recipes.py +++ b/graphiti_core/search/search_config_recipes.py @@ -68,7 +68,7 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig( ) # performs a hybrid search over edges with mmr reranking -EDGE_HYBRID_SEARCH_mmr = SearchConfig( +EDGE_HYBRID_SEARCH_MMR = SearchConfig( edge_config=EdgeSearchConfig( search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], reranker=EdgeReranker.mmr, @@ -80,7 +80,8 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( edge_config=EdgeSearchConfig( search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], reranker=EdgeReranker.node_distance, - ) + ), + limit=30, ) # performs a hybrid search over edges with episode mention reranking diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index c75dabd8..275999cc 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -544,7 +544,7 @@ async def node_distance_reranker( # rerank on shortest distance filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) - # add back in filtered center uuids + # add back in filtered center uuid filtered_uuids = [center_node_uuid] + filtered_uuids return filtered_uuids