feat: list vector distance in cogneegraph (#1926)
<!-- .github/pull_request_template.md -->
## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
- `map_vector_distances_to_graph_nodes` and
`map_vector_distances_to_graph_edges` accept both single-query (flat
list) and multi-query (nested list) inputs.
- `query_list_length` controls the mode: omit it for single-query
behavior, or provide it to enable multi-query mode with strict length
validation and per-query results.
- `vector_distance` on `Node` and `Edge` is now a list (one distance per
query). Constructors set it to `None`, and `reset_distances` initializes
it at the start of each search.
- `Node.update_distance_for_query` and `Edge.update_distance_for_query`
are the only methods that write to `vector_distance`. They ensure the
list has enough elements and keep unmatched queries at the penalty
value.
- `triplet_distance_penalty` is the default distance value used
everywhere. Unmatched nodes/edges and missing scores all use this same
penalty for consistency.
- `edges_by_distance_key` is an index mapping edge labels to matching
edges. This lets us update all edges with the same label at once,
instead of scanning the full edge list repeatedly.
- `calculate_top_triplet_importances` returns `List[Edge]` for
single-query mode and `List[List[Edge]]` for multi-query mode.
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Multi-query support for mapping/scoring node and edge distances and a
configurable triplet distance penalty.
* Distance-keyed edge indexing for more accurate distance-to-edge
matching.
* **Refactor**
* Vector distance metadata changed from scalars to per-query lists;
added reset/normalization and per-query update flows.
* Node/edge distance initialization now supports deferred/listed
distances.
* **Tests**
* Updated and expanded tests for multi-query flows, list-based
distances, edge-key handling, and related error cases.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
310e9e97ae
5 changed files with 538 additions and 68 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
|
||||
|
||||
from cognee.modules.graph.exceptions import (
|
||||
EntityNotFoundError,
|
||||
|
|
@ -25,12 +25,16 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
nodes: Dict[str, Node]
|
||||
edges: List[Edge]
|
||||
edges_by_distance_key: Dict[str, List[Edge]]
|
||||
directed: bool
|
||||
triplet_distance_penalty: float
|
||||
|
||||
def __init__(self, directed: bool = True):
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
self.edges_by_distance_key = {}
|
||||
self.directed = directed
|
||||
self.triplet_distance_penalty = 3.5
|
||||
|
||||
def add_node(self, node: Node) -> None:
|
||||
if node.id not in self.nodes:
|
||||
|
|
@ -42,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
self.edges.append(edge)
|
||||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
key = edge.get_distance_key()
|
||||
if not key:
|
||||
return
|
||||
if key not in self.edges_by_distance_key:
|
||||
self.edges_by_distance_key[key] = []
|
||||
self.edges_by_distance_key[key].append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Node:
|
||||
return self.nodes.get(node_id, None)
|
||||
|
|
@ -56,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None:
|
||||
"""Reset vector distances for a collection of nodes or edges."""
|
||||
for item in collection:
|
||||
item.reset_vector_distances(query_count, self.triplet_distance_penalty)
|
||||
|
||||
def _normalize_query_distance_lists(
|
||||
self, distances: List, query_list_length: Optional[int] = None, name: str = "distances"
|
||||
) -> List:
|
||||
"""Normalize shape: flat list -> single-query; nested list -> multi-query."""
|
||||
if not distances:
|
||||
return []
|
||||
first_item = distances[0]
|
||||
if isinstance(first_item, (list, tuple)):
|
||||
per_query_lists = distances
|
||||
else:
|
||||
per_query_lists = [distances]
|
||||
if query_list_length is not None and len(per_query_lists) != query_list_length:
|
||||
raise ValueError(
|
||||
f"{name} has {len(per_query_lists)} query lists, "
|
||||
f"but query_list_length is {query_list_length}"
|
||||
)
|
||||
return per_query_lists
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
|
|
@ -148,7 +181,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
adapter, memory_fragment_filter
|
||||
)
|
||||
|
||||
import time
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
|
||||
start_time = time.time()
|
||||
# Process nodes
|
||||
|
|
@ -200,41 +233,123 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
logger.error(f"Error during graph projection: {str(e)}")
|
||||
raise
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
mapped_nodes = 0
|
||||
for category, scored_results in node_distances.items():
|
||||
for scored_result in scored_results:
|
||||
node_id = str(scored_result.id)
|
||||
score = scored_result.score
|
||||
node = self.get_node(node_id)
|
||||
if node:
|
||||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
async def map_vector_distances_to_graph_nodes(
|
||||
self,
|
||||
node_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
|
||||
|
||||
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
||||
try:
|
||||
if edge_distances is None:
|
||||
return
|
||||
query_count = query_list_length or 1
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
self.reset_distances(self.nodes.values(), query_count)
|
||||
|
||||
for edge in self.edges:
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
for collection_name, scored_results in node_distances.items():
|
||||
if not scored_results:
|
||||
continue
|
||||
|
||||
except Exception as ex:
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
per_query_scored_results = self._normalize_query_distance_lists(
|
||||
scored_results, query_list_length, f"Collection '{collection_name}'"
|
||||
)
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
|
||||
def score(edge):
|
||||
n1 = edge.node1.attributes.get("vector_distance", 1)
|
||||
n2 = edge.node2.attributes.get("vector_distance", 1)
|
||||
e = edge.attributes.get("vector_distance", 1)
|
||||
return n1 + n2 + e
|
||||
for query_index, scored_results in enumerate(per_query_scored_results):
|
||||
for result in scored_results:
|
||||
node_id = str(getattr(result, "id", None))
|
||||
if not node_id:
|
||||
continue
|
||||
node = self.get_node(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
score = float(getattr(result, "score", self.triplet_distance_penalty))
|
||||
node.update_distance_for_query(
|
||||
query_index=query_index,
|
||||
score=score,
|
||||
query_count=query_count,
|
||||
default_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self,
|
||||
edge_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
|
||||
query_count = query_list_length or 1
|
||||
|
||||
self.reset_distances(self.edges, query_count)
|
||||
|
||||
if not edge_distances:
|
||||
return None
|
||||
|
||||
per_query_scored_results = self._normalize_query_distance_lists(
|
||||
edge_distances, query_list_length, "edge_distances"
|
||||
)
|
||||
|
||||
for query_index, scored_results in enumerate(per_query_scored_results):
|
||||
for result in scored_results:
|
||||
payload = getattr(result, "payload", None)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
text = payload.get("text")
|
||||
if not text:
|
||||
continue
|
||||
matching_edges = self.edges_by_distance_key.get(str(text))
|
||||
if not matching_edges:
|
||||
continue
|
||||
for edge in matching_edges:
|
||||
edge.update_distance_for_query(
|
||||
query_index=query_index,
|
||||
score=float(getattr(result, "score", self.triplet_distance_penalty)),
|
||||
query_count=query_count,
|
||||
default_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
def _calculate_query_top_triplet_importances(
|
||||
self,
|
||||
k: int,
|
||||
query_index: int = 0,
|
||||
) -> List[Edge]:
|
||||
"""Calculate top k triplet importances for a specific query index."""
|
||||
|
||||
def score(edge: Edge) -> float:
|
||||
elements = (
|
||||
(edge.node1, f"node {edge.node1.id}"),
|
||||
(edge.node2, f"node {edge.node2.id}"),
|
||||
(edge, f"edge {edge.node1.id}->{edge.node2.id}"),
|
||||
)
|
||||
|
||||
importances = []
|
||||
for element, label in elements:
|
||||
distances = element.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or query_index >= len(distances):
|
||||
raise ValueError(
|
||||
f"{label}: vector_distance must be a list with length > {query_index} "
|
||||
f"before scoring (got {type(distances).__name__} with length "
|
||||
f"{len(distances) if isinstance(distances, list) else 'n/a'})"
|
||||
)
|
||||
value = distances[query_index]
|
||||
try:
|
||||
importances.append(float(value))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"{label}: vector_distance[{query_index}] must be float-like, "
|
||||
f"got {type(value).__name__}"
|
||||
)
|
||||
|
||||
return sum(importances)
|
||||
|
||||
return heapq.nsmallest(k, self.edges, key=score)
|
||||
|
||||
async def calculate_top_triplet_importances(
|
||||
self, k: int, query_list_length: Optional[int] = None
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""Calculate top k triplet importances, supporting both single and multi-query modes."""
|
||||
query_count = query_list_length or 1
|
||||
results = [
|
||||
self._calculate_query_top_triplet_importances(k=k, query_index=i)
|
||||
for i in range(query_count)
|
||||
]
|
||||
|
||||
if query_list_length is None:
|
||||
return results[0]
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -30,11 +30,31 @@ class Node:
|
|||
raise InvalidDimensionsError()
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = node_penalty
|
||||
self.attributes["vector_distance"] = None
|
||||
self.skeleton_neighbours = []
|
||||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
|
||||
self.attributes["vector_distance"] = [default_penalty] * query_count
|
||||
|
||||
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
|
||||
distances = self.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or len(distances) != query_count:
|
||||
distances = [default_penalty] * query_count
|
||||
self.attributes["vector_distance"] = distances
|
||||
return distances
|
||||
|
||||
def update_distance_for_query(
|
||||
self,
|
||||
query_index: int,
|
||||
score: float,
|
||||
query_count: int,
|
||||
default_penalty: float,
|
||||
) -> None:
|
||||
distances = self.ensure_vector_distance_list(query_count, default_penalty)
|
||||
distances[query_index] = score
|
||||
|
||||
def add_skeleton_neighbor(self, neighbor: "Node") -> None:
|
||||
if neighbor not in self.skeleton_neighbours:
|
||||
self.skeleton_neighbours.append(neighbor)
|
||||
|
|
@ -116,10 +136,36 @@ class Edge:
|
|||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = edge_penalty
|
||||
self.attributes["vector_distance"] = None
|
||||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
def get_distance_key(self) -> Optional[str]:
|
||||
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
|
||||
if key is None:
|
||||
return None
|
||||
return str(key)
|
||||
|
||||
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
|
||||
self.attributes["vector_distance"] = [default_penalty] * query_count
|
||||
|
||||
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
|
||||
distances = self.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or len(distances) != query_count:
|
||||
distances = [default_penalty] * query_count
|
||||
self.attributes["vector_distance"] = distances
|
||||
return distances
|
||||
|
||||
def update_distance_for_query(
|
||||
self,
|
||||
query_index: int,
|
||||
score: float,
|
||||
query_count: int,
|
||||
default_penalty: float,
|
||||
) -> None:
|
||||
distances = self.ensure_vector_distance_list(query_count, default_penalty)
|
||||
distances[query_index] = score
|
||||
|
||||
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
|
||||
if dimension < 0 or dimension >= len(self.status):
|
||||
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)
|
||||
|
|
|
|||
|
|
@ -350,11 +350,41 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
|||
assert triplets, f"{name}: Triplets list should not be empty"
|
||||
for edge in triplets:
|
||||
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||
distance = edge.attributes.get("vector_distance")
|
||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||
assert isinstance(distance, float), f"{name}: vector_distance should be float"
|
||||
vector_distances = edge.attributes.get("vector_distance")
|
||||
assert vector_distances is not None, (
|
||||
f"{name}: vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert isinstance(vector_distances, list) and vector_distances, (
|
||||
f"{name}: vector_distance should be a non-empty list"
|
||||
)
|
||||
distance = vector_distances[0]
|
||||
assert isinstance(distance, float), (
|
||||
f"{name}: vector_distance[0] should be float, got {type(distance)}"
|
||||
)
|
||||
assert 0 <= distance <= 1
|
||||
|
||||
node1_distances = edge.node1.attributes.get("vector_distance")
|
||||
node2_distances = edge.node2.attributes.get("vector_distance")
|
||||
assert node1_distances is not None, (
|
||||
f"{name}: node1 vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert node2_distances is not None, (
|
||||
f"{name}: node2 vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert isinstance(node1_distances, list) and node1_distances, (
|
||||
f"{name}: node1 vector_distance should be a non-empty list"
|
||||
)
|
||||
assert isinstance(node2_distances, list) and node2_distances, (
|
||||
f"{name}: node2 vector_distance should be a non-empty list"
|
||||
)
|
||||
node1_distance = node1_distances[0]
|
||||
node2_distance = node2_distances[0]
|
||||
assert isinstance(node1_distance, float), (
|
||||
f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}"
|
||||
)
|
||||
assert isinstance(node2_distance, float), (
|
||||
f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}"
|
||||
)
|
||||
assert 0 <= node1_distance <= 1
|
||||
assert 0 <= node2_distance <= 1
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_node_initialization():
|
|||
"""Test that a Node is initialized correctly."""
|
||||
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
||||
assert node.id == "node1"
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": None}
|
||||
assert len(node.status) == 2
|
||||
assert np.all(node.status == 1)
|
||||
|
||||
|
|
@ -86,6 +86,46 @@ def test_node_hash():
|
|||
assert hash(node) == hash("node1")
|
||||
|
||||
|
||||
def test_node_vector_distance_stays_none():
|
||||
"""Test that vector_distance remains None when no distances are passed."""
|
||||
node = Node("node1")
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
|
||||
# Verify it stays None even after other operations
|
||||
node.add_attribute("other_attr", "value")
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
def test_node_vector_distance_with_custom_attributes():
|
||||
"""Test that vector_distance is None even when node has custom attributes."""
|
||||
node = Node("node1", {"custom": "value", "another": 42})
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
assert node.attributes["custom"] == "value"
|
||||
assert node.attributes["another"] == 42
|
||||
|
||||
|
||||
def test_edge_vector_distance_stays_none():
|
||||
"""Test that vector_distance remains None when no distances are passed."""
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
edge = Edge(node1, node2)
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
# Verify it stays None even after other operations
|
||||
edge.add_attribute("other_attr", "value")
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
def test_edge_vector_distance_with_custom_attributes():
|
||||
"""Test that vector_distance is None even when edge has custom attributes."""
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
edge = Edge(node1, node2, {"weight": 5, "type": "test"})
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
assert edge.attributes["weight"] == 5
|
||||
assert edge.attributes["type"] == "test"
|
||||
|
||||
|
||||
### Tests for Edge ###
|
||||
|
||||
|
||||
|
|
@ -96,7 +136,7 @@ def test_edge_initialization():
|
|||
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
||||
assert edge.node1 == node1
|
||||
assert edge.node2 == node2
|
||||
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
|
||||
assert edge.attributes == {"vector_distance": None, "weight": 10}
|
||||
assert edge.directed is False
|
||||
assert len(edge.status) == 2
|
||||
assert np.all(edge.status == 1)
|
||||
|
|
|
|||
|
|
@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter):
|
||||
"""Test that project_graph_from_db stores triplet_distance_penalty on the graph."""
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
||||
nodes_data = [("1", {"name": "Node1"})]
|
||||
edges_data = [("1", "1", "SELF", {})]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
graph = CogneeGraph()
|
||||
custom_penalty = 5.0
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
triplet_distance_penalty=custom_penalty,
|
||||
)
|
||||
|
||||
assert graph.triplet_distance_penalty == custom_penalty
|
||||
|
||||
graph2 = CogneeGraph()
|
||||
await graph2.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
)
|
||||
|
||||
assert graph2.triplet_distance_penalty == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
||||
"""Test that edges referencing missing nodes raise error."""
|
||||
|
|
@ -241,8 +272,8 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -266,9 +297,9 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -298,10 +329,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [0.92]
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph):
|
||||
"""Test mapping vector distances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("1", 0.95)], # query 0
|
||||
[MockScoredResult("2", 0.87)], # query 1
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -327,7 +384,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -353,8 +410,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -380,7 +437,7 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.85]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -406,12 +463,12 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_none_returns_early(setup_graph):
|
||||
"""Test that edge_distances=None returns early without error."""
|
||||
"""Test that edge_distances=None returns early without error and vector_distance is set to default penalty."""
|
||||
graph = setup_graph
|
||||
graph.add_node(Node("1"))
|
||||
graph.add_node(Node("2"))
|
||||
|
|
@ -419,7 +476,82 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_empty_nodes_returns_early(setup_graph):
|
||||
"""Test that node_distances={} returns early without error and vector_distance is set to default penalty."""
|
||||
graph = setup_graph
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes({})
|
||||
|
||||
assert node1.attributes.get("vector_distance") == [3.5]
|
||||
assert node2.attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
||||
"""Test mapping edge distances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances, query_list_length=2
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph):
|
||||
"""Test that unmapped indices in multi-query mode stay at default penalty."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
||||
[], # query 1: no edges mapped
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances, query_list_length=2
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -432,10 +564,10 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
|
||||
node1.add_attribute("vector_distance", 0.9)
|
||||
node2.add_attribute("vector_distance", 0.8)
|
||||
node3.add_attribute("vector_distance", 0.7)
|
||||
node4.add_attribute("vector_distance", 0.6)
|
||||
node1.add_attribute("vector_distance", [0.9])
|
||||
node2.add_attribute("vector_distance", [0.8])
|
||||
node3.add_attribute("vector_distance", [0.7])
|
||||
node4.add_attribute("vector_distance", [0.6])
|
||||
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
|
@ -446,9 +578,9 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
edge2 = Edge(node2, node3)
|
||||
edge3 = Edge(node3, node4)
|
||||
|
||||
edge1.add_attribute("vector_distance", 0.85)
|
||||
edge2.add_attribute("vector_distance", 0.75)
|
||||
edge3.add_attribute("vector_distance", 0.65)
|
||||
edge1.add_attribute("vector_distance", [0.85])
|
||||
edge2.add_attribute("vector_distance", [0.75])
|
||||
edge3.add_attribute("vector_distance", [0.65])
|
||||
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
|
@ -464,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have no vector distances."""
|
||||
"""Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
|
|
@ -475,7 +607,114 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
# Verify vector_distance is None when no distances are passed
|
||||
assert node1.attributes.get("vector_distance") is None
|
||||
assert node2.attributes.get("vector_distance") is None
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
# When no distances are set, calculate_top_triplet_importances should handle None
|
||||
# by either raising an error or skipping edges with None distances
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph):
|
||||
"""Test calculating top triplet importances for a single query index."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1])
|
||||
node2.add_attribute("vector_distance", [0.2])
|
||||
node3.add_attribute("vector_distance", [0.3])
|
||||
|
||||
edge1 = Edge(node1, node2)
|
||||
edge2 = Edge(node2, node3)
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge1.add_attribute("vector_distance", [0.3])
|
||||
edge2.add_attribute("vector_distance", [0.4])
|
||||
|
||||
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
assert results[0][0] == edge1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_multi_query(setup_graph):
|
||||
"""Test calculating top triplet importances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge_a = Edge(node1, node2)
|
||||
edge_b = Edge(node2, node3)
|
||||
graph.add_edge(edge_a)
|
||||
graph.add_edge(edge_b)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1, 0.9])
|
||||
node2.add_attribute("vector_distance", [0.1, 0.9])
|
||||
node3.add_attribute("vector_distance", [0.9, 0.1])
|
||||
edge_a.add_attribute("vector_distance", [0.1, 0.9])
|
||||
edge_b.add_attribute("vector_distance", [0.9, 0.1])
|
||||
|
||||
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == edge_a
|
||||
assert results[1][0] == edge_b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph):
|
||||
"""Test that scoring raises ValueError when list is too short for query_index."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1])
|
||||
node2.add_attribute("vector_distance", [0.2])
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
edge.add_attribute("vector_distance", [0.3])
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph):
|
||||
"""Test that scoring raises error when vector_distance is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
del node1.attributes["vector_distance"]
|
||||
del node2.attributes["vector_distance"]
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
del edge.attributes["vector_distance"]
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue