feat: add multi-query support to graph distance mapping

This commit is contained in:
lxobr 2025-12-17 18:14:57 +01:00
parent cc7ca45e73
commit 69ab8e7ede
3 changed files with 190 additions and 43 deletions

View file

@ -1,6 +1,6 @@
import time
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union, Optional, Type
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
from cognee.modules.graph.exceptions import (
EntityNotFoundError,
@ -204,31 +204,105 @@ class CogneeGraph(CogneeAbstractGraph):
logger.error(f"Error during graph projection: {str(e)}")
raise
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
mapped_nodes = 0
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node = self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
mapped_nodes += 1
def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None:
"""Initialize vector_distance as a list of default penalties for all graph elements."""
query_count = query_list_length or 1
for element in graph_elements:
element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
def _normalize_query_input(self, distance_data, query_list_length=None, name="input"):
"""Normalize single-query or multi-query input to list of lists, return empty list if empty."""
if not distance_data:
return []
normalized = (
distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data]
)
if query_list_length is not None and len(normalized) != query_list_length:
raise ValueError(
f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}"
)
return normalized
def _apply_vector_distance_updates(
self,
element_distances,
query_index: int,
get_element: Callable[[str], Optional[Union[Node, Edge]]],
get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]],
) -> None:
"""Apply updates into element.attributes["vector_distance"][query_index]."""
for res in element_distances:
key, score = get_id_and_score(res)
if key is None or score is None:
continue
element = get_element(key)
if element is None:
continue
element.attributes["vector_distance"][query_index] = score
def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]:
"""Extract node ID and score from a scored result."""
return str(res.id), float(res.score)
def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]:
"""Extract edge key and score from a scored result."""
payload = getattr(res, "payload", None)
if not payload:
return None, None
text = payload.get("text")
if text is None:
return None, None
return str(text), float(res.score)
async def map_vector_distances_to_graph_nodes(
self,
node_distances,
query_list_length: Optional[int] = None,
) -> None:
self._initialize_vector_distance(self.nodes.values(), query_list_length)
for collection_name, scored_results in node_distances.items():
per_query_lists = self._normalize_query_input(
scored_results, query_list_length, f"Collection '{collection_name}'"
)
if not per_query_lists:
continue
for query_index, scored_list in enumerate(per_query_lists):
self._apply_vector_distance_updates(
element_distances=scored_list,
query_index=query_index,
get_element=self.nodes.get,
get_id_and_score=self._get_node_id_and_score,
)
async def map_vector_distances_to_graph_edges(
self,
edge_distances,
query_list_length: Optional[int] = None,
) -> None:
try:
if edge_distances is None:
self._initialize_vector_distance(self.edges, query_list_length)
normalized_edges = self._normalize_query_input(
edge_distances, query_list_length, "edge_distances"
)
if not normalized_edges:
return
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
edges_by_key: Dict[str, Edge] = {}
for edge in self.edges:
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
if key:
edges_by_key[str(key)] = edge
for query_index, scored_list in enumerate(normalized_edges):
self._apply_vector_distance_updates(
element_distances=scored_list,
query_index=query_index,
get_element=edges_by_key.get,
get_id_and_score=self._get_edge_id_and_score,
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance
except Exception as ex:
logger.error(f"Error mapping vector distances to edges: {str(ex)}")

View file

@ -350,11 +350,32 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), f"{name}: vector_distance should be float"
vector_distances = edge.attributes.get("vector_distance", [])
assert isinstance(vector_distances, list) and vector_distances, (
f"{name}: vector_distance should be a non-empty list"
)
distance = vector_distances[0]
assert isinstance(distance, float), (
f"{name}: vector_distance[0] should be float, got {type(distance)}"
)
assert 0 <= distance <= 1
node1_distances = edge.node1.attributes.get("vector_distance", [])
node2_distances = edge.node2.attributes.get("vector_distance", [])
assert isinstance(node1_distances, list) and node1_distances, (
f"{name}: node1 vector_distance should be a non-empty list"
)
assert isinstance(node2_distances, list) and node2_distances, (
f"{name}: node2 vector_distance should be a non-empty list"
)
node1_distance = node1_distances[0]
node2_distance = node2_distances[0]
assert isinstance(node1_distance, float), (
f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}"
)
assert isinstance(node2_distance, float), (
f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}"
)
assert 0 <= node1_distance <= 1
assert 0 <= node2_distance <= 1

View file

@ -241,12 +241,11 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
@pytest.mark.asyncio
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
async def test_map_vector_distances_partial_node_coverage(setup_graph):
"""Test mapping vector distances when only some nodes have results."""
graph = setup_graph
@ -267,13 +266,12 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
async def test_map_vector_distances_multiple_categories(setup_graph):
"""Test mapping vector distances from multiple collection categories."""
graph = setup_graph
@ -300,10 +298,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [0.92]
assert graph.get_node("4").attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph):
"""Test mapping vector distances with multiple queries."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
node_distances = {
"Entity_name": [
[MockScoredResult("1", 0.95)], # query 0
[MockScoredResult("2", 0.87)], # query 1
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2)
assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5]
assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5]
@pytest.mark.asyncio
@ -329,11 +353,10 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
@pytest.mark.asyncio
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
"""Test mapping edge distances when only some edges have results."""
graph = setup_graph
@ -356,8 +379,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[1].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
assert graph.edges[1].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
@ -383,11 +406,10 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.85
assert graph.edges[0].attributes.get("vector_distance") == [0.85]
@pytest.mark.asyncio
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
async def test_map_vector_distances_no_edge_matches(setup_graph):
"""Test edge mapping when no edges match the distance results."""
graph = setup_graph
@ -410,7 +432,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
@ -423,7 +445,37 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
"""Test mapping edge distances with multiple queries."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
]
await graph.map_vector_distances_to_graph_edges(
edge_distances=edge_distances, query_list_length=2
)
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
@pytest.mark.asyncio