diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 47d72343..fd87234e 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -959,7 +959,7 @@ class Graphiti: nodes = await get_mentioned_nodes(self.driver, episodes) - return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[]) + return SearchResults(edges=edges, nodes=nodes) async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode): if source_node.name_embedding is None: diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 394fd528..1458def7 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -80,12 +80,7 @@ async def search( cross_encoder = clients.cross_encoder if query.strip() == '': - return SearchResults( - edges=[], - nodes=[], - episodes=[], - communities=[], - ) + return SearchResults() query_vector = ( query_vector if query_vector is not None @@ -94,7 +89,12 @@ async def search( # if group_ids is empty, set it to None group_ids = group_ids if group_ids and group_ids != [''] else None - edges, nodes, episodes, communities = await semaphore_gather( + ( + (edges, edge_reranker_scores), + (nodes, node_reranker_scores), + (episodes, episode_reranker_scores), + (communities, community_reranker_scores), + ) = await semaphore_gather( edge_search( driver, cross_encoder, @@ -146,9 +146,13 @@ async def search( results = SearchResults( edges=edges, + edge_reranker_scores=edge_reranker_scores, nodes=nodes, + node_reranker_scores=node_reranker_scores, episodes=episodes, + episode_reranker_scores=episode_reranker_scores, communities=communities, + community_reranker_scores=community_reranker_scores, ) latency = (time() - start) * 1000 @@ -170,9 +174,9 @@ async def edge_search( bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, reranker_min_score: float = 0, -) -> list[EntityEdge]: +) -> tuple[list[EntityEdge], list[float]]: if config is None: - return [] + return [], [] search_results: list[list[EntityEdge]] = list( await semaphore_gather( *[ @@ -215,15 +219,16 @@ async def edge_search( edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} reranked_uuids: list[str] = [] + edge_scores: list[float] = [] if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == EdgeReranker.mmr: search_result_uuids_and_vectors = await get_embeddings_for_edges( driver, list(edge_uuid_map.values()) ) - reranked_uuids = maximal_marginal_relevance( + reranked_uuids, edge_scores = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda, @@ -235,12 +240,13 @@ async def edge_search( reranked_uuids = [ fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score ] + edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score] elif config.reranker == EdgeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') # use rrf as a preliminary sort - sorted_result_uuids = rrf( + sorted_result_uuids, node_scores = rrf( [[edge.uuid for edge in result] for result in search_results], min_score=reranker_min_score, ) @@ -253,7 +259,7 @@ async def edge_search( source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map] - reranked_node_uuids = await node_distance_reranker( + reranked_node_uuids, edge_scores = await node_distance_reranker( driver, source_uuids, center_node_uuid, min_score=reranker_min_score ) @@ -265,7 +271,7 @@ async def edge_search( if config.reranker == EdgeReranker.episode_mentions: reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes)) - return reranked_edges[:limit] + return reranked_edges[:limit], edge_scores[:limit] async def node_search( @@ -280,9 +286,9 @@ async def node_search( bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, reranker_min_score: float = 0, -) -> list[EntityNode]: +) -> tuple[list[EntityNode], list[float]]: if config is None: - return [] + return [], [] search_results: list[list[EntityNode]] = list( await semaphore_gather( *[ @@ -319,14 +325,15 @@ async def node_search( node_uuid_map = {node.uuid: node for result in search_results for node in result} reranked_uuids: list[str] = [] + node_scores: list[float] = [] if config.reranker == NodeReranker.rrf: - reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == NodeReranker.mmr: search_result_uuids_and_vectors = await get_embeddings_for_nodes( driver, list(node_uuid_map.values()) ) - reranked_uuids = maximal_marginal_relevance( + reranked_uuids, node_scores = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda, @@ -341,23 +348,24 @@ async def node_search( for name, score in reranked_node_names if score >= reranker_min_score ] + node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score] elif config.reranker == NodeReranker.episode_mentions: - reranked_uuids = await episode_mentions_reranker( + reranked_uuids, node_scores = await episode_mentions_reranker( driver, search_result_uuids, min_score=reranker_min_score ) elif config.reranker == NodeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') - reranked_uuids = await node_distance_reranker( + reranked_uuids, node_scores = await node_distance_reranker( driver, - rrf(search_result_uuids, min_score=reranker_min_score), + rrf(search_result_uuids, min_score=reranker_min_score)[0], center_node_uuid, min_score=reranker_min_score, ) reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] - return reranked_nodes[:limit] + return reranked_nodes[:limit], node_scores[:limit] async def episode_search( @@ -370,9 +378,9 @@ async def episode_search( search_filter: SearchFilters, limit=DEFAULT_SEARCH_LIMIT, reranker_min_score: float = 0, -) -> list[EpisodicNode]: +) -> tuple[list[EpisodicNode], list[float]]: if config is None: - return [] + return [], [] search_results: list[list[EpisodicNode]] = list( await semaphore_gather( *[ @@ -385,12 +393,13 @@ async def episode_search( episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result} reranked_uuids: list[str] = [] + episode_scores: list[float] = [] if config.reranker == EpisodeReranker.rrf: - reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == EpisodeReranker.cross_encoder: # use rrf as a preliminary reranker - rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score) rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit] content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results} @@ -401,10 +410,11 @@ async def episode_search( for content, score in reranked_contents if score >= reranker_min_score ] + episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score] reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids] - return reranked_episodes[:limit] + return reranked_episodes[:limit], episode_scores[:limit] async def community_search( @@ -416,9 +426,9 @@ async def community_search( config: CommunitySearchConfig | None, limit=DEFAULT_SEARCH_LIMIT, reranker_min_score: float = 0, -) -> list[CommunityNode]: +) -> tuple[list[CommunityNode], list[float]]: if config is None: - return [] + return [], [] search_results: list[list[CommunityNode]] = list( await semaphore_gather( @@ -437,14 +447,15 @@ async def community_search( } reranked_uuids: list[str] = [] + community_scores: list[float] = [] if config.reranker == CommunityReranker.rrf: - reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) + reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == CommunityReranker.mmr: search_result_uuids_and_vectors = await get_embeddings_for_communities( driver, list(community_uuid_map.values()) ) - reranked_uuids = maximal_marginal_relevance( + reranked_uuids, community_scores = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score ) elif config.reranker == CommunityReranker.cross_encoder: @@ -453,7 +464,8 @@ async def community_search( reranked_uuids = [ name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score ] + community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] - return reranked_communities[:limit] + return reranked_communities[:limit], community_scores[:limit] diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 63b1a114..f24a3f3e 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -119,7 +119,11 @@ class SearchConfig(BaseModel): class SearchResults(BaseModel): - edges: list[EntityEdge] - nodes: list[EntityNode] - episodes: list[EpisodicNode] - communities: list[CommunityNode] + edges: list[EntityEdge] = Field(default_factory=list) + edge_reranker_scores: list[float] = Field(default_factory=list) + nodes: list[EntityNode] = Field(default_factory=list) + node_reranker_scores: list[float] = Field(default_factory=list) + episodes: list[EpisodicNode] = Field(default_factory=list) + episode_reranker_scores: list[float] = Field(default_factory=list) + communities: list[CommunityNode] = Field(default_factory=list) + community_reranker_scores: list[float] = Field(default_factory=list) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index f689dd86..5d6828f7 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -294,13 +294,13 @@ async def edge_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - AND r.group_id IN $group_ids - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + WHERE r.uuid = rel.uuid + AND r.group_id IN $group_ids + """ + filter_query + """ RETURN DISTINCT @@ -445,11 +445,11 @@ async def node_bfs_search( query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - AND origin.group_id IN $group_ids - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + WHERE n.group_id = origin.group_id + AND origin.group_id IN $group_ids + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -672,7 +672,7 @@ async def hybrid_node_search( } result_uuids = [[node.uuid for node in result] for result in results] - ranked_uuids = rrf(result_uuids) + ranked_uuids, _ = rrf(result_uuids) relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids] @@ -914,7 +914,9 @@ async def get_edge_invalidation_candidates( # takes in a list of rankings of uuids -def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]: +def rrf( + results: list[list[str]], rank_const=1, min_score: float = 0 +) -> tuple[list[str], list[float]]: scores: dict[str, float] = defaultdict(float) for result in results: for i, uuid in enumerate(result): @@ -925,7 +927,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st sorted_uuids = [term[0] for term in scored_uuids] - return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] + return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [ + scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score + ] async def node_distance_reranker( @@ -933,7 +937,7 @@ async def node_distance_reranker( node_uuids: list[str], center_node_uuid: str, min_score: float = 0, -) -> list[str]: +) -> tuple[list[str], list[float]]: # filter out node_uuid center node node uuid filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) scores: dict[str, float] = {center_node_uuid: 0.0} @@ -970,14 +974,16 @@ async def node_distance_reranker( scores[center_node_uuid] = 0.1 filtered_uuids = [center_node_uuid] + filtered_uuids - return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score] + return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [ + 1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score + ] async def episode_mentions_reranker( driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0 -) -> list[str]: +) -> tuple[list[str], list[float]]: # use rrf as a preliminary ranker - sorted_uuids = rrf(node_uuids) + sorted_uuids, _ = rrf(node_uuids) scores: dict[str, float] = {} # Find the shortest path to center node @@ -998,7 +1004,9 @@ async def episode_mentions_reranker( # rerank on shortest distance sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) - return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score] + return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [ + scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score + ] def maximal_marginal_relevance( @@ -1006,7 +1014,7 @@ def maximal_marginal_relevance( candidates: dict[str, list[float]], mmr_lambda: float = DEFAULT_MMR_LAMBDA, min_score: float = -2.0, -) -> list[str]: +) -> tuple[list[str], list[float]]: start = time() query_array = np.array(query_vector) candidate_arrays: dict[str, NDArray] = {} @@ -1037,7 +1045,9 @@ def maximal_marginal_relevance( end = time() logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms') - return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score] + return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [ + mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score + ] async def get_embeddings_for_nodes( diff --git a/pyproject.toml b/pyproject.toml index af6661f7..d2e49acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.17.11" +version = "0.18.0" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index a9007b03..acec74cb 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -71,7 +71,7 @@ async def test_graphiti_init(): ) results = await graphiti.search_( - query='What is the hall of portrait?', + query='Who is Tania', search_filter=search_filter, ) diff --git a/uv.lock b/uv.lock index d19c1e74..832bec3b 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.17.10" +version = "0.18.0" source = { editable = "." } dependencies = [ { name = "diskcache" },