feat: add multi-query support to graph distance mapping
This commit is contained in:
parent
cc7ca45e73
commit
69ab8e7ede
3 changed files with 190 additions and 43 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
|
||||
|
||||
from cognee.modules.graph.exceptions import (
|
||||
EntityNotFoundError,
|
||||
|
|
@ -204,31 +204,105 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
logger.error(f"Error during graph projection: {str(e)}")
|
||||
raise
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
mapped_nodes = 0
|
||||
for category, scored_results in node_distances.items():
|
||||
for scored_result in scored_results:
|
||||
node_id = str(scored_result.id)
|
||||
score = scored_result.score
|
||||
node = self.get_node(node_id)
|
||||
if node:
|
||||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None:
|
||||
"""Initialize vector_distance as a list of default penalties for all graph elements."""
|
||||
query_count = query_list_length or 1
|
||||
for element in graph_elements:
|
||||
element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count
|
||||
|
||||
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
||||
def _normalize_query_input(self, distance_data, query_list_length=None, name="input"):
|
||||
"""Normalize single-query or multi-query input to list of lists, return empty list if empty."""
|
||||
if not distance_data:
|
||||
return []
|
||||
normalized = (
|
||||
distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data]
|
||||
)
|
||||
if query_list_length is not None and len(normalized) != query_list_length:
|
||||
raise ValueError(
|
||||
f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
def _apply_vector_distance_updates(
|
||||
self,
|
||||
element_distances,
|
||||
query_index: int,
|
||||
get_element: Callable[[str], Optional[Union[Node, Edge]]],
|
||||
get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]],
|
||||
) -> None:
|
||||
"""Apply updates into element.attributes["vector_distance"][query_index]."""
|
||||
for res in element_distances:
|
||||
key, score = get_id_and_score(res)
|
||||
if key is None or score is None:
|
||||
continue
|
||||
element = get_element(key)
|
||||
if element is None:
|
||||
continue
|
||||
element.attributes["vector_distance"][query_index] = score
|
||||
|
||||
def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]:
|
||||
"""Extract node ID and score from a scored result."""
|
||||
return str(res.id), float(res.score)
|
||||
|
||||
def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]:
|
||||
"""Extract edge key and score from a scored result."""
|
||||
payload = getattr(res, "payload", None)
|
||||
if not payload:
|
||||
return None, None
|
||||
text = payload.get("text")
|
||||
if text is None:
|
||||
return None, None
|
||||
return str(text), float(res.score)
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(
|
||||
self,
|
||||
node_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
self._initialize_vector_distance(self.nodes.values(), query_list_length)
|
||||
|
||||
for collection_name, scored_results in node_distances.items():
|
||||
per_query_lists = self._normalize_query_input(
|
||||
scored_results, query_list_length, f"Collection '{collection_name}'"
|
||||
)
|
||||
if not per_query_lists:
|
||||
continue
|
||||
|
||||
for query_index, scored_list in enumerate(per_query_lists):
|
||||
self._apply_vector_distance_updates(
|
||||
element_distances=scored_list,
|
||||
query_index=query_index,
|
||||
get_element=self.nodes.get,
|
||||
get_id_and_score=self._get_node_id_and_score,
|
||||
)
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self,
|
||||
edge_distances,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if edge_distances is None:
|
||||
self._initialize_vector_distance(self.edges, query_list_length)
|
||||
|
||||
normalized_edges = self._normalize_query_input(
|
||||
edge_distances, query_list_length, "edge_distances"
|
||||
)
|
||||
if not normalized_edges:
|
||||
return
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
edges_by_key: Dict[str, Edge] = {}
|
||||
for edge in self.edges:
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
if key:
|
||||
edges_by_key[str(key)] = edge
|
||||
|
||||
for query_index, scored_list in enumerate(normalized_edges):
|
||||
self._apply_vector_distance_updates(
|
||||
element_distances=scored_list,
|
||||
query_index=query_index,
|
||||
get_element=edges_by_key.get,
|
||||
get_id_and_score=self._get_edge_id_and_score,
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
except Exception as ex:
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
|
|
|
|||
|
|
@ -350,11 +350,32 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
|||
assert triplets, f"{name}: Triplets list should not be empty"
|
||||
for edge in triplets:
|
||||
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
||||
distance = edge.attributes.get("vector_distance")
|
||||
node1_distance = edge.node1.attributes.get("vector_distance")
|
||||
node2_distance = edge.node2.attributes.get("vector_distance")
|
||||
assert isinstance(distance, float), f"{name}: vector_distance should be float"
|
||||
vector_distances = edge.attributes.get("vector_distance", [])
|
||||
assert isinstance(vector_distances, list) and vector_distances, (
|
||||
f"{name}: vector_distance should be a non-empty list"
|
||||
)
|
||||
distance = vector_distances[0]
|
||||
assert isinstance(distance, float), (
|
||||
f"{name}: vector_distance[0] should be float, got {type(distance)}"
|
||||
)
|
||||
assert 0 <= distance <= 1
|
||||
|
||||
node1_distances = edge.node1.attributes.get("vector_distance", [])
|
||||
node2_distances = edge.node2.attributes.get("vector_distance", [])
|
||||
assert isinstance(node1_distances, list) and node1_distances, (
|
||||
f"{name}: node1 vector_distance should be a non-empty list"
|
||||
)
|
||||
assert isinstance(node2_distances, list) and node2_distances, (
|
||||
f"{name}: node2 vector_distance should be a non-empty list"
|
||||
)
|
||||
node1_distance = node1_distances[0]
|
||||
node2_distance = node2_distances[0]
|
||||
assert isinstance(node1_distance, float), (
|
||||
f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}"
|
||||
)
|
||||
assert isinstance(node2_distance, float), (
|
||||
f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}"
|
||||
)
|
||||
assert 0 <= node1_distance <= 1
|
||||
assert 0 <= node2_distance <= 1
|
||||
|
||||
|
|
|
|||
|
|
@ -241,12 +241,11 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
|
||||
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
||||
"""Test mapping vector distances when only some nodes have results."""
|
||||
graph = setup_graph
|
||||
|
|
@ -267,13 +266,12 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
|
||||
async def test_map_vector_distances_multiple_categories(setup_graph):
|
||||
"""Test mapping vector distances from multiple collection categories."""
|
||||
graph = setup_graph
|
||||
|
|
@ -300,10 +298,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [0.92]
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph):
|
||||
"""Test mapping vector distances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("1", 0.95)], # query 0
|
||||
[MockScoredResult("2", 0.87)], # query 1
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5]
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87]
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -329,11 +353,10 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
|
@ -356,8 +379,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -383,11 +406,10 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.85]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances")
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
|
@ -410,7 +432,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -423,7 +445,37 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
|
|||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
||||
"""Test mapping edge distances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances, query_list_length=2
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue