chore: tweak mapping and scoring

This commit is contained in:
lxobr 2025-12-19 11:40:56 +01:00
parent c1ea7a8cc2
commit a85df53c74
2 changed files with 32 additions and 18 deletions

View file

@ -241,12 +241,9 @@ class CogneeGraph(CogneeAbstractGraph):
query_list_length: Optional[int] = None,
) -> None:
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
if not node_distances:
return None
query_count = query_list_length or 1
# Reset all node distances for this search
self.reset_distances(self.nodes.values(), query_count)
for collection_name, scored_results in node_distances.items():
@ -279,19 +276,17 @@ class CogneeGraph(CogneeAbstractGraph):
query_list_length: Optional[int] = None,
) -> None:
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
if not edge_distances:
return None
query_count = query_list_length or 1
# Reset all edge distances for this search
self.reset_distances(self.edges, query_count)
if not edge_distances:
return None
per_query_edge_lists = self._normalize_query_distance_lists(
edge_distances, query_list_length, "edge_distances"
)
# For each query, apply distances to all matching edges
for query_index, scored_list in enumerate(per_query_edge_lists):
for result in scored_list:
payload = getattr(result, "payload", None)
@ -318,13 +313,32 @@ class CogneeGraph(CogneeAbstractGraph):
) -> List[Edge]:
"""Calculate top k triplet importances for a specific query index."""
def score(edge):
distances = [
edge.node1.attributes.get("vector_distance"),
edge.node2.attributes.get("vector_distance"),
edge.attributes.get("vector_distance"),
]
return sum(float(d[query_index]) for d in distances)
def score(edge: Edge) -> float:
elements = (
(edge.node1, f"node {edge.node1.id}"),
(edge.node2, f"node {edge.node2.id}"),
(edge, f"edge {edge.node1.id}->{edge.node2.id}"),
)
importances = []
for element, label in elements:
distances = element.attributes.get("vector_distance")
if not isinstance(distances, list) or query_index >= len(distances):
raise ValueError(
f"{label}: vector_distance must be a list with length > {query_index} "
f"before scoring (got {type(distances).__name__} with length "
f"{len(distances) if isinstance(distances, list) else 'n/a'})"
)
value = distances[query_index]
try:
importances.append(float(value))
except (TypeError, ValueError):
raise ValueError(
f"{label}: vector_distance[{query_index}] must be float-like, "
f"got {type(value).__name__}"
)
return sum(importances)
return heapq.nsmallest(k, self.edges, key=score)

View file

@ -614,7 +614,7 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
# When no distances are set, calculate_top_triplet_importances should handle None
# by either raising an error or skipping edges with None distances
with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"):
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1)
@ -695,7 +695,7 @@ async def test_calculate_top_triplet_importances_raises_on_short_list(setup_grap
edge.add_attribute("vector_distance", [0.3])
graph.add_edge(edge)
with pytest.raises(IndexError):
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
@ -716,5 +716,5 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
del edge.attributes["vector_distance"]
graph.add_edge(edge)
with pytest.raises((KeyError, TypeError)):
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)