481 lines
14 KiB
Python
481 lines
14 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock
|
|
|
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_graph():
|
|
"""Fixture to initialize a CogneeGraph instance."""
|
|
return CogneeGraph()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_adapter():
|
|
"""Fixture to create a mock adapter for database operations."""
|
|
adapter = AsyncMock()
|
|
return adapter
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_engine():
|
|
"""Fixture to create a mock vector engine."""
|
|
engine = AsyncMock()
|
|
engine.search = AsyncMock()
|
|
return engine
|
|
|
|
|
|
class MockScoredResult:
|
|
"""Mock class for vector search results."""
|
|
|
|
def __init__(self, id, score, payload=None):
|
|
self.id = id
|
|
self.score = score
|
|
self.payload = payload or {}
|
|
|
|
|
|
def test_add_node_success(setup_graph):
|
|
"""Test successful addition of a node."""
|
|
graph = setup_graph
|
|
node = Node("node1")
|
|
graph.add_node(node)
|
|
assert graph.get_node("node1") == node
|
|
|
|
|
|
def test_add_duplicate_node(setup_graph):
|
|
"""Test adding a duplicate node raises an exception."""
|
|
graph = setup_graph
|
|
node = Node("node1")
|
|
graph.add_node(node)
|
|
with pytest.raises(EntityAlreadyExistsError, match="Node with id node1 already exists."):
|
|
graph.add_node(node)
|
|
|
|
|
|
def test_add_edge_success(setup_graph):
|
|
"""Test successful addition of an edge."""
|
|
graph = setup_graph
|
|
node1 = Node("node1")
|
|
node2 = Node("node2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
edge = Edge(node1, node2)
|
|
graph.add_edge(edge)
|
|
assert edge in graph.edges
|
|
assert edge in node1.skeleton_edges
|
|
assert edge in node2.skeleton_edges
|
|
|
|
|
|
def test_get_node_success(setup_graph):
|
|
"""Test retrieving an existing node."""
|
|
graph = setup_graph
|
|
node = Node("node1")
|
|
graph.add_node(node)
|
|
assert graph.get_node("node1") == node
|
|
|
|
|
|
def test_get_node_nonexistent(setup_graph):
|
|
"""Test retrieving a nonexistent node returns None."""
|
|
graph = setup_graph
|
|
assert graph.get_node("nonexistent") is None
|
|
|
|
|
|
def test_get_edges_success(setup_graph):
|
|
"""Test retrieving edges of a node."""
|
|
graph = setup_graph
|
|
node1 = Node("node1")
|
|
node2 = Node("node2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
edge = Edge(node1, node2)
|
|
graph.add_edge(edge)
|
|
assert edge in graph.get_edges_from_node("node1")
|
|
|
|
|
|
def test_get_edges_nonexistent_node(setup_graph):
|
|
"""Test retrieving edges for a nonexistent node raises an exception."""
|
|
graph = setup_graph
|
|
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
|
graph.get_edges_from_node("nonexistent")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
|
|
"""Test projecting a full graph from database."""
|
|
graph = setup_graph
|
|
|
|
nodes_data = [
|
|
("1", {"name": "Node1", "description": "First node"}),
|
|
("2", {"name": "Node2", "description": "Second node"}),
|
|
]
|
|
edges_data = [
|
|
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
]
|
|
|
|
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
|
|
await graph.project_graph_from_db(
|
|
adapter=mock_adapter,
|
|
node_properties_to_project=["name", "description"],
|
|
edge_properties_to_project=["relationship_name"],
|
|
)
|
|
|
|
assert len(graph.nodes) == 2
|
|
assert len(graph.edges) == 1
|
|
assert graph.get_node("1") is not None
|
|
assert graph.get_node("2") is not None
|
|
assert graph.edges[0].node1.id == "1"
|
|
assert graph.edges[0].node2.id == "2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
|
|
"""Test projecting an ID-filtered graph from database."""
|
|
graph = setup_graph
|
|
|
|
nodes_data = [
|
|
("1", {"name": "Node1"}),
|
|
("2", {"name": "Node2"}),
|
|
]
|
|
edges_data = [
|
|
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
]
|
|
|
|
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
|
|
await graph.project_graph_from_db(
|
|
adapter=mock_adapter,
|
|
node_properties_to_project=["name"],
|
|
edge_properties_to_project=["relationship_name"],
|
|
relevant_ids_to_filter=["1", "2"],
|
|
)
|
|
|
|
assert len(graph.nodes) == 2
|
|
assert len(graph.edges) == 1
|
|
mock_adapter.get_id_filtered_graph_data.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
|
|
"""Test projecting a nodeset subgraph filtered by node type and name."""
|
|
graph = setup_graph
|
|
|
|
nodes_data = [
|
|
("1", {"name": "Alice", "type": "Person"}),
|
|
("2", {"name": "Bob", "type": "Person"}),
|
|
]
|
|
edges_data = [
|
|
("1", "2", "KNOWS", {"relationship_name": "knows"}),
|
|
]
|
|
|
|
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
|
|
|
|
await graph.project_graph_from_db(
|
|
adapter=mock_adapter,
|
|
node_properties_to_project=["name", "type"],
|
|
edge_properties_to_project=["relationship_name"],
|
|
node_type="Person",
|
|
node_name=["Alice"],
|
|
)
|
|
|
|
assert len(graph.nodes) == 2
|
|
assert graph.get_node("1") is not None
|
|
assert len(graph.edges) == 1
|
|
mock_adapter.get_nodeset_subgraph.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
|
"""Test projecting empty graph raises EntityNotFoundError."""
|
|
graph = setup_graph
|
|
|
|
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
|
|
|
|
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
|
|
await graph.project_graph_from_db(
|
|
adapter=mock_adapter,
|
|
node_properties_to_project=["name"],
|
|
edge_properties_to_project=[],
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
|
"""Test that edges referencing missing nodes raise error."""
|
|
graph = setup_graph
|
|
|
|
nodes_data = [
|
|
("1", {"name": "Node1"}),
|
|
]
|
|
edges_data = [
|
|
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
]
|
|
|
|
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
|
|
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
|
|
await graph.project_graph_from_db(
|
|
adapter=mock_adapter,
|
|
node_properties_to_project=["name"],
|
|
edge_properties_to_project=["relationship_name"],
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
|
"""Test mapping vector distances to graph nodes."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1", {"name": "Node1"})
|
|
node2 = Node("2", {"name": "Node2"})
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
|
|
node_distances = {
|
|
"Entity_name": [
|
|
MockScoredResult("1", 0.95),
|
|
MockScoredResult("2", 0.87),
|
|
]
|
|
}
|
|
|
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
|
|
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
|
"""Test mapping vector distances when only some nodes have results."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1", {"name": "Node1"})
|
|
node2 = Node("2", {"name": "Node2"})
|
|
node3 = Node("3", {"name": "Node3"})
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
graph.add_node(node3)
|
|
|
|
node_distances = {
|
|
"Entity_name": [
|
|
MockScoredResult("1", 0.95),
|
|
MockScoredResult("2", 0.87),
|
|
]
|
|
}
|
|
|
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
|
|
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_multiple_categories(setup_graph):
|
|
"""Test mapping vector distances from multiple collection categories."""
|
|
graph = setup_graph
|
|
|
|
# Create nodes
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
node3 = Node("3")
|
|
node4 = Node("4")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
graph.add_node(node3)
|
|
graph.add_node(node4)
|
|
|
|
node_distances = {
|
|
"Entity_name": [
|
|
MockScoredResult("1", 0.95),
|
|
MockScoredResult("2", 0.87),
|
|
],
|
|
"TextSummary_text": [
|
|
MockScoredResult("3", 0.92),
|
|
],
|
|
}
|
|
|
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
|
|
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
|
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|
"""Test mapping vector distances to edges when edge_distances provided."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
|
|
edge = Edge(
|
|
node1,
|
|
node2,
|
|
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
|
)
|
|
graph.add_edge(edge)
|
|
|
|
edge_distances = [
|
|
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
|
]
|
|
|
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
|
|
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|
"""Test mapping edge distances when only some edges have results."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
node3 = Node("3")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
graph.add_node(node3)
|
|
|
|
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
|
|
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
|
|
graph.add_edge(edge1)
|
|
graph.add_edge(edge2)
|
|
|
|
edge_distances = [
|
|
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
|
]
|
|
|
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
|
|
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
|
|
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
|
|
edge = Edge(
|
|
node1,
|
|
node2,
|
|
attributes={"relationship_type": "KNOWS"},
|
|
)
|
|
graph.add_edge(edge)
|
|
|
|
edge_distances = [
|
|
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
|
]
|
|
|
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
|
|
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|
"""Test edge mapping when no edges match the distance results."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
|
|
edge = Edge(
|
|
node1,
|
|
node2,
|
|
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
|
)
|
|
graph.add_edge(edge)
|
|
|
|
edge_distances = [
|
|
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
|
]
|
|
|
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
|
|
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_map_vector_distances_none_returns_early(setup_graph):
|
|
"""Test that edge_distances=None returns early without error."""
|
|
graph = setup_graph
|
|
graph.add_node(Node("1"))
|
|
graph.add_node(Node("2"))
|
|
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
|
|
|
|
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
|
|
|
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_top_triplet_importances(setup_graph):
|
|
"""Test calculating top triplet importances by score."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
node3 = Node("3")
|
|
node4 = Node("4")
|
|
|
|
node1.add_attribute("vector_distance", 0.9)
|
|
node2.add_attribute("vector_distance", 0.8)
|
|
node3.add_attribute("vector_distance", 0.7)
|
|
node4.add_attribute("vector_distance", 0.6)
|
|
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
graph.add_node(node3)
|
|
graph.add_node(node4)
|
|
|
|
edge1 = Edge(node1, node2)
|
|
edge2 = Edge(node2, node3)
|
|
edge3 = Edge(node3, node4)
|
|
|
|
edge1.add_attribute("vector_distance", 0.85)
|
|
edge2.add_attribute("vector_distance", 0.75)
|
|
edge3.add_attribute("vector_distance", 0.65)
|
|
|
|
graph.add_edge(edge1)
|
|
graph.add_edge(edge2)
|
|
graph.add_edge(edge3)
|
|
|
|
top_triplets = await graph.calculate_top_triplet_importances(k=2)
|
|
|
|
assert len(top_triplets) == 2
|
|
|
|
assert top_triplets[0] == edge3
|
|
assert top_triplets[1] == edge2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|
"""Test calculating importances when nodes/edges have no vector distances."""
|
|
graph = setup_graph
|
|
|
|
node1 = Node("1")
|
|
node2 = Node("2")
|
|
graph.add_node(node1)
|
|
graph.add_node(node2)
|
|
|
|
edge = Edge(node1, node2)
|
|
graph.add_edge(edge)
|
|
|
|
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
|
|
|
assert len(top_triplets) == 1
|
|
assert top_triplets[0] == edge
|