From fca1f7342e82c2e9595fc7d21e267271f0ad649a Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:51:32 -0400 Subject: [PATCH] Node group error type (#185) * add new error * update for compatibility wit hdev environment * update * fix mmr score * make mmr more readable --- graphiti_core/errors.py | 8 ++++++++ graphiti_core/search/search_utils.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index 58a8ee3b..cdf74e95 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError): super().__init__(self.message) +class GroupsNodesNotFoundError(GraphitiError): + """Raised when no nodes are found for a list of group ids.""" + + def __init__(self, group_ids: list[str]): + self.message = f'no nodes found for group ids {group_ids}' + super().__init__(self.message) + + class NodeNotFoundError(GraphitiError): """Raised when a node is not found.""" diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index d818e090..c75dabd8 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -589,7 +589,7 @@ def maximal_marginal_relevance( candidates_with_mmr: list[tuple[str, float]] = [] for candidate in candidates: max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates]) - mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim + mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim candidates_with_mmr.append((candidate[0], mmr)) candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])