feat: remove secondary search

This commit is contained in:
lxobr 2025-12-08 16:43:10 +01:00
parent 7a3138edf8
commit c04d255aca
4 changed files with 49 additions and 85 deletions

View file

@ -211,24 +211,10 @@ class CogneeGraph(CogneeAbstractGraph):
node.add_attribute("vector_distance", score)
mapped_nodes += 1
async def map_vector_distances_to_graph_edges(
self, vector_engine, query_vector, edge_distances
) -> None:
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
try:
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
if edge_distances is None:
start_time = time.time()
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_vector=query_vector,
limit=None,
)
projection_time = time.time() - start_time
logger.info(
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
)
return
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

View file

@ -137,6 +137,9 @@ async def brute_force_triplet_search(
"DocumentChunk_text",
]
if "EdgeType_relationship_name" not in collections:
collections.append("EdgeType_relationship_name")
try:
vector_engine = get_vector_engine()
except Exception as e:
@ -197,9 +200,7 @@ async def brute_force_triplet_search(
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
)
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)

View file

@ -305,7 +305,7 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
"""Test mapping vector distances to edges when edge_distances provided."""
graph = setup_graph
@ -325,48 +325,13 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, moc
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
"""Test mapping edge distances when searching for them."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
mock_vector_engine.search.return_value = [
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=None,
)
mock_vector_engine.search.assert_called_once()
assert graph.edges[0].attributes.get("vector_distance") == 0.88
@pytest.mark.asyncio
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
"""Test mapping edge distances when only some edges have results."""
graph = setup_graph
@ -386,20 +351,14 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vect
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[1].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_edges_fallback_to_relationship_type(
setup_graph, mock_vector_engine
):
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
graph = setup_graph
@ -419,17 +378,13 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 0.85
@pytest.mark.asyncio
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
async def test_map_vector_distances_no_edge_matches(setup_graph):
"""Test edge mapping when no edges match the distance results."""
graph = setup_graph
@ -449,26 +404,22 @@ async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_eng
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
"""Test that invalid query vector raises error."""
async def test_map_vector_distances_none_returns_early(setup_graph):
"""Test that edge_distances=None returns early without error."""
graph = setup_graph
graph.add_node(Node("1"))
graph.add_node(Node("2"))
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
with pytest.raises(ValueError, match="Failed to generate query embedding"):
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[],
edge_distances=None,
)
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio

View file

@ -127,6 +127,7 @@ async def test_brute_force_triplet_search_default_collections():
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
"EdgeType_relationship_name",
]
call_collections = [
@ -154,7 +155,32 @@ async def test_brute_force_triplet_search_custom_collections():
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == custom_collections
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_always_includes_edge_collection():
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
collections_without_edge = ["Entity_name", "TextSummary_text"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=collections_without_edge)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert "EdgeType_relationship_name" in call_collections
assert set(call_collections) == set(collections_without_edge) | {
"EdgeType_relationship_name"
}
@pytest.mark.asyncio