feat: list vector distance in cogneegraph (#1926)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

- `map_vector_distances_to_graph_nodes` and
`map_vector_distances_to_graph_edges` accept both single-query (flat
list) and multi-query (nested list) inputs.
- `query_list_length` controls the mode: omit it for single-query
behavior, or provide it to enable multi-query mode with strict length
validation and per-query results.
- `vector_distance` on `Node` and `Edge` is now a list (one distance per
query). Constructors set it to `None`, and `reset_distances` initializes
it at the start of each search.
- `Node.update_distance_for_query` and `Edge.update_distance_for_query`
are the only methods that write to `vector_distance`. They ensure the
list has enough elements and keep unmatched queries at the penalty
value.
- `triplet_distance_penalty` is the default distance value used
everywhere. Unmatched nodes/edges and missing scores all use this same
penalty for consistency.
- `edges_by_distance_key` is an index mapping edge labels to matching
edges. This lets us update all edges with the same label at once,
instead of scanning the full edge list repeatedly.
- `calculate_top_triplet_importances` returns `List[Edge]` for
single-query mode and `List[List[Edge]]` for multi-query mode.


## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Multi-query support for mapping/scoring node and edge distances and a
configurable triplet distance penalty.
* Distance-keyed edge indexing for more accurate distance-to-edge
matching.

* **Refactor**
* Vector distance metadata changed from scalars to per-query lists;
added reset/normalization and per-query update flows.
* Node/edge distance initialization now supports deferred/listed
distances.

* **Tests**
* Updated and expanded tests for multi-query flows, list-based
distances, edge-key handling, and related error cases.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Vasilije 2025-12-23 14:47:27 +01:00 committed by GitHub
commit 310e9e97ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 538 additions and 68 deletions

View file

@ -1,6 +1,6 @@
import time
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union, Optional, Type
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
from cognee.modules.graph.exceptions import (
EntityNotFoundError,
@ -25,12 +25,16 @@ class CogneeGraph(CogneeAbstractGraph):
nodes: Dict[str, Node]
edges: List[Edge]
edges_by_distance_key: Dict[str, List[Edge]]
directed: bool
triplet_distance_penalty: float
def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.edges_by_distance_key = {}
self.directed = directed
self.triplet_distance_penalty = 3.5
def add_node(self, node: Node) -> None:
if node.id not in self.nodes:
@ -42,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph):
self.edges.append(edge)
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
key = edge.get_distance_key()
if not key:
return
if key not in self.edges_by_distance_key:
self.edges_by_distance_key[key] = []
self.edges_by_distance_key[key].append(edge)
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
@ -56,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None:
"""Reset vector distances for a collection of nodes or edges."""
for item in collection:
item.reset_vector_distances(query_count, self.triplet_distance_penalty)
def _normalize_query_distance_lists(
self, distances: List, query_list_length: Optional[int] = None, name: str = "distances"
) -> List:
"""Normalize shape: flat list -> single-query; nested list -> multi-query."""
if not distances:
return []
first_item = distances[0]
if isinstance(first_item, (list, tuple)):
per_query_lists = distances
else:
per_query_lists = [distances]
if query_list_length is not None and len(per_query_lists) != query_list_length:
raise ValueError(
f"{name} has {len(per_query_lists)} query lists, "
f"but query_list_length is {query_list_length}"
)
return per_query_lists
async def _get_nodeset_subgraph(
self,
adapter,
@ -148,7 +181,7 @@ class CogneeGraph(CogneeAbstractGraph):
adapter, memory_fragment_filter
)
import time
self.triplet_distance_penalty = triplet_distance_penalty
start_time = time.time()
# Process nodes
@ -200,41 +233,123 @@ class CogneeGraph(CogneeAbstractGraph):
logger.error(f"Error during graph projection: {str(e)}")
raise
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
mapped_nodes = 0
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node = self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
mapped_nodes += 1
async def map_vector_distances_to_graph_nodes(
self,
node_distances,
query_list_length: Optional[int] = None,
) -> None:
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
try:
if edge_distances is None:
return
query_count = query_list_length or 1
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
self.reset_distances(self.nodes.values(), query_count)
for edge in self.edges:
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance
for collection_name, scored_results in node_distances.items():
if not scored_results:
continue
except Exception as ex:
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
per_query_scored_results = self._normalize_query_distance_lists(
scored_results, query_list_length, f"Collection '{collection_name}'"
)
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
def score(edge):
n1 = edge.node1.attributes.get("vector_distance", 1)
n2 = edge.node2.attributes.get("vector_distance", 1)
e = edge.attributes.get("vector_distance", 1)
return n1 + n2 + e
for query_index, scored_results in enumerate(per_query_scored_results):
for result in scored_results:
node_id = str(getattr(result, "id", None))
if not node_id:
continue
node = self.get_node(node_id)
if node is None:
continue
score = float(getattr(result, "score", self.triplet_distance_penalty))
node.update_distance_for_query(
query_index=query_index,
score=score,
query_count=query_count,
default_penalty=self.triplet_distance_penalty,
)
async def map_vector_distances_to_graph_edges(
self,
edge_distances,
query_list_length: Optional[int] = None,
) -> None:
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
query_count = query_list_length or 1
self.reset_distances(self.edges, query_count)
if not edge_distances:
return None
per_query_scored_results = self._normalize_query_distance_lists(
edge_distances, query_list_length, "edge_distances"
)
for query_index, scored_results in enumerate(per_query_scored_results):
for result in scored_results:
payload = getattr(result, "payload", None)
if not isinstance(payload, dict):
continue
text = payload.get("text")
if not text:
continue
matching_edges = self.edges_by_distance_key.get(str(text))
if not matching_edges:
continue
for edge in matching_edges:
edge.update_distance_for_query(
query_index=query_index,
score=float(getattr(result, "score", self.triplet_distance_penalty)),
query_count=query_count,
default_penalty=self.triplet_distance_penalty,
)
def _calculate_query_top_triplet_importances(
self,
k: int,
query_index: int = 0,
) -> List[Edge]:
"""Calculate top k triplet importances for a specific query index."""
def score(edge: Edge) -> float:
elements = (
(edge.node1, f"node {edge.node1.id}"),
(edge.node2, f"node {edge.node2.id}"),
(edge, f"edge {edge.node1.id}->{edge.node2.id}"),
)
importances = []
for element, label in elements:
distances = element.attributes.get("vector_distance")
if not isinstance(distances, list) or query_index >= len(distances):
raise ValueError(
f"{label}: vector_distance must be a list with length > {query_index} "
f"before scoring (got {type(distances).__name__} with length "
f"{len(distances) if isinstance(distances, list) else 'n/a'})"
)
value = distances[query_index]
try:
importances.append(float(value))
except (TypeError, ValueError):
raise ValueError(
f"{label}: vector_distance[{query_index}] must be float-like, "
f"got {type(value).__name__}"
)
return sum(importances)
return heapq.nsmallest(k, self.edges, key=score)
async def calculate_top_triplet_importances(
self, k: int, query_list_length: Optional[int] = None
) -> Union[List[Edge], List[List[Edge]]]:
"""Calculate top k triplet importances, supporting both single and multi-query modes."""
query_count = query_list_length or 1
results = [
self._calculate_query_top_triplet_importances(k=k, query_index=i)
for i in range(query_count)
]
if query_list_length is None:
return results[0]
return results

View file

@ -30,11 +30,31 @@ class Node:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = node_penalty
self.attributes["vector_distance"] = None
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def add_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor not in self.skeleton_neighbours:
self.skeleton_neighbours.append(neighbor)
@ -116,10 +136,36 @@ class Edge:
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = edge_penalty
self.attributes["vector_distance"] = None
self.directed = directed
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
if key is None:
return None
return str(key)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)

View file

@ -350,11 +350,41 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
distance = edge.attributes.get("vector_distance")
node1_distance = edge.node1.attributes.get("vector_distance")
node2_distance = edge.node2.attributes.get("vector_distance")
assert isinstance(distance, float), f"{name}: vector_distance should be float"
vector_distances = edge.attributes.get("vector_distance")
assert vector_distances is not None, (
f"{name}: vector_distance should be set when retrievers return results"
)
assert isinstance(vector_distances, list) and vector_distances, (
f"{name}: vector_distance should be a non-empty list"
)
distance = vector_distances[0]
assert isinstance(distance, float), (
f"{name}: vector_distance[0] should be float, got {type(distance)}"
)
assert 0 <= distance <= 1
node1_distances = edge.node1.attributes.get("vector_distance")
node2_distances = edge.node2.attributes.get("vector_distance")
assert node1_distances is not None, (
f"{name}: node1 vector_distance should be set when retrievers return results"
)
assert node2_distances is not None, (
f"{name}: node2 vector_distance should be set when retrievers return results"
)
assert isinstance(node1_distances, list) and node1_distances, (
f"{name}: node1 vector_distance should be a non-empty list"
)
assert isinstance(node2_distances, list) and node2_distances, (
f"{name}: node2 vector_distance should be a non-empty list"
)
node1_distance = node1_distances[0]
node2_distance = node2_distances[0]
assert isinstance(node1_distance, float), (
f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}"
)
assert isinstance(node2_distance, float), (
f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}"
)
assert 0 <= node1_distance <= 1
assert 0 <= node2_distance <= 1

View file

@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1"
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
assert node.attributes == {"attr1": "value1", "vector_distance": None}
assert len(node.status) == 2
assert np.all(node.status == 1)
@ -86,6 +86,46 @@ def test_node_hash():
assert hash(node) == hash("node1")
def test_node_vector_distance_stays_none():
"""Test that vector_distance remains None when no distances are passed."""
node = Node("node1")
assert node.attributes.get("vector_distance") is None
# Verify it stays None even after other operations
node.add_attribute("other_attr", "value")
assert node.attributes.get("vector_distance") is None
def test_node_vector_distance_with_custom_attributes():
"""Test that vector_distance is None even when node has custom attributes."""
node = Node("node1", {"custom": "value", "another": 42})
assert node.attributes.get("vector_distance") is None
assert node.attributes["custom"] == "value"
assert node.attributes["another"] == 42
def test_edge_vector_distance_stays_none():
"""Test that vector_distance remains None when no distances are passed."""
node1 = Node("node1")
node2 = Node("node2")
edge = Edge(node1, node2)
assert edge.attributes.get("vector_distance") is None
# Verify it stays None even after other operations
edge.add_attribute("other_attr", "value")
assert edge.attributes.get("vector_distance") is None
def test_edge_vector_distance_with_custom_attributes():
"""Test that vector_distance is None even when edge has custom attributes."""
node1 = Node("node1")
node2 = Node("node2")
edge = Edge(node1, node2, {"weight": 5, "type": "test"})
assert edge.attributes.get("vector_distance") is None
assert edge.attributes["weight"] == 5
assert edge.attributes["type"] == "test"
### Tests for Edge ###
@ -96,7 +136,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1
assert edge.node2 == node2
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
assert edge.attributes == {"vector_distance": None, "weight": 10}
assert edge.directed is False
assert len(edge.status) == 2
assert np.all(edge.status == 1)

View file

@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
)
@pytest.mark.asyncio
async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter):
"""Test that project_graph_from_db stores triplet_distance_penalty on the graph."""
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
nodes_data = [("1", {"name": "Node1"})]
edges_data = [("1", "1", "SELF", {})]
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
graph = CogneeGraph()
custom_penalty = 5.0
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=[],
triplet_distance_penalty=custom_penalty,
)
assert graph.triplet_distance_penalty == custom_penalty
graph2 = CogneeGraph()
await graph2.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=[],
)
assert graph2.triplet_distance_penalty == 3.5
@pytest.mark.asyncio
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
"""Test that edges referencing missing nodes raise error."""
@ -241,8 +272,8 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
@pytest.mark.asyncio
@ -266,9 +297,9 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
@ -298,10 +329,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [0.92]
assert graph.get_node("4").attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph):
"""Test mapping vector distances with multiple queries."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
node_distances = {
"Entity_name": [
[MockScoredResult("1", 0.95)], # query 0
[MockScoredResult("2", 0.87)], # query 1
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2)
assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5]
assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87]
assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5]
@pytest.mark.asyncio
@ -327,7 +384,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
@pytest.mark.asyncio
@ -353,8 +410,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[1].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
assert graph.edges[1].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
@ -380,7 +437,7 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.85
assert graph.edges[0].attributes.get("vector_distance") == [0.85]
@pytest.mark.asyncio
@ -406,12 +463,12 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_none_returns_early(setup_graph):
"""Test that edge_distances=None returns early without error."""
"""Test that edge_distances=None returns early without error and vector_distance is set to default penalty."""
graph = setup_graph
graph.add_node(Node("1"))
graph.add_node(Node("2"))
@ -419,7 +476,82 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_empty_nodes_returns_early(setup_graph):
"""Test that node_distances={} returns early without error and vector_distance is set to default penalty."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
await graph.map_vector_distances_to_graph_nodes({})
assert node1.attributes.get("vector_distance") == [3.5]
assert node2.attributes.get("vector_distance") == [3.5]
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
"""Test mapping edge distances with multiple queries."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
]
await graph.map_vector_distances_to_graph_edges(
edge_distances=edge_distances, query_list_length=2
)
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph):
"""Test that unmapped indices in multi-query mode stay at default penalty."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
[], # query 1: no edges mapped
]
await graph.map_vector_distances_to_graph_edges(
edge_distances=edge_distances, query_list_length=2
)
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5]
@pytest.mark.asyncio
@ -432,10 +564,10 @@ async def test_calculate_top_triplet_importances(setup_graph):
node3 = Node("3")
node4 = Node("4")
node1.add_attribute("vector_distance", 0.9)
node2.add_attribute("vector_distance", 0.8)
node3.add_attribute("vector_distance", 0.7)
node4.add_attribute("vector_distance", 0.6)
node1.add_attribute("vector_distance", [0.9])
node2.add_attribute("vector_distance", [0.8])
node3.add_attribute("vector_distance", [0.7])
node4.add_attribute("vector_distance", [0.6])
graph.add_node(node1)
graph.add_node(node2)
@ -446,9 +578,9 @@ async def test_calculate_top_triplet_importances(setup_graph):
edge2 = Edge(node2, node3)
edge3 = Edge(node3, node4)
edge1.add_attribute("vector_distance", 0.85)
edge2.add_attribute("vector_distance", 0.75)
edge3.add_attribute("vector_distance", 0.65)
edge1.add_attribute("vector_distance", [0.85])
edge2.add_attribute("vector_distance", [0.75])
edge3.add_attribute("vector_distance", [0.65])
graph.add_edge(edge1)
graph.add_edge(edge2)
@ -464,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph):
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
"""Test calculating importances when nodes/edges have no vector distances."""
"""Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it."""
graph = setup_graph
node1 = Node("1")
@ -475,7 +607,114 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
edge = Edge(node1, node2)
graph.add_edge(edge)
top_triplets = await graph.calculate_top_triplet_importances(k=1)
# Verify vector_distance is None when no distances are passed
assert node1.attributes.get("vector_distance") is None
assert node2.attributes.get("vector_distance") is None
assert edge.attributes.get("vector_distance") is None
assert len(top_triplets) == 1
assert top_triplets[0] == edge
# When no distances are set, calculate_top_triplet_importances should handle None
# by either raising an error or skipping edges with None distances
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1)
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph):
"""Test calculating top triplet importances for a single query index."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
node1.add_attribute("vector_distance", [0.1])
node2.add_attribute("vector_distance", [0.2])
node3.add_attribute("vector_distance", [0.3])
edge1 = Edge(node1, node2)
edge2 = Edge(node2, node3)
graph.add_edge(edge1)
graph.add_edge(edge2)
edge1.add_attribute("vector_distance", [0.3])
edge2.add_attribute("vector_distance", [0.4])
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
assert len(results) == 1
assert len(results[0]) == 1
assert results[0][0] == edge1
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_multi_query(setup_graph):
"""Test calculating top triplet importances with multiple queries."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge_a = Edge(node1, node2)
edge_b = Edge(node2, node3)
graph.add_edge(edge_a)
graph.add_edge(edge_b)
node1.add_attribute("vector_distance", [0.1, 0.9])
node2.add_attribute("vector_distance", [0.1, 0.9])
node3.add_attribute("vector_distance", [0.9, 0.1])
edge_a.add_attribute("vector_distance", [0.1, 0.9])
edge_b.add_attribute("vector_distance", [0.9, 0.1])
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
assert len(results) == 2
assert results[0][0] == edge_a
assert results[1][0] == edge_b
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph):
"""Test that scoring raises ValueError when list is too short for query_index."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
node1.add_attribute("vector_distance", [0.1])
node2.add_attribute("vector_distance", [0.2])
edge = Edge(node1, node2)
edge.add_attribute("vector_distance", [0.3])
graph.add_edge(edge)
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph):
"""Test that scoring raises error when vector_distance is missing."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
del node1.attributes["vector_distance"]
del node2.attributes["vector_distance"]
edge = Edge(node1, node2)
del edge.attributes["vector_distance"]
graph.add_edge(edge)
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)