Merge remote-tracking branch 'origin/code-graph'
This commit is contained in:
commit
925346986e
1 changed files with 41 additions and 33 deletions
|
|
@ -2,9 +2,9 @@ import os
|
|||
from typing import AsyncGenerator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import aiofiles
|
||||
from tqdm.asyncio import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import asyncio
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
|
||||
|
|
@ -45,46 +45,54 @@ def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bo
|
|||
return (file_path, dependency, {"relation": "depends_directly_on"})
|
||||
|
||||
|
||||
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list[DataPoint], None]:
|
||||
def run_coroutine(coroutine_func, *args, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
|
||||
"""Generate a dependency graph for Python files in the given repository path."""
|
||||
py_files_dict = await get_py_files_dict(repo_path)
|
||||
|
||||
repo = Repository(
|
||||
id = uuid5(NAMESPACE_OID, repo_path),
|
||||
path = repo_path,
|
||||
id=uuid5(NAMESPACE_OID, repo_path),
|
||||
path=repo_path,
|
||||
)
|
||||
|
||||
# data_points = [repo]
|
||||
yield repo
|
||||
|
||||
# dependency_graph = nx.DiGraph()
|
||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# dependency_graph.add_nodes_from(py_files_dict.items())
|
||||
tasks = [
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
run_coroutine,
|
||||
get_local_script_dependencies,
|
||||
os.path.join(repo_path, file_path),
|
||||
repo_path
|
||||
)
|
||||
for file_path, metadata in py_files_dict.items()
|
||||
if metadata.get("source_code") is not None
|
||||
]
|
||||
|
||||
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"):
|
||||
source_code = metadata.get("source_code")
|
||||
if source_code is None:
|
||||
continue
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
|
||||
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
|
||||
source_code = metadata.get("source_code")
|
||||
|
||||
# data_points.append()
|
||||
yield CodeFile(
|
||||
id = uuid5(NAMESPACE_OID, file_path),
|
||||
source_code = source_code,
|
||||
extracted_id = file_path,
|
||||
part_of = repo,
|
||||
depends_on = [
|
||||
CodeFile(
|
||||
id = uuid5(NAMESPACE_OID, dependency),
|
||||
extracted_id = dependency,
|
||||
part_of = repo,
|
||||
) for dependency in dependencies
|
||||
] if len(dependencies) else None,
|
||||
)
|
||||
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
|
||||
|
||||
# dependency_graph.add_edges_from(dependency_edges)
|
||||
|
||||
# return data_points
|
||||
# return dependency_graph
|
||||
yield CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, file_path),
|
||||
source_code=source_code,
|
||||
extracted_id=file_path,
|
||||
part_of=repo,
|
||||
depends_on=[
|
||||
CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, dependency),
|
||||
extracted_id=dependency,
|
||||
part_of=repo,
|
||||
) for dependency in dependencies
|
||||
] if dependencies else None,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue