From 9c5f1a2686b91ec5589dcc774689ebc16624b0df Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Wed, 9 Jul 2025 00:33:23 +0200 Subject: [PATCH] fix: break circular data points in a graph model --- .../graph/utils/get_graph_from_model.py | 8 +- .../graph/get_graph_from_huge_model_test.py | 35 +++++--- .../get_graph_from_model_circular_test.py | 67 ++++++++++++++ .../graph/get_graph_from_model_flat_test.py | 79 ---------------- .../graph/get_graph_from_model_test.py | 89 ------------------- 5 files changed, 93 insertions(+), 185 deletions(-) create mode 100644 cognee/tests/unit/interfaces/graph/get_graph_from_model_circular_test.py delete mode 100644 cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py delete mode 100644 cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 5aea2de90..a6d50f41c 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -35,6 +35,7 @@ async def get_graph_from_model( if property_key in visited_properties: continue + visited_properties[property_key] = True properties_to_visit.add(field_name) continue @@ -52,6 +53,7 @@ async def get_graph_from_model( if property_key in visited_properties: continue + visited_properties[property_key] = True properties_to_visit.add(f"{field_name}.{index}") continue @@ -98,6 +100,9 @@ async def get_graph_from_model( if str(field_value.id) in added_nodes: continue + property_key = str(data_point.id) + field_name + str(field_value.id) + visited_properties[property_key] = True + property_nodes, property_edges = await get_graph_from_model( field_value, include_root=True, @@ -112,9 +117,6 @@ async def get_graph_from_model( for edge in property_edges: edges.append(edge) - property_key = str(data_point.id) + field_name + str(field_value.id) - visited_properties[property_key] = True - return nodes, edges diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py index c9bcf99dc..4a4619e19 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py @@ -26,7 +26,6 @@ class CodeFile(DataPoint): class CodePart(DataPoint): - part_of: CodeFile source_code: str metadata: dict = {"index_fields": []} @@ -67,27 +66,25 @@ async def test_circular_reference_extraction(): source_code="source code", part_of=repo, contains=[], - depends_on=[ - CodeFile( - id=uuid5(NAMESPACE_OID, f"file{random_id}"), - source_code="source code", - part_of=repo, - depends_on=[], - ) - for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))] - ], + depends_on=[], ) for file_index in range(1500) ] for code_file in code_files: + code_file.depends_on.extend( + [ + code_files[random.randint(0, len(code_files) - 1)] + for _ in range(2) + ] + ) code_file.contains.extend( [ CodePart( part_of=code_file, source_code=f"Part {part_index}", ) - for part_index in range(random.randint(1, 20)) + for part_index in range(2) ] ) @@ -96,12 +93,18 @@ async def test_circular_reference_extraction(): added_nodes = {} added_edges = {} + visited_properties = {} start = time.perf_counter_ns() results = await asyncio.gather( *[ - get_graph_from_model(code_file, added_nodes=added_nodes, added_edges=added_edges) + get_graph_from_model( + code_file, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) for code_file in code_files ] ) @@ -114,8 +117,12 @@ async def test_circular_reference_extraction(): nodes.extend(result_nodes) edges.extend(result_edges) - assert len(nodes) == 1501 - assert len(edges) == 1501 * 20 + 1500 * 5 + code_files = [node for node in nodes if node.type == "CodeFile"] + code_parts = [node for node in nodes if node.type == "CodePart"] + + assert len(code_files) == 1500 + assert len(code_parts) == 3000 + assert len(edges) == 7500 if __name__ == "__main__": diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_circular_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_circular_test.py new file mode 100644 index 000000000..8bc05f1c0 --- /dev/null +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_circular_test.py @@ -0,0 +1,67 @@ +import random +import pytest +import asyncio +from typing import List +from uuid import NAMESPACE_OID, uuid5 + + +from cognee.infrastructure.engine import DataPoint +from cognee.modules.graph.utils import get_graph_from_model + +random.seed(1500) + + +class Repository(DataPoint): + path: str + metadata: dict = {"index_fields": []} + + +class CodeFile(DataPoint): + part_of: Repository + contains: List["CodePart"] = [] + depends_on: List["CodeFile"] = [] + source_code: str + metadata: dict = {"index_fields": []} + + +class CodePart(DataPoint): + part_of: CodeFile + source_code: str + metadata: dict = {"index_fields": []} + + +CodeFile.model_rebuild() +CodePart.model_rebuild() + + +@pytest.mark.asyncio +async def test_circular_reference_extraction(): + repo = Repository(path="repo1") + + code_file_1 = CodeFile( + id=uuid5(NAMESPACE_OID, f"file_0"), + source_code="source code", + part_of=repo, + contains=[], + depends_on=[], + ) + code_part_1 = CodePart(source_code="part_0", part_of=code_file_1) + code_file_1.contains.append(code_part_1) + + added_nodes = {} + added_edges = {} + visited_properties = {} + + nodes, edges = await get_graph_from_model( + code_file_1, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + + assert len(nodes) == 3 + assert len(edges) == 3 + + +if __name__ == "__main__": + asyncio.run(test_circular_reference_extraction()) diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py deleted file mode 100644 index e115b8c9b..000000000 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py +++ /dev/null @@ -1,79 +0,0 @@ -import pytest -import asyncio -from typing import List -from uuid import NAMESPACE_OID, uuid5 - - -from cognee.infrastructure.engine import DataPoint -from cognee.modules.graph.utils import get_graph_from_model - - -class Document(DataPoint): - path: str - metadata: dict = {"index_fields": []} - - -class DocumentChunk(DataPoint): - part_of: Document - text: str - contains: List["Entity"] = None - metadata: dict = {"index_fields": ["text"]} - - -class EntityType(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - -class Entity(DataPoint): - name: str - is_type: EntityType - metadata: dict = {"index_fields": ["name"]} - - -DocumentChunk.model_rebuild() - - -@pytest.mark.asyncio -async def get_graph_from_model_test(): - document = Document(path="file_path") - - document_chunk = DocumentChunk( - id=uuid5(NAMESPACE_OID, "file_name"), - text="some text", - part_of=document, - contains=[], - ) - - document_chunk.contains.append( - Entity( - name="Entity", - is_type=EntityType( - name="Type 1", - ), - ) - ) - - added_nodes = {} - added_edges = {} - visited_properties = {} - - result = await get_graph_from_model( - document_chunk, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - - nodes = result[0] - edges = result[1] - - assert len(nodes) == 4 - assert len(edges) == 3 - - document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes)) - assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed" - - -if __name__ == "__main__": - asyncio.run(get_graph_from_model_test()) diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py deleted file mode 100644 index 6b6837eaf..000000000 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import asyncio -from typing import List -from uuid import NAMESPACE_OID, uuid5 - - -from cognee.infrastructure.engine import DataPoint -from cognee.modules.graph.utils import get_graph_from_model - - -class Document(DataPoint): - path: str - metadata: dict = {"index_fields": []} - - -class DocumentChunk(DataPoint): - part_of: Document - text: str - contains: List["Entity"] = None - metadata: dict = {"index_fields": ["text"]} - - -class EntityType(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - -class Entity(DataPoint): - name: str - is_type: EntityType - metadata: dict = {"index_fields": ["name"]} - - -DocumentChunk.model_rebuild() - - -@pytest.mark.asyncio -async def get_graph_from_model_test(): - document = Document(path="file_path") - - document_chunks = [ - DocumentChunk( - id=uuid5(NAMESPACE_OID, f"file{file_index}"), - text="some text", - part_of=document, - contains=[], - ) - for file_index in range(1) - ] - - for document_chunk in document_chunks: - document_chunk.contains.append( - Entity( - name="Entity", - is_type=EntityType( - name="Type 1", - ), - ) - ) - - nodes = [] - edges = [] - - added_nodes = {} - added_edges = {} - visited_properties = {} - - results = await asyncio.gather( - *[ - get_graph_from_model( - document_chunk, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - for document_chunk in document_chunks - ] - ) - - for result_nodes, result_edges in results: - nodes.extend(result_nodes) - edges.extend(result_edges) - - assert len(nodes) == 4 - assert len(edges) == 3 - - -if __name__ == "__main__": - asyncio.run(get_graph_from_model_test())