feat: add multi-query support to score calculation
This commit is contained in:
parent
69ab8e7ede
commit
46ff01021a
2 changed files with 200 additions and 22 deletions
|
|
@ -308,21 +308,33 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
|
||||
def _as_distance(self, value: Union[float, List[float], None]) -> float:
|
||||
"""Normalize distance value to float, handling None, lists, and scalars."""
|
||||
if value is None:
|
||||
return self.triplet_distance_penalty
|
||||
if isinstance(value, list) and value:
|
||||
return float(value[0])
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
return self.triplet_distance_penalty
|
||||
def _calculate_query_top_triplet_importances(
|
||||
self,
|
||||
k: int,
|
||||
query_index: int = 0,
|
||||
) -> List[Edge]:
|
||||
"""Calculate top k triplet importances for a specific query index."""
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
|
||||
def score(edge):
|
||||
n1 = self._as_distance(edge.node1.attributes.get("vector_distance"))
|
||||
n2 = self._as_distance(edge.node2.attributes.get("vector_distance"))
|
||||
e = self._as_distance(edge.attributes.get("vector_distance"))
|
||||
return n1 + n2 + e
|
||||
distances = [
|
||||
edge.node1.attributes.get("vector_distance"),
|
||||
edge.node2.attributes.get("vector_distance"),
|
||||
edge.attributes.get("vector_distance"),
|
||||
]
|
||||
return sum(float(d[query_index]) for d in distances)
|
||||
|
||||
return heapq.nsmallest(k, self.edges, key=score)
|
||||
|
||||
async def calculate_top_triplet_importances(
|
||||
self, k: int, query_list_length: Optional[int] = None
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""Calculate top k triplet importances, supporting both single and multi-query modes."""
|
||||
query_count = query_list_length or 1
|
||||
results = [
|
||||
self._calculate_query_top_triplet_importances(k=k, query_index=i)
|
||||
for i in range(query_count)
|
||||
]
|
||||
|
||||
if query_list_length is None:
|
||||
return results[0]
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter):
|
||||
"""Test that project_graph_from_db stores triplet_distance_penalty on the graph."""
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
||||
nodes_data = [("1", {"name": "Node1"})]
|
||||
edges_data = [("1", "1", "SELF", {})]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
graph = CogneeGraph()
|
||||
custom_penalty = 5.0
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
triplet_distance_penalty=custom_penalty,
|
||||
)
|
||||
|
||||
assert graph.triplet_distance_penalty == custom_penalty
|
||||
|
||||
graph2 = CogneeGraph()
|
||||
await graph2.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
)
|
||||
|
||||
assert graph2.triplet_distance_penalty == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
||||
"""Test that edges referencing missing nodes raise error."""
|
||||
|
|
@ -478,6 +509,36 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
|||
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph):
|
||||
"""Test that unmapped indices in multi-query mode stay at default penalty."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
||||
[], # query 1: no edges mapped
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances, query_list_length=2
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
||||
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances(setup_graph):
|
||||
"""Test calculating top triplet importances by score."""
|
||||
|
|
@ -488,10 +549,10 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
|
||||
node1.add_attribute("vector_distance", 0.9)
|
||||
node2.add_attribute("vector_distance", 0.8)
|
||||
node3.add_attribute("vector_distance", 0.7)
|
||||
node4.add_attribute("vector_distance", 0.6)
|
||||
node1.add_attribute("vector_distance", [0.9])
|
||||
node2.add_attribute("vector_distance", [0.8])
|
||||
node3.add_attribute("vector_distance", [0.7])
|
||||
node4.add_attribute("vector_distance", [0.6])
|
||||
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
|
@ -502,9 +563,9 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
edge2 = Edge(node2, node3)
|
||||
edge3 = Edge(node3, node4)
|
||||
|
||||
edge1.add_attribute("vector_distance", 0.85)
|
||||
edge2.add_attribute("vector_distance", 0.75)
|
||||
edge3.add_attribute("vector_distance", 0.65)
|
||||
edge1.add_attribute("vector_distance", [0.85])
|
||||
edge2.add_attribute("vector_distance", [0.75])
|
||||
edge3.add_attribute("vector_distance", [0.65])
|
||||
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
|
@ -520,7 +581,7 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have no vector distances."""
|
||||
"""Test calculating importances when nodes/edges have default vector distances."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
|
|
@ -531,7 +592,112 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes({})
|
||||
await graph.map_vector_distances_to_graph_edges(None)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph):
|
||||
"""Test calculating top triplet importances for a single query index."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1])
|
||||
node2.add_attribute("vector_distance", [0.2])
|
||||
node3.add_attribute("vector_distance", [0.3])
|
||||
|
||||
edge1 = Edge(node1, node2)
|
||||
edge2 = Edge(node2, node3)
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge1.add_attribute("vector_distance", [0.3])
|
||||
edge2.add_attribute("vector_distance", [0.4])
|
||||
|
||||
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
assert results[0][0] == edge1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_multi_query(setup_graph):
|
||||
"""Test calculating top triplet importances with multiple queries."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge_a = Edge(node1, node2)
|
||||
edge_b = Edge(node2, node3)
|
||||
graph.add_edge(edge_a)
|
||||
graph.add_edge(edge_b)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1, 0.9])
|
||||
node2.add_attribute("vector_distance", [0.1, 0.9])
|
||||
node3.add_attribute("vector_distance", [0.9, 0.1])
|
||||
edge_a.add_attribute("vector_distance", [0.1, 0.9])
|
||||
edge_b.add_attribute("vector_distance", [0.9, 0.1])
|
||||
|
||||
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == edge_a
|
||||
assert results[1][0] == edge_b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph):
|
||||
"""Test that scoring raises ValueError when list is too short for query_index."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
node1.add_attribute("vector_distance", [0.1])
|
||||
node2.add_attribute("vector_distance", [0.2])
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
edge.add_attribute("vector_distance", [0.3])
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph):
|
||||
"""Test that scoring raises error when vector_distance is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
del node1.attributes["vector_distance"]
|
||||
del node2.attributes["vector_distance"]
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
del edge.attributes["vector_distance"]
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises((KeyError, TypeError)):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue