diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 2e0b82e8d..6233c245f 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -211,24 +211,10 @@ class CogneeGraph(CogneeAbstractGraph): node.add_attribute("vector_distance", score) mapped_nodes += 1 - async def map_vector_distances_to_graph_edges( - self, vector_engine, query_vector, edge_distances - ) -> None: + async def map_vector_distances_to_graph_edges(self, edge_distances) -> None: try: - if query_vector is None or len(query_vector) == 0: - raise ValueError("Failed to generate query embedding.") - if edge_distances is None: - start_time = time.time() - edge_distances = await vector_engine.search( - collection_name="EdgeType_relationship_name", - query_vector=query_vector, - limit=None, - ) - projection_time = time.time() - start_time - logger.info( - f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s" - ) + return embedding_map = {result.payload["text"]: result.score for result in edge_distances} diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 2f8a545f7..bd412e0ca 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -137,6 +137,9 @@ async def brute_force_triplet_search( "DocumentChunk_text", ] + if "EdgeType_relationship_name" not in collections: + collections.append("EdgeType_relationship_name") + try: vector_engine = get_vector_engine() except Exception as e: @@ -197,9 +200,7 @@ async def brute_force_triplet_search( ) await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges( - vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances - ) + await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances) results = await memory_fragment.calculate_top_triplet_importances(k=top_k) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 711479387..edbd8ef9d 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -305,7 +305,7 @@ async def test_map_vector_distances_multiple_categories(setup_graph): @pytest.mark.asyncio -async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine): +async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): """Test mapping vector distances to edges when edge_distances provided.""" graph = setup_graph @@ -325,48 +325,13 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, moc MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), ] - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[0.1, 0.2, 0.3], - edge_distances=edge_distances, - ) + await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) assert graph.edges[0].attributes.get("vector_distance") == 0.92 @pytest.mark.asyncio -async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine): - """Test mapping edge distances when searching for them.""" - graph = setup_graph - - node1 = Node("1") - node2 = Node("2") - graph.add_node(node1) - graph.add_node(node2) - - edge = Edge( - node1, - node2, - attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, - ) - graph.add_edge(edge) - - mock_vector_engine.search.return_value = [ - MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}), - ] - - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[0.1, 0.2, 0.3], - edge_distances=None, - ) - - mock_vector_engine.search.assert_called_once() - assert graph.edges[0].attributes.get("vector_distance") == 0.88 - - -@pytest.mark.asyncio -async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine): +async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -386,20 +351,14 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vect MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), ] - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[0.1, 0.2, 0.3], - edge_distances=edge_distances, - ) + await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) assert graph.edges[0].attributes.get("vector_distance") == 0.92 assert graph.edges[1].attributes.get("vector_distance") == 3.5 @pytest.mark.asyncio -async def test_map_vector_distances_edges_fallback_to_relationship_type( - setup_graph, mock_vector_engine -): +async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph): """Test that edge mapping falls back to relationship_type when edge_text is missing.""" graph = setup_graph @@ -419,17 +378,13 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type( MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), ] - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[0.1, 0.2, 0.3], - edge_distances=edge_distances, - ) + await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) assert graph.edges[0].attributes.get("vector_distance") == 0.85 @pytest.mark.asyncio -async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine): +async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph @@ -449,26 +404,22 @@ async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_eng MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), ] - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[0.1, 0.2, 0.3], - edge_distances=edge_distances, - ) + await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) assert graph.edges[0].attributes.get("vector_distance") == 3.5 @pytest.mark.asyncio -async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine): - """Test that invalid query vector raises error.""" +async def test_map_vector_distances_none_returns_early(setup_graph): + """Test that edge_distances=None returns early without error.""" graph = setup_graph + graph.add_node(Node("1")) + graph.add_node(Node("2")) + graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2"))) - with pytest.raises(ValueError, match="Failed to generate query embedding"): - await graph.map_vector_distances_to_graph_edges( - vector_engine=mock_vector_engine, - query_vector=[], - edge_distances=None, - ) + await graph.map_vector_distances_to_graph_edges(edge_distances=None) + + assert graph.edges[0].attributes.get("vector_distance") == 3.5 @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 5eb6fb105..3dc9f38d9 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -127,6 +127,7 @@ async def test_brute_force_triplet_search_default_collections(): "TextSummary_text", "EntityType_name", "DocumentChunk_text", + "EdgeType_relationship_name", ] call_collections = [ @@ -154,7 +155,32 @@ async def test_brute_force_triplet_search_custom_collections(): call_collections = [ call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list ] - assert call_collections == custom_collections + assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_always_includes_edge_collection(): + """Test that EdgeType_relationship_name is always searched even when not in collections.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + collections_without_edge = ["Entity_name", "TextSummary_text"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=collections_without_edge) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert "EdgeType_relationship_name" in call_collections + assert set(call_collections) == set(collections_without_edge) | { + "EdgeType_relationship_name" + } @pytest.mark.asyncio