Node group error type (#185)

* add new error

* update for compatibility wit hdev environment

* update

* fix mmr score

* make mmr more readable
This commit is contained in:
Preston Rasmussen 2024-10-11 16:51:32 -04:00 committed by GitHub
parent 6c3b32e620
commit fca1f7342e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

View file

@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError):
super().__init__(self.message)
class GroupsNodesNotFoundError(GraphitiError):
"""Raised when no nodes are found for a list of group ids."""
def __init__(self, group_ids: list[str]):
self.message = f'no nodes found for group ids {group_ids}'
super().__init__(self.message)
class NodeNotFoundError(GraphitiError):
"""Raised when a node is not found."""

View file

@ -589,7 +589,7 @@ def maximal_marginal_relevance(
candidates_with_mmr: list[tuple[str, float]] = []
for candidate in candidates:
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
candidates_with_mmr.append((candidate[0], mmr))
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])