feat: implement unit tests and extensive checks around the get_graph_from_model [COG-754] (#491)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **Tests** - Added comprehensive unit tests for graph model generation - Introduced new test scenarios covering various data structures and edge cases - Implemented tests for document, chunk, and entity relationships - **Chores** - Updated continuous deployment workflow to trigger only on `dev` branch The release focuses on improving test coverage and refining the deployment process. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
8879f3fbbe
commit
2fd6bfa44c
3 changed files with 151 additions and 121 deletions
1
.github/workflows/cd.yaml
vendored
1
.github/workflows/cd.yaml
vendored
|
|
@ -4,7 +4,6 @@ on:
|
|||
push:
|
||||
branches:
|
||||
- dev
|
||||
- feature/*
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'examples/**'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,151 @@
|
|||
import pytest
|
||||
from typing import List, Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
_metadata = {"index_fields": [], "type": "Document"}
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
_metadata = {"index_fields": ["name"], "type": "EntityType"}
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
_metadata = {"index_fields": ["name"], "type": "Entity"}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_simple_structure():
|
||||
"""Tests simple pydantic structure for get_graph_from_model"""
|
||||
|
||||
entitytype = EntityType(
|
||||
name="TestType",
|
||||
)
|
||||
|
||||
entity = Entity(name="TestEntity", is_type=entitytype)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(entity, added_nodes, added_edges, visited_properties)
|
||||
|
||||
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"
|
||||
|
||||
edge_key = str(entity.id) + str(entitytype.id) + "is_type"
|
||||
assert edge_key in added_edges, f"Edge {edge_key} not found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_with_document_and_chunk():
|
||||
"""Tests multiple entities to document connection"""
|
||||
doc = Document(path="test/path")
|
||||
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
|
||||
entity_type = EntityType(name="Person")
|
||||
entity = Entity(name="Alice", is_type=entity_type)
|
||||
entity2 = Entity(name="Alice2", is_type=entity_type)
|
||||
doc_chunk.contains.append(entity)
|
||||
doc_chunk.contains.append(entity2)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
doc_chunk, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_duplicate_references():
|
||||
"""Tests duplicated objects in document list"""
|
||||
doc = Document(path="test/path")
|
||||
doc_chunk = DocumentChunk(part_of=doc, text="Chunk with duplicates", contains=[])
|
||||
|
||||
entity_type = EntityType(name="Animal")
|
||||
shared_entity = Entity(name="Cat", is_type=entity_type)
|
||||
|
||||
doc_chunk.contains.extend([shared_entity, shared_entity, shared_entity])
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
doc_chunk, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 4, f"Expected 4 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 3, f"Expected 3 edges, got {len(edges)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_multi_level_nesting():
|
||||
"""Tests multi level nested structure extraction"""
|
||||
doc = Document(path="multi-level/path")
|
||||
|
||||
chunk1 = DocumentChunk(part_of=doc, text="Chunk 1 text", contains=[])
|
||||
chunk2 = DocumentChunk(part_of=doc, text="Chunk 2 text", contains=[])
|
||||
|
||||
entity_type_vehicle = EntityType(name="Vehicle")
|
||||
entity_type_person = EntityType(name="Person")
|
||||
|
||||
entity_car = Entity(name="Car", is_type=entity_type_vehicle)
|
||||
entity_bike = Entity(name="Bike", is_type=entity_type_vehicle)
|
||||
entity_alice = Entity(name="Alice", is_type=entity_type_person)
|
||||
|
||||
chunk1.contains.extend([entity_car, entity_bike])
|
||||
chunk2.contains.append(entity_alice)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(chunk1, added_nodes, added_edges, visited_properties)
|
||||
|
||||
nodes2, edges2 = await get_graph_from_model(
|
||||
chunk2, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
all_nodes = nodes + nodes2
|
||||
all_edges = edges + edges2
|
||||
|
||||
assert len(all_nodes) == 8, f"Expected 8 nodes, got {len(all_nodes)}"
|
||||
assert len(all_edges) == 8, f"Expected 8 edges, got {len(all_edges)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_no_contains():
|
||||
"""Tests graph from model with empty contains element"""
|
||||
doc = Document(path="empty-contains/path")
|
||||
chunk = DocumentChunk(part_of=doc, text="A chunk with no entities", contains=[])
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
|
||||
|
||||
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
||||
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
from uuid import uuid4
|
||||
|
||||
from IPython.utils.wildcard import is_type
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models.Entity import Entity, EntityType
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_basic_initialization():
|
||||
"""Test the basic behavior of get_graph_from_model with a simple data point - without connection."""
|
||||
data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"})
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
data_point, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 1
|
||||
assert len(edges) == 0
|
||||
assert str(data_point.id) in added_nodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_with_single_neighbor():
|
||||
"""Test the behavior of get_graph_from_model when a data point has a single DataPoint property."""
|
||||
type_node = EntityType(
|
||||
id=uuid4(),
|
||||
name="Vehicle",
|
||||
description="This is a Vehicle node",
|
||||
)
|
||||
|
||||
entity_node = Entity(
|
||||
id=uuid4(),
|
||||
name="Car",
|
||||
is_a=type_node,
|
||||
description="This is a car node",
|
||||
)
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
entity_node, added_nodes, added_edges, visited_properties
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert len(edges) == 1
|
||||
assert str(entity_node.id) in added_nodes
|
||||
assert str(type_node.id) in added_nodes
|
||||
assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_from_model_with_multiple_nested_connections():
|
||||
"""Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property."""
|
||||
type_node = EntityType(
|
||||
id=uuid4(),
|
||||
name="Transportation tool",
|
||||
description="This is a Vehicle node",
|
||||
)
|
||||
|
||||
entity_node_1 = Entity(
|
||||
id=uuid4(),
|
||||
name="Car",
|
||||
is_a=type_node,
|
||||
description="This is a car node",
|
||||
)
|
||||
|
||||
entity_node_2 = Entity(
|
||||
id=uuid4(),
|
||||
name="Bus",
|
||||
is_a=type_node,
|
||||
description="This is a bus node",
|
||||
)
|
||||
|
||||
document = Document(
|
||||
name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test"
|
||||
)
|
||||
|
||||
chunk = DocumentChunk(
|
||||
id=uuid4(),
|
||||
word_count=8,
|
||||
chunk_index=0,
|
||||
cut_type="test",
|
||||
text="The car and the bus are transportation tools",
|
||||
is_part_of=document,
|
||||
contains=[entity_node_1, entity_node_2],
|
||||
)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
|
||||
|
||||
assert len(nodes) == 5
|
||||
assert len(edges) == 5
|
||||
|
||||
assert str(entity_node_1.id) in added_nodes
|
||||
assert str(entity_node_2.id) in added_nodes
|
||||
assert str(type_node.id) in added_nodes
|
||||
assert str(document.id) in added_nodes
|
||||
assert str(chunk.id) in added_nodes
|
||||
|
||||
assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges
|
||||
assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges
|
||||
assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges
|
||||
assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges
|
||||
assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges
|
||||
Loading…
Add table
Reference in a new issue