feat: Cog-693 expand dependency graph
Expand each file node into a subgraph containing high-level code parts - Implemented `extract_code_parts` to parse and extract high-level components (classes, functions, imports, and top-level code) from Python source files using `parso`. - Developed `expand_dependency_graph` to expand Python file nodes into their components. - Included a checker script
This commit is contained in:
parent
d33c740dc6
commit
7ec5cffd8e
7 changed files with 308 additions and 0 deletions
27
cognee/tasks/code/enrich_dependency_graph_checker.py
Normal file
27
cognee/tasks/code/enrich_dependency_graph_checker.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
|
||||
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_dependency_graph(repo_path))
|
||||
graph = asyncio.run(enrich_dependency_graph(graph))
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
for _, target, data in graph.out_edges(node, data=True):
|
||||
print(f" Edge to {target}, data: {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
29
cognee/tasks/code/expand_dependency_graph_checker.py
Normal file
29
cognee/tasks/code/expand_dependency_graph_checker.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
|
||||
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
|
||||
from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_dependency_graph(repo_path))
|
||||
graph = asyncio.run(enrich_dependency_graph(graph))
|
||||
graph = expand_dependency_graph(graph)
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
for _, target, data in graph.out_edges(node, data=True):
|
||||
print(f" Edge to {target}, data: {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
cognee/tasks/code/get_repo_dependency_graph_checker.py
Normal file
27
cognee/tasks/code/get_repo_dependency_graph_checker.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_dependency_graph(repo_path))
|
||||
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
edges = graph.edges(node, data=True)
|
||||
for _, target, data in edges:
|
||||
print(f" Edge to {target}, Relation: {data.get('relation')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
cognee/tasks/repo_processor/enrich_dependency_graph.py
Normal file
56
cognee/tasks/repo_processor/enrich_dependency_graph.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import networkx as nx
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def topologically_sort_subgraph(subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph) -> List[str]:
|
||||
"""Performs a topological sort on a subgraph based on node indegrees."""
|
||||
results = []
|
||||
remaining_nodes = subgraph_node_to_indegree.copy()
|
||||
while remaining_nodes:
|
||||
next_node = min(remaining_nodes, key=remaining_nodes.get)
|
||||
results.append(next_node)
|
||||
for successor in graph.successors(next_node):
|
||||
if successor in remaining_nodes:
|
||||
remaining_nodes[successor] -= 1
|
||||
remaining_nodes.pop(next_node)
|
||||
return results
|
||||
|
||||
|
||||
def topologically_sort(graph: nx.DiGraph) -> List[str]:
|
||||
"""Performs a topological sort on the entire graph."""
|
||||
subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph))
|
||||
topological_order = []
|
||||
for subgraph in subgraphs:
|
||||
node_to_indegree = {
|
||||
node: len(list(subgraph.successors(node)))
|
||||
for node in subgraph.nodes
|
||||
}
|
||||
topological_order.extend(
|
||||
topologically_sort_subgraph(node_to_indegree, subgraph)
|
||||
)
|
||||
return topological_order
|
||||
|
||||
|
||||
def node_enrich_and_connect(graph: nx.MultiDiGraph, topological_order: List[str], node: str) -> None:
|
||||
"""Adds 'depends_on' edges to the graph based on topological order."""
|
||||
topological_rank = topological_order.index(node)
|
||||
graph.nodes[node]['topological_rank'] = topological_rank
|
||||
node_descendants = nx.descendants(graph, node)
|
||||
if graph.has_edge(node,node):
|
||||
node_descendants.add(node)
|
||||
for desc in node_descendants:
|
||||
if desc not in topological_order[:topological_rank+1]:
|
||||
continue
|
||||
graph.add_edge(node, desc, relation='depends_on')
|
||||
|
||||
|
||||
async def enrich_dependency_graph(graph: nx.DiGraph) -> nx.MultiDiGraph:
|
||||
"""Enriches the graph with topological ranks and 'depends_on' edges."""
|
||||
graph = nx.MultiDiGraph(graph)
|
||||
topological_order = topologically_sort(graph)
|
||||
node_rank_map = {node: idx for idx, node in enumerate(topological_order)}
|
||||
for node in graph.nodes:
|
||||
if node not in node_rank_map:
|
||||
continue
|
||||
node_enrich_and_connect(graph, topological_order, node)
|
||||
return graph
|
||||
49
cognee/tasks/repo_processor/expand_dependency_graph.py
Normal file
49
cognee/tasks/repo_processor/expand_dependency_graph.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import networkx as nx
|
||||
|
||||
from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts
|
||||
from cognee.tasks.repo_processor import logger
|
||||
|
||||
|
||||
def _add_code_parts_nodes_and_edges(graph, parent_node_id, part_type, code_parts):
|
||||
"""Add code part nodes and edges for a specific part type."""
|
||||
if not code_parts:
|
||||
logger.debug(f"No code parts to add for parent_node_id {parent_node_id} and part_type {part_type}.")
|
||||
return
|
||||
|
||||
for idx, code_part in enumerate(code_parts):
|
||||
if not code_part.strip():
|
||||
logger.warning(f"Empty code part in parent_node_id {parent_node_id} and part_type {part_type}.")
|
||||
continue
|
||||
part_node_id = f"{parent_node_id}_{part_type}_{idx}"
|
||||
graph.add_node(part_node_id, source_code=code_part, node_type=part_type)
|
||||
graph.add_edge(parent_node_id, part_node_id, relation="contains")
|
||||
|
||||
|
||||
def _process_single_node(graph, node_id, node_data):
|
||||
"""Process a single Python file node."""
|
||||
graph.nodes[node_id]["node_type"] = "python_file"
|
||||
source_code = node_data.get("source_code", "")
|
||||
|
||||
if not source_code.strip():
|
||||
logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.")
|
||||
return
|
||||
|
||||
try:
|
||||
code_parts_dict = extract_code_parts(source_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node_id}: {e}")
|
||||
return
|
||||
|
||||
for part_type, code_parts in code_parts_dict.items():
|
||||
_add_code_parts_nodes_and_edges(graph, node_id, part_type, code_parts)
|
||||
|
||||
|
||||
def expand_dependency_graph(graph: nx.MultiDiGraph) -> nx.MultiDiGraph:
|
||||
"""Process Python file nodes, adding code part nodes and edges."""
|
||||
expanded_graph = graph.copy()
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
if not node_data: # Check if node_data is empty
|
||||
logger.warning(f"Node {node_id} has no data. Skipping.")
|
||||
continue
|
||||
_process_single_node(expanded_graph, node_id, node_data)
|
||||
return expanded_graph
|
||||
59
cognee/tasks/repo_processor/extract_code_parts.py
Normal file
59
cognee/tasks/repo_processor/extract_code_parts.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Dict, List
|
||||
import parso
|
||||
|
||||
from cognee.tasks.repo_processor import logger
|
||||
|
||||
|
||||
def _extract_parts_from_module(module, parts_dict: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
||||
"""Extract code parts from a parsed module."""
|
||||
|
||||
current_top_level_code = []
|
||||
child_to_code_type = {
|
||||
'classdef': "classes",
|
||||
'funcdef': "functions",
|
||||
'import_name': "imports",
|
||||
'import_from': "imports",
|
||||
}
|
||||
|
||||
for child in module.children:
|
||||
if child.type == 'simple_stmt':
|
||||
current_top_level_code.append(child.get_code())
|
||||
continue
|
||||
|
||||
if current_top_level_code:
|
||||
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
|
||||
current_top_level_code = []
|
||||
|
||||
if child.type in child_to_code_type:
|
||||
code_type = child_to_code_type[child.type]
|
||||
parts_dict[code_type].append(child.get_code())
|
||||
|
||||
if current_top_level_code:
|
||||
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
|
||||
|
||||
if parts_dict["imports"]:
|
||||
parts_dict["imports"] = ['\n'.join(parts_dict["imports"])]
|
||||
|
||||
return parts_dict
|
||||
|
||||
|
||||
def extract_code_parts(source_code: str) -> Dict[str, List[str]]:
|
||||
"""Extract high-level parts of the source code."""
|
||||
|
||||
parts_dict = {"classes": [], "functions": [], "imports": [], "top_level_code": []}
|
||||
|
||||
if not source_code.strip():
|
||||
logger.warning("Empty source_code provided.")
|
||||
return parts_dict
|
||||
|
||||
try:
|
||||
module = parso.parse(source_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing source code: {e}")
|
||||
return parts_dict
|
||||
|
||||
if not module.children:
|
||||
logger.warning("Parsed module has no children (empty or invalid source code).")
|
||||
return parts_dict
|
||||
|
||||
return _extract_parts_from_module(module, parts_dict)
|
||||
61
cognee/tasks/repo_processor/get_repo_dependency_graph.py
Normal file
61
cognee/tasks/repo_processor/get_repo_dependency_graph.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import aiofiles
|
||||
|
||||
import networkx as nx
|
||||
from typing import Dict, List
|
||||
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
|
||||
|
||||
async def get_py_path_and_source(file_path, repo_path):
|
||||
relative_path = os.path.relpath(file_path, repo_path)
|
||||
try:
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
source_code = await f.read()
|
||||
return relative_path, source_code
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return relative_path, None
|
||||
|
||||
|
||||
async def get_py_files_dict(repo_path):
|
||||
"""Get .py files and their source code"""
|
||||
if not os.path.exists(repo_path):
|
||||
return {}
|
||||
|
||||
py_files_paths = (
|
||||
os.path.join(root, file)
|
||||
for root, _, files in os.walk(repo_path) for file in files if file.endswith(".py")
|
||||
)
|
||||
|
||||
py_files_dict = {}
|
||||
for file_path in py_files_paths:
|
||||
relative_path, source_code = await get_py_path_and_source(file_path, repo_path)
|
||||
py_files_dict[relative_path] = {"source_code": source_code}
|
||||
|
||||
return py_files_dict
|
||||
|
||||
def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bool = True) -> tuple:
|
||||
if relative_paths:
|
||||
file_path = os.path.relpath(file_path, repo_path)
|
||||
dependency = os.path.relpath(dependency, repo_path)
|
||||
return (file_path, dependency, {"relation": "depends_directly_on"})
|
||||
|
||||
|
||||
async def get_repo_dependency_graph(repo_path: str) -> nx.DiGraph:
|
||||
"""Generate a dependency graph for Python files in the given repository path."""
|
||||
py_files_dict = await get_py_files_dict(repo_path)
|
||||
|
||||
dependency_graph = nx.DiGraph()
|
||||
|
||||
dependency_graph.add_nodes_from(py_files_dict.items())
|
||||
|
||||
for file_path, metadata in py_files_dict.items():
|
||||
source_code = metadata.get("source_code")
|
||||
if source_code is None:
|
||||
continue
|
||||
|
||||
dependencies = await get_local_script_dependencies(file_path, repo_path)
|
||||
dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
|
||||
dependency_graph.add_edges_from(dependency_edges)
|
||||
return dependency_graph
|
||||
Loading…
Add table
Reference in a new issue