fix: break circular data points in a graph model
This commit is contained in:
parent
b99b0455b0
commit
9c5f1a2686
5 changed files with 93 additions and 185 deletions
|
|
@ -35,6 +35,7 @@ async def get_graph_from_model(
|
|||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
properties_to_visit.add(field_name)
|
||||
|
||||
continue
|
||||
|
|
@ -52,6 +53,7 @@ async def get_graph_from_model(
|
|||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
properties_to_visit.add(f"{field_name}.{index}")
|
||||
|
||||
continue
|
||||
|
|
@ -98,6 +100,9 @@ async def get_graph_from_model(
|
|||
if str(field_value.id) in added_nodes:
|
||||
continue
|
||||
|
||||
property_key = str(data_point.id) + field_name + str(field_value.id)
|
||||
visited_properties[property_key] = True
|
||||
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
field_value,
|
||||
include_root=True,
|
||||
|
|
@ -112,9 +117,6 @@ async def get_graph_from_model(
|
|||
for edge in property_edges:
|
||||
edges.append(edge)
|
||||
|
||||
property_key = str(data_point.id) + field_name + str(field_value.id)
|
||||
visited_properties[property_key] = True
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ class CodeFile(DataPoint):
|
|||
|
||||
|
||||
class CodePart(DataPoint):
|
||||
part_of: CodeFile
|
||||
source_code: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
|
@ -67,27 +66,25 @@ async def test_circular_reference_extraction():
|
|||
source_code="source code",
|
||||
part_of=repo,
|
||||
contains=[],
|
||||
depends_on=[
|
||||
CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, f"file{random_id}"),
|
||||
source_code="source code",
|
||||
part_of=repo,
|
||||
depends_on=[],
|
||||
)
|
||||
for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))]
|
||||
],
|
||||
depends_on=[],
|
||||
)
|
||||
for file_index in range(1500)
|
||||
]
|
||||
|
||||
for code_file in code_files:
|
||||
code_file.depends_on.extend(
|
||||
[
|
||||
code_files[random.randint(0, len(code_files) - 1)]
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
code_file.contains.extend(
|
||||
[
|
||||
CodePart(
|
||||
part_of=code_file,
|
||||
source_code=f"Part {part_index}",
|
||||
)
|
||||
for part_index in range(random.randint(1, 20))
|
||||
for part_index in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -96,12 +93,18 @@ async def test_circular_reference_extraction():
|
|||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
start = time.perf_counter_ns()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
get_graph_from_model(code_file, added_nodes=added_nodes, added_edges=added_edges)
|
||||
get_graph_from_model(
|
||||
code_file,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
for code_file in code_files
|
||||
]
|
||||
)
|
||||
|
|
@ -114,8 +117,12 @@ async def test_circular_reference_extraction():
|
|||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
assert len(nodes) == 1501
|
||||
assert len(edges) == 1501 * 20 + 1500 * 5
|
||||
code_files = [node for node in nodes if node.type == "CodeFile"]
|
||||
code_parts = [node for node in nodes if node.type == "CodePart"]
|
||||
|
||||
assert len(code_files) == 1500
|
||||
assert len(code_parts) == 3000
|
||||
assert len(edges) == 7500
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
import random
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
random.seed(1500)
|
||||
|
||||
|
||||
class Repository(DataPoint):
|
||||
path: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
part_of: Repository
|
||||
contains: List["CodePart"] = []
|
||||
depends_on: List["CodeFile"] = []
|
||||
source_code: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodePart(DataPoint):
|
||||
part_of: CodeFile
|
||||
source_code: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
CodeFile.model_rebuild()
|
||||
CodePart.model_rebuild()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circular_reference_extraction():
|
||||
repo = Repository(path="repo1")
|
||||
|
||||
code_file_1 = CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, f"file_0"),
|
||||
source_code="source code",
|
||||
part_of=repo,
|
||||
contains=[],
|
||||
depends_on=[],
|
||||
)
|
||||
code_part_1 = CodePart(source_code="part_0", part_of=code_file_1)
|
||||
code_file_1.contains.append(code_part_1)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
nodes, edges = await get_graph_from_model(
|
||||
code_file_1,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
|
||||
assert len(nodes) == 3
|
||||
assert len(edges) == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_circular_reference_extraction())
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def get_graph_from_model_test():
|
||||
document = Document(path="file_path")
|
||||
|
||||
document_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, "file_name"),
|
||||
text="some text",
|
||||
part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
document_chunk.contains.append(
|
||||
Entity(
|
||||
name="Entity",
|
||||
is_type=EntityType(
|
||||
name="Type 1",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
result = await get_graph_from_model(
|
||||
document_chunk,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
|
||||
nodes = result[0]
|
||||
edges = result[1]
|
||||
|
||||
assert len(nodes) == 4
|
||||
assert len(edges) == 3
|
||||
|
||||
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
|
||||
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(get_graph_from_model_test())
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def get_graph_from_model_test():
|
||||
document = Document(path="file_path")
|
||||
|
||||
document_chunks = [
|
||||
DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"file{file_index}"),
|
||||
text="some text",
|
||||
part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
for file_index in range(1)
|
||||
]
|
||||
|
||||
for document_chunk in document_chunks:
|
||||
document_chunk.contains.append(
|
||||
Entity(
|
||||
name="Entity",
|
||||
is_type=EntityType(
|
||||
name="Type 1",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
get_graph_from_model(
|
||||
document_chunk,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
for document_chunk in document_chunks
|
||||
]
|
||||
)
|
||||
|
||||
for result_nodes, result_edges in results:
|
||||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
assert len(nodes) == 4
|
||||
assert len(edges) == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(get_graph_from_model_test())
|
||||
Loading…
Add table
Reference in a new issue