diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index f68df779..394fd528 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -188,7 +188,12 @@ async def edge_search( config.sim_min_score, ), edge_bfs_search( - driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit + driver, + bfs_origin_node_uuids, + config.bfs_max_depth, + search_filter, + group_ids, + 2 * limit, ), ] ) @@ -198,7 +203,12 @@ async def edge_search( source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] search_results.append( await edge_bfs_search( - driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit + driver, + source_node_uuids, + config.bfs_max_depth, + search_filter, + group_ids, + 2 * limit, ) ) @@ -281,7 +291,12 @@ async def node_search( driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score ), node_bfs_search( - driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit + driver, + bfs_origin_node_uuids, + search_filter, + config.bfs_max_depth, + group_ids, + 2 * limit, ), ] ) @@ -291,7 +306,12 @@ async def node_search( origin_node_uuids = [node.uuid for result in search_results for node in result] search_results.append( await node_bfs_search( - driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit + driver, + origin_node_uuids, + search_filter, + config.bfs_max_depth, + group_ids, + 2 * limit, ) ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index cc23efc1..f689dd86 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -283,7 +283,8 @@ async def edge_bfs_search( bfs_origin_node_uuids: list[str] | None, bfs_max_depth: int, search_filter: SearchFilters, - limit: int, + group_ids: list[str] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # vector similarity search over embedded facts if bfs_origin_node_uuids is None: @@ -293,12 +294,13 @@ async def edge_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + WHERE r.uuid = rel.uuid + AND r.group_id IN $group_ids + """ + filter_query + """ RETURN DISTINCT @@ -323,6 +325,7 @@ async def edge_bfs_search( params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, + group_ids=group_ids, limit=limit, routing_='r', ) @@ -431,7 +434,8 @@ async def node_bfs_search( bfs_origin_node_uuids: list[str] | None, search_filter: SearchFilters, bfs_max_depth: int, - limit: int, + group_ids: list[str] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # vector similarity search over entity names if bfs_origin_node_uuids is None: @@ -441,10 +445,11 @@ async def node_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + WHERE n.group_id = origin.group_id + AND origin.group_id IN $group_ids + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -456,6 +461,7 @@ async def node_bfs_search( params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, + group_ids=group_ids, limit=limit, routing_='r', ) @@ -482,6 +488,7 @@ async def episode_fulltext_search( YIELD node AS episode, score MATCH (e:Episodic) WHERE e.uuid = episode.uuid + AND e.group_id IN $group_ids RETURN e.content AS content, e.created_at AS created_at, @@ -524,6 +531,7 @@ async def community_fulltext_search( get_nodes_query(driver.provider, 'community_name', '$query') + """ YIELD node AS comm, score + WHERE comm.group_id IN $group_ids RETURN comm.uuid AS uuid, comm.group_id AS group_id,