feat: remove secondary search
This commit is contained in:
parent
7a3138edf8
commit
c04d255aca
4 changed files with 49 additions and 85 deletions
|
|
@ -211,24 +211,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query_vector, edge_distances
|
||||
) -> None:
|
||||
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
||||
try:
|
||||
if query_vector is None or len(query_vector) == 0:
|
||||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
if edge_distances is None:
|
||||
start_time = time.time()
|
||||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_vector=query_vector,
|
||||
limit=None,
|
||||
)
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
||||
)
|
||||
return
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,9 @@ async def brute_force_triplet_search(
|
|||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
if "EdgeType_relationship_name" not in collections:
|
||||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
|
|
@ -197,9 +200,7 @@ async def brute_force_triplet_search(
|
|||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -325,48 +325,13 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, moc
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -386,20 +351,14 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vect
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -419,17 +378,13 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
|||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -449,26 +404,22 @@ async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_eng
|
|||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
async def test_map_vector_distances_none_returns_early(setup_graph):
|
||||
"""Test that edge_distances=None returns early without error."""
|
||||
graph = setup_graph
|
||||
graph.add_node(Node("1"))
|
||||
graph.add_node(Node("2"))
|
||||
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
|
|
@ -154,7 +155,32 @@ async def test_brute_force_triplet_search_custom_collections():
|
|||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert "EdgeType_relationship_name" in call_collections
|
||||
assert set(call_collections) == set(collections_without_edge) | {
|
||||
"EdgeType_relationship_name"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue