fix duplicate search results bug (#190)

* fix bugs

* format

* syntax
This commit is contained in:
Preston Rasmussen 2024-10-14 21:54:33 -04:00 committed by GitHub
parent 7c5135910e
commit 7c15b729a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 30 deletions

View file

@ -258,6 +258,9 @@ class EntityNode(Node):
nodes = [get_entity_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod
@ -351,6 +354,9 @@ class CommunityNode(Node):
nodes = [get_community_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod

View file

@ -53,12 +53,12 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()
query = query.replace('\n', ' ')
@ -107,13 +107,13 @@ async def search(
async def edge_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
return []
@ -160,7 +160,7 @@ async def edge_search(
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [edge.source_node_uuid for edge in sorted_results]
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
@ -176,13 +176,13 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
return []
@ -230,12 +230,12 @@ async def node_search(
async def community_search(
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
driver: AsyncDriver,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
return []

View file

@ -68,7 +68,7 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
)
# performs a hybrid search over edges with mmr reranking
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
EDGE_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
@ -80,7 +80,8 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance,
)
),
limit=30,
)
# performs a hybrid search over edges with episode mention reranking

View file

@ -544,7 +544,7 @@ async def node_distance_reranker(
# rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
# add back in filtered center uuids
# add back in filtered center uuid
filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids