Group ID filtering in BFS and full-text queries (#754)

Enhance search functions to include group ID filtering in BFS and full-text queries
This commit is contained in:
Daniel Chalef 2025-07-22 17:02:46 -07:00 committed by GitHub
parent 5bbc3cf814
commit 059a64e5e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 16 deletions

View file

@ -188,7 +188,12 @@ async def edge_search(
config.sim_min_score, config.sim_min_score,
), ),
edge_bfs_search( edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
), ),
] ]
) )
@ -198,7 +203,12 @@ async def edge_search(
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append( search_results.append(
await edge_bfs_search( await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit driver,
source_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
) )
) )
@ -281,7 +291,12 @@ async def node_search(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
), ),
node_bfs_search( node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
), ),
] ]
) )
@ -291,7 +306,12 @@ async def node_search(
origin_node_uuids = [node.uuid for result in search_results for node in result] origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append( search_results.append(
await node_bfs_search( await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit driver,
origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
) )
) )

View file

@ -283,7 +283,8 @@ async def edge_bfs_search(
bfs_origin_node_uuids: list[str] | None, bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int, bfs_max_depth: int,
search_filter: SearchFilters, search_filter: SearchFilters,
limit: int, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
if bfs_origin_node_uuids is None: if bfs_origin_node_uuids is None:
@ -293,12 +294,13 @@ async def edge_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid WHERE r.uuid = rel.uuid
""" AND r.group_id IN $group_ids
"""
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT RETURN DISTINCT
@ -323,6 +325,7 @@ async def edge_bfs_search(
params=filter_params, params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids, bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth, depth=bfs_max_depth,
group_ids=group_ids,
limit=limit, limit=limit,
routing_='r', routing_='r',
) )
@ -431,7 +434,8 @@ async def node_bfs_search(
bfs_origin_node_uuids: list[str] | None, bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters, search_filter: SearchFilters,
bfs_max_depth: int, bfs_max_depth: int,
limit: int, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
if bfs_origin_node_uuids is None: if bfs_origin_node_uuids is None:
@ -441,10 +445,11 @@ async def node_bfs_search(
query = ( query = (
""" """
UNWIND $bfs_origin_node_uuids AS origin_uuid UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id WHERE n.group_id = origin.group_id
""" AND origin.group_id IN $group_ids
"""
+ filter_query + filter_query
+ ENTITY_NODE_RETURN + ENTITY_NODE_RETURN
+ """ + """
@ -456,6 +461,7 @@ async def node_bfs_search(
params=filter_params, params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids, bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth, depth=bfs_max_depth,
group_ids=group_ids,
limit=limit, limit=limit,
routing_='r', routing_='r',
) )
@ -482,6 +488,7 @@ async def episode_fulltext_search(
YIELD node AS episode, score YIELD node AS episode, score
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.uuid = episode.uuid WHERE e.uuid = episode.uuid
AND e.group_id IN $group_ids
RETURN RETURN
e.content AS content, e.content AS content,
e.created_at AS created_at, e.created_at AS created_at,
@ -524,6 +531,7 @@ async def community_fulltext_search(
get_nodes_query(driver.provider, 'community_name', '$query') get_nodes_query(driver.provider, 'community_name', '$query')
+ """ + """
YIELD node AS comm, score YIELD node AS comm, score
WHERE comm.group_id IN $group_ids
RETURN RETURN
comm.uuid AS uuid, comm.uuid AS uuid,
comm.group_id AS group_id, comm.group_id AS group_id,