Merge remote-tracking branch 'origin/code-graph'

This commit is contained in:
Boris Arzentar 2024-12-01 11:52:03 +01:00
commit 925346986e

View file

@ -2,9 +2,9 @@ import os
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
import aiofiles
from tqdm.asyncio import tqdm
from concurrent.futures import ProcessPoolExecutor
import asyncio
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
def run_coroutine(coroutine_func, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
repo = Repository(
id = uuid5(NAMESPACE_OID, repo_path),
path = repo_path,
id=uuid5(NAMESPACE_OID, repo_path),
path=repo_path,
)
# data_points = [repo]
yield repo
# dependency_graph = nx.DiGraph()
with ProcessPoolExecutor(max_workers=12) as executor:
loop = asyncio.get_event_loop()
# dependency_graph.add_nodes_from(py_files_dict.items())
tasks = [
loop.run_in_executor(
executor,
run_coroutine,
get_local_script_dependencies,
os.path.join(repo_path, file_path),
repo_path
)
for file_path, metadata in py_files_dict.items()
if metadata.get("source_code") is not None
]
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"):
source_code = metadata.get("source_code")
if source_code is None:
continue
results = await asyncio.gather(*tasks)
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code")
# data_points.append()
yield CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
part_of = repo,
depends_on = [
CodeFile(
id = uuid5(NAMESPACE_OID, dependency),
extracted_id = dependency,
part_of = repo,
) for dependency in dependencies
] if len(dependencies) else None,
)
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
# dependency_graph.add_edges_from(dependency_edges)
# return data_points
# return dependency_graph
yield CodeFile(
id=uuid5(NAMESPACE_OID, file_path),
source_code=source_code,
extracted_id=file_path,
part_of=repo,
depends_on=[
CodeFile(
id=uuid5(NAMESPACE_OID, dependency),
extracted_id=dependency,
part_of=repo,
) for dependency in dependencies
] if dependencies else None,
)