parent
7c5135910e
commit
7c15b729a9
4 changed files with 37 additions and 30 deletions
|
|
@ -258,6 +258,9 @@ class EntityNode(Node):
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
raise NodeNotFoundError(uuid)
|
||||||
|
|
||||||
return nodes[0]
|
return nodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -351,6 +354,9 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
nodes = [get_community_node_from_record(record) for record in records]
|
nodes = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
raise NodeNotFoundError(uuid)
|
||||||
|
|
||||||
return nodes[0]
|
return nodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -53,12 +53,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
query = query.replace('\n', ' ')
|
query = query.replace('\n', ' ')
|
||||||
|
|
@ -107,13 +107,13 @@ async def search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_search(
|
async def edge_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: EdgeSearchConfig | None,
|
config: EdgeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -160,7 +160,7 @@ async def edge_search(
|
||||||
for edge in sorted_results:
|
for edge in sorted_results:
|
||||||
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
|
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
|
||||||
|
|
||||||
source_uuids = [edge.source_node_uuid for edge in sorted_results]
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
||||||
|
|
||||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||||
|
|
||||||
|
|
@ -176,13 +176,13 @@ async def edge_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_search(
|
async def node_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: NodeSearchConfig | None,
|
config: NodeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -230,12 +230,12 @@ async def node_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_search(
|
async def community_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: CommunitySearchConfig | None,
|
config: CommunitySearchConfig | None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
# performs a hybrid search over edges with mmr reranking
|
# performs a hybrid search over edges with mmr reranking
|
||||||
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
|
EDGE_HYBRID_SEARCH_MMR = SearchConfig(
|
||||||
edge_config=EdgeSearchConfig(
|
edge_config=EdgeSearchConfig(
|
||||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
reranker=EdgeReranker.mmr,
|
reranker=EdgeReranker.mmr,
|
||||||
|
|
@ -80,7 +80,8 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
edge_config=EdgeSearchConfig(
|
edge_config=EdgeSearchConfig(
|
||||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
reranker=EdgeReranker.node_distance,
|
reranker=EdgeReranker.node_distance,
|
||||||
)
|
),
|
||||||
|
limit=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
# performs a hybrid search over edges with episode mention reranking
|
# performs a hybrid search over edges with episode mention reranking
|
||||||
|
|
|
||||||
|
|
@ -544,7 +544,7 @@ async def node_distance_reranker(
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
# add back in filtered center uuids
|
# add back in filtered center uuid
|
||||||
filtered_uuids = [center_node_uuid] + filtered_uuids
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||||
|
|
||||||
return filtered_uuids
|
return filtered_uuids
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue