Refactor get_graph_from_model
This commit is contained in:
parent
2c0fce32d3
commit
05ea357520
1 changed files with 79 additions and 57 deletions
|
|
@ -2,9 +2,70 @@ from datetime import datetime, timezone
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges):
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||
def get_graph_from_model(
|
||||
data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None
|
||||
):
|
||||
|
||||
if not added_nodes:
|
||||
added_nodes = {}
|
||||
if not added_edges:
|
||||
added_edges = {}
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
data_point_properties = {}
|
||||
excluded_properties = set()
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
elif isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(field_value, list)
|
||||
and len(field_value) > 0
|
||||
and isinstance(field_value[0], DataPoint)
|
||||
):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
||||
data_point, field_name, item, nodes, edges, added_nodes, added_edges
|
||||
)
|
||||
edges = [
|
||||
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
|
||||
for edge in edges
|
||||
]
|
||||
else:
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields={
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
},
|
||||
exclude_fields=excluded_properties,
|
||||
)
|
||||
|
||||
if include_root:
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def add_nodes_and_edges(
|
||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
||||
):
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(
|
||||
field_value, True, added_nodes, added_edges
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
|
|
@ -22,63 +83,24 @@ def add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added
|
|||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
edges.append(
|
||||
(
|
||||
data_point.id,
|
||||
property_node.id,
|
||||
field_name,
|
||||
{
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
return(nodes, edges, added_nodes, added_edges)
|
||||
|
||||
|
||||
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = None, added_edges = None):
|
||||
|
||||
if not added_nodes:
|
||||
added_nodes = {}
|
||||
if not added_edges:
|
||||
added_edges = {}
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
data_point_properties = {}
|
||||
excluded_properties = set()
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
|
||||
if isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, field_value, nodes, edges, added_nodes, added_edges)
|
||||
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(data_point, field_name, item, nodes, edges, added_nodes, added_edges)
|
||||
edges = [(*edge[:3],{**edge[3], "metadata": {"type": "list"}}) for edge in edges]
|
||||
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
)
|
||||
|
||||
if include_root:
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
|
||||
return nodes, edges
|
||||
return (nodes, edges, added_nodes, added_edges)
|
||||
|
||||
|
||||
def get_own_properties(property_nodes, property_edges):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue