diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 48d46417f..978205d2f 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -164,7 +164,7 @@ async def get_default_tasks( summarization_model=cognee_config.summarization_model, task_config={"batch_size": 10}, ), - Task(add_data_points, only_root=True, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": 10}), Task(store_descriptive_metrics), ] except Exception as error: diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 41bfb891d..a5c1f3eb3 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -205,12 +205,20 @@ class Neo4jAdapter(GraphDBInterface): async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: query = """ - UNWIND $edges AS edge - MATCH (from_node {id: edge.from_node}) - MATCH (to_node {id: edge.to_node}) - CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel - RETURN rel - """ + UNWIND $edges AS edge + MATCH (from_node {id: edge.from_node}) + MATCH (to_node {id: edge.to_node}) + CALL apoc.merge.relationship( + from_node, + edge.relationship_name, + { + source_node_id: edge.from_node, + target_node_id: edge.to_node + }, + edge.properties, + to_node + ) YIELD rel + RETURN rel""" edges = [ { diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index f4d2ed77a..2798b48b6 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -8,7 +8,6 @@ async def get_graph_from_model( added_nodes: dict, added_edges: dict, visited_properties: dict = None, - only_root=False, include_root=True, ): if str(data_point.id) in added_nodes: @@ -98,7 +97,7 @@ async def get_graph_from_model( ) added_edges[str(edge_key)] = True - if str(field_value.id) in added_nodes or only_root: + if str(field_value.id) in added_nodes: continue property_nodes, property_edges = await get_graph_from_model( diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index 21cc5a3c2..540575b5b 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -5,7 +5,7 @@ from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_fr from .index_data_points import index_data_points -async def add_data_points(data_points: list[DataPoint], only_root=False): +async def add_data_points(data_points: list[DataPoint]): nodes = [] edges = [] @@ -20,7 +20,6 @@ async def add_data_points(data_points: list[DataPoint], only_root=False): added_nodes=added_nodes, added_edges=added_edges, visited_properties=visited_properties, - only_root=only_root, ) for data_point in data_points ] diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_tests.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_tests.py new file mode 100644 index 000000000..91132ac45 --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_tests.py @@ -0,0 +1,120 @@ +import pytest +import asyncio +import random +from typing import List +from uuid import NAMESPACE_OID, uuid5 +from uuid import uuid4 + +from IPython.utils.wildcard import is_type + +from cognee.infrastructure.engine import DataPoint +from cognee.modules.engine.models.Entity import Entity, EntityType +from cognee.modules.chunking.models.DocumentChunk import DocumentChunk +from cognee.modules.data.processing.document_types import Document +from cognee.modules.graph.utils import get_graph_from_model + + +@pytest.mark.asyncio +async def test_get_graph_from_model_basic_initialization(): + """Test the basic behavior of get_graph_from_model with a simple data point - without connection.""" + data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"}) + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model( + data_point, added_nodes, added_edges, visited_properties + ) + + assert len(nodes) == 1 + assert len(edges) == 0 + assert str(data_point.id) in added_nodes + + +@pytest.mark.asyncio +async def test_get_graph_from_model_with_single_neighbor(): + """Test the behavior of get_graph_from_model when a data point has a single DataPoint property.""" + type_node = EntityType( + id=uuid4(), + name="Vehicle", + description="This is a Vehicle node", + ) + + entity_node = Entity( + id=uuid4(), + name="Car", + is_a=type_node, + description="This is a car node", + ) + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model( + entity_node, added_nodes, added_edges, visited_properties + ) + + assert len(nodes) == 2 + assert len(edges) == 1 + assert str(entity_node.id) in added_nodes + assert str(type_node.id) in added_nodes + assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges + + +@pytest.mark.asyncio +async def test_get_graph_from_model_with_multiple_nested_connections(): + """Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property.""" + type_node = EntityType( + id=uuid4(), + name="Transportation tool", + description="This is a Vehicle node", + ) + + entity_node_1 = Entity( + id=uuid4(), + name="Car", + is_a=type_node, + description="This is a car node", + ) + + entity_node_2 = Entity( + id=uuid4(), + name="Bus", + is_a=type_node, + description="This is a bus node", + ) + + document = Document( + name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test" + ) + + chunk = DocumentChunk( + id=uuid4(), + word_count=8, + chunk_index=0, + cut_type="test", + text="The car and the bus are transportation tools", + is_part_of=document, + contains=[entity_node_1, entity_node_2], + ) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties) + + assert len(nodes) == 5 + assert len(edges) == 5 + + assert str(entity_node_1.id) in added_nodes + assert str(entity_node_2.id) in added_nodes + assert str(type_node.id) in added_nodes + assert str(document.id) in added_nodes + assert str(chunk.id) in added_nodes + + assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges + assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges + assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges + assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges + assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges