fix duplicate search results bug (#190)

* fix bugs

* format

* syntax
This commit is contained in:
Preston Rasmussen 2024-10-14 21:54:33 -04:00 committed by GitHub
parent 7c5135910e
commit 7c15b729a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 30 deletions

View file

@ -258,6 +258,9 @@ class EntityNode(Node):
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0] return nodes[0]
@classmethod @classmethod
@ -351,6 +354,9 @@ class CommunityNode(Node):
nodes = [get_community_node_from_record(record) for record in records] nodes = [get_community_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0] return nodes[0]
@classmethod @classmethod

View file

@ -53,12 +53,12 @@ logger = logging.getLogger(__name__)
async def search( async def search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
query = query.replace('\n', ' ') query = query.replace('\n', ' ')
@ -107,13 +107,13 @@ async def search(
async def edge_search( async def edge_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: EdgeSearchConfig | None, config: EdgeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
if config is None: if config is None:
return [] return []
@ -160,7 +160,7 @@ async def edge_search(
for edge in sorted_results: for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid) source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [edge.source_node_uuid for edge in sorted_results] source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
@ -176,13 +176,13 @@ async def edge_search(
async def node_search( async def node_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: NodeSearchConfig | None, config: NodeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
if config is None: if config is None:
return [] return []
@ -230,12 +230,12 @@ async def node_search(
async def community_search( async def community_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: CommunitySearchConfig | None, config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
if config is None: if config is None:
return [] return []

View file

@ -68,7 +68,7 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
) )
# performs a hybrid search over edges with mmr reranking # performs a hybrid search over edges with mmr reranking
EDGE_HYBRID_SEARCH_mmr = SearchConfig( EDGE_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig( edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr, reranker=EdgeReranker.mmr,
@ -80,7 +80,8 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig( edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance, reranker=EdgeReranker.node_distance,
) ),
limit=30,
) )
# performs a hybrid search over edges with episode mention reranking # performs a hybrid search over edges with episode mention reranking

View file

@ -544,7 +544,7 @@ async def node_distance_reranker(
# rerank on shortest distance # rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
# add back in filtered center uuids # add back in filtered center uuid
filtered_uuids = [center_node_uuid] + filtered_uuids filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids return filtered_uuids