From 68700f32c78c99610f6fa8a2a4bbe983f6cb905d Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Fri, 8 Nov 2024 15:31:02 +0100 Subject: [PATCH] fix: add code graph generation pipeline --- .../databases/graph/networkx/adapter.py | 30 ++++++++++++++----- .../graph/utils/get_graph_from_model.py | 2 +- cognee/shared/utils.py | 2 +- cognee/tasks/graph/__init__.py | 1 + cognee/tasks/storage/index_data_points.py | 2 +- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index dcb05c2ed..6c7abd498 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface): def __init__(self, filename = "cognee_graph.pkl"): self.filename = filename + async def get_graph_data(self): + await self.load_graph_from_file() + return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True))) + async def query(self, query: str, params: dict): pass @@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface): async with aiofiles.open(file_path, "r") as file: graph_data = json.loads(await file.read()) for node in graph_data["nodes"]: - node["id"] = UUID(node["id"]) - node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + node["id"] = UUID(node["id"]) + except: + pass + if "updated_at" in node: + node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") for edge in graph_data["links"]: - edge["source"] = UUID(edge["source"]) - edge["target"] = UUID(edge["target"]) - edge["source_node_id"] = UUID(edge["source_node_id"]) - edge["target_node_id"] = UUID(edge["target_node_id"]) - edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + source_id = UUID(edge["source"]) + target_id = UUID(edge["target"]) + + edge["source"] = source_id + edge["target"] = target_id + edge["source_node_id"] = source_id + edge["target_node_id"] = target_id + except: + pass + + if "updated_at" in node: + edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) else: diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 35e00fb5d..29137ddc7 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes added_edges[str(edge_key)] = True continue - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index e32fad15e..42a95b88b 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key): source: str(edge[0]), target: str(edge[1]), edge_key: str(edge[2]), - } for edge in graph.edges] + } for edge in graph.edges(keys = True, data = True)] return pd.DataFrame(edge_list) diff --git a/cognee/tasks/graph/__init__.py b/cognee/tasks/graph/__init__.py index 94dc82f20..eafc12921 100644 --- a/cognee/tasks/graph/__init__.py +++ b/cognee/tasks/graph/__init__.py @@ -1,2 +1,3 @@ from .extract_graph_from_data import extract_graph_from_data +from .extract_graph_from_code import extract_graph_from_code from .query_graph_connections import query_graph_connections diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 681fbaa1f..dc74d705d 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): for field_value_item in field_value: new_data_points = get_data_points_from_model(field_value_item, added_data_points)