fix: improve graph distance mapping

This commit is contained in:
lxobr 2025-12-18 14:52:35 +01:00
parent 46ff01021a
commit c1ea7a8cc2
5 changed files with 210 additions and 95 deletions

View file

@ -25,12 +25,14 @@ class CogneeGraph(CogneeAbstractGraph):
nodes: Dict[str, Node]
edges: List[Edge]
edges_by_distance_key: Dict[str, List[Edge]]
directed: bool
triplet_distance_penalty: float
def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.edges_by_distance_key = {}
self.directed = directed
self.triplet_distance_penalty = 3.5
@ -44,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph):
self.edges.append(edge)
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
key = edge.get_distance_key()
if not key:
return
if key not in self.edges_by_distance_key:
self.edges_by_distance_key[key] = []
self.edges_by_distance_key[key].append(edge)
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
@ -58,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None:
"""Reset vector distances for a collection of nodes or edges."""
for item in collection:
item.reset_vector_distances(query_count, self.triplet_distance_penalty)
def _normalize_query_distance_lists(
self, distances: List, query_list_length: Optional[int] = None, name: str = "distances"
) -> List:
"""Normalize shape: flat list -> single-query; nested list -> multi-query."""
if not distances:
return []
first_item = distances[0]
if isinstance(first_item, (list, tuple)):
per_query_lists = distances
else:
per_query_lists = [distances]
if query_list_length is not None and len(per_query_lists) != query_list_length:
raise ValueError(
f"{name} has {len(per_query_lists)} query lists, "
f"but query_list_length is {query_list_length}"
)
return per_query_lists
async def _get_nodeset_subgraph(
self,
adapter,
@ -204,109 +235,81 @@ class CogneeGraph(CogneeAbstractGraph):
logger.error(f"Error during graph projection: {str(e)}")
raise
def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None:
"""Initialize vector_distance as a list of default penalties for all graph elements."""
query_count = query_list_length or 1
for element in graph_elements:
element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count
def _normalize_query_input(self, distance_data, query_list_length=None, name="input"):
"""Normalize single-query or multi-query input to list of lists, return empty list if empty."""
if not distance_data:
return []
normalized = (
distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data]
)
if query_list_length is not None and len(normalized) != query_list_length:
raise ValueError(
f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}"
)
return normalized
def _apply_vector_distance_updates(
self,
element_distances,
query_index: int,
get_element: Callable[[str], Optional[Union[Node, Edge]]],
get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]],
) -> None:
"""Apply updates into element.attributes["vector_distance"][query_index]."""
for res in element_distances:
key, score = get_id_and_score(res)
if key is None or score is None:
continue
element = get_element(key)
if element is None:
continue
element.attributes["vector_distance"][query_index] = score
def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]:
"""Extract node ID and score from a scored result."""
return str(res.id), float(res.score)
def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]:
"""Extract edge key and score from a scored result."""
payload = getattr(res, "payload", None)
if not payload:
return None, None
text = payload.get("text")
if text is None:
return None, None
return str(text), float(res.score)
async def map_vector_distances_to_graph_nodes(
self,
node_distances,
query_list_length: Optional[int] = None,
) -> None:
self._initialize_vector_distance(self.nodes.values(), query_list_length)
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
if not node_distances:
return None
query_count = query_list_length or 1
# Reset all node distances for this search
self.reset_distances(self.nodes.values(), query_count)
for collection_name, scored_results in node_distances.items():
per_query_lists = self._normalize_query_input(
scored_results, query_list_length, f"Collection '{collection_name}'"
)
if not per_query_lists:
if not scored_results:
continue
per_query_lists = self._normalize_query_distance_lists(
scored_results, query_list_length, f"Collection '{collection_name}'"
)
for query_index, scored_list in enumerate(per_query_lists):
self._apply_vector_distance_updates(
element_distances=scored_list,
query_index=query_index,
get_element=self.nodes.get,
get_id_and_score=self._get_node_id_and_score,
)
for result in scored_list:
node_id = str(getattr(result, "id", None))
if not node_id:
continue
node = self.get_node(node_id)
if node is None:
continue
score = float(getattr(result, "score", self.triplet_distance_penalty))
node.update_distance_for_query(
query_index=query_index,
score=score,
query_count=query_count,
default_penalty=self.triplet_distance_penalty,
)
async def map_vector_distances_to_graph_edges(
self,
edge_distances,
query_list_length: Optional[int] = None,
) -> None:
try:
self._initialize_vector_distance(self.edges, query_list_length)
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
if not edge_distances:
return None
normalized_edges = self._normalize_query_input(
edge_distances, query_list_length, "edge_distances"
)
if not normalized_edges:
return
query_count = query_list_length or 1
edges_by_key: Dict[str, Edge] = {}
for edge in self.edges:
key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
if key:
edges_by_key[str(key)] = edge
# Reset all edge distances for this search
self.reset_distances(self.edges, query_count)
for query_index, scored_list in enumerate(normalized_edges):
self._apply_vector_distance_updates(
element_distances=scored_list,
query_index=query_index,
get_element=edges_by_key.get,
get_id_and_score=self._get_edge_id_and_score,
)
per_query_edge_lists = self._normalize_query_distance_lists(
edge_distances, query_list_length, "edge_distances"
)
except Exception as ex:
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
# For each query, apply distances to all matching edges
for query_index, scored_list in enumerate(per_query_edge_lists):
for result in scored_list:
payload = getattr(result, "payload", None)
if not isinstance(payload, dict):
continue
text = payload.get("text")
if not text:
continue
matching_edges = self.edges_by_distance_key.get(str(text))
if not matching_edges:
continue
for edge in matching_edges:
edge.update_distance_for_query(
query_index=query_index,
score=float(getattr(result, "score", self.triplet_distance_penalty)),
query_count=query_count,
default_penalty=self.triplet_distance_penalty,
)
def _calculate_query_top_triplet_importances(
self,

View file

@ -35,6 +35,26 @@ class Node:
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def add_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor not in self.skeleton_neighbours:
self.skeleton_neighbours.append(neighbor)
@ -120,6 +140,32 @@ class Edge:
self.directed = directed
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
if key is None:
return None
return str(key)
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
self.attributes["vector_distance"] = [default_penalty] * query_count
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
distances = self.attributes.get("vector_distance")
if not isinstance(distances, list) or len(distances) != query_count:
distances = [default_penalty] * query_count
self.attributes["vector_distance"] = distances
return distances
def update_distance_for_query(
self,
query_index: int,
score: float,
query_count: int,
default_penalty: float,
) -> None:
distances = self.ensure_vector_distance_list(query_count, default_penalty)
distances[query_index] = score
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)

View file

@ -350,7 +350,10 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
assert triplets, f"{name}: Triplets list should not be empty"
for edge in triplets:
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
vector_distances = edge.attributes.get("vector_distance", [])
vector_distances = edge.attributes.get("vector_distance")
assert vector_distances is not None, (
f"{name}: vector_distance should be set when retrievers return results"
)
assert isinstance(vector_distances, list) and vector_distances, (
f"{name}: vector_distance should be a non-empty list"
)
@ -360,8 +363,14 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
)
assert 0 <= distance <= 1
node1_distances = edge.node1.attributes.get("vector_distance", [])
node2_distances = edge.node2.attributes.get("vector_distance", [])
node1_distances = edge.node1.attributes.get("vector_distance")
node2_distances = edge.node2.attributes.get("vector_distance")
assert node1_distances is not None, (
f"{name}: node1 vector_distance should be set when retrievers return results"
)
assert node2_distances is not None, (
f"{name}: node2 vector_distance should be set when retrievers return results"
)
assert isinstance(node1_distances, list) and node1_distances, (
f"{name}: node1 vector_distance should be a non-empty list"
)

View file

@ -86,6 +86,46 @@ def test_node_hash():
assert hash(node) == hash("node1")
def test_node_vector_distance_stays_none():
"""Test that vector_distance remains None when no distances are passed."""
node = Node("node1")
assert node.attributes.get("vector_distance") is None
# Verify it stays None even after other operations
node.add_attribute("other_attr", "value")
assert node.attributes.get("vector_distance") is None
def test_node_vector_distance_with_custom_attributes():
"""Test that vector_distance is None even when node has custom attributes."""
node = Node("node1", {"custom": "value", "another": 42})
assert node.attributes.get("vector_distance") is None
assert node.attributes["custom"] == "value"
assert node.attributes["another"] == 42
def test_edge_vector_distance_stays_none():
"""Test that vector_distance remains None when no distances are passed."""
node1 = Node("node1")
node2 = Node("node2")
edge = Edge(node1, node2)
assert edge.attributes.get("vector_distance") is None
# Verify it stays None even after other operations
edge.add_attribute("other_attr", "value")
assert edge.attributes.get("vector_distance") is None
def test_edge_vector_distance_with_custom_attributes():
"""Test that vector_distance is None even when edge has custom attributes."""
node1 = Node("node1")
node2 = Node("node2")
edge = Edge(node1, node2, {"weight": 5, "type": "test"})
assert edge.attributes.get("vector_distance") is None
assert edge.attributes["weight"] == 5
assert edge.attributes["type"] == "test"
### Tests for Edge ###

View file

@ -468,7 +468,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
@pytest.mark.asyncio
async def test_map_vector_distances_none_returns_early(setup_graph):
"""Test that edge_distances=None returns early without error."""
"""Test that edge_distances=None returns early without error and vector_distance stays None."""
graph = setup_graph
graph.add_node(Node("1"))
graph.add_node(Node("2"))
@ -476,7 +476,22 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
assert graph.edges[0].attributes.get("vector_distance") is None
@pytest.mark.asyncio
async def test_map_vector_distances_empty_nodes_returns_early(setup_graph):
"""Test that node_distances={} returns early without error and vector_distance stays None."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
await graph.map_vector_distances_to_graph_nodes({})
assert node1.attributes.get("vector_distance") is None
assert node2.attributes.get("vector_distance") is None
@pytest.mark.asyncio
@ -581,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph):
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
"""Test calculating importances when nodes/edges have default vector distances."""
"""Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it."""
graph = setup_graph
node1 = Node("1")
@ -592,13 +607,15 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
edge = Edge(node1, node2)
graph.add_edge(edge)
await graph.map_vector_distances_to_graph_nodes({})
await graph.map_vector_distances_to_graph_edges(None)
# Verify vector_distance is None when no distances are passed
assert node1.attributes.get("vector_distance") is None
assert node2.attributes.get("vector_distance") is None
assert edge.attributes.get("vector_distance") is None
top_triplets = await graph.calculate_top_triplet_importances(k=1)
assert len(top_triplets) == 1
assert top_triplets[0] == edge
# When no distances are set, calculate_top_triplet_importances should handle None
# by either raising an error or skipping edges with None distances
with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"):
await graph.calculate_top_triplet_importances(k=1)
@pytest.mark.asyncio