fix: break circular data points in a graph model

This commit is contained in:
Boris Arzentar 2025-07-09 00:33:23 +02:00
parent b99b0455b0
commit 9c5f1a2686
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
5 changed files with 93 additions and 185 deletions

View file

@ -35,6 +35,7 @@ async def get_graph_from_model(
if property_key in visited_properties:
continue
visited_properties[property_key] = True
properties_to_visit.add(field_name)
continue
@ -52,6 +53,7 @@ async def get_graph_from_model(
if property_key in visited_properties:
continue
visited_properties[property_key] = True
properties_to_visit.add(f"{field_name}.{index}")
continue
@ -98,6 +100,9 @@ async def get_graph_from_model(
if str(field_value.id) in added_nodes:
continue
property_key = str(data_point.id) + field_name + str(field_value.id)
visited_properties[property_key] = True
property_nodes, property_edges = await get_graph_from_model(
field_value,
include_root=True,
@ -112,9 +117,6 @@ async def get_graph_from_model(
for edge in property_edges:
edges.append(edge)
property_key = str(data_point.id) + field_name + str(field_value.id)
visited_properties[property_key] = True
return nodes, edges

View file

@ -26,7 +26,6 @@ class CodeFile(DataPoint):
class CodePart(DataPoint):
part_of: CodeFile
source_code: str
metadata: dict = {"index_fields": []}
@ -67,27 +66,25 @@ async def test_circular_reference_extraction():
source_code="source code",
part_of=repo,
contains=[],
depends_on=[
CodeFile(
id=uuid5(NAMESPACE_OID, f"file{random_id}"),
source_code="source code",
part_of=repo,
depends_on=[],
)
for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))]
],
depends_on=[],
)
for file_index in range(1500)
]
for code_file in code_files:
code_file.depends_on.extend(
[
code_files[random.randint(0, len(code_files) - 1)]
for _ in range(2)
]
)
code_file.contains.extend(
[
CodePart(
part_of=code_file,
source_code=f"Part {part_index}",
)
for part_index in range(random.randint(1, 20))
for part_index in range(2)
]
)
@ -96,12 +93,18 @@ async def test_circular_reference_extraction():
added_nodes = {}
added_edges = {}
visited_properties = {}
start = time.perf_counter_ns()
results = await asyncio.gather(
*[
get_graph_from_model(code_file, added_nodes=added_nodes, added_edges=added_edges)
get_graph_from_model(
code_file,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for code_file in code_files
]
)
@ -114,8 +117,12 @@ async def test_circular_reference_extraction():
nodes.extend(result_nodes)
edges.extend(result_edges)
assert len(nodes) == 1501
assert len(edges) == 1501 * 20 + 1500 * 5
code_files = [node for node in nodes if node.type == "CodeFile"]
code_parts = [node for node in nodes if node.type == "CodePart"]
assert len(code_files) == 1500
assert len(code_parts) == 3000
assert len(edges) == 7500
if __name__ == "__main__":

View file

@ -0,0 +1,67 @@
import random
import pytest
import asyncio
from typing import List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
random.seed(1500)
class Repository(DataPoint):
path: str
metadata: dict = {"index_fields": []}
class CodeFile(DataPoint):
part_of: Repository
contains: List["CodePart"] = []
depends_on: List["CodeFile"] = []
source_code: str
metadata: dict = {"index_fields": []}
class CodePart(DataPoint):
part_of: CodeFile
source_code: str
metadata: dict = {"index_fields": []}
CodeFile.model_rebuild()
CodePart.model_rebuild()
@pytest.mark.asyncio
async def test_circular_reference_extraction():
repo = Repository(path="repo1")
code_file_1 = CodeFile(
id=uuid5(NAMESPACE_OID, f"file_0"),
source_code="source code",
part_of=repo,
contains=[],
depends_on=[],
)
code_part_1 = CodePart(source_code="part_0", part_of=code_file_1)
code_file_1.contains.append(code_part_1)
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = await get_graph_from_model(
code_file_1,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
assert len(nodes) == 3
assert len(edges) == 3
if __name__ == "__main__":
asyncio.run(test_circular_reference_extraction())

View file

@ -1,79 +0,0 @@
import pytest
import asyncio
from typing import List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
metadata: dict = {"index_fields": []}
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
metadata: dict = {"index_fields": ["text"]}
class EntityType(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Entity(DataPoint):
name: str
is_type: EntityType
metadata: dict = {"index_fields": ["name"]}
DocumentChunk.model_rebuild()
@pytest.mark.asyncio
async def get_graph_from_model_test():
document = Document(path="file_path")
document_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, "file_name"),
text="some text",
part_of=document,
contains=[],
)
document_chunk.contains.append(
Entity(
name="Entity",
is_type=EntityType(
name="Type 1",
),
)
)
added_nodes = {}
added_edges = {}
visited_properties = {}
result = await get_graph_from_model(
document_chunk,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes = result[0]
edges = result[1]
assert len(nodes) == 4
assert len(edges) == 3
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
if __name__ == "__main__":
asyncio.run(get_graph_from_model_test())

View file

@ -1,89 +0,0 @@
import pytest
import asyncio
from typing import List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
metadata: dict = {"index_fields": []}
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
metadata: dict = {"index_fields": ["text"]}
class EntityType(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Entity(DataPoint):
name: str
is_type: EntityType
metadata: dict = {"index_fields": ["name"]}
DocumentChunk.model_rebuild()
@pytest.mark.asyncio
async def get_graph_from_model_test():
document = Document(path="file_path")
document_chunks = [
DocumentChunk(
id=uuid5(NAMESPACE_OID, f"file{file_index}"),
text="some text",
part_of=document,
contains=[],
)
for file_index in range(1)
]
for document_chunk in document_chunks:
document_chunk.contains.append(
Entity(
name="Entity",
is_type=EntityType(
name="Type 1",
),
)
)
nodes = []
edges = []
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(
*[
get_graph_from_model(
document_chunk,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for document_chunk in document_chunks
]
)
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
assert len(nodes) == 4
assert len(edges) == 3
if __name__ == "__main__":
asyncio.run(get_graph_from_model_test())