chore: tweak mapping and scoring
This commit is contained in:
parent
c1ea7a8cc2
commit
a85df53c74
2 changed files with 32 additions and 18 deletions
|
|
@ -241,12 +241,9 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
query_list_length: Optional[int] = None,
|
query_list_length: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
|
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
|
||||||
if not node_distances:
|
|
||||||
return None
|
|
||||||
|
|
||||||
query_count = query_list_length or 1
|
query_count = query_list_length or 1
|
||||||
|
|
||||||
# Reset all node distances for this search
|
|
||||||
self.reset_distances(self.nodes.values(), query_count)
|
self.reset_distances(self.nodes.values(), query_count)
|
||||||
|
|
||||||
for collection_name, scored_results in node_distances.items():
|
for collection_name, scored_results in node_distances.items():
|
||||||
|
|
@ -279,19 +276,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
query_list_length: Optional[int] = None,
|
query_list_length: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
|
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
|
||||||
if not edge_distances:
|
|
||||||
return None
|
|
||||||
|
|
||||||
query_count = query_list_length or 1
|
query_count = query_list_length or 1
|
||||||
|
|
||||||
# Reset all edge distances for this search
|
|
||||||
self.reset_distances(self.edges, query_count)
|
self.reset_distances(self.edges, query_count)
|
||||||
|
|
||||||
|
if not edge_distances:
|
||||||
|
return None
|
||||||
|
|
||||||
per_query_edge_lists = self._normalize_query_distance_lists(
|
per_query_edge_lists = self._normalize_query_distance_lists(
|
||||||
edge_distances, query_list_length, "edge_distances"
|
edge_distances, query_list_length, "edge_distances"
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each query, apply distances to all matching edges
|
|
||||||
for query_index, scored_list in enumerate(per_query_edge_lists):
|
for query_index, scored_list in enumerate(per_query_edge_lists):
|
||||||
for result in scored_list:
|
for result in scored_list:
|
||||||
payload = getattr(result, "payload", None)
|
payload = getattr(result, "payload", None)
|
||||||
|
|
@ -318,13 +313,32 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
) -> List[Edge]:
|
) -> List[Edge]:
|
||||||
"""Calculate top k triplet importances for a specific query index."""
|
"""Calculate top k triplet importances for a specific query index."""
|
||||||
|
|
||||||
def score(edge):
|
def score(edge: Edge) -> float:
|
||||||
distances = [
|
elements = (
|
||||||
edge.node1.attributes.get("vector_distance"),
|
(edge.node1, f"node {edge.node1.id}"),
|
||||||
edge.node2.attributes.get("vector_distance"),
|
(edge.node2, f"node {edge.node2.id}"),
|
||||||
edge.attributes.get("vector_distance"),
|
(edge, f"edge {edge.node1.id}->{edge.node2.id}"),
|
||||||
]
|
)
|
||||||
return sum(float(d[query_index]) for d in distances)
|
|
||||||
|
importances = []
|
||||||
|
for element, label in elements:
|
||||||
|
distances = element.attributes.get("vector_distance")
|
||||||
|
if not isinstance(distances, list) or query_index >= len(distances):
|
||||||
|
raise ValueError(
|
||||||
|
f"{label}: vector_distance must be a list with length > {query_index} "
|
||||||
|
f"before scoring (got {type(distances).__name__} with length "
|
||||||
|
f"{len(distances) if isinstance(distances, list) else 'n/a'})"
|
||||||
|
)
|
||||||
|
value = distances[query_index]
|
||||||
|
try:
|
||||||
|
importances.append(float(value))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise ValueError(
|
||||||
|
f"{label}: vector_distance[{query_index}] must be float-like, "
|
||||||
|
f"got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sum(importances)
|
||||||
|
|
||||||
return heapq.nsmallest(k, self.edges, key=score)
|
return heapq.nsmallest(k, self.edges, key=score)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -614,7 +614,7 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||||
|
|
||||||
# When no distances are set, calculate_top_triplet_importances should handle None
|
# When no distances are set, calculate_top_triplet_importances should handle None
|
||||||
# by either raising an error or skipping edges with None distances
|
# by either raising an error or skipping edges with None distances
|
||||||
with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"):
|
with pytest.raises(ValueError):
|
||||||
await graph.calculate_top_triplet_importances(k=1)
|
await graph.calculate_top_triplet_importances(k=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -695,7 +695,7 @@ async def test_calculate_top_triplet_importances_raises_on_short_list(setup_grap
|
||||||
edge.add_attribute("vector_distance", [0.3])
|
edge.add_attribute("vector_distance", [0.3])
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(ValueError):
|
||||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -716,5 +716,5 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
|
||||||
del edge.attributes["vector_distance"]
|
del edge.attributes["vector_distance"]
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
with pytest.raises((KeyError, TypeError)):
|
with pytest.raises(ValueError):
|
||||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue