cognee/cognee/tasks/repo_processor/get_repo_file_dependencies.py
Fardeen Malik fdb0c8292a
Update cognee/tasks/repo_processor/get_repo_file_dependencies.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-08-18 22:04:47 +05:30

200 lines
7.3 KiB
Python

import asyncio
import math
import os
# from concurrent.futures import ProcessPoolExecutor
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
"""
Retrieve source code files from the specified repository path for multiple languages.
Parameters:
-----------
- repo_path: The file path to the repository to search for source files.
- language_config: dict mapping language names to file extensions, e.g.,
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
Returns:
--------
A list of (absolute_path, language) tuples for source code files.
"""
def _get_language_from_extension(file, language_config):
for lang, exts in language_config.items():
for ext in exts:
if file.endswith(ext):
return lang
return None
# Default config if not provided
if language_config is None:
language_config = {
'python': ['.py'],
'javascript': ['.js', '.jsx'],
'typescript': ['.ts', '.tsx'],
'java': ['.java'],
'csharp': ['.cs'],
'go': ['.go'],
'rust': ['.rs'],
'cpp': ['.cpp', '.c', '.h', '.hpp'],
}
if not os.path.exists(repo_path):
return []
source_code_files = set()
for root, _, files in os.walk(repo_path):
for file in files:
lang = _get_language_from_extension(file, language_config)
if lang is None:
continue
# Exclude tests and common build/venv directories
excluded_dirs = {
".venv", "venv", "env", ".env", "site-packages",
"node_modules", "dist", "build", ".git",
"tests", "test",
}
root_parts = set(os.path.normpath(root).split(os.sep))
base_name, _ext = os.path.splitext(file)
if (
base_name.startswith("test_")
or base_name.endswith("_test") # catches Go's *_test.go and similar
or ".test." in file
or ".spec." in file
or (excluded_dirs & root_parts)
):
continue
file_path = os.path.abspath(os.path.join(root, file))
if os.path.getsize(file_path) == 0:
continue
source_code_files.add((file_path, lang))
return sorted(list(source_code_files))
def run_coroutine(coroutine_func, *args, **kwargs):
"""
Run a coroutine function until it completes.
This function creates a new asyncio event loop, sets it as the current loop, and
executes the given coroutine function with the provided arguments. Once the coroutine
completes, the loop is closed. Intended for use in environments where an existing event
loop is not available or desirable.
Parameters:
-----------
- coroutine_func: The coroutine function to be run.
- *args: Positional arguments to pass to the coroutine function.
- **kwargs: Keyword arguments to pass to the coroutine function.
Returns:
--------
The result returned by the coroutine after completion.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
) -> AsyncGenerator[DataPoint, None]:
"""
Generate a dependency graph for source files (multi-language) in the given repository path.
Check the validity of the repository path and yield a repository object followed by the
dependencies of source files within that repository. Raise a FileNotFoundError if the
provided path does not exist. The extraction of detailed dependencies can be controlled
via the `detailed_extraction` argument. Languages considered can be restricted via
the `supported_languages` argument.
Parameters:
-----------
- repo_path (str): The file path to the repository to process.
- detailed_extraction (bool): Whether to perform a detailed extraction of code parts.
- supported_languages (list | None): Subset of languages to include; if None, use defaults.
"""
if isinstance(repo_path, list) and len(repo_path) == 1:
repo_path = repo_path[0]
if not os.path.exists(repo_path):
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
# Build language config from supported_languages
default_language_config = {
'python': ['.py'],
'javascript': ['.js', '.jsx'],
'typescript': ['.ts', '.tsx'],
'java': ['.java'],
'csharp': ['.cs'],
'go': ['.go'],
'rust': ['.rs'],
'cpp': ['.cpp', '.c', '.h', '.hpp'],
}
if supported_languages is not None:
language_config = {k: v for k, v in default_language_config.items() if k in supported_languages}
else:
language_config = default_language_config
source_code_files = await get_source_code_files(repo_path, language_config=language_config)
repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path),
path=repo_path,
)
yield repo
chunk_size = 100
number_of_chunks = math.ceil(len(source_code_files) / chunk_size)
chunk_ranges = [
(
chunk_number * chunk_size,
min((chunk_number + 1) * chunk_size, len(source_code_files)) - 1,
)
for chunk_number in range(number_of_chunks)
]
# Import dependency extractors for each language (Python for now, extend later)
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
import aiofiles
# TODO: Add other language extractors here
for start_range, end_range in chunk_ranges:
tasks = []
for file_path, lang in source_code_files[start_range : end_range + 1]:
# For now, only Python is supported; extend with other languages
if lang == 'python':
tasks.append(get_local_script_dependencies(repo_path, file_path, detailed_extraction))
else:
# Placeholder: create a minimal CodeFile for other languages
async def make_codefile_stub(file_path=file_path, lang=lang):
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source = await f.read()
return CodeFile(
id=uuid5(NAMESPACE_OID, file_path),
name=os.path.relpath(file_path, repo_path),
file_path=file_path,
language=lang,
source_code=source,
)
tasks.append(make_codefile_stub())
results: list[CodeFile] = await asyncio.gather(*tasks)
for source_code_file in results:
source_code_file.part_of = repo
if (getattr(source_code_file, 'language', None) is None and source_code_file.file_path.endswith('.py')):
source_code_file.language = 'python'
yield source_code_file