feat: Implements multiprocessing for get_repo_file_dependencies task (#43)
This commit is contained in:
parent
bbaf78f54e
commit
198f71b9be
1 changed files with 41 additions and 33 deletions
|
|
@ -2,9 +2,9 @@ import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from tqdm.asyncio import tqdm
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
|
||||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||||
|
|
||||||
|
|
@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
|
||||||
return (file_path, dependency, {"relation": "depends_directly_on"})
|
return (file_path, dependency, {"relation": "depends_directly_on"})
|
||||||
|
|
||||||
|
|
||||||
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
|
def run_coroutine(coroutine_func, *args, **kwargs):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||||
|
loop.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
|
||||||
"""Generate a dependency graph for Python files in the given repository path."""
|
"""Generate a dependency graph for Python files in the given repository path."""
|
||||||
py_files_dict = await get_py_files_dict(repo_path)
|
py_files_dict = await get_py_files_dict(repo_path)
|
||||||
|
|
||||||
repo = Repository(
|
repo = Repository(
|
||||||
id = uuid5(NAMESPACE_OID, repo_path),
|
id=uuid5(NAMESPACE_OID, repo_path),
|
||||||
path = repo_path,
|
path=repo_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# data_points = [repo]
|
|
||||||
yield repo
|
yield repo
|
||||||
|
|
||||||
# dependency_graph = nx.DiGraph()
|
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# dependency_graph.add_nodes_from(py_files_dict.items())
|
tasks = [
|
||||||
|
loop.run_in_executor(
|
||||||
|
executor,
|
||||||
|
run_coroutine,
|
||||||
|
get_local_script_dependencies,
|
||||||
|
os.path.join(repo_path, file_path),
|
||||||
|
repo_path
|
||||||
|
)
|
||||||
|
for file_path, metadata in py_files_dict.items()
|
||||||
|
if metadata.get("source_code") is not None
|
||||||
|
]
|
||||||
|
|
||||||
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"):
|
results = await asyncio.gather(*tasks)
|
||||||
source_code = metadata.get("source_code")
|
|
||||||
if source_code is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
|
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
|
||||||
|
source_code = metadata.get("source_code")
|
||||||
|
|
||||||
# data_points.append()
|
yield CodeFile(
|
||||||
yield CodeFile(
|
id=uuid5(NAMESPACE_OID, file_path),
|
||||||
id = uuid5(NAMESPACE_OID, file_path),
|
source_code=source_code,
|
||||||
source_code = source_code,
|
extracted_id=file_path,
|
||||||
extracted_id = file_path,
|
part_of=repo,
|
||||||
part_of = repo,
|
depends_on=[
|
||||||
depends_on = [
|
CodeFile(
|
||||||
CodeFile(
|
id=uuid5(NAMESPACE_OID, dependency),
|
||||||
id = uuid5(NAMESPACE_OID, dependency),
|
extracted_id=dependency,
|
||||||
extracted_id = dependency,
|
part_of=repo,
|
||||||
part_of = repo,
|
) for dependency in dependencies
|
||||||
) for dependency in dependencies
|
] if dependencies else None,
|
||||||
] if len(dependencies) else None,
|
)
|
||||||
)
|
|
||||||
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
|
|
||||||
|
|
||||||
# dependency_graph.add_edges_from(dependency_edges)
|
|
||||||
|
|
||||||
# return data_points
|
|
||||||
# return dependency_graph
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue