feat: Cog-693 expand dependency graph

Expand each file node into a subgraph containing high-level code parts

- Implemented `extract_code_parts` to parse and extract high-level
components (classes, functions, imports, and top-level code) from Python
source files using `parso`.
- Developed `expand_dependency_graph` to expand Python file nodes into
their components.
- Included a checker script
This commit is contained in:
lxobr 2024-11-23 14:02:21 +01:00 committed by GitHub
parent d33c740dc6
commit 7ec5cffd8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 308 additions and 0 deletions

View file

@ -0,0 +1,27 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_dependency_graph(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,29 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_dependency_graph(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
graph = expand_dependency_graph(graph)
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,27 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_dependency_graph(repo_path))
for node in graph.nodes:
print(f"Node: {node}")
edges = graph.edges(node, data=True)
for _, target, data in edges:
print(f" Edge to {target}, Relation: {data.get('relation')}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,56 @@
import networkx as nx
from typing import Dict, List
def topologically_sort_subgraph(subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on a subgraph based on node indegrees."""
results = []
remaining_nodes = subgraph_node_to_indegree.copy()
while remaining_nodes:
next_node = min(remaining_nodes, key=remaining_nodes.get)
results.append(next_node)
for successor in graph.successors(next_node):
if successor in remaining_nodes:
remaining_nodes[successor] -= 1
remaining_nodes.pop(next_node)
return results
def topologically_sort(graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on the entire graph."""
subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph))
topological_order = []
for subgraph in subgraphs:
node_to_indegree = {
node: len(list(subgraph.successors(node)))
for node in subgraph.nodes
}
topological_order.extend(
topologically_sort_subgraph(node_to_indegree, subgraph)
)
return topological_order
def node_enrich_and_connect(graph: nx.MultiDiGraph, topological_order: List[str], node: str) -> None:
"""Adds 'depends_on' edges to the graph based on topological order."""
topological_rank = topological_order.index(node)
graph.nodes[node]['topological_rank'] = topological_rank
node_descendants = nx.descendants(graph, node)
if graph.has_edge(node,node):
node_descendants.add(node)
for desc in node_descendants:
if desc not in topological_order[:topological_rank+1]:
continue
graph.add_edge(node, desc, relation='depends_on')
async def enrich_dependency_graph(graph: nx.DiGraph) -> nx.MultiDiGraph:
"""Enriches the graph with topological ranks and 'depends_on' edges."""
graph = nx.MultiDiGraph(graph)
topological_order = topologically_sort(graph)
node_rank_map = {node: idx for idx, node in enumerate(topological_order)}
for node in graph.nodes:
if node not in node_rank_map:
continue
node_enrich_and_connect(graph, topological_order, node)
return graph

View file

@ -0,0 +1,49 @@
import networkx as nx
from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts
from cognee.tasks.repo_processor import logger
def _add_code_parts_nodes_and_edges(graph, parent_node_id, part_type, code_parts):
"""Add code part nodes and edges for a specific part type."""
if not code_parts:
logger.debug(f"No code parts to add for parent_node_id {parent_node_id} and part_type {part_type}.")
return
for idx, code_part in enumerate(code_parts):
if not code_part.strip():
logger.warning(f"Empty code part in parent_node_id {parent_node_id} and part_type {part_type}.")
continue
part_node_id = f"{parent_node_id}_{part_type}_{idx}"
graph.add_node(part_node_id, source_code=code_part, node_type=part_type)
graph.add_edge(parent_node_id, part_node_id, relation="contains")
def _process_single_node(graph, node_id, node_data):
"""Process a single Python file node."""
graph.nodes[node_id]["node_type"] = "python_file"
source_code = node_data.get("source_code", "")
if not source_code.strip():
logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.")
return
try:
code_parts_dict = extract_code_parts(source_code)
except Exception as e:
logger.error(f"Error processing node {node_id}: {e}")
return
for part_type, code_parts in code_parts_dict.items():
_add_code_parts_nodes_and_edges(graph, node_id, part_type, code_parts)
def expand_dependency_graph(graph: nx.MultiDiGraph) -> nx.MultiDiGraph:
"""Process Python file nodes, adding code part nodes and edges."""
expanded_graph = graph.copy()
for node_id, node_data in graph.nodes(data=True):
if not node_data: # Check if node_data is empty
logger.warning(f"Node {node_id} has no data. Skipping.")
continue
_process_single_node(expanded_graph, node_id, node_data)
return expanded_graph

View file

@ -0,0 +1,59 @@
from typing import Dict, List
import parso
from cognee.tasks.repo_processor import logger
def _extract_parts_from_module(module, parts_dict: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Extract code parts from a parsed module."""
current_top_level_code = []
child_to_code_type = {
'classdef': "classes",
'funcdef': "functions",
'import_name': "imports",
'import_from': "imports",
}
for child in module.children:
if child.type == 'simple_stmt':
current_top_level_code.append(child.get_code())
continue
if current_top_level_code:
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
current_top_level_code = []
if child.type in child_to_code_type:
code_type = child_to_code_type[child.type]
parts_dict[code_type].append(child.get_code())
if current_top_level_code:
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
if parts_dict["imports"]:
parts_dict["imports"] = ['\n'.join(parts_dict["imports"])]
return parts_dict
def extract_code_parts(source_code: str) -> Dict[str, List[str]]:
"""Extract high-level parts of the source code."""
parts_dict = {"classes": [], "functions": [], "imports": [], "top_level_code": []}
if not source_code.strip():
logger.warning("Empty source_code provided.")
return parts_dict
try:
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return parts_dict
if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return parts_dict
return _extract_parts_from_module(module, parts_dict)

View file

@ -0,0 +1,61 @@
import os
import aiofiles
import networkx as nx
from typing import Dict, List
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
async def get_py_path_and_source(file_path, repo_path):
relative_path = os.path.relpath(file_path, repo_path)
try:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source_code = await f.read()
return relative_path, source_code
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return relative_path, None
async def get_py_files_dict(repo_path):
"""Get .py files and their source code"""
if not os.path.exists(repo_path):
return {}
py_files_paths = (
os.path.join(root, file)
for root, _, files in os.walk(repo_path) for file in files if file.endswith(".py")
)
py_files_dict = {}
for file_path in py_files_paths:
relative_path, source_code = await get_py_path_and_source(file_path, repo_path)
py_files_dict[relative_path] = {"source_code": source_code}
return py_files_dict
def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bool = True) -> tuple:
if relative_paths:
file_path = os.path.relpath(file_path, repo_path)
dependency = os.path.relpath(dependency, repo_path)
return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_dependency_graph(repo_path: str) -> nx.DiGraph:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
dependency_graph = nx.DiGraph()
dependency_graph.add_nodes_from(py_files_dict.items())
for file_path, metadata in py_files_dict.items():
source_code = metadata.get("source_code")
if source_code is None:
continue
dependencies = await get_local_script_dependencies(file_path, repo_path)
dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
dependency_graph.add_edges_from(dependency_edges)
return dependency_graph