Group ID filtering in BFS and full-text queries (#754)
Enhance search functions to include group ID filtering in BFS and full-text queries
This commit is contained in:
parent
5bbc3cf814
commit
059a64e5e9
2 changed files with 44 additions and 16 deletions
|
|
@ -188,7 +188,12 @@ async def edge_search(
|
||||||
config.sim_min_score,
|
config.sim_min_score,
|
||||||
),
|
),
|
||||||
edge_bfs_search(
|
edge_bfs_search(
|
||||||
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
|
driver,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
search_filter,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -198,7 +203,12 @@ async def edge_search(
|
||||||
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
||||||
search_results.append(
|
search_results.append(
|
||||||
await edge_bfs_search(
|
await edge_bfs_search(
|
||||||
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
|
driver,
|
||||||
|
source_node_uuids,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
search_filter,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -281,7 +291,12 @@ async def node_search(
|
||||||
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
||||||
),
|
),
|
||||||
node_bfs_search(
|
node_bfs_search(
|
||||||
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
driver,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
|
search_filter,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -291,7 +306,12 @@ async def node_search(
|
||||||
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
||||||
search_results.append(
|
search_results.append(
|
||||||
await node_bfs_search(
|
await node_bfs_search(
|
||||||
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
driver,
|
||||||
|
origin_node_uuids,
|
||||||
|
search_filter,
|
||||||
|
config.bfs_max_depth,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -283,7 +283,8 @@ async def edge_bfs_search(
|
||||||
bfs_origin_node_uuids: list[str] | None,
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
bfs_max_depth: int,
|
bfs_max_depth: int,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
limit: int,
|
group_ids: list[str] | None = None,
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
if bfs_origin_node_uuids is None:
|
if bfs_origin_node_uuids is None:
|
||||||
|
|
@ -293,12 +294,13 @@ async def edge_bfs_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
UNWIND relationships(path) AS rel
|
UNWIND relationships(path) AS rel
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||||
WHERE r.uuid = rel.uuid
|
WHERE r.uuid = rel.uuid
|
||||||
"""
|
AND r.group_id IN $group_ids
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
|
|
@ -323,6 +325,7 @@ async def edge_bfs_search(
|
||||||
params=filter_params,
|
params=filter_params,
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
depth=bfs_max_depth,
|
||||||
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -431,7 +434,8 @@ async def node_bfs_search(
|
||||||
bfs_origin_node_uuids: list[str] | None,
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
bfs_max_depth: int,
|
bfs_max_depth: int,
|
||||||
limit: int,
|
group_ids: list[str] | None = None,
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
if bfs_origin_node_uuids is None:
|
if bfs_origin_node_uuids is None:
|
||||||
|
|
@ -441,10 +445,11 @@ async def node_bfs_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
WHERE n.group_id = origin.group_id
|
WHERE n.group_id = origin.group_id
|
||||||
"""
|
AND origin.group_id IN $group_ids
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -456,6 +461,7 @@ async def node_bfs_search(
|
||||||
params=filter_params,
|
params=filter_params,
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
depth=bfs_max_depth,
|
||||||
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -482,6 +488,7 @@ async def episode_fulltext_search(
|
||||||
YIELD node AS episode, score
|
YIELD node AS episode, score
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
WHERE e.uuid = episode.uuid
|
WHERE e.uuid = episode.uuid
|
||||||
|
AND e.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
e.content AS content,
|
e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
|
|
@ -524,6 +531,7 @@ async def community_fulltext_search(
|
||||||
get_nodes_query(driver.provider, 'community_name', '$query')
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
||||||
+ """
|
+ """
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
|
WHERE comm.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid AS uuid,
|
comm.uuid AS uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue