fix: improve graph distance mapping
This commit is contained in:
parent
46ff01021a
commit
c1ea7a8cc2
5 changed files with 210 additions and 95 deletions
|
|
@ -25,12 +25,14 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
nodes: Dict[str, Node]
|
||||
edges: List[Edge]
|
||||
edges_by_distance_key: Dict[str, List[Edge]]
|
||||
directed: bool
|
||||
triplet_distance_penalty: float
|
||||
|
||||
def __init__(self, directed: bool = True):
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
self.edges_by_distance_key = {}
|
||||
self.directed = directed
|
||||
self.triplet_distance_penalty = 3.5
|
||||
|
||||
|
|
@ -44,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
self.edges.append(edge)
|
||||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
key = edge.get_distance_key()
|
||||
if not key:
|
||||
return
|
||||
if key not in self.edges_by_distance_key:
|
||||
self.edges_by_distance_key[key] = []
|
||||
self.edges_by_distance_key[key].append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Node:
|
||||
return self.nodes.get(node_id, None)
|
||||
|
|
@ -58,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None:
|
||||
"""Reset vector distances for a collection of nodes or edges."""
|
||||
for item in collection:
|
||||
item.reset_vector_distances(query_count, self.triplet_distance_penalty)
|
||||
|
||||
def _normalize_query_distance_lists(
|
||||
self, distances: List, query_list_length: Optional[int] = None, name: str = "distances"
|
||||
) -> List:
|
||||
"""Normalize shape: flat list -> single-query; nested list -> multi-query."""
|
||||
if not distances:
|
||||
return []
|
||||
first_item = distances[0]
|
||||
if isinstance(first_item, (list, tuple)):
|
||||
per_query_lists = distances
|
||||
else:
|
||||
per_query_lists = [distances]
|
||||
if query_list_length is not None and len(per_query_lists) != query_list_length:
|
||||
raise ValueError(
|
||||
f"{name} has {len(per_query_lists)} query lists, "
|
||||
f"but query_list_length is {query_list_length}"
|
||||
)
|
||||
return per_query_lists
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
|
|
@ -204,109 +235,81 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
logger.error(f"Error during graph projection: {str(e)}")
|
||||
raise
|
||||
|
||||
def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None:
|
||||
"""Initialize vector_distance as a list of default penalties for all graph elements."""
|
||||
query_count = query_list_length or 1
|
||||
for element in graph_elements:
|
||||
element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count
|
||||
|
||||
def _normalize_query_input(self, distance_data, query_list_length=None, name="input"):
|
||||
"""Normalize single-query or multi-query input to list of lists, return empty list if empty."""
|
||||
if not distance_data:
|
||||
return []
|
||||
normalized = (
|
||||
distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data]
|
||||
)
|
||||
if query_list_length is not None and len(normalized) != query_list_length:
|
||||
raise ValueError(
|
||||
f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
def _apply_vector_distance_updates(
|
||||
self,
|
||||
element_distances,
|
||||
query_index: int,
|
||||
get_element: Callable[[str], Optional[Union[Node, Edge]]],
|
||||
get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]],
|
||||
) -> None:
|
||||
"""Apply updates into element.attributes["vector_distance"][query_index]."""
|
||||
for res in element_distances:
|
||||
key, score = get_id_and_score(res)
|
||||
if key is None or score is None:
|
||||
continue
|
||||
element = get_element(key)
|
||||
if element is None:
|
||||
continue
|
||||
element.attributes["vector_distance"][query_index] = score
|
||||
|
||||
def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]:
|
||||
"""Extract node ID and score from a scored result."""
|
||||
return str(res.id), float(res.score)
|
||||
|
||||
def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]:
|
||||
"""Extract edge key and score from a scored result."""
|
||||
payload = getattr(res, "payload", None)
|
||||
if not payload:
|
||||
return None, None
|
||||
text = payload.get("text")
|
||||
if text is None:
|
||||
return None, None
|
||||
return str(text), float(res.score)
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(
|
||||
self,
|
||||
node_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
self._initialize_vector_distance(self.nodes.values(), query_list_length)
|
||||
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
|
||||
if not node_distances:
|
||||
return None
|
||||
|
||||
query_count = query_list_length or 1
|
||||
|
||||
# Reset all node distances for this search
|
||||
self.reset_distances(self.nodes.values(), query_count)
|
||||
|
||||
for collection_name, scored_results in node_distances.items():
|
||||
per_query_lists = self._normalize_query_input(
|
||||
scored_results, query_list_length, f"Collection '{collection_name}'"
|
||||
)
|
||||
if not per_query_lists:
|
||||
if not scored_results:
|
||||
continue
|
||||
|
||||
per_query_lists = self._normalize_query_distance_lists(
|
||||
scored_results, query_list_length, f"Collection '{collection_name}'"
|
||||
)
|
||||
|
||||
for query_index, scored_list in enumerate(per_query_lists):
|
||||
self._apply_vector_distance_updates(
|
||||
element_distances=scored_list,
|
||||
query_index=query_index,
|
||||
get_element=self.nodes.get,
|
||||
get_id_and_score=self._get_node_id_and_score,
|
||||
)
|
||||
for result in scored_list:
|
||||
node_id = str(getattr(result, "id", None))
|
||||
if not node_id:
|
||||
continue
|
||||
node = self.get_node(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
score = float(getattr(result, "score", self.triplet_distance_penalty))
|
||||
node.update_distance_for_query(
|
||||
query_index=query_index,
|
||||
score=score,
|
||||
query_count=query_count,
|
||||
default_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self,
|
||||
edge_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
try:
|
||||
self._initialize_vector_distance(self.edges, query_list_length)
|
||||
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
|
||||
if not edge_distances:
|
||||
return None
|
||||
|
||||
normalized_edges = self._normalize_query_input(
|
||||
edge_distances, query_list_length, "edge_distances"
|
||||
)
|
||||
if not normalized_edges:
|
||||
return
|
||||
query_count = query_list_length or 1
|
||||
|
||||
edges_by_key: Dict[str, Edge] = {}
|
||||
for edge in self.edges:
|
||||
key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
if key:
|
||||
edges_by_key[str(key)] = edge
|
||||
# Reset all edge distances for this search
|
||||
self.reset_distances(self.edges, query_count)
|
||||
|
||||
for query_index, scored_list in enumerate(normalized_edges):
|
||||
self._apply_vector_distance_updates(
|
||||
element_distances=scored_list,
|
||||
query_index=query_index,
|
||||
get_element=edges_by_key.get,
|
||||
get_id_and_score=self._get_edge_id_and_score,
|
||||
)
|
||||
per_query_edge_lists = self._normalize_query_distance_lists(
|
||||
edge_distances, query_list_length, "edge_distances"
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
# For each query, apply distances to all matching edges
|
||||
for query_index, scored_list in enumerate(per_query_edge_lists):
|
||||
for result in scored_list:
|
||||
payload = getattr(result, "payload", None)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
text = payload.get("text")
|
||||
if not text:
|
||||
continue
|
||||
matching_edges = self.edges_by_distance_key.get(str(text))
|
||||
if not matching_edges:
|
||||
continue
|
||||
for edge in matching_edges:
|
||||
edge.update_distance_for_query(
|
||||
query_index=query_index,
|
||||
score=float(getattr(result, "score", self.triplet_distance_penalty)),
|
||||
query_count=query_count,
|
||||
default_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
def _calculate_query_top_triplet_importances(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,26 @@ class Node:
|
|||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
|
||||
self.attributes["vector_distance"] = [default_penalty] * query_count
|
||||
|
||||
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
|
||||
distances = self.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or len(distances) != query_count:
|
||||
distances = [default_penalty] * query_count
|
||||
self.attributes["vector_distance"] = distances
|
||||
return distances
|
||||
|
||||
def update_distance_for_query(
|
||||
self,
|
||||
query_index: int,
|
||||
score: float,
|
||||
query_count: int,
|
||||
default_penalty: float,
|
||||
) -> None:
|
||||
distances = self.ensure_vector_distance_list(query_count, default_penalty)
|
||||
distances[query_index] = score
|
||||
|
||||
def add_skeleton_neighbor(self, neighbor: "Node") -> None:
|
||||
if neighbor not in self.skeleton_neighbours:
|
||||
self.skeleton_neighbours.append(neighbor)
|
||||
|
|
@ -120,6 +140,32 @@ class Edge:
|
|||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
def get_distance_key(self) -> Optional[str]:
|
||||
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
|
||||
if key is None:
|
||||
return None
|
||||
return str(key)
|
||||
|
||||
def reset_vector_distances(self, query_count: int, default_penalty: float) -> None:
|
||||
self.attributes["vector_distance"] = [default_penalty] * query_count
|
||||
|
||||
def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]:
|
||||
distances = self.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or len(distances) != query_count:
|
||||
distances = [default_penalty] * query_count
|
||||
self.attributes["vector_distance"] = distances
|
||||
return distances
|
||||
|
||||
def update_distance_for_query(
|
||||
self,
|
||||
query_index: int,
|
||||
score: float,
|
||||
query_count: int,
|
||||
default_penalty: float,
|
||||
) -> None:
|
||||
distances = self.ensure_vector_distance_list(query_count, default_penalty)
|
||||
distances[query_index] = score
|
||||
|
||||
def is_edge_alive_in_dimension(self, dimension: int) -> bool:
|
||||
if dimension < 0 or dimension >= len(self.status):
|
||||
raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1)
|
||||
|
|
|
|||
|
|
@ -350,7 +350,10 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
|||
assert triplets, f"{name}: Triplets list should not be empty"
|
||||
for edge in triplets:
|
||||
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||
vector_distances = edge.attributes.get("vector_distance", [])
|
||||
vector_distances = edge.attributes.get("vector_distance")
|
||||
assert vector_distances is not None, (
|
||||
f"{name}: vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert isinstance(vector_distances, list) and vector_distances, (
|
||||
f"{name}: vector_distance should be a non-empty list"
|
||||
)
|
||||
|
|
@ -360,8 +363,14 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
|||
)
|
||||
assert 0 <= distance <= 1
|
||||
|
||||
node1_distances = edge.node1.attributes.get("vector_distance", [])
|
||||
node2_distances = edge.node2.attributes.get("vector_distance", [])
|
||||
node1_distances = edge.node1.attributes.get("vector_distance")
|
||||
node2_distances = edge.node2.attributes.get("vector_distance")
|
||||
assert node1_distances is not None, (
|
||||
f"{name}: node1 vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert node2_distances is not None, (
|
||||
f"{name}: node2 vector_distance should be set when retrievers return results"
|
||||
)
|
||||
assert isinstance(node1_distances, list) and node1_distances, (
|
||||
f"{name}: node1 vector_distance should be a non-empty list"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -86,6 +86,46 @@ def test_node_hash():
|
|||
assert hash(node) == hash("node1")
|
||||
|
||||
|
||||
def test_node_vector_distance_stays_none():
|
||||
"""Test that vector_distance remains None when no distances are passed."""
|
||||
node = Node("node1")
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
|
||||
# Verify it stays None even after other operations
|
||||
node.add_attribute("other_attr", "value")
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
def test_node_vector_distance_with_custom_attributes():
|
||||
"""Test that vector_distance is None even when node has custom attributes."""
|
||||
node = Node("node1", {"custom": "value", "another": 42})
|
||||
assert node.attributes.get("vector_distance") is None
|
||||
assert node.attributes["custom"] == "value"
|
||||
assert node.attributes["another"] == 42
|
||||
|
||||
|
||||
def test_edge_vector_distance_stays_none():
|
||||
"""Test that vector_distance remains None when no distances are passed."""
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
edge = Edge(node1, node2)
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
# Verify it stays None even after other operations
|
||||
edge.add_attribute("other_attr", "value")
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
def test_edge_vector_distance_with_custom_attributes():
|
||||
"""Test that vector_distance is None even when edge has custom attributes."""
|
||||
node1 = Node("node1")
|
||||
node2 = Node("node2")
|
||||
edge = Edge(node1, node2, {"weight": 5, "type": "test"})
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
assert edge.attributes["weight"] == 5
|
||||
assert edge.attributes["type"] == "test"
|
||||
|
||||
|
||||
### Tests for Edge ###
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -468,7 +468,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_none_returns_early(setup_graph):
|
||||
"""Test that edge_distances=None returns early without error."""
|
||||
"""Test that edge_distances=None returns early without error and vector_distance stays None."""
|
||||
graph = setup_graph
|
||||
graph.add_node(Node("1"))
|
||||
graph.add_node(Node("2"))
|
||||
|
|
@ -476,7 +476,22 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
||||
assert graph.edges[0].attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_empty_nodes_returns_early(setup_graph):
|
||||
"""Test that node_distances={} returns early without error and vector_distance stays None."""
|
||||
graph = setup_graph
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes({})
|
||||
|
||||
assert node1.attributes.get("vector_distance") is None
|
||||
assert node2.attributes.get("vector_distance") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -581,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have default vector distances."""
|
||||
"""Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
|
|
@ -592,13 +607,15 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes({})
|
||||
await graph.map_vector_distances_to_graph_edges(None)
|
||||
# Verify vector_distance is None when no distances are passed
|
||||
assert node1.attributes.get("vector_distance") is None
|
||||
assert node2.attributes.get("vector_distance") is None
|
||||
assert edge.attributes.get("vector_distance") is None
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
# When no distances are set, calculate_top_triplet_importances should handle None
|
||||
# by either raising an error or skipping edges with None distances
|
||||
with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"):
|
||||
await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue