From 46ff01021a2ca1b3a115ac6279a36ff90925bf3c Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:09:02 +0100 Subject: [PATCH] feat: add multi-query support to score calculation --- .../modules/graph/cognee_graph/CogneeGraph.py | 40 ++-- .../unit/modules/graph/cognee_graph_test.py | 182 +++++++++++++++++- 2 files changed, 200 insertions(+), 22 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4a586e488..4838d5bc0 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -308,21 +308,33 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex - def _as_distance(self, value: Union[float, List[float], None]) -> float: - """Normalize distance value to float, handling None, lists, and scalars.""" - if value is None: - return self.triplet_distance_penalty - if isinstance(value, list) and value: - return float(value[0]) - if isinstance(value, (int, float)): - return float(value) - return self.triplet_distance_penalty + def _calculate_query_top_triplet_importances( + self, + k: int, + query_index: int = 0, + ) -> List[Edge]: + """Calculate top k triplet importances for a specific query index.""" - async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): - n1 = self._as_distance(edge.node1.attributes.get("vector_distance")) - n2 = self._as_distance(edge.node2.attributes.get("vector_distance")) - e = self._as_distance(edge.attributes.get("vector_distance")) - return n1 + n2 + e + distances = [ + edge.node1.attributes.get("vector_distance"), + edge.node2.attributes.get("vector_distance"), + edge.attributes.get("vector_distance"), + ] + return sum(float(d[query_index]) for d in distances) return heapq.nsmallest(k, self.edges, key=score) + + async def calculate_top_triplet_importances( + self, k: int, query_list_length: Optional[int] = None + ) -> Union[List[Edge], List[List[Edge]]]: + """Calculate top k triplet importances, supporting both single and multi-query modes.""" + query_count = query_list_length or 1 + results = [ + self._calculate_query_top_triplet_importances(k=k, query_index=i) + for i in range(query_count) + ] + + if query_list_length is None: + return results[0] + return results diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 8babdfe47..84e6411e2 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): ) +@pytest.mark.asyncio +async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter): + """Test that project_graph_from_db stores triplet_distance_penalty on the graph.""" + from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + + nodes_data = [("1", {"name": "Node1"})] + edges_data = [("1", "1", "SELF", {})] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + graph = CogneeGraph() + custom_penalty = 5.0 + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + triplet_distance_penalty=custom_penalty, + ) + + assert graph.triplet_distance_penalty == custom_penalty + + graph2 = CogneeGraph() + await graph2.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + assert graph2.triplet_distance_penalty == 3.5 + + @pytest.mark.asyncio async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): """Test that edges referencing missing nodes raise error.""" @@ -478,6 +509,36 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph): + """Test that unmapped indices in multi-query mode stay at default penalty.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped + [], # query 1: no edges mapped + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5] + + @pytest.mark.asyncio async def test_calculate_top_triplet_importances(setup_graph): """Test calculating top triplet importances by score.""" @@ -488,10 +549,10 @@ async def test_calculate_top_triplet_importances(setup_graph): node3 = Node("3") node4 = Node("4") - node1.add_attribute("vector_distance", 0.9) - node2.add_attribute("vector_distance", 0.8) - node3.add_attribute("vector_distance", 0.7) - node4.add_attribute("vector_distance", 0.6) + node1.add_attribute("vector_distance", [0.9]) + node2.add_attribute("vector_distance", [0.8]) + node3.add_attribute("vector_distance", [0.7]) + node4.add_attribute("vector_distance", [0.6]) graph.add_node(node1) graph.add_node(node2) @@ -502,9 +563,9 @@ async def test_calculate_top_triplet_importances(setup_graph): edge2 = Edge(node2, node3) edge3 = Edge(node3, node4) - edge1.add_attribute("vector_distance", 0.85) - edge2.add_attribute("vector_distance", 0.75) - edge3.add_attribute("vector_distance", 0.65) + edge1.add_attribute("vector_distance", [0.85]) + edge2.add_attribute("vector_distance", [0.75]) + edge3.add_attribute("vector_distance", [0.65]) graph.add_edge(edge1) graph.add_edge(edge2) @@ -520,7 +581,7 @@ async def test_calculate_top_triplet_importances(setup_graph): @pytest.mark.asyncio async def test_calculate_top_triplet_importances_default_distances(setup_graph): - """Test calculating importances when nodes/edges have no vector distances.""" + """Test calculating importances when nodes/edges have default vector distances.""" graph = setup_graph node1 = Node("1") @@ -531,7 +592,112 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): edge = Edge(node1, node2) graph.add_edge(edge) + await graph.map_vector_distances_to_graph_nodes({}) + await graph.map_vector_distances_to_graph_edges(None) + top_triplets = await graph.calculate_top_triplet_importances(k=1) assert len(top_triplets) == 1 assert top_triplets[0] == edge + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph): + """Test calculating top triplet importances for a single query index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + node3.add_attribute("vector_distance", [0.3]) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge1.add_attribute("vector_distance", [0.3]) + edge2.add_attribute("vector_distance", [0.4]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + assert len(results) == 1 + assert len(results[0]) == 1 + assert results[0][0] == edge1 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_multi_query(setup_graph): + """Test calculating top triplet importances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge_a = Edge(node1, node2) + edge_b = Edge(node2, node3) + graph.add_edge(edge_a) + graph.add_edge(edge_b) + + node1.add_attribute("vector_distance", [0.1, 0.9]) + node2.add_attribute("vector_distance", [0.1, 0.9]) + node3.add_attribute("vector_distance", [0.9, 0.1]) + edge_a.add_attribute("vector_distance", [0.1, 0.9]) + edge_b.add_attribute("vector_distance", [0.9, 0.1]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + assert len(results) == 2 + assert results[0][0] == edge_a + assert results[1][0] == edge_b + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph): + """Test that scoring raises ValueError when list is too short for query_index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + + edge = Edge(node1, node2) + edge.add_attribute("vector_distance", [0.3]) + graph.add_edge(edge) + + with pytest.raises(IndexError): + await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph): + """Test that scoring raises error when vector_distance is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + del node1.attributes["vector_distance"] + del node2.attributes["vector_distance"] + + edge = Edge(node1, node2) + del edge.attributes["vector_distance"] + graph.add_edge(edge) + + with pytest.raises((KeyError, TypeError)): + await graph.calculate_top_triplet_importances(k=1, query_list_length=1)