feat: Use unwind for batch edge save and add unit tests for get_graph_from_model
* feat: adds some unit tests for get_graph_from_model * feat: updates neo4j add_edges cypher and deletes shallow get_graph_from_model * fix: fixing merge conflict false resolve * chore: deletes old only_root unit test
This commit is contained in:
parent
a79f7133fd
commit
f843c256e4
5 changed files with 137 additions and 11 deletions
|
|
@ -164,7 +164,7 @@ async def get_default_tasks(
|
||||||
summarization_model=cognee_config.summarization_model,
|
summarization_model=cognee_config.summarization_model,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 10},
|
||||||
),
|
),
|
||||||
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
|
Task(add_data_points, task_config={"batch_size": 10}),
|
||||||
Task(store_descriptive_metrics),
|
Task(store_descriptive_metrics),
|
||||||
]
|
]
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
|
|
||||||
|
|
@ -205,12 +205,20 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||||
query = """
|
query = """
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (from_node {id: edge.from_node})
|
MATCH (from_node {id: edge.from_node})
|
||||||
MATCH (to_node {id: edge.to_node})
|
MATCH (to_node {id: edge.to_node})
|
||||||
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
|
CALL apoc.merge.relationship(
|
||||||
RETURN rel
|
from_node,
|
||||||
"""
|
edge.relationship_name,
|
||||||
|
{
|
||||||
|
source_node_id: edge.from_node,
|
||||||
|
target_node_id: edge.to_node
|
||||||
|
},
|
||||||
|
edge.properties,
|
||||||
|
to_node
|
||||||
|
) YIELD rel
|
||||||
|
RETURN rel"""
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ async def get_graph_from_model(
|
||||||
added_nodes: dict,
|
added_nodes: dict,
|
||||||
added_edges: dict,
|
added_edges: dict,
|
||||||
visited_properties: dict = None,
|
visited_properties: dict = None,
|
||||||
only_root=False,
|
|
||||||
include_root=True,
|
include_root=True,
|
||||||
):
|
):
|
||||||
if str(data_point.id) in added_nodes:
|
if str(data_point.id) in added_nodes:
|
||||||
|
|
@ -98,7 +97,7 @@ async def get_graph_from_model(
|
||||||
)
|
)
|
||||||
added_edges[str(edge_key)] = True
|
added_edges[str(edge_key)] = True
|
||||||
|
|
||||||
if str(field_value.id) in added_nodes or only_root:
|
if str(field_value.id) in added_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
property_nodes, property_edges = await get_graph_from_model(
|
property_nodes, property_edges = await get_graph_from_model(
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_fr
|
||||||
from .index_data_points import index_data_points
|
from .index_data_points import index_data_points
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(data_points: list[DataPoint], only_root=False):
|
async def add_data_points(data_points: list[DataPoint]):
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
|
|
@ -20,7 +20,6 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
|
||||||
added_nodes=added_nodes,
|
added_nodes=added_nodes,
|
||||||
added_edges=added_edges,
|
added_edges=added_edges,
|
||||||
visited_properties=visited_properties,
|
visited_properties=visited_properties,
|
||||||
only_root=only_root,
|
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from IPython.utils.wildcard import is_type
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.engine.models.Entity import Entity, EntityType
|
||||||
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types import Document
|
||||||
|
from cognee.modules.graph.utils import get_graph_from_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_from_model_basic_initialization():
|
||||||
|
"""Test the basic behavior of get_graph_from_model with a simple data point - without connection."""
|
||||||
|
data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"})
|
||||||
|
added_nodes = {}
|
||||||
|
added_edges = {}
|
||||||
|
visited_properties = {}
|
||||||
|
|
||||||
|
nodes, edges = await get_graph_from_model(
|
||||||
|
data_point, added_nodes, added_edges, visited_properties
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(nodes) == 1
|
||||||
|
assert len(edges) == 0
|
||||||
|
assert str(data_point.id) in added_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_from_model_with_single_neighbor():
|
||||||
|
"""Test the behavior of get_graph_from_model when a data point has a single DataPoint property."""
|
||||||
|
type_node = EntityType(
|
||||||
|
id=uuid4(),
|
||||||
|
name="Vehicle",
|
||||||
|
description="This is a Vehicle node",
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_node = Entity(
|
||||||
|
id=uuid4(),
|
||||||
|
name="Car",
|
||||||
|
is_a=type_node,
|
||||||
|
description="This is a car node",
|
||||||
|
)
|
||||||
|
added_nodes = {}
|
||||||
|
added_edges = {}
|
||||||
|
visited_properties = {}
|
||||||
|
|
||||||
|
nodes, edges = await get_graph_from_model(
|
||||||
|
entity_node, added_nodes, added_edges, visited_properties
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(nodes) == 2
|
||||||
|
assert len(edges) == 1
|
||||||
|
assert str(entity_node.id) in added_nodes
|
||||||
|
assert str(type_node.id) in added_nodes
|
||||||
|
assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_from_model_with_multiple_nested_connections():
|
||||||
|
"""Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property."""
|
||||||
|
type_node = EntityType(
|
||||||
|
id=uuid4(),
|
||||||
|
name="Transportation tool",
|
||||||
|
description="This is a Vehicle node",
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_node_1 = Entity(
|
||||||
|
id=uuid4(),
|
||||||
|
name="Car",
|
||||||
|
is_a=type_node,
|
||||||
|
description="This is a car node",
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_node_2 = Entity(
|
||||||
|
id=uuid4(),
|
||||||
|
name="Bus",
|
||||||
|
is_a=type_node,
|
||||||
|
description="This is a bus node",
|
||||||
|
)
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = DocumentChunk(
|
||||||
|
id=uuid4(),
|
||||||
|
word_count=8,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="test",
|
||||||
|
text="The car and the bus are transportation tools",
|
||||||
|
is_part_of=document,
|
||||||
|
contains=[entity_node_1, entity_node_2],
|
||||||
|
)
|
||||||
|
|
||||||
|
added_nodes = {}
|
||||||
|
added_edges = {}
|
||||||
|
visited_properties = {}
|
||||||
|
|
||||||
|
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
|
||||||
|
|
||||||
|
assert len(nodes) == 5
|
||||||
|
assert len(edges) == 5
|
||||||
|
|
||||||
|
assert str(entity_node_1.id) in added_nodes
|
||||||
|
assert str(entity_node_2.id) in added_nodes
|
||||||
|
assert str(type_node.id) in added_nodes
|
||||||
|
assert str(document.id) in added_nodes
|
||||||
|
assert str(chunk.id) in added_nodes
|
||||||
|
|
||||||
|
assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges
|
||||||
|
assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges
|
||||||
|
assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges
|
||||||
|
assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges
|
||||||
|
assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges
|
||||||
Loading…
Add table
Reference in a new issue