feat: avoid double edge vector search in triplet search (#1877)
<!-- .github/pull_request_template.md -->
## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
Eliminates double vector search for edges by ensuring all edge lookups
happen once in the retrieval layer.
- `brute_force_triplet_search`: Always includes
"EdgeType_relationship_name" in collections
- `CogneeGraph.map_vector_distances_to_graph_edges`: Removed internal
vector search fallback; only maps provided distances.
- Tests updated to reflect the new behavior.
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Ensured relationship edges are automatically included in search
collections, improving search completeness and accuracy.
* **Refactor**
* Simplified graph edge distance mapping logic by removing unnecessary
external dependencies, resulting in more efficient edge processing
during retrieval operations.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
49f7c5188c
4 changed files with 49 additions and 85 deletions
|
|
@ -211,24 +211,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query_vector, edge_distances
|
||||
) -> None:
|
||||
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
||||
try:
|
||||
if query_vector is None or len(query_vector) == 0:
|
||||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
if edge_distances is None:
|
||||
start_time = time.time()
|
||||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_vector=query_vector,
|
||||
limit=None,
|
||||
)
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
||||
)
|
||||
return
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,9 @@ async def brute_force_triplet_search(
|
|||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
if "EdgeType_relationship_name" not in collections:
|
||||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
|
|
@ -197,9 +200,7 @@ async def brute_force_triplet_search(
|
|||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -325,48 +325,13 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, moc
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -386,20 +351,14 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vect
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -419,17 +378,13 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
|||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -449,26 +404,22 @@ async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_eng
|
|||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
async def test_map_vector_distances_none_returns_early(setup_graph):
|
||||
"""Test that edge_distances=None returns early without error."""
|
||||
graph = setup_graph
|
||||
graph.add_node(Node("1"))
|
||||
graph.add_node(Node("2"))
|
||||
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
|
|
@ -154,7 +155,32 @@ async def test_brute_force_triplet_search_custom_collections():
|
|||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert "EdgeType_relationship_name" in call_collections
|
||||
assert set(call_collections) == set(collections_without_edge) | {
|
||||
"EdgeType_relationship_name"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue