Merge remote-tracking branch 'origin/code-graph'

This commit is contained in:
Boris Arzentar 2024-12-01 11:52:03 +01:00
commit 925346986e

View file

@ -2,9 +2,9 @@ import os
from typing import AsyncGenerator from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
import aiofiles import aiofiles
from tqdm.asyncio import tqdm from concurrent.futures import ProcessPoolExecutor
import asyncio
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository from cognee.shared.CodeGraphEntities import CodeFile, Repository
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
return (file_path, dependency, {"relation": "depends_directly_on"}) return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]: def run_coroutine(coroutine_func, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
"""Generate a dependency graph for Python files in the given repository path.""" """Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path) py_files_dict = await get_py_files_dict(repo_path)
repo = Repository( repo = Repository(
id = uuid5(NAMESPACE_OID, repo_path), id=uuid5(NAMESPACE_OID, repo_path),
path = repo_path, path=repo_path,
) )
# data_points = [repo]
yield repo yield repo
# dependency_graph = nx.DiGraph() with ProcessPoolExecutor(max_workers=12) as executor:
loop = asyncio.get_event_loop()
# dependency_graph.add_nodes_from(py_files_dict.items()) tasks = [
loop.run_in_executor(
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"): executor,
source_code = metadata.get("source_code") run_coroutine,
if source_code is None: get_local_script_dependencies,
continue os.path.join(repo_path, file_path),
repo_path
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
# data_points.append()
yield CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
part_of = repo,
depends_on = [
CodeFile(
id = uuid5(NAMESPACE_OID, dependency),
extracted_id = dependency,
part_of = repo,
) for dependency in dependencies
] if len(dependencies) else None,
) )
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies] for file_path, metadata in py_files_dict.items()
if metadata.get("source_code") is not None
]
# dependency_graph.add_edges_from(dependency_edges) results = await asyncio.gather(*tasks)
# return data_points for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
# return dependency_graph source_code = metadata.get("source_code")
yield CodeFile(
id=uuid5(NAMESPACE_OID, file_path),
source_code=source_code,
extracted_id=file_path,
part_of=repo,
depends_on=[
CodeFile(
id=uuid5(NAMESPACE_OID, dependency),
extracted_id=dependency,
part_of=repo,
) for dependency in dependencies
] if dependencies else None,
)