From a85df53c7479053b8490d817b08177a81555b444 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:40:56 +0100 Subject: [PATCH] chore: tweak mapping and scoring --- .../modules/graph/cognee_graph/CogneeGraph.py | 44 ++++++++++++------- .../unit/modules/graph/cognee_graph_test.py | 6 +-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index bc29bb828..d4f07d0e6 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -241,12 +241,9 @@ class CogneeGraph(CogneeAbstractGraph): query_list_length: Optional[int] = None, ) -> None: """Map vector distances to nodes, supporting single- and multi-query input shapes.""" - if not node_distances: - return None query_count = query_list_length or 1 - # Reset all node distances for this search self.reset_distances(self.nodes.values(), query_count) for collection_name, scored_results in node_distances.items(): @@ -279,19 +276,17 @@ class CogneeGraph(CogneeAbstractGraph): query_list_length: Optional[int] = None, ) -> None: """Map vector distances to graph edges, supporting single- and multi-query input shapes.""" - if not edge_distances: - return None - query_count = query_list_length or 1 - # Reset all edge distances for this search self.reset_distances(self.edges, query_count) + if not edge_distances: + return None + per_query_edge_lists = self._normalize_query_distance_lists( edge_distances, query_list_length, "edge_distances" ) - # For each query, apply distances to all matching edges for query_index, scored_list in enumerate(per_query_edge_lists): for result in scored_list: payload = getattr(result, "payload", None) @@ -318,13 +313,32 @@ class CogneeGraph(CogneeAbstractGraph): ) -> List[Edge]: """Calculate top k triplet importances for a specific query index.""" - def score(edge): - distances = [ - edge.node1.attributes.get("vector_distance"), - edge.node2.attributes.get("vector_distance"), - edge.attributes.get("vector_distance"), - ] - return sum(float(d[query_index]) for d in distances) + def score(edge: Edge) -> float: + elements = ( + (edge.node1, f"node {edge.node1.id}"), + (edge.node2, f"node {edge.node2.id}"), + (edge, f"edge {edge.node1.id}->{edge.node2.id}"), + ) + + importances = [] + for element, label in elements: + distances = element.attributes.get("vector_distance") + if not isinstance(distances, list) or query_index >= len(distances): + raise ValueError( + f"{label}: vector_distance must be a list with length > {query_index} " + f"before scoring (got {type(distances).__name__} with length " + f"{len(distances) if isinstance(distances, list) else 'n/a'})" + ) + value = distances[query_index] + try: + importances.append(float(value)) + except (TypeError, ValueError): + raise ValueError( + f"{label}: vector_distance[{query_index}] must be float-like, " + f"got {type(value).__name__}" + ) + + return sum(importances) return heapq.nsmallest(k, self.edges, key=score) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index e4ff0251e..0458fca76 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -614,7 +614,7 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): # When no distances are set, calculate_top_triplet_importances should handle None # by either raising an error or skipping edges with None distances - with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1) @@ -695,7 +695,7 @@ async def test_calculate_top_triplet_importances_raises_on_short_list(setup_grap edge.add_attribute("vector_distance", [0.3]) graph.add_edge(edge) - with pytest.raises(IndexError): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=2) @@ -716,5 +716,5 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set del edge.attributes["vector_distance"] graph.add_edge(edge) - with pytest.raises((KeyError, TypeError)): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=1)