diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 9ac4e9f2e..746721f1f 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -2,9 +2,9 @@ import os from typing import AsyncGenerator from uuid import NAMESPACE_OID, uuid5 import aiofiles -from tqdm.asyncio import tqdm +from concurrent.futures import ProcessPoolExecutor +import asyncio -from cognee.infrastructure.engine import DataPoint from cognee.shared.CodeGraphEntities import CodeFile, Repository from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies @@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo return (file_path, dependency, {"relation": "depends_directly_on"}) -async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]: +def run_coroutine(coroutine_func, *args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result = loop.run_until_complete(coroutine_func(*args, **kwargs)) + loop.close() + return result + +async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]: """Generate a dependency graph for Python files in the given repository path.""" py_files_dict = await get_py_files_dict(repo_path) repo = Repository( - id = uuid5(NAMESPACE_OID, repo_path), - path = repo_path, + id=uuid5(NAMESPACE_OID, repo_path), + path=repo_path, ) - # data_points = [repo] yield repo - # dependency_graph = nx.DiGraph() + with ProcessPoolExecutor(max_workers=12) as executor: + loop = asyncio.get_event_loop() - # dependency_graph.add_nodes_from(py_files_dict.items()) + tasks = [ + loop.run_in_executor( + executor, + run_coroutine, + get_local_script_dependencies, + os.path.join(repo_path, file_path), + repo_path + ) + for file_path, metadata in py_files_dict.items() + if metadata.get("source_code") is not None + ] - async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"): - source_code = metadata.get("source_code") - if source_code is None: - continue + results = await asyncio.gather(*tasks) - dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path) + for (file_path, metadata), dependencies in zip(py_files_dict.items(), results): + source_code = metadata.get("source_code") - # data_points.append() - yield CodeFile( - id = uuid5(NAMESPACE_OID, file_path), - source_code = source_code, - extracted_id = file_path, - part_of = repo, - depends_on = [ - CodeFile( - id = uuid5(NAMESPACE_OID, dependency), - extracted_id = dependency, - part_of = repo, - ) for dependency in dependencies - ] if len(dependencies) else None, - ) - # dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies] - - # dependency_graph.add_edges_from(dependency_edges) - - # return data_points - # return dependency_graph + yield CodeFile( + id=uuid5(NAMESPACE_OID, file_path), + source_code=source_code, + extracted_id=file_path, + part_of=repo, + depends_on=[ + CodeFile( + id=uuid5(NAMESPACE_OID, dependency), + extracted_id=dependency, + part_of=repo, + ) for dependency in dependencies + ] if dependencies else None, + )