feat: implement unit tests and extensive checks around the get_graph_from_model [COG-754] (#491)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **Tests**
  - Added comprehensive unit tests for graph model generation
- Introduced new test scenarios covering various data structures and
edge cases
  - Implemented tests for document, chunk, and entity relationships

- **Chores**
- Updated continuous deployment workflow to trigger only on `dev` branch

The release focuses on improving test coverage and refining the
deployment process.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
hajdul88 2025-01-31 18:17:23 +01:00 committed by GitHub
parent 8879f3fbbe
commit 2fd6bfa44c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 151 additions and 121 deletions

View file

@ -4,7 +4,6 @@ on:
push:
branches:
- dev
- feature/*
paths-ignore:
- '**.md'
- 'examples/**'

View file

@ -0,0 +1,151 @@
import pytest
from typing import List, Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
_metadata = {"index_fields": [], "type": "Document"}
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
class EntityType(DataPoint):
name: str
_metadata = {"index_fields": ["name"], "type": "EntityType"}
class Entity(DataPoint):
name: str
is_type: EntityType
_metadata = {"index_fields": ["name"], "type": "Entity"}
DocumentChunk.model_rebuild()
@pytest.mark.asyncio
async def test_get_graph_from_model_simple_structure():
"""Tests simple pydantic structure for get_graph_from_model"""
entitytype = EntityType(
name="TestType",
)
entity = Entity(name="TestEntity", is_type=entitytype)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(entity, added_nodes, added_edges, visited_properties)
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"
edge_key = str(entity.id) + str(entitytype.id) + "is_type"
assert edge_key in added_edges, f"Edge {edge_key} not found"
@pytest.mark.asyncio
async def test_get_graph_from_model_with_document_and_chunk():
"""Tests multiple entities to document connection"""
doc = Document(path="test/path")
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
entity_type = EntityType(name="Person")
entity = Entity(name="Alice", is_type=entity_type)
entity2 = Entity(name="Alice2", is_type=entity_type)
doc_chunk.contains.append(entity)
doc_chunk.contains.append(entity2)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(
doc_chunk, added_nodes, added_edges, visited_properties
)
assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}"
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
@pytest.mark.asyncio
async def test_get_graph_from_model_duplicate_references():
"""Tests duplicated objects in document list"""
doc = Document(path="test/path")
doc_chunk = DocumentChunk(part_of=doc, text="Chunk with duplicates", contains=[])
entity_type = EntityType(name="Animal")
shared_entity = Entity(name="Cat", is_type=entity_type)
doc_chunk.contains.extend([shared_entity, shared_entity, shared_entity])
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(
doc_chunk, added_nodes, added_edges, visited_properties
)
assert len(nodes) == 4, f"Expected 4 nodes, got {len(nodes)}"
assert len(edges) == 3, f"Expected 3 edges, got {len(edges)}"
@pytest.mark.asyncio
async def test_get_graph_from_model_multi_level_nesting():
"""Tests multi level nested structure extraction"""
doc = Document(path="multi-level/path")
chunk1 = DocumentChunk(part_of=doc, text="Chunk 1 text", contains=[])
chunk2 = DocumentChunk(part_of=doc, text="Chunk 2 text", contains=[])
entity_type_vehicle = EntityType(name="Vehicle")
entity_type_person = EntityType(name="Person")
entity_car = Entity(name="Car", is_type=entity_type_vehicle)
entity_bike = Entity(name="Bike", is_type=entity_type_vehicle)
entity_alice = Entity(name="Alice", is_type=entity_type_person)
chunk1.contains.extend([entity_car, entity_bike])
chunk2.contains.append(entity_alice)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(chunk1, added_nodes, added_edges, visited_properties)
nodes2, edges2 = await get_graph_from_model(
chunk2, added_nodes, added_edges, visited_properties
)
all_nodes = nodes + nodes2
all_edges = edges + edges2
assert len(all_nodes) == 8, f"Expected 8 nodes, got {len(all_nodes)}"
assert len(all_edges) == 8, f"Expected 8 edges, got {len(all_edges)}"
@pytest.mark.asyncio
async def test_get_graph_from_model_no_contains():
"""Tests graph from model with empty contains element"""
doc = Document(path="empty-contains/path")
chunk = DocumentChunk(part_of=doc, text="A chunk with no entities", contains=[])
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"

View file

@ -1,120 +0,0 @@
import pytest
import asyncio
import random
from typing import List
from uuid import NAMESPACE_OID, uuid5
from uuid import uuid4
from IPython.utils.wildcard import is_type
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models.Entity import Entity, EntityType
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types import Document
from cognee.modules.graph.utils import get_graph_from_model
@pytest.mark.asyncio
async def test_get_graph_from_model_basic_initialization():
"""Test the basic behavior of get_graph_from_model with a simple data point - without connection."""
data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"})
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(
data_point, added_nodes, added_edges, visited_properties
)
assert len(nodes) == 1
assert len(edges) == 0
assert str(data_point.id) in added_nodes
@pytest.mark.asyncio
async def test_get_graph_from_model_with_single_neighbor():
"""Test the behavior of get_graph_from_model when a data point has a single DataPoint property."""
type_node = EntityType(
id=uuid4(),
name="Vehicle",
description="This is a Vehicle node",
)
entity_node = Entity(
id=uuid4(),
name="Car",
is_a=type_node,
description="This is a car node",
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(
entity_node, added_nodes, added_edges, visited_properties
)
assert len(nodes) == 2
assert len(edges) == 1
assert str(entity_node.id) in added_nodes
assert str(type_node.id) in added_nodes
assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges
@pytest.mark.asyncio
async def test_get_graph_from_model_with_multiple_nested_connections():
"""Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property."""
type_node = EntityType(
id=uuid4(),
name="Transportation tool",
description="This is a Vehicle node",
)
entity_node_1 = Entity(
id=uuid4(),
name="Car",
is_a=type_node,
description="This is a car node",
)
entity_node_2 = Entity(
id=uuid4(),
name="Bus",
is_a=type_node,
description="This is a bus node",
)
document = Document(
name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test"
)
chunk = DocumentChunk(
id=uuid4(),
word_count=8,
chunk_index=0,
cut_type="test",
text="The car and the bus are transportation tools",
is_part_of=document,
contains=[entity_node_1, entity_node_2],
)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)
assert len(nodes) == 5
assert len(edges) == 5
assert str(entity_node_1.id) in added_nodes
assert str(entity_node_2.id) in added_nodes
assert str(type_node.id) in added_nodes
assert str(document.id) in added_nodes
assert str(chunk.id) in added_nodes
assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges
assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges
assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges
assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges
assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges