fix: fixes cognify duplicated edges and resets the methods to an older version
This commit is contained in:
parent
b0eb9af9c2
commit
6841c83566
2 changed files with 83 additions and 110 deletions
|
|
@ -1,16 +1,8 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import copy_model
|
from cognee.modules.storage.utils import copy_model
|
||||||
|
|
||||||
|
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||||
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):
|
|
||||||
|
|
||||||
if not added_nodes:
|
|
||||||
added_nodes = {}
|
|
||||||
if not added_edges:
|
|
||||||
added_edges = {}
|
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
|
|
@ -20,92 +12,85 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No
|
||||||
for field_name, field_value in data_point:
|
for field_name, field_value in data_point:
|
||||||
if field_name == "_metadata":
|
if field_name == "_metadata":
|
||||||
continue
|
continue
|
||||||
elif isinstance(field_value, DataPoint):
|
|
||||||
excluded_properties.add(field_name)
|
|
||||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
|
||||||
data_point,
|
|
||||||
field_name,
|
|
||||||
field_value,
|
|
||||||
nodes,
|
|
||||||
edges,
|
|
||||||
added_nodes,
|
|
||||||
added_edges,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif (
|
if isinstance(field_value, DataPoint):
|
||||||
isinstance(field_value, list)
|
|
||||||
and len(field_value) > 0
|
|
||||||
and isinstance(field_value[0], DataPoint)
|
|
||||||
):
|
|
||||||
excluded_properties.add(field_name)
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
for item in field_value:
|
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||||
n_edges_before = len(edges)
|
|
||||||
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
|
|
||||||
data_point, field_name, item, nodes, edges, added_nodes, added_edges
|
|
||||||
)
|
|
||||||
edges = edges[:n_edges_before] + [
|
|
||||||
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
|
|
||||||
for edge in edges[n_edges_before:]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
data_point_properties[field_name] = field_value
|
|
||||||
|
|
||||||
SimpleDataPointModel = copy_model(
|
for node in property_nodes:
|
||||||
type(data_point),
|
if str(node.id) not in added_nodes:
|
||||||
include_fields={
|
nodes.append(node)
|
||||||
"_metadata": (dict, data_point._metadata),
|
added_nodes[str(node.id)] = True
|
||||||
},
|
|
||||||
exclude_fields=excluded_properties,
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
return nodes, edges
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[str(edge_key)] = True
|
||||||
|
|
||||||
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
|
|
||||||
def add_nodes_and_edges(
|
if str(edge_key) not in added_edges:
|
||||||
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
):
|
|
||||||
|
|
||||||
property_nodes, property_edges = get_graph_from_model(
|
|
||||||
field_value, dict(added_nodes), dict(added_edges)
|
|
||||||
)
|
|
||||||
|
|
||||||
for node in property_nodes:
|
|
||||||
if str(node.id) not in added_nodes:
|
|
||||||
nodes.append(node)
|
|
||||||
added_nodes[str(node.id)] = True
|
|
||||||
|
|
||||||
for edge in property_edges:
|
|
||||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append(edge)
|
|
||||||
added_edges[str(edge_key)] = True
|
|
||||||
|
|
||||||
for property_node in get_own_properties(property_nodes, property_edges):
|
|
||||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
|
||||||
|
|
||||||
if str(edge_key) not in added_edges:
|
|
||||||
edges.append(
|
|
||||||
(
|
|
||||||
data_point.id,
|
|
||||||
property_node.id,
|
|
||||||
field_name,
|
|
||||||
{
|
|
||||||
"source_node_id": data_point.id,
|
"source_node_id": data_point.id,
|
||||||
"target_node_id": property_node.id,
|
"target_node_id": property_node.id,
|
||||||
"relationship_name": field_name,
|
"relationship_name": field_name,
|
||||||
"updated_at": datetime.now(timezone.utc).strftime(
|
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
"%Y-%m-%d %H:%M:%S"
|
}))
|
||||||
),
|
added_edges[str(edge_key)] = True
|
||||||
},
|
continue
|
||||||
)
|
|
||||||
)
|
|
||||||
added_edges[str(edge_key)] = True
|
|
||||||
|
|
||||||
return (nodes, edges, added_nodes, added_edges)
|
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||||
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
|
for item in field_value:
|
||||||
|
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||||
|
|
||||||
|
for node in property_nodes:
|
||||||
|
if str(node.id) not in added_nodes:
|
||||||
|
nodes.append(node)
|
||||||
|
added_nodes[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[edge_key] = True
|
||||||
|
|
||||||
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
|
"source_node_id": data_point.id,
|
||||||
|
"target_node_id": property_node.id,
|
||||||
|
"relationship_name": field_name,
|
||||||
|
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"metadata": {
|
||||||
|
"type": "list"
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
added_edges[edge_key] = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_point_properties[field_name] = field_value
|
||||||
|
|
||||||
|
SimpleDataPointModel = copy_model(
|
||||||
|
type(data_point),
|
||||||
|
include_fields = {
|
||||||
|
"_metadata": (dict, data_point._metadata),
|
||||||
|
},
|
||||||
|
exclude_fields = excluded_properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_root:
|
||||||
|
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||||
|
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
|
||||||
def get_own_properties(property_nodes, property_edges):
|
def get_own_properties(property_nodes, property_edges):
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,29 @@
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import copy_model
|
from cognee.modules.storage.utils import copy_model
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_from_graph(
|
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
|
||||||
nodes: list[DataPoint],
|
node_map = {}
|
||||||
edges: list[tuple[str, str, str, dict[str, str]]],
|
|
||||||
entity_id: str,
|
|
||||||
):
|
|
||||||
node_map = {node.id: node for node in nodes}
|
|
||||||
|
|
||||||
for source_node_id, target_node_id, edge_label, edge_properties in edges:
|
for node in nodes:
|
||||||
source_node = node_map[source_node_id]
|
node_map[node.id] = node
|
||||||
target_node = node_map[target_node_id]
|
|
||||||
|
for edge in edges:
|
||||||
|
source_node = node_map[edge[0]]
|
||||||
|
target_node = node_map[edge[1]]
|
||||||
|
edge_label = edge[2]
|
||||||
|
edge_properties = edge[3] if len(edge) == 4 else {}
|
||||||
edge_metadata = edge_properties.get("metadata", {})
|
edge_metadata = edge_properties.get("metadata", {})
|
||||||
edge_type = edge_metadata.get("type", "default")
|
edge_type = edge_metadata.get("type")
|
||||||
|
|
||||||
if edge_type == "list":
|
if edge_type == "list":
|
||||||
NewModel = copy_model(
|
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
|
||||||
type(source_node),
|
|
||||||
{edge_label: (list[type(target_node)], PydanticUndefined)},
|
|
||||||
)
|
|
||||||
source_node_dict = source_node.model_dump()
|
|
||||||
source_node_edge_label_values = source_node_dict.get(edge_label, [])
|
|
||||||
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
|
|
||||||
|
|
||||||
node_map[source_node_id] = NewModel(**source_node_dict)
|
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
|
||||||
else:
|
else:
|
||||||
NewModel = copy_model(
|
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
|
||||||
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
|
|
||||||
)
|
|
||||||
|
|
||||||
node_map[target_node_id] = NewModel(
|
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
|
||||||
**source_node.model_dump(), **{edge_label: target_node}
|
|
||||||
)
|
|
||||||
|
|
||||||
return node_map[entity_id]
|
return node_map[entity_id]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue