parent
7c5135910e
commit
7c15b729a9
4 changed files with 37 additions and 30 deletions
|
|
@ -258,6 +258,9 @@ class EntityNode(Node):
|
|||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
if len(nodes) == 0:
|
||||
raise NodeNotFoundError(uuid)
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
|
|
@ -351,6 +354,9 @@ class CommunityNode(Node):
|
|||
|
||||
nodes = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
if len(nodes) == 0:
|
||||
raise NodeNotFoundError(uuid)
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -53,12 +53,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
query = query.replace('\n', ' ')
|
||||
|
|
@ -107,13 +107,13 @@ async def search(
|
|||
|
||||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -160,7 +160,7 @@ async def edge_search(
|
|||
for edge in sorted_results:
|
||||
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
|
||||
|
||||
source_uuids = [edge.source_node_uuid for edge in sorted_results]
|
||||
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
||||
|
||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||
|
||||
|
|
@ -176,13 +176,13 @@ async def edge_search(
|
|||
|
||||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
@ -230,12 +230,12 @@ async def node_search(
|
|||
|
||||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|||
)
|
||||
|
||||
# performs a hybrid search over edges with mmr reranking
|
||||
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
|
||||
EDGE_HYBRID_SEARCH_MMR = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.mmr,
|
||||
|
|
@ -80,7 +80,8 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.node_distance,
|
||||
)
|
||||
),
|
||||
limit=30,
|
||||
)
|
||||
|
||||
# performs a hybrid search over edges with episode mention reranking
|
||||
|
|
|
|||
|
|
@ -544,7 +544,7 @@ async def node_distance_reranker(
|
|||
# rerank on shortest distance
|
||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
# add back in filtered center uuids
|
||||
# add back in filtered center uuid
|
||||
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||
|
||||
return filtered_uuids
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue