fix: add code graph generation pipeline
This commit is contained in:
parent
e1e5e7336a
commit
68700f32c7
5 changed files with 27 additions and 10 deletions
|
|
@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
def __init__(self, filename = "cognee_graph.pkl"):
|
def __init__(self, filename = "cognee_graph.pkl"):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
|
||||||
|
async def get_graph_data(self):
|
||||||
|
await self.load_graph_from_file()
|
||||||
|
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||||
|
|
||||||
async def query(self, query: str, params: dict):
|
async def query(self, query: str, params: dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
async with aiofiles.open(file_path, "r") as file:
|
async with aiofiles.open(file_path, "r") as file:
|
||||||
graph_data = json.loads(await file.read())
|
graph_data = json.loads(await file.read())
|
||||||
for node in graph_data["nodes"]:
|
for node in graph_data["nodes"]:
|
||||||
node["id"] = UUID(node["id"])
|
try:
|
||||||
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
node["id"] = UUID(node["id"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if "updated_at" in node:
|
||||||
|
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
for edge in graph_data["links"]:
|
for edge in graph_data["links"]:
|
||||||
edge["source"] = UUID(edge["source"])
|
try:
|
||||||
edge["target"] = UUID(edge["target"])
|
source_id = UUID(edge["source"])
|
||||||
edge["source_node_id"] = UUID(edge["source_node_id"])
|
target_id = UUID(edge["target"])
|
||||||
edge["target_node_id"] = UUID(edge["target_node_id"])
|
|
||||||
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
edge["source"] = source_id
|
||||||
|
edge["target"] = target_id
|
||||||
|
edge["source_node_id"] = source_id
|
||||||
|
edge["target_node_id"] = target_id
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "updated_at" in node:
|
||||||
|
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
|
||||||
added_edges[str(edge_key)] = True
|
added_edges[str(edge_key)] = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||||
excluded_properties.add(field_name)
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
for item in field_value:
|
for item in field_value:
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key):
|
||||||
source: str(edge[0]),
|
source: str(edge[0]),
|
||||||
target: str(edge[1]),
|
target: str(edge[1]),
|
||||||
edge_key: str(edge[2]),
|
edge_key: str(edge[2]),
|
||||||
} for edge in graph.edges]
|
} for edge in graph.edges(keys = True, data = True)]
|
||||||
|
|
||||||
return pd.DataFrame(edge_list)
|
return pd.DataFrame(edge_list)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from .extract_graph_from_data import extract_graph_from_data
|
from .extract_graph_from_data import extract_graph_from_data
|
||||||
|
from .extract_graph_from_code import extract_graph_from_code
|
||||||
from .query_graph_connections import query_graph_connections
|
from .query_graph_connections import query_graph_connections
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
|
||||||
added_data_points[str(new_point.id)] = True
|
added_data_points[str(new_point.id)] = True
|
||||||
data_points.append(new_point)
|
data_points.append(new_point)
|
||||||
|
|
||||||
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||||
for field_value_item in field_value:
|
for field_value_item in field_value:
|
||||||
new_data_points = get_data_points_from_model(field_value_item, added_data_points)
|
new_data_points = get_data_points_from_model(field_value_item, added_data_points)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue