feat: Implements multiprocessing for get_repo_file_dependencies task (#43)

This commit is contained in:
hajdul88 2024-12-01 11:51:04 +01:00 committed by GitHub
parent bbaf78f54e
commit 198f71b9be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,9 +2,9 @@ import os
from typing import AsyncGenerator from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
import aiofiles import aiofiles
from tqdm.asyncio import tqdm from concurrent.futures import ProcessPoolExecutor
import asyncio
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository from cognee.shared.CodeGraphEntities import CodeFile, Repository
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
return (file_path, dependency, {"relation": "depends_directly_on"}) return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]: def run_coroutine(coroutine_func, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
"""Generate a dependency graph for Python files in the given repository path.""" """Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path) py_files_dict = await get_py_files_dict(repo_path)
repo = Repository( repo = Repository(
id = uuid5(NAMESPACE_OID, repo_path), id=uuid5(NAMESPACE_OID, repo_path),
path = repo_path, path=repo_path,
) )
# data_points = [repo]
yield repo yield repo
# dependency_graph = nx.DiGraph() with ProcessPoolExecutor(max_workers=12) as executor:
loop = asyncio.get_event_loop()
# dependency_graph.add_nodes_from(py_files_dict.items()) tasks = [
loop.run_in_executor(
executor,
run_coroutine,
get_local_script_dependencies,
os.path.join(repo_path, file_path),
repo_path
)
for file_path, metadata in py_files_dict.items()
if metadata.get("source_code") is not None
]
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"): results = await asyncio.gather(*tasks)
source_code = metadata.get("source_code")
if source_code is None:
continue
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path) for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code")
# data_points.append() yield CodeFile(
yield CodeFile( id=uuid5(NAMESPACE_OID, file_path),
id = uuid5(NAMESPACE_OID, file_path), source_code=source_code,
source_code = source_code, extracted_id=file_path,
extracted_id = file_path, part_of=repo,
part_of = repo, depends_on=[
depends_on = [ CodeFile(
CodeFile( id=uuid5(NAMESPACE_OID, dependency),
id = uuid5(NAMESPACE_OID, dependency), extracted_id=dependency,
extracted_id = dependency, part_of=repo,
part_of = repo, ) for dependency in dependencies
) for dependency in dependencies ] if dependencies else None,
] if len(dependencies) else None, )
)
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
# dependency_graph.add_edges_from(dependency_edges)
# return data_points
# return dependency_graph