chore: tweak mapping and scoring
This commit is contained in:
parent
c1ea7a8cc2
commit
a85df53c74
2 changed files with 32 additions and 18 deletions
|
|
@ -241,12 +241,9 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Map vector distances to nodes, supporting single- and multi-query input shapes."""
|
||||
if not node_distances:
|
||||
return None
|
||||
|
||||
query_count = query_list_length or 1
|
||||
|
||||
# Reset all node distances for this search
|
||||
self.reset_distances(self.nodes.values(), query_count)
|
||||
|
||||
for collection_name, scored_results in node_distances.items():
|
||||
|
|
@ -279,19 +276,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
query_list_length: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Map vector distances to graph edges, supporting single- and multi-query input shapes."""
|
||||
if not edge_distances:
|
||||
return None
|
||||
|
||||
query_count = query_list_length or 1
|
||||
|
||||
# Reset all edge distances for this search
|
||||
self.reset_distances(self.edges, query_count)
|
||||
|
||||
if not edge_distances:
|
||||
return None
|
||||
|
||||
per_query_edge_lists = self._normalize_query_distance_lists(
|
||||
edge_distances, query_list_length, "edge_distances"
|
||||
)
|
||||
|
||||
# For each query, apply distances to all matching edges
|
||||
for query_index, scored_list in enumerate(per_query_edge_lists):
|
||||
for result in scored_list:
|
||||
payload = getattr(result, "payload", None)
|
||||
|
|
@ -318,13 +313,32 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
) -> List[Edge]:
|
||||
"""Calculate top k triplet importances for a specific query index."""
|
||||
|
||||
def score(edge):
|
||||
distances = [
|
||||
edge.node1.attributes.get("vector_distance"),
|
||||
edge.node2.attributes.get("vector_distance"),
|
||||
edge.attributes.get("vector_distance"),
|
||||
]
|
||||
return sum(float(d[query_index]) for d in distances)
|
||||
def score(edge: Edge) -> float:
|
||||
elements = (
|
||||
(edge.node1, f"node {edge.node1.id}"),
|
||||
(edge.node2, f"node {edge.node2.id}"),
|
||||
(edge, f"edge {edge.node1.id}->{edge.node2.id}"),
|
||||
)
|
||||
|
||||
importances = []
|
||||
for element, label in elements:
|
||||
distances = element.attributes.get("vector_distance")
|
||||
if not isinstance(distances, list) or query_index >= len(distances):
|
||||
raise ValueError(
|
||||
f"{label}: vector_distance must be a list with length > {query_index} "
|
||||
f"before scoring (got {type(distances).__name__} with length "
|
||||
f"{len(distances) if isinstance(distances, list) else 'n/a'})"
|
||||
)
|
||||
value = distances[query_index]
|
||||
try:
|
||||
importances.append(float(value))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"{label}: vector_distance[{query_index}] must be float-like, "
|
||||
f"got {type(value).__name__}"
|
||||
)
|
||||
|
||||
return sum(importances)
|
||||
|
||||
return heapq.nsmallest(k, self.edges, key=score)
|
||||
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|||
|
||||
# When no distances are set, calculate_top_triplet_importances should handle None
|
||||
# by either raising an error or skipping edges with None distances
|
||||
with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"):
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
|
||||
|
|
@ -695,7 +695,7 @@ async def test_calculate_top_triplet_importances_raises_on_short_list(setup_grap
|
|||
edge.add_attribute("vector_distance", [0.3])
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
||||
|
||||
|
||||
|
|
@ -716,5 +716,5 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
|
|||
del edge.attributes["vector_distance"]
|
||||
graph.add_edge(edge)
|
||||
|
||||
with pytest.raises((KeyError, TypeError)):
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue