Group ID filtering in BFS and full-text queries (#754)

Enhance search functions to include group ID filtering in BFS and full-text queries
This commit is contained in:
Daniel Chalef 2025-07-22 17:02:46 -07:00 committed by GitHub
parent 5bbc3cf814
commit 059a64e5e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 16 deletions

View file

@ -188,7 +188,12 @@ async def edge_search(
config.sim_min_score,
),
edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
),
]
)
@ -198,7 +203,12 @@ async def edge_search(
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
driver,
source_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
@ -281,7 +291,12 @@ async def node_search(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
),
]
)
@ -291,7 +306,12 @@ async def node_search(
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
driver,
origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)

View file

@ -283,7 +283,8 @@ async def edge_bfs_search(
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
search_filter: SearchFilters,
limit: int,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
@ -293,12 +294,13 @@ async def edge_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
AND r.group_id IN $group_ids
"""
+ filter_query
+ """
RETURN DISTINCT
@ -323,6 +325,7 @@ async def edge_bfs_search(
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
group_ids=group_ids,
limit=limit,
routing_='r',
)
@ -431,7 +434,8 @@ async def node_bfs_search(
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
limit: int,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
@ -441,10 +445,11 @@ async def node_bfs_search(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
AND origin.group_id IN $group_ids
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -456,6 +461,7 @@ async def node_bfs_search(
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
group_ids=group_ids,
limit=limit,
routing_='r',
)
@ -482,6 +488,7 @@ async def episode_fulltext_search(
YIELD node AS episode, score
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
AND e.group_id IN $group_ids
RETURN
e.content AS content,
e.created_at AS created_at,
@ -524,6 +531,7 @@ async def community_fulltext_search(
get_nodes_query(driver.provider, 'community_name', '$query')
+ """
YIELD node AS comm, score
WHERE comm.group_id IN $group_ids
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,