From cc7ca45e7315300fc775509a8b3784cfae2ed99a Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:48:24 +0100 Subject: [PATCH 1/8] feat: make vector_distance list based --- .../modules/graph/cognee_graph/CogneeGraph.py | 20 ++++++++++++++++--- .../graph/cognee_graph/CogneeGraphElements.py | 4 ++-- .../graph/cognee_graph_elements_test.py | 4 ++-- .../unit/modules/graph/cognee_graph_test.py | 4 ++++ 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 6233c245f..dd05c8c4f 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -26,11 +26,13 @@ class CogneeGraph(CogneeAbstractGraph): nodes: Dict[str, Node] edges: List[Edge] directed: bool + triplet_distance_penalty: float def __init__(self, directed: bool = True): self.nodes = {} self.edges = [] self.directed = directed + self.triplet_distance_penalty = 3.5 def add_node(self, node: Node) -> None: if node.id not in self.nodes: @@ -148,6 +150,8 @@ class CogneeGraph(CogneeAbstractGraph): adapter, memory_fragment_filter ) + self.triplet_distance_penalty = triplet_distance_penalty + import time start_time = time.time() @@ -230,11 +234,21 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex + def _as_distance(self, value: Union[float, List[float], None]) -> float: + """Normalize distance value to float, handling None, lists, and scalars.""" + if value is None: + return self.triplet_distance_penalty + if isinstance(value, list) and value: + return float(value[0]) + if isinstance(value, (int, float)): + return float(value) + return self.triplet_distance_penalty + async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): - n1 = edge.node1.attributes.get("vector_distance", 1) - n2 = edge.node2.attributes.get("vector_distance", 1) - e = edge.attributes.get("vector_distance", 1) + n1 = self._as_distance(edge.node1.attributes.get("vector_distance")) + n2 = self._as_distance(edge.node2.attributes.get("vector_distance")) + e = self._as_distance(edge.attributes.get("vector_distance")) return n1 + n2 + e return heapq.nsmallest(k, self.edges, key=score) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 62ef8d9fd..5d8e0df34 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -30,7 +30,7 @@ class Node: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = node_penalty + self.attributes["vector_distance"] = None self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -116,7 +116,7 @@ class Edge: self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = edge_penalty + self.attributes["vector_distance"] = None self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 1d2b79cf9..e59888525 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} + assert node.attributes == {"attr1": "value1", "vector_distance": None} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": 3.5, "weight": 10} + assert edge.attributes == {"vector_distance": None, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index edbd8ef9d..d30167262 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -246,6 +246,7 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_node_coverage(setup_graph): """Test mapping vector distances when only some nodes have results.""" graph = setup_graph @@ -272,6 +273,7 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_multiple_categories(setup_graph): """Test mapping vector distances from multiple collection categories.""" graph = setup_graph @@ -331,6 +333,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -384,6 +387,7 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph From 69ab8e7edee9e422b33059a55f0d711fd5e4cde6 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 18:14:57 +0100 Subject: [PATCH 2/8] feat: add multi-query support to graph distance mapping --- .../modules/graph/cognee_graph/CogneeGraph.py | 114 +++++++++++++++--- cognee/tests/test_search_db.py | 29 ++++- .../unit/modules/graph/cognee_graph_test.py | 90 +++++++++++--- 3 files changed, 190 insertions(+), 43 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index dd05c8c4f..4a586e488 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,6 +1,6 @@ import time from cognee.shared.logging_utils import get_logger -from typing import List, Dict, Union, Optional, Type +from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any from cognee.modules.graph.exceptions import ( EntityNotFoundError, @@ -204,31 +204,105 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error during graph projection: {str(e)}") raise - async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: - mapped_nodes = 0 - for category, scored_results in node_distances.items(): - for scored_result in scored_results: - node_id = str(scored_result.id) - score = scored_result.score - node = self.get_node(node_id) - if node: - node.add_attribute("vector_distance", score) - mapped_nodes += 1 + def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None: + """Initialize vector_distance as a list of default penalties for all graph elements.""" + query_count = query_list_length or 1 + for element in graph_elements: + element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count - async def map_vector_distances_to_graph_edges(self, edge_distances) -> None: + def _normalize_query_input(self, distance_data, query_list_length=None, name="input"): + """Normalize single-query or multi-query input to list of lists, return empty list if empty.""" + if not distance_data: + return [] + normalized = ( + distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data] + ) + if query_list_length is not None and len(normalized) != query_list_length: + raise ValueError( + f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}" + ) + return normalized + + def _apply_vector_distance_updates( + self, + element_distances, + query_index: int, + get_element: Callable[[str], Optional[Union[Node, Edge]]], + get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]], + ) -> None: + """Apply updates into element.attributes["vector_distance"][query_index].""" + for res in element_distances: + key, score = get_id_and_score(res) + if key is None or score is None: + continue + element = get_element(key) + if element is None: + continue + element.attributes["vector_distance"][query_index] = score + + def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]: + """Extract node ID and score from a scored result.""" + return str(res.id), float(res.score) + + def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]: + """Extract edge key and score from a scored result.""" + payload = getattr(res, "payload", None) + if not payload: + return None, None + text = payload.get("text") + if text is None: + return None, None + return str(text), float(res.score) + + async def map_vector_distances_to_graph_nodes( + self, + node_distances, + query_list_length: Optional[int] = None, + ) -> None: + self._initialize_vector_distance(self.nodes.values(), query_list_length) + + for collection_name, scored_results in node_distances.items(): + per_query_lists = self._normalize_query_input( + scored_results, query_list_length, f"Collection '{collection_name}'" + ) + if not per_query_lists: + continue + + for query_index, scored_list in enumerate(per_query_lists): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=self.nodes.get, + get_id_and_score=self._get_node_id_and_score, + ) + + async def map_vector_distances_to_graph_edges( + self, + edge_distances, + query_list_length: Optional[int] = None, + ) -> None: try: - if edge_distances is None: + self._initialize_vector_distance(self.edges, query_list_length) + + normalized_edges = self._normalize_query_input( + edge_distances, query_list_length, "edge_distances" + ) + if not normalized_edges: return - embedding_map = {result.payload["text"]: result.score for result in edge_distances} - + edges_by_key: Dict[str, Edge] = {} for edge in self.edges: - edge_key = edge.attributes.get("edge_text") or edge.attributes.get( - "relationship_type" + key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + if key: + edges_by_key[str(key)] = edge + + for query_index, scored_list in enumerate(normalized_edges): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=edges_by_key.get, + get_id_and_score=self._get_edge_id_and_score, ) - distance = embedding_map.get(edge_key, None) - if distance is not None: - edge.attributes["vector_distance"] = distance except Exception as ex: logger.error(f"Error mapping vector distances to edges: {str(ex)}") diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 0916be322..d0b78dfcc 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -350,11 +350,32 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): assert triplets, f"{name}: Triplets list should not be empty" for edge in triplets: assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), f"{name}: vector_distance should be float" + vector_distances = edge.attributes.get("vector_distance", []) + assert isinstance(vector_distances, list) and vector_distances, ( + f"{name}: vector_distance should be a non-empty list" + ) + distance = vector_distances[0] + assert isinstance(distance, float), ( + f"{name}: vector_distance[0] should be float, got {type(distance)}" + ) assert 0 <= distance <= 1 + + node1_distances = edge.node1.attributes.get("vector_distance", []) + node2_distances = edge.node2.attributes.get("vector_distance", []) + assert isinstance(node1_distances, list) and node1_distances, ( + f"{name}: node1 vector_distance should be a non-empty list" + ) + assert isinstance(node2_distances, list) and node2_distances, ( + f"{name}: node2 vector_distance should be a non-empty list" + ) + node1_distance = node1_distances[0] + node2_distance = node2_distances[0] + assert isinstance(node1_distance, float), ( + f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}" + ) + assert isinstance(node2_distance, float), ( + f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}" + ) assert 0 <= node1_distance <= 1 assert 0 <= node2_distance <= 1 diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index d30167262..8babdfe47 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -241,12 +241,11 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_node_coverage(setup_graph): """Test mapping vector distances when only some nodes have results.""" graph = setup_graph @@ -267,13 +266,12 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_multiple_categories(setup_graph): """Test mapping vector distances from multiple collection categories.""" graph = setup_graph @@ -300,10 +298,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 0.92 - assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [0.92] + assert graph.get_node("4").attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph): + """Test mapping vector distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + [MockScoredResult("1", 0.95)], # query 0 + [MockScoredResult("2", 0.87)], # query 1 + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2) + + assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5] + assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5] @pytest.mark.asyncio @@ -329,11 +353,10 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -356,8 +379,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 - assert graph.edges[1].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] + assert graph.edges[1].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -383,11 +406,10 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.85 + assert graph.edges[0].attributes.get("vector_distance") == [0.85] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph @@ -410,7 +432,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -423,7 +445,37 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): + """Test mapping edge distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0 + [MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1 + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] @pytest.mark.asyncio From 46ff01021a2ca1b3a115ac6279a36ff90925bf3c Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:09:02 +0100 Subject: [PATCH 3/8] feat: add multi-query support to score calculation --- .../modules/graph/cognee_graph/CogneeGraph.py | 40 ++-- .../unit/modules/graph/cognee_graph_test.py | 182 +++++++++++++++++- 2 files changed, 200 insertions(+), 22 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4a586e488..4838d5bc0 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -308,21 +308,33 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex - def _as_distance(self, value: Union[float, List[float], None]) -> float: - """Normalize distance value to float, handling None, lists, and scalars.""" - if value is None: - return self.triplet_distance_penalty - if isinstance(value, list) and value: - return float(value[0]) - if isinstance(value, (int, float)): - return float(value) - return self.triplet_distance_penalty + def _calculate_query_top_triplet_importances( + self, + k: int, + query_index: int = 0, + ) -> List[Edge]: + """Calculate top k triplet importances for a specific query index.""" - async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): - n1 = self._as_distance(edge.node1.attributes.get("vector_distance")) - n2 = self._as_distance(edge.node2.attributes.get("vector_distance")) - e = self._as_distance(edge.attributes.get("vector_distance")) - return n1 + n2 + e + distances = [ + edge.node1.attributes.get("vector_distance"), + edge.node2.attributes.get("vector_distance"), + edge.attributes.get("vector_distance"), + ] + return sum(float(d[query_index]) for d in distances) return heapq.nsmallest(k, self.edges, key=score) + + async def calculate_top_triplet_importances( + self, k: int, query_list_length: Optional[int] = None + ) -> Union[List[Edge], List[List[Edge]]]: + """Calculate top k triplet importances, supporting both single and multi-query modes.""" + query_count = query_list_length or 1 + results = [ + self._calculate_query_top_triplet_importances(k=k, query_index=i) + for i in range(query_count) + ] + + if query_list_length is None: + return results[0] + return results diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 8babdfe47..84e6411e2 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): ) +@pytest.mark.asyncio +async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter): + """Test that project_graph_from_db stores triplet_distance_penalty on the graph.""" + from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + + nodes_data = [("1", {"name": "Node1"})] + edges_data = [("1", "1", "SELF", {})] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + graph = CogneeGraph() + custom_penalty = 5.0 + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + triplet_distance_penalty=custom_penalty, + ) + + assert graph.triplet_distance_penalty == custom_penalty + + graph2 = CogneeGraph() + await graph2.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + assert graph2.triplet_distance_penalty == 3.5 + + @pytest.mark.asyncio async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): """Test that edges referencing missing nodes raise error.""" @@ -478,6 +509,36 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph): + """Test that unmapped indices in multi-query mode stay at default penalty.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped + [], # query 1: no edges mapped + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5] + + @pytest.mark.asyncio async def test_calculate_top_triplet_importances(setup_graph): """Test calculating top triplet importances by score.""" @@ -488,10 +549,10 @@ async def test_calculate_top_triplet_importances(setup_graph): node3 = Node("3") node4 = Node("4") - node1.add_attribute("vector_distance", 0.9) - node2.add_attribute("vector_distance", 0.8) - node3.add_attribute("vector_distance", 0.7) - node4.add_attribute("vector_distance", 0.6) + node1.add_attribute("vector_distance", [0.9]) + node2.add_attribute("vector_distance", [0.8]) + node3.add_attribute("vector_distance", [0.7]) + node4.add_attribute("vector_distance", [0.6]) graph.add_node(node1) graph.add_node(node2) @@ -502,9 +563,9 @@ async def test_calculate_top_triplet_importances(setup_graph): edge2 = Edge(node2, node3) edge3 = Edge(node3, node4) - edge1.add_attribute("vector_distance", 0.85) - edge2.add_attribute("vector_distance", 0.75) - edge3.add_attribute("vector_distance", 0.65) + edge1.add_attribute("vector_distance", [0.85]) + edge2.add_attribute("vector_distance", [0.75]) + edge3.add_attribute("vector_distance", [0.65]) graph.add_edge(edge1) graph.add_edge(edge2) @@ -520,7 +581,7 @@ async def test_calculate_top_triplet_importances(setup_graph): @pytest.mark.asyncio async def test_calculate_top_triplet_importances_default_distances(setup_graph): - """Test calculating importances when nodes/edges have no vector distances.""" + """Test calculating importances when nodes/edges have default vector distances.""" graph = setup_graph node1 = Node("1") @@ -531,7 +592,112 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): edge = Edge(node1, node2) graph.add_edge(edge) + await graph.map_vector_distances_to_graph_nodes({}) + await graph.map_vector_distances_to_graph_edges(None) + top_triplets = await graph.calculate_top_triplet_importances(k=1) assert len(top_triplets) == 1 assert top_triplets[0] == edge + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph): + """Test calculating top triplet importances for a single query index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + node3.add_attribute("vector_distance", [0.3]) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge1.add_attribute("vector_distance", [0.3]) + edge2.add_attribute("vector_distance", [0.4]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + assert len(results) == 1 + assert len(results[0]) == 1 + assert results[0][0] == edge1 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_multi_query(setup_graph): + """Test calculating top triplet importances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge_a = Edge(node1, node2) + edge_b = Edge(node2, node3) + graph.add_edge(edge_a) + graph.add_edge(edge_b) + + node1.add_attribute("vector_distance", [0.1, 0.9]) + node2.add_attribute("vector_distance", [0.1, 0.9]) + node3.add_attribute("vector_distance", [0.9, 0.1]) + edge_a.add_attribute("vector_distance", [0.1, 0.9]) + edge_b.add_attribute("vector_distance", [0.9, 0.1]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + assert len(results) == 2 + assert results[0][0] == edge_a + assert results[1][0] == edge_b + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph): + """Test that scoring raises ValueError when list is too short for query_index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + + edge = Edge(node1, node2) + edge.add_attribute("vector_distance", [0.3]) + graph.add_edge(edge) + + with pytest.raises(IndexError): + await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph): + """Test that scoring raises error when vector_distance is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + del node1.attributes["vector_distance"] + del node2.attributes["vector_distance"] + + edge = Edge(node1, node2) + del edge.attributes["vector_distance"] + graph.add_edge(edge) + + with pytest.raises((KeyError, TypeError)): + await graph.calculate_top_triplet_importances(k=1, query_list_length=1) From c1ea7a8cc235067358cf68006bd87b40ad7830b7 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:52:35 +0100 Subject: [PATCH 4/8] fix: improve graph distance mapping --- .../modules/graph/cognee_graph/CogneeGraph.py | 169 +++++++++--------- .../graph/cognee_graph/CogneeGraphElements.py | 46 +++++ cognee/tests/test_search_db.py | 15 +- .../graph/cognee_graph_elements_test.py | 40 +++++ .../unit/modules/graph/cognee_graph_test.py | 35 +++- 5 files changed, 210 insertions(+), 95 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4838d5bc0..bc29bb828 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -25,12 +25,14 @@ class CogneeGraph(CogneeAbstractGraph): nodes: Dict[str, Node] edges: List[Edge] + edges_by_distance_key: Dict[str, List[Edge]] directed: bool triplet_distance_penalty: float def __init__(self, directed: bool = True): self.nodes = {} self.edges = [] + self.edges_by_distance_key = {} self.directed = directed self.triplet_distance_penalty = 3.5 @@ -44,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph): self.edges.append(edge) edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) + key = edge.get_distance_key() + if not key: + return + if key not in self.edges_by_distance_key: + self.edges_by_distance_key[key] = [] + self.edges_by_distance_key[key].append(edge) def get_node(self, node_id: str) -> Node: return self.nodes.get(node_id, None) @@ -58,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None: + """Reset vector distances for a collection of nodes or edges.""" + for item in collection: + item.reset_vector_distances(query_count, self.triplet_distance_penalty) + + def _normalize_query_distance_lists( + self, distances: List, query_list_length: Optional[int] = None, name: str = "distances" + ) -> List: + """Normalize shape: flat list -> single-query; nested list -> multi-query.""" + if not distances: + return [] + first_item = distances[0] + if isinstance(first_item, (list, tuple)): + per_query_lists = distances + else: + per_query_lists = [distances] + if query_list_length is not None and len(per_query_lists) != query_list_length: + raise ValueError( + f"{name} has {len(per_query_lists)} query lists, " + f"but query_list_length is {query_list_length}" + ) + return per_query_lists + async def _get_nodeset_subgraph( self, adapter, @@ -204,109 +235,81 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error during graph projection: {str(e)}") raise - def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None: - """Initialize vector_distance as a list of default penalties for all graph elements.""" - query_count = query_list_length or 1 - for element in graph_elements: - element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count - - def _normalize_query_input(self, distance_data, query_list_length=None, name="input"): - """Normalize single-query or multi-query input to list of lists, return empty list if empty.""" - if not distance_data: - return [] - normalized = ( - distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data] - ) - if query_list_length is not None and len(normalized) != query_list_length: - raise ValueError( - f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}" - ) - return normalized - - def _apply_vector_distance_updates( - self, - element_distances, - query_index: int, - get_element: Callable[[str], Optional[Union[Node, Edge]]], - get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]], - ) -> None: - """Apply updates into element.attributes["vector_distance"][query_index].""" - for res in element_distances: - key, score = get_id_and_score(res) - if key is None or score is None: - continue - element = get_element(key) - if element is None: - continue - element.attributes["vector_distance"][query_index] = score - - def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]: - """Extract node ID and score from a scored result.""" - return str(res.id), float(res.score) - - def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]: - """Extract edge key and score from a scored result.""" - payload = getattr(res, "payload", None) - if not payload: - return None, None - text = payload.get("text") - if text is None: - return None, None - return str(text), float(res.score) - async def map_vector_distances_to_graph_nodes( self, node_distances, query_list_length: Optional[int] = None, ) -> None: - self._initialize_vector_distance(self.nodes.values(), query_list_length) + """Map vector distances to nodes, supporting single- and multi-query input shapes.""" + if not node_distances: + return None + + query_count = query_list_length or 1 + + # Reset all node distances for this search + self.reset_distances(self.nodes.values(), query_count) for collection_name, scored_results in node_distances.items(): - per_query_lists = self._normalize_query_input( - scored_results, query_list_length, f"Collection '{collection_name}'" - ) - if not per_query_lists: + if not scored_results: continue + per_query_lists = self._normalize_query_distance_lists( + scored_results, query_list_length, f"Collection '{collection_name}'" + ) + for query_index, scored_list in enumerate(per_query_lists): - self._apply_vector_distance_updates( - element_distances=scored_list, - query_index=query_index, - get_element=self.nodes.get, - get_id_and_score=self._get_node_id_and_score, - ) + for result in scored_list: + node_id = str(getattr(result, "id", None)) + if not node_id: + continue + node = self.get_node(node_id) + if node is None: + continue + score = float(getattr(result, "score", self.triplet_distance_penalty)) + node.update_distance_for_query( + query_index=query_index, + score=score, + query_count=query_count, + default_penalty=self.triplet_distance_penalty, + ) async def map_vector_distances_to_graph_edges( self, edge_distances, query_list_length: Optional[int] = None, ) -> None: - try: - self._initialize_vector_distance(self.edges, query_list_length) + """Map vector distances to graph edges, supporting single- and multi-query input shapes.""" + if not edge_distances: + return None - normalized_edges = self._normalize_query_input( - edge_distances, query_list_length, "edge_distances" - ) - if not normalized_edges: - return + query_count = query_list_length or 1 - edges_by_key: Dict[str, Edge] = {} - for edge in self.edges: - key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") - if key: - edges_by_key[str(key)] = edge + # Reset all edge distances for this search + self.reset_distances(self.edges, query_count) - for query_index, scored_list in enumerate(normalized_edges): - self._apply_vector_distance_updates( - element_distances=scored_list, - query_index=query_index, - get_element=edges_by_key.get, - get_id_and_score=self._get_edge_id_and_score, - ) + per_query_edge_lists = self._normalize_query_distance_lists( + edge_distances, query_list_length, "edge_distances" + ) - except Exception as ex: - logger.error(f"Error mapping vector distances to edges: {str(ex)}") - raise ex + # For each query, apply distances to all matching edges + for query_index, scored_list in enumerate(per_query_edge_lists): + for result in scored_list: + payload = getattr(result, "payload", None) + if not isinstance(payload, dict): + continue + text = payload.get("text") + if not text: + continue + matching_edges = self.edges_by_distance_key.get(str(text)) + if not matching_edges: + continue + for edge in matching_edges: + edge.update_distance_for_query( + query_index=query_index, + score=float(getattr(result, "score", self.triplet_distance_penalty)), + query_count=query_count, + default_penalty=self.triplet_distance_penalty, + ) def _calculate_query_top_triplet_importances( self, diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 5d8e0df34..c9226b6a1 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -35,6 +35,26 @@ class Node: self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) + def reset_vector_distances(self, query_count: int, default_penalty: float) -> None: + self.attributes["vector_distance"] = [default_penalty] * query_count + + def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]: + distances = self.attributes.get("vector_distance") + if not isinstance(distances, list) or len(distances) != query_count: + distances = [default_penalty] * query_count + self.attributes["vector_distance"] = distances + return distances + + def update_distance_for_query( + self, + query_index: int, + score: float, + query_count: int, + default_penalty: float, + ) -> None: + distances = self.ensure_vector_distance_list(query_count, default_penalty) + distances[query_index] = score + def add_skeleton_neighbor(self, neighbor: "Node") -> None: if neighbor not in self.skeleton_neighbours: self.skeleton_neighbours.append(neighbor) @@ -120,6 +140,32 @@ class Edge: self.directed = directed self.status = np.ones(dimension, dtype=int) + def get_distance_key(self) -> Optional[str]: + key = self.attributes.get("edge_text") or self.attributes.get("relationship_type") + if key is None: + return None + return str(key) + + def reset_vector_distances(self, query_count: int, default_penalty: float) -> None: + self.attributes["vector_distance"] = [default_penalty] * query_count + + def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]: + distances = self.attributes.get("vector_distance") + if not isinstance(distances, list) or len(distances) != query_count: + distances = [default_penalty] * query_count + self.attributes["vector_distance"] = distances + return distances + + def update_distance_for_query( + self, + query_index: int, + score: float, + query_count: int, + default_penalty: float, + ) -> None: + distances = self.ensure_vector_distance_list(query_count, default_penalty) + distances[query_index] = score + def is_edge_alive_in_dimension(self, dimension: int) -> bool: if dimension < 0 or dimension >= len(self.status): raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index d0b78dfcc..c5cd0061e 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -350,7 +350,10 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): assert triplets, f"{name}: Triplets list should not be empty" for edge in triplets: assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - vector_distances = edge.attributes.get("vector_distance", []) + vector_distances = edge.attributes.get("vector_distance") + assert vector_distances is not None, ( + f"{name}: vector_distance should be set when retrievers return results" + ) assert isinstance(vector_distances, list) and vector_distances, ( f"{name}: vector_distance should be a non-empty list" ) @@ -360,8 +363,14 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): ) assert 0 <= distance <= 1 - node1_distances = edge.node1.attributes.get("vector_distance", []) - node2_distances = edge.node2.attributes.get("vector_distance", []) + node1_distances = edge.node1.attributes.get("vector_distance") + node2_distances = edge.node2.attributes.get("vector_distance") + assert node1_distances is not None, ( + f"{name}: node1 vector_distance should be set when retrievers return results" + ) + assert node2_distances is not None, ( + f"{name}: node2 vector_distance should be set when retrievers return results" + ) assert isinstance(node1_distances, list) and node1_distances, ( f"{name}: node1 vector_distance should be a non-empty list" ) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index e59888525..809cde4cd 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -86,6 +86,46 @@ def test_node_hash(): assert hash(node) == hash("node1") +def test_node_vector_distance_stays_none(): + """Test that vector_distance remains None when no distances are passed.""" + node = Node("node1") + assert node.attributes.get("vector_distance") is None + + # Verify it stays None even after other operations + node.add_attribute("other_attr", "value") + assert node.attributes.get("vector_distance") is None + + +def test_node_vector_distance_with_custom_attributes(): + """Test that vector_distance is None even when node has custom attributes.""" + node = Node("node1", {"custom": "value", "another": 42}) + assert node.attributes.get("vector_distance") is None + assert node.attributes["custom"] == "value" + assert node.attributes["another"] == 42 + + +def test_edge_vector_distance_stays_none(): + """Test that vector_distance remains None when no distances are passed.""" + node1 = Node("node1") + node2 = Node("node2") + edge = Edge(node1, node2) + assert edge.attributes.get("vector_distance") is None + + # Verify it stays None even after other operations + edge.add_attribute("other_attr", "value") + assert edge.attributes.get("vector_distance") is None + + +def test_edge_vector_distance_with_custom_attributes(): + """Test that vector_distance is None even when edge has custom attributes.""" + node1 = Node("node1") + node2 = Node("node2") + edge = Edge(node1, node2, {"weight": 5, "type": "test"}) + assert edge.attributes.get("vector_distance") is None + assert edge.attributes["weight"] == 5 + assert edge.attributes["type"] == "test" + + ### Tests for Edge ### diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 84e6411e2..e4ff0251e 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -468,7 +468,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): @pytest.mark.asyncio async def test_map_vector_distances_none_returns_early(setup_graph): - """Test that edge_distances=None returns early without error.""" + """Test that edge_distances=None returns early without error and vector_distance stays None.""" graph = setup_graph graph.add_node(Node("1")) graph.add_node(Node("2")) @@ -476,7 +476,22 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") == [3.5] + assert graph.edges[0].attributes.get("vector_distance") is None + + +@pytest.mark.asyncio +async def test_map_vector_distances_empty_nodes_returns_early(setup_graph): + """Test that node_distances={} returns early without error and vector_distance stays None.""" + graph = setup_graph + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + await graph.map_vector_distances_to_graph_nodes({}) + + assert node1.attributes.get("vector_distance") is None + assert node2.attributes.get("vector_distance") is None @pytest.mark.asyncio @@ -581,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph): @pytest.mark.asyncio async def test_calculate_top_triplet_importances_default_distances(setup_graph): - """Test calculating importances when nodes/edges have default vector distances.""" + """Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it.""" graph = setup_graph node1 = Node("1") @@ -592,13 +607,15 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): edge = Edge(node1, node2) graph.add_edge(edge) - await graph.map_vector_distances_to_graph_nodes({}) - await graph.map_vector_distances_to_graph_edges(None) + # Verify vector_distance is None when no distances are passed + assert node1.attributes.get("vector_distance") is None + assert node2.attributes.get("vector_distance") is None + assert edge.attributes.get("vector_distance") is None - top_triplets = await graph.calculate_top_triplet_importances(k=1) - - assert len(top_triplets) == 1 - assert top_triplets[0] == edge + # When no distances are set, calculate_top_triplet_importances should handle None + # by either raising an error or skipping edges with None distances + with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + await graph.calculate_top_triplet_importances(k=1) @pytest.mark.asyncio From a85df53c7479053b8490d817b08177a81555b444 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:40:56 +0100 Subject: [PATCH 5/8] chore: tweak mapping and scoring --- .../modules/graph/cognee_graph/CogneeGraph.py | 44 ++++++++++++------- .../unit/modules/graph/cognee_graph_test.py | 6 +-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index bc29bb828..d4f07d0e6 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -241,12 +241,9 @@ class CogneeGraph(CogneeAbstractGraph): query_list_length: Optional[int] = None, ) -> None: """Map vector distances to nodes, supporting single- and multi-query input shapes.""" - if not node_distances: - return None query_count = query_list_length or 1 - # Reset all node distances for this search self.reset_distances(self.nodes.values(), query_count) for collection_name, scored_results in node_distances.items(): @@ -279,19 +276,17 @@ class CogneeGraph(CogneeAbstractGraph): query_list_length: Optional[int] = None, ) -> None: """Map vector distances to graph edges, supporting single- and multi-query input shapes.""" - if not edge_distances: - return None - query_count = query_list_length or 1 - # Reset all edge distances for this search self.reset_distances(self.edges, query_count) + if not edge_distances: + return None + per_query_edge_lists = self._normalize_query_distance_lists( edge_distances, query_list_length, "edge_distances" ) - # For each query, apply distances to all matching edges for query_index, scored_list in enumerate(per_query_edge_lists): for result in scored_list: payload = getattr(result, "payload", None) @@ -318,13 +313,32 @@ class CogneeGraph(CogneeAbstractGraph): ) -> List[Edge]: """Calculate top k triplet importances for a specific query index.""" - def score(edge): - distances = [ - edge.node1.attributes.get("vector_distance"), - edge.node2.attributes.get("vector_distance"), - edge.attributes.get("vector_distance"), - ] - return sum(float(d[query_index]) for d in distances) + def score(edge: Edge) -> float: + elements = ( + (edge.node1, f"node {edge.node1.id}"), + (edge.node2, f"node {edge.node2.id}"), + (edge, f"edge {edge.node1.id}->{edge.node2.id}"), + ) + + importances = [] + for element, label in elements: + distances = element.attributes.get("vector_distance") + if not isinstance(distances, list) or query_index >= len(distances): + raise ValueError( + f"{label}: vector_distance must be a list with length > {query_index} " + f"before scoring (got {type(distances).__name__} with length " + f"{len(distances) if isinstance(distances, list) else 'n/a'})" + ) + value = distances[query_index] + try: + importances.append(float(value)) + except (TypeError, ValueError): + raise ValueError( + f"{label}: vector_distance[{query_index}] must be float-like, " + f"got {type(value).__name__}" + ) + + return sum(importances) return heapq.nsmallest(k, self.edges, key=score) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index e4ff0251e..0458fca76 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -614,7 +614,7 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): # When no distances are set, calculate_top_triplet_importances should handle None # by either raising an error or skipping edges with None distances - with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1) @@ -695,7 +695,7 @@ async def test_calculate_top_triplet_importances_raises_on_short_list(setup_grap edge.add_attribute("vector_distance", [0.3]) graph.add_edge(edge) - with pytest.raises(IndexError): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=2) @@ -716,5 +716,5 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set del edge.attributes["vector_distance"] graph.add_edge(edge) - with pytest.raises((KeyError, TypeError)): + with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=1) From 9808077b4c9f558d951b03c50643c4b57c9b5791 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 19 Dec 2025 15:35:34 +0100 Subject: [PATCH 6/8] nit: update variable names --- cognee/modules/graph/cognee_graph/CogneeGraph.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index d4f07d0e6..da8c2254a 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -250,12 +250,12 @@ class CogneeGraph(CogneeAbstractGraph): if not scored_results: continue - per_query_lists = self._normalize_query_distance_lists( + per_query_scored_results = self._normalize_query_distance_lists( scored_results, query_list_length, f"Collection '{collection_name}'" ) - for query_index, scored_list in enumerate(per_query_lists): - for result in scored_list: + for query_index, scored_results in enumerate(per_query_scored_results): + for result in scored_results: node_id = str(getattr(result, "id", None)) if not node_id: continue @@ -283,12 +283,12 @@ class CogneeGraph(CogneeAbstractGraph): if not edge_distances: return None - per_query_edge_lists = self._normalize_query_distance_lists( + per_query_scored_results = self._normalize_query_distance_lists( edge_distances, query_list_length, "edge_distances" ) - for query_index, scored_list in enumerate(per_query_edge_lists): - for result in scored_list: + for query_index, scored_results in enumerate(per_query_scored_results): + for result in scored_results: payload = getattr(result, "payload", None) if not isinstance(payload, dict): continue From c3cec818d7509f0d6fd85593a884e070a41a49a2 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:22:47 +0100 Subject: [PATCH 7/8] fix: update tests --- cognee/tests/unit/modules/graph/cognee_graph_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 0458fca76..41f12e73a 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -468,7 +468,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): @pytest.mark.asyncio async def test_map_vector_distances_none_returns_early(setup_graph): - """Test that edge_distances=None returns early without error and vector_distance stays None.""" + """Test that edge_distances=None returns early without error and vector_distance is set to default penalty.""" graph = setup_graph graph.add_node(Node("1")) graph.add_node(Node("2")) @@ -476,12 +476,12 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") is None + assert graph.edges[0].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio async def test_map_vector_distances_empty_nodes_returns_early(setup_graph): - """Test that node_distances={} returns early without error and vector_distance stays None.""" + """Test that node_distances={} returns early without error and vector_distance is set to default penalty.""" graph = setup_graph node1 = Node("1") node2 = Node("2") @@ -490,8 +490,8 @@ async def test_map_vector_distances_empty_nodes_returns_early(setup_graph): await graph.map_vector_distances_to_graph_nodes({}) - assert node1.attributes.get("vector_distance") is None - assert node2.attributes.get("vector_distance") is None + assert node1.attributes.get("vector_distance") == [3.5] + assert node2.attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio From f6c76ce19edaf7ca2301adfec2acecc6a0ed50e6 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:24:49 +0100 Subject: [PATCH 8/8] chore: remove duplicate import --- cognee/modules/graph/cognee_graph/CogneeGraph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index da8c2254a..bec9b15fd 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -183,8 +183,6 @@ class CogneeGraph(CogneeAbstractGraph): self.triplet_distance_penalty = triplet_distance_penalty - import time - start_time = time.time() # Process nodes for node_id, properties in nodes_data: