Extend CodeGraph pipeline for multi-language support (closes #1160)
This commit is contained in:
parent
2182f619df
commit
0a330683de
4 changed files with 95 additions and 40 deletions
|
|
@ -40,8 +40,13 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
detailed_extraction = True
|
detailed_extraction = True
|
||||||
|
|
||||||
|
|
||||||
|
# Multi-language support: allow passing supported_languages
|
||||||
|
supported_languages = [
|
||||||
|
'python', 'javascript', 'typescript', 'java', 'csharp', 'go', 'rust', 'cpp'
|
||||||
|
]
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
|
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction, supported_languages=supported_languages),
|
||||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||||
Task(add_data_points, task_config={"batch_size": 30}),
|
Task(add_data_points, task_config={"batch_size": 30}),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ class ClassDefinition(DataPoint):
|
||||||
class CodeFile(DataPoint):
|
class CodeFile(DataPoint):
|
||||||
name: str
|
name: str
|
||||||
file_path: str
|
file_path: str
|
||||||
|
language: Optional[str] = None # e.g., 'python', 'javascript', 'java', etc.
|
||||||
source_code: Optional[str] = None
|
source_code: Optional[str] = None
|
||||||
part_of: Optional[Repository] = None
|
part_of: Optional[Repository] = None
|
||||||
depends_on: Optional[List["ImportStatement"]] = []
|
depends_on: Optional[List["ImportStatement"]] = []
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,7 @@ async def get_local_script_dependencies(
|
||||||
name=file_path_relative_to_repo,
|
name=file_path_relative_to_repo,
|
||||||
source_code=source_code,
|
source_code=source_code,
|
||||||
file_path=script_path,
|
file_path=script_path,
|
||||||
|
language="python",
|
||||||
)
|
)
|
||||||
return code_file_node
|
return code_file_node
|
||||||
|
|
||||||
|
|
@ -188,6 +189,7 @@ async def get_local_script_dependencies(
|
||||||
name=file_path_relative_to_repo,
|
name=file_path_relative_to_repo,
|
||||||
source_code=None,
|
source_code=None,
|
||||||
file_path=script_path,
|
file_path=script_path,
|
||||||
|
language="python",
|
||||||
)
|
)
|
||||||
|
|
||||||
async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
|
async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
|
||||||
|
|
|
||||||
|
|
@ -12,46 +12,58 @@ from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||||
|
|
||||||
async def get_source_code_files(repo_path):
|
async def get_source_code_files(repo_path):
|
||||||
"""
|
"""
|
||||||
Retrieve Python source code files from the specified repository path.
|
Retrieve source code files from the specified repository path for multiple languages.
|
||||||
|
|
||||||
This function scans the given repository path for files that have the .py extension
|
|
||||||
while excluding test files and files within a virtual environment. It returns a list of
|
|
||||||
absolute paths to the source code files that are not empty.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
- repo_path: The file path to the repository to search for source files.
|
||||||
- repo_path: The file path to the repository to search for Python source files.
|
- language_config: dict mapping language names to file extensions, e.g.,
|
||||||
|
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
A list of (absolute_path, language) tuples for source code files.
|
||||||
A list of absolute paths to .py files that contain source code, excluding empty
|
|
||||||
files, test files, and files from a virtual environment.
|
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(repo_path):
|
def _get_language_from_extension(file, language_config):
|
||||||
return {}
|
for lang, exts in language_config.items():
|
||||||
|
for ext in exts:
|
||||||
|
if file.endswith(ext):
|
||||||
|
return lang
|
||||||
|
return None
|
||||||
|
|
||||||
py_files_paths = (
|
# Default config if not provided
|
||||||
os.path.join(root, file)
|
import inspect
|
||||||
for root, _, files in os.walk(repo_path)
|
frame = inspect.currentframe()
|
||||||
for file in files
|
args, _, _, values = inspect.getargvalues(frame)
|
||||||
if (
|
language_config = values.get('language_config', None)
|
||||||
file.endswith(".py")
|
if language_config is None:
|
||||||
and not file.startswith("test_")
|
language_config = {
|
||||||
and not file.endswith("_test")
|
'python': ['.py'],
|
||||||
and ".venv" not in file
|
'javascript': ['.js', '.jsx'],
|
||||||
)
|
'typescript': ['.ts', '.tsx'],
|
||||||
)
|
'java': ['.java'],
|
||||||
|
'csharp': ['.cs'],
|
||||||
|
'go': ['.go'],
|
||||||
|
'rust': ['.rs'],
|
||||||
|
'cpp': ['.cpp', '.c', '.h', '.hpp'],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not os.path.exists(repo_path):
|
||||||
|
return []
|
||||||
|
|
||||||
source_code_files = set()
|
source_code_files = set()
|
||||||
for file_path in py_files_paths:
|
for root, _, files in os.walk(repo_path):
|
||||||
file_path = os.path.abspath(file_path)
|
for file in files:
|
||||||
|
lang = _get_language_from_extension(file, language_config)
|
||||||
if os.path.getsize(file_path) == 0:
|
if lang is None:
|
||||||
continue
|
continue
|
||||||
|
# Exclude test files and venv for all languages
|
||||||
source_code_files.add(file_path)
|
if file.startswith("test_") or file.endswith("_test") or ".venv" in file:
|
||||||
|
continue
|
||||||
|
file_path = os.path.abspath(os.path.join(root, file))
|
||||||
|
if os.path.getsize(file_path) == 0:
|
||||||
|
continue
|
||||||
|
source_code_files.add((file_path, lang))
|
||||||
|
|
||||||
return list(source_code_files)
|
return list(source_code_files)
|
||||||
|
|
||||||
|
|
@ -85,7 +97,7 @@ def run_coroutine(coroutine_func, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
async def get_repo_file_dependencies(
|
async def get_repo_file_dependencies(
|
||||||
repo_path: str, detailed_extraction: bool = False
|
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
|
||||||
) -> AsyncGenerator[DataPoint, None]:
|
) -> AsyncGenerator[DataPoint, None]:
|
||||||
"""
|
"""
|
||||||
Generate a dependency graph for Python files in the given repository path.
|
Generate a dependency graph for Python files in the given repository path.
|
||||||
|
|
@ -106,7 +118,23 @@ async def get_repo_file_dependencies(
|
||||||
if not os.path.exists(repo_path):
|
if not os.path.exists(repo_path):
|
||||||
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
||||||
|
|
||||||
source_code_files = await get_source_code_files(repo_path)
|
# Build language config from supported_languages
|
||||||
|
default_language_config = {
|
||||||
|
'python': ['.py'],
|
||||||
|
'javascript': ['.js', '.jsx'],
|
||||||
|
'typescript': ['.ts', '.tsx'],
|
||||||
|
'java': ['.java'],
|
||||||
|
'csharp': ['.cs'],
|
||||||
|
'go': ['.go'],
|
||||||
|
'rust': ['.rs'],
|
||||||
|
'cpp': ['.cpp', '.c', '.h', '.hpp'],
|
||||||
|
}
|
||||||
|
if supported_languages is not None:
|
||||||
|
language_config = {k: v for k, v in default_language_config.items() if k in supported_languages}
|
||||||
|
else:
|
||||||
|
language_config = default_language_config
|
||||||
|
|
||||||
|
source_code_files = await get_source_code_files(repo_path, language_config=language_config)
|
||||||
|
|
||||||
repo = Repository(
|
repo = Repository(
|
||||||
id=uuid5(NAMESPACE_OID, repo_path),
|
id=uuid5(NAMESPACE_OID, repo_path),
|
||||||
|
|
@ -125,19 +153,38 @@ async def get_repo_file_dependencies(
|
||||||
for chunk_number in range(number_of_chunks)
|
for chunk_number in range(number_of_chunks)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Codegraph dependencies are not installed by default, so we import where we use them.
|
# Import dependency extractors for each language (Python for now, extend later)
|
||||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||||
|
# TODO: Add other language extractors here
|
||||||
|
|
||||||
for start_range, end_range in chunk_ranges:
|
for start_range, end_range in chunk_ranges:
|
||||||
# with ProcessPoolExecutor(max_workers=12) as executor:
|
tasks = []
|
||||||
tasks = [
|
for file_path, lang in source_code_files[start_range : end_range + 1]:
|
||||||
get_local_script_dependencies(repo_path, file_path, detailed_extraction)
|
# For now, only Python is supported; extend with other languages
|
||||||
for file_path in source_code_files[start_range : end_range + 1]
|
if lang == 'python':
|
||||||
]
|
tasks.append(get_local_script_dependencies(repo_path, file_path, detailed_extraction))
|
||||||
|
else:
|
||||||
|
# Placeholder: create a minimal CodeFile for other languages
|
||||||
|
from cognee.shared.CodeGraphEntities import CodeFile
|
||||||
|
import aiofiles
|
||||||
|
async def make_codefile_stub():
|
||||||
|
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
source = await f.read()
|
||||||
|
return CodeFile(
|
||||||
|
id=uuid5(NAMESPACE_OID, file_path),
|
||||||
|
name=os.path.relpath(file_path, repo_path),
|
||||||
|
file_path=file_path,
|
||||||
|
language=lang,
|
||||||
|
source_code=source,
|
||||||
|
)
|
||||||
|
tasks.append(make_codefile_stub())
|
||||||
|
|
||||||
results: list[CodeFile] = await asyncio.gather(*tasks)
|
results: list[CodeFile] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
for source_code_file in results:
|
for source_code_file in results:
|
||||||
source_code_file.part_of = repo
|
source_code_file.part_of = repo
|
||||||
|
if not hasattr(source_code_file, 'language') or source_code_file.language is None:
|
||||||
|
# Set language for python files if not set
|
||||||
|
if source_code_file.file_path.endswith('.py'):
|
||||||
|
source_code_file.language = 'python'
|
||||||
yield source_code_file
|
yield source_code_file
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue