diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index 58a8ee3b..cdf74e95 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError): super().__init__(self.message) +class GroupsNodesNotFoundError(GraphitiError): + """Raised when no nodes are found for a list of group ids.""" + + def __init__(self, group_ids: list[str]): + self.message = f'no nodes found for group ids {group_ids}' + super().__init__(self.message) + + class NodeNotFoundError(GraphitiError): """Raised when a node is not found.""" diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index d818e090..c75dabd8 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -589,7 +589,7 @@ def maximal_marginal_relevance( candidates_with_mmr: list[tuple[str, float]] = [] for candidate in candidates: max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates]) - mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim + mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim candidates_with_mmr.append((candidate[0], mmr)) candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])