feat: Use unwind for batch edge save and add unit tests for get_graph_from_model
* feat: adds some unit tests for get_graph_from_model * feat: updates neo4j add_edges cypher and deletes shallow get_graph_from_model * fix: fixing merge conflict false resolve * chore: deletes old only_root unit test
This commit is contained in:
parent
a79f7133fd
commit
f843c256e4
5 changed files with 137 additions and 11 deletions
|
|
@ -164,7 +164,7 @@ async def get_default_tasks(
|
|||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(store_descriptive_metrics),
|
||||
]
|
||||
except Exception as error:
|
||||
|
|
|
|||
|
|
@ -205,12 +205,20 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
|
||||
RETURN rel
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL apoc.merge.relationship(
|
||||
from_node,
|
||||
edge.relationship_name,
|
||||
{
|
||||
source_node_id: edge.from_node,
|
||||
target_node_id: edge.to_node
|
||||
},
|
||||
edge.properties,
|
||||
to_node
|
||||
) YIELD rel
|
||||
RETURN rel"""
|
||||
|
||||
edges = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ async def get_graph_from_model(
|
|||
added_nodes: dict,
|
||||
added_edges: dict,
|
||||
visited_properties: dict = None,
|
||||
only_root=False,
|
||||
include_root=True,
|
||||
):
|
||||
if str(data_point.id) in added_nodes:
|
||||
|
|
@ -98,7 +97,7 @@ async def get_graph_from_model(
|
|||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
if str(field_value.id) in added_nodes or only_root:
|
||||
if str(field_value.id) in added_nodes:
|
||||
continue
|
||||
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_fr
|
|||
from .index_data_points import index_data_points
|
||||
|
||||
|
||||
async def add_data_points(data_points: list[DataPoint], only_root=False):
|
||||
async def add_data_points(data_points: list[DataPoint]):
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
|
@ -20,7 +20,6 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
|
|||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
only_root=only_root,
|
||||
)
|
||||
for data_point in data_points
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
from uuid import uuid4
|
||||
|
||||
from IPython.utils.wildcard import is_type
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models.Entity import Entity, EntityType
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_basic_initialization():
|
||||
"""Test the basic behavior of get_graph_from_model with a simple data point - without connection."""
|
||||
data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"})
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
data_point, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 1
|
||||
assert len(edges) == 0
|
||||
assert str(data_point.id) in added_nodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_with_single_neighbor():
|
||||
"""Test the behavior of get_graph_from_model when a data point has a single DataPoint property."""
|
||||
type_node = EntityType(
|
||||
id=uuid4(),
|
||||
name="Vehicle",
|
||||
description="This is a Vehicle node",
|
||||
)
|
||||
|
||||
entity_node = Entity(
|
||||
id=uuid4(),
|
||||
name="Car",
|
||||
is_a=type_node,
|
||||
description="This is a car node",
|
||||
)
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
entity_node, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert len(edges) == 1
|
||||
assert str(entity_node.id) in added_nodes
|
||||
assert str(type_node.id) in added_nodes
|
||||
assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_with_multiple_nested_connections():
|
||||
"""Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property."""
|
||||
type_node = EntityType(
|
||||
id=uuid4(),
|
||||
name="Transportation tool",
|
||||
description="This is a Vehicle node",
|
||||
)
|
||||
|
||||
entity_node_1 = Entity(
|
||||
id=uuid4(),
|
||||
name="Car",
|
||||
is_a=type_node,
|
||||
description="This is a car node",
|
||||
)
|
||||
|
||||
entity_node_2 = Entity(
|
||||
id=uuid4(),
|
||||
name="Bus",
|
||||
is_a=type_node,
|
||||
description="This is a bus node",
|
||||
)
|
||||
|
||||
document = Document(
|
||||
name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test"
|
||||
)
|
||||
|
||||
chunk = DocumentChunk(
|
||||
id=uuid4(),
|
||||
word_count=8,
|
||||
chunk_index=0,
|
||||
cut_type="test",
|
||||
text="The car and the bus are transportation tools",
|
||||
is_part_of=document,
|
||||
contains=[entity_node_1, entity_node_2],
|
||||
)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
|
||||
|
||||
assert len(nodes) == 5
|
||||
assert len(edges) == 5
|
||||
|
||||
assert str(entity_node_1.id) in added_nodes
|
||||
assert str(entity_node_2.id) in added_nodes
|
||||
assert str(type_node.id) in added_nodes
|
||||
assert str(document.id) in added_nodes
|
||||
assert str(chunk.id) in added_nodes
|
||||
|
||||
assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges
|
||||
assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges
|
||||
assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges
|
||||
assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges
|
||||
assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges
|
||||
Loading…
Add table
Reference in a new issue