From 69ab8e7edee9e422b33059a55f0d711fd5e4cde6 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 18:14:57 +0100 Subject: [PATCH] feat: add multi-query support to graph distance mapping --- .../modules/graph/cognee_graph/CogneeGraph.py | 114 +++++++++++++++--- cognee/tests/test_search_db.py | 29 ++++- .../unit/modules/graph/cognee_graph_test.py | 90 +++++++++++--- 3 files changed, 190 insertions(+), 43 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index dd05c8c4f..4a586e488 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,6 +1,6 @@ import time from cognee.shared.logging_utils import get_logger -from typing import List, Dict, Union, Optional, Type +from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any from cognee.modules.graph.exceptions import ( EntityNotFoundError, @@ -204,31 +204,105 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error during graph projection: {str(e)}") raise - async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: - mapped_nodes = 0 - for category, scored_results in node_distances.items(): - for scored_result in scored_results: - node_id = str(scored_result.id) - score = scored_result.score - node = self.get_node(node_id) - if node: - node.add_attribute("vector_distance", score) - mapped_nodes += 1 + def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None: + """Initialize vector_distance as a list of default penalties for all graph elements.""" + query_count = query_list_length or 1 + for element in graph_elements: + element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count - async def map_vector_distances_to_graph_edges(self, edge_distances) -> None: + def _normalize_query_input(self, distance_data, query_list_length=None, name="input"): + """Normalize single-query or multi-query input to list of lists, return empty list if empty.""" + if not distance_data: + return [] + normalized = ( + distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data] + ) + if query_list_length is not None and len(normalized) != query_list_length: + raise ValueError( + f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}" + ) + return normalized + + def _apply_vector_distance_updates( + self, + element_distances, + query_index: int, + get_element: Callable[[str], Optional[Union[Node, Edge]]], + get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]], + ) -> None: + """Apply updates into element.attributes["vector_distance"][query_index].""" + for res in element_distances: + key, score = get_id_and_score(res) + if key is None or score is None: + continue + element = get_element(key) + if element is None: + continue + element.attributes["vector_distance"][query_index] = score + + def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]: + """Extract node ID and score from a scored result.""" + return str(res.id), float(res.score) + + def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]: + """Extract edge key and score from a scored result.""" + payload = getattr(res, "payload", None) + if not payload: + return None, None + text = payload.get("text") + if text is None: + return None, None + return str(text), float(res.score) + + async def map_vector_distances_to_graph_nodes( + self, + node_distances, + query_list_length: Optional[int] = None, + ) -> None: + self._initialize_vector_distance(self.nodes.values(), query_list_length) + + for collection_name, scored_results in node_distances.items(): + per_query_lists = self._normalize_query_input( + scored_results, query_list_length, f"Collection '{collection_name}'" + ) + if not per_query_lists: + continue + + for query_index, scored_list in enumerate(per_query_lists): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=self.nodes.get, + get_id_and_score=self._get_node_id_and_score, + ) + + async def map_vector_distances_to_graph_edges( + self, + edge_distances, + query_list_length: Optional[int] = None, + ) -> None: try: - if edge_distances is None: + self._initialize_vector_distance(self.edges, query_list_length) + + normalized_edges = self._normalize_query_input( + edge_distances, query_list_length, "edge_distances" + ) + if not normalized_edges: return - embedding_map = {result.payload["text"]: result.score for result in edge_distances} - + edges_by_key: Dict[str, Edge] = {} for edge in self.edges: - edge_key = edge.attributes.get("edge_text") or edge.attributes.get( - "relationship_type" + key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + if key: + edges_by_key[str(key)] = edge + + for query_index, scored_list in enumerate(normalized_edges): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=edges_by_key.get, + get_id_and_score=self._get_edge_id_and_score, ) - distance = embedding_map.get(edge_key, None) - if distance is not None: - edge.attributes["vector_distance"] = distance except Exception as ex: logger.error(f"Error mapping vector distances to edges: {str(ex)}") diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 0916be322..d0b78dfcc 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -350,11 +350,32 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): assert triplets, f"{name}: Triplets list should not be empty" for edge in triplets: assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), f"{name}: vector_distance should be float" + vector_distances = edge.attributes.get("vector_distance", []) + assert isinstance(vector_distances, list) and vector_distances, ( + f"{name}: vector_distance should be a non-empty list" + ) + distance = vector_distances[0] + assert isinstance(distance, float), ( + f"{name}: vector_distance[0] should be float, got {type(distance)}" + ) assert 0 <= distance <= 1 + + node1_distances = edge.node1.attributes.get("vector_distance", []) + node2_distances = edge.node2.attributes.get("vector_distance", []) + assert isinstance(node1_distances, list) and node1_distances, ( + f"{name}: node1 vector_distance should be a non-empty list" + ) + assert isinstance(node2_distances, list) and node2_distances, ( + f"{name}: node2 vector_distance should be a non-empty list" + ) + node1_distance = node1_distances[0] + node2_distance = node2_distances[0] + assert isinstance(node1_distance, float), ( + f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}" + ) + assert isinstance(node2_distance, float), ( + f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}" + ) assert 0 <= node1_distance <= 1 assert 0 <= node2_distance <= 1 diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index d30167262..8babdfe47 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -241,12 +241,11 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_node_coverage(setup_graph): """Test mapping vector distances when only some nodes have results.""" graph = setup_graph @@ -267,13 +266,12 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_multiple_categories(setup_graph): """Test mapping vector distances from multiple collection categories.""" graph = setup_graph @@ -300,10 +298,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 0.92 - assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [0.92] + assert graph.get_node("4").attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph): + """Test mapping vector distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + [MockScoredResult("1", 0.95)], # query 0 + [MockScoredResult("2", 0.87)], # query 1 + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2) + + assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5] + assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5] @pytest.mark.asyncio @@ -329,11 +353,10 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -356,8 +379,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 - assert graph.edges[1].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] + assert graph.edges[1].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -383,11 +406,10 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.85 + assert graph.edges[0].attributes.get("vector_distance") == [0.85] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph @@ -410,7 +432,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -423,7 +445,37 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): + """Test mapping edge distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0 + [MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1 + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] @pytest.mark.asyncio