fix: break circular data points in a graph model
This commit is contained in:
parent
b99b0455b0
commit
9c5f1a2686
5 changed files with 93 additions and 185 deletions
|
|
@ -35,6 +35,7 @@ async def get_graph_from_model(
|
||||||
if property_key in visited_properties:
|
if property_key in visited_properties:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
visited_properties[property_key] = True
|
||||||
properties_to_visit.add(field_name)
|
properties_to_visit.add(field_name)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
@ -52,6 +53,7 @@ async def get_graph_from_model(
|
||||||
if property_key in visited_properties:
|
if property_key in visited_properties:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
visited_properties[property_key] = True
|
||||||
properties_to_visit.add(f"{field_name}.{index}")
|
properties_to_visit.add(f"{field_name}.{index}")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
@ -98,6 +100,9 @@ async def get_graph_from_model(
|
||||||
if str(field_value.id) in added_nodes:
|
if str(field_value.id) in added_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
property_key = str(data_point.id) + field_name + str(field_value.id)
|
||||||
|
visited_properties[property_key] = True
|
||||||
|
|
||||||
property_nodes, property_edges = await get_graph_from_model(
|
property_nodes, property_edges = await get_graph_from_model(
|
||||||
field_value,
|
field_value,
|
||||||
include_root=True,
|
include_root=True,
|
||||||
|
|
@ -112,9 +117,6 @@ async def get_graph_from_model(
|
||||||
for edge in property_edges:
|
for edge in property_edges:
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
|
|
||||||
property_key = str(data_point.id) + field_name + str(field_value.id)
|
|
||||||
visited_properties[property_key] = True
|
|
||||||
|
|
||||||
return nodes, edges
|
return nodes, edges
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ class CodeFile(DataPoint):
|
||||||
|
|
||||||
|
|
||||||
class CodePart(DataPoint):
|
class CodePart(DataPoint):
|
||||||
part_of: CodeFile
|
|
||||||
source_code: str
|
source_code: str
|
||||||
metadata: dict = {"index_fields": []}
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
@ -67,27 +66,25 @@ async def test_circular_reference_extraction():
|
||||||
source_code="source code",
|
source_code="source code",
|
||||||
part_of=repo,
|
part_of=repo,
|
||||||
contains=[],
|
contains=[],
|
||||||
depends_on=[
|
depends_on=[],
|
||||||
CodeFile(
|
|
||||||
id=uuid5(NAMESPACE_OID, f"file{random_id}"),
|
|
||||||
source_code="source code",
|
|
||||||
part_of=repo,
|
|
||||||
depends_on=[],
|
|
||||||
)
|
|
||||||
for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))]
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
for file_index in range(1500)
|
for file_index in range(1500)
|
||||||
]
|
]
|
||||||
|
|
||||||
for code_file in code_files:
|
for code_file in code_files:
|
||||||
|
code_file.depends_on.extend(
|
||||||
|
[
|
||||||
|
code_files[random.randint(0, len(code_files) - 1)]
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
code_file.contains.extend(
|
code_file.contains.extend(
|
||||||
[
|
[
|
||||||
CodePart(
|
CodePart(
|
||||||
part_of=code_file,
|
part_of=code_file,
|
||||||
source_code=f"Part {part_index}",
|
source_code=f"Part {part_index}",
|
||||||
)
|
)
|
||||||
for part_index in range(random.randint(1, 20))
|
for part_index in range(2)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -96,12 +93,18 @@ async def test_circular_reference_extraction():
|
||||||
|
|
||||||
added_nodes = {}
|
added_nodes = {}
|
||||||
added_edges = {}
|
added_edges = {}
|
||||||
|
visited_properties = {}
|
||||||
|
|
||||||
start = time.perf_counter_ns()
|
start = time.perf_counter_ns()
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
get_graph_from_model(code_file, added_nodes=added_nodes, added_edges=added_edges)
|
get_graph_from_model(
|
||||||
|
code_file,
|
||||||
|
added_nodes=added_nodes,
|
||||||
|
added_edges=added_edges,
|
||||||
|
visited_properties=visited_properties,
|
||||||
|
)
|
||||||
for code_file in code_files
|
for code_file in code_files
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -114,8 +117,12 @@ async def test_circular_reference_extraction():
|
||||||
nodes.extend(result_nodes)
|
nodes.extend(result_nodes)
|
||||||
edges.extend(result_edges)
|
edges.extend(result_edges)
|
||||||
|
|
||||||
assert len(nodes) == 1501
|
code_files = [node for node in nodes if node.type == "CodeFile"]
|
||||||
assert len(edges) == 1501 * 20 + 1500 * 5
|
code_parts = [node for node in nodes if node.type == "CodePart"]
|
||||||
|
|
||||||
|
assert len(code_files) == 1500
|
||||||
|
assert len(code_parts) == 3000
|
||||||
|
assert len(edges) == 7500
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import random
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.graph.utils import get_graph_from_model
|
||||||
|
|
||||||
|
random.seed(1500)
|
||||||
|
|
||||||
|
|
||||||
|
class Repository(DataPoint):
|
||||||
|
path: str
|
||||||
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFile(DataPoint):
|
||||||
|
part_of: Repository
|
||||||
|
contains: List["CodePart"] = []
|
||||||
|
depends_on: List["CodeFile"] = []
|
||||||
|
source_code: str
|
||||||
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
||||||
|
class CodePart(DataPoint):
|
||||||
|
part_of: CodeFile
|
||||||
|
source_code: str
|
||||||
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
||||||
|
CodeFile.model_rebuild()
|
||||||
|
CodePart.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_circular_reference_extraction():
|
||||||
|
repo = Repository(path="repo1")
|
||||||
|
|
||||||
|
code_file_1 = CodeFile(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"file_0"),
|
||||||
|
source_code="source code",
|
||||||
|
part_of=repo,
|
||||||
|
contains=[],
|
||||||
|
depends_on=[],
|
||||||
|
)
|
||||||
|
code_part_1 = CodePart(source_code="part_0", part_of=code_file_1)
|
||||||
|
code_file_1.contains.append(code_part_1)
|
||||||
|
|
||||||
|
added_nodes = {}
|
||||||
|
added_edges = {}
|
||||||
|
visited_properties = {}
|
||||||
|
|
||||||
|
nodes, edges = await get_graph_from_model(
|
||||||
|
code_file_1,
|
||||||
|
added_nodes=added_nodes,
|
||||||
|
added_edges=added_edges,
|
||||||
|
visited_properties=visited_properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(nodes) == 3
|
||||||
|
assert len(edges) == 3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_circular_reference_extraction())
|
||||||
|
|
@ -1,79 +0,0 @@
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
|
||||||
|
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
|
||||||
from cognee.modules.graph.utils import get_graph_from_model
|
|
||||||
|
|
||||||
|
|
||||||
class Document(DataPoint):
|
|
||||||
path: str
|
|
||||||
metadata: dict = {"index_fields": []}
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(DataPoint):
|
|
||||||
part_of: Document
|
|
||||||
text: str
|
|
||||||
contains: List["Entity"] = None
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
|
||||||
|
|
||||||
|
|
||||||
class EntityType(DataPoint):
|
|
||||||
name: str
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(DataPoint):
|
|
||||||
name: str
|
|
||||||
is_type: EntityType
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
|
||||||
|
|
||||||
|
|
||||||
DocumentChunk.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def get_graph_from_model_test():
|
|
||||||
document = Document(path="file_path")
|
|
||||||
|
|
||||||
document_chunk = DocumentChunk(
|
|
||||||
id=uuid5(NAMESPACE_OID, "file_name"),
|
|
||||||
text="some text",
|
|
||||||
part_of=document,
|
|
||||||
contains=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
document_chunk.contains.append(
|
|
||||||
Entity(
|
|
||||||
name="Entity",
|
|
||||||
is_type=EntityType(
|
|
||||||
name="Type 1",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
added_nodes = {}
|
|
||||||
added_edges = {}
|
|
||||||
visited_properties = {}
|
|
||||||
|
|
||||||
result = await get_graph_from_model(
|
|
||||||
document_chunk,
|
|
||||||
added_nodes=added_nodes,
|
|
||||||
added_edges=added_edges,
|
|
||||||
visited_properties=visited_properties,
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = result[0]
|
|
||||||
edges = result[1]
|
|
||||||
|
|
||||||
assert len(nodes) == 4
|
|
||||||
assert len(edges) == 3
|
|
||||||
|
|
||||||
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
|
|
||||||
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(get_graph_from_model_test())
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
|
||||||
|
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
|
||||||
from cognee.modules.graph.utils import get_graph_from_model
|
|
||||||
|
|
||||||
|
|
||||||
class Document(DataPoint):
|
|
||||||
path: str
|
|
||||||
metadata: dict = {"index_fields": []}
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(DataPoint):
|
|
||||||
part_of: Document
|
|
||||||
text: str
|
|
||||||
contains: List["Entity"] = None
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
|
||||||
|
|
||||||
|
|
||||||
class EntityType(DataPoint):
|
|
||||||
name: str
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(DataPoint):
|
|
||||||
name: str
|
|
||||||
is_type: EntityType
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
|
||||||
|
|
||||||
|
|
||||||
DocumentChunk.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def get_graph_from_model_test():
|
|
||||||
document = Document(path="file_path")
|
|
||||||
|
|
||||||
document_chunks = [
|
|
||||||
DocumentChunk(
|
|
||||||
id=uuid5(NAMESPACE_OID, f"file{file_index}"),
|
|
||||||
text="some text",
|
|
||||||
part_of=document,
|
|
||||||
contains=[],
|
|
||||||
)
|
|
||||||
for file_index in range(1)
|
|
||||||
]
|
|
||||||
|
|
||||||
for document_chunk in document_chunks:
|
|
||||||
document_chunk.contains.append(
|
|
||||||
Entity(
|
|
||||||
name="Entity",
|
|
||||||
is_type=EntityType(
|
|
||||||
name="Type 1",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = []
|
|
||||||
edges = []
|
|
||||||
|
|
||||||
added_nodes = {}
|
|
||||||
added_edges = {}
|
|
||||||
visited_properties = {}
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[
|
|
||||||
get_graph_from_model(
|
|
||||||
document_chunk,
|
|
||||||
added_nodes=added_nodes,
|
|
||||||
added_edges=added_edges,
|
|
||||||
visited_properties=visited_properties,
|
|
||||||
)
|
|
||||||
for document_chunk in document_chunks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for result_nodes, result_edges in results:
|
|
||||||
nodes.extend(result_nodes)
|
|
||||||
edges.extend(result_edges)
|
|
||||||
|
|
||||||
assert len(nodes) == 4
|
|
||||||
assert len(edges) == 3
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(get_graph_from_model_test())
|
|
||||||
Loading…
Add table
Reference in a new issue