feat: Configurable path exclusion code graph (#1218)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
00646f108c
3 changed files with 68 additions and 30 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
|
|
@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline")
|
||||||
|
|
||||||
|
|
||||||
@observe
|
@observe
|
||||||
async def run_code_graph_pipeline(repo_path, include_docs=False):
|
async def run_code_graph_pipeline(
|
||||||
|
repo_path,
|
||||||
|
include_docs=False,
|
||||||
|
excluded_paths: Optional[list[str]] = None,
|
||||||
|
supported_languages: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup
|
from cognee.low_level import setup
|
||||||
|
|
||||||
|
|
@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
detailed_extraction = True
|
detailed_extraction = True
|
||||||
|
|
||||||
# Multi-language support: allow passing supported_languages
|
|
||||||
supported_languages = None # defer to task defaults
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(
|
Task(
|
||||||
get_repo_file_dependencies,
|
get_repo_file_dependencies,
|
||||||
detailed_extraction=detailed_extraction,
|
detailed_extraction=detailed_extraction,
|
||||||
supported_languages=supported_languages,
|
supported_languages=supported_languages,
|
||||||
|
excluded_paths=excluded_paths,
|
||||||
),
|
),
|
||||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||||
Task(add_data_points, task_config={"batch_size": 30}),
|
Task(add_data_points, task_config={"batch_size": 30}),
|
||||||
|
|
@ -95,7 +100,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||||
print(f"{run_status.pipeline_name}: {run_status.status}")
|
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||||
|
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
|
||||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
existing_collection = []
|
||||||
for collection in self.classes_and_functions_collections:
|
for collection in self.classes_and_functions_collections:
|
||||||
|
if await vector_engine.has_collection(collection):
|
||||||
|
existing_collection.append(collection)
|
||||||
|
|
||||||
|
if not existing_collection:
|
||||||
|
raise RuntimeError("No collection found for code retriever")
|
||||||
|
|
||||||
|
for collection in existing_collection:
|
||||||
logger.debug(f"Searching {collection} collection with general query")
|
logger.debug(f"Searching {collection} collection with general query")
|
||||||
search_results_code = await vector_engine.search(
|
search_results_code = await vector_engine.search(
|
||||||
collection, query, limit=self.top_k
|
collection, query, limit=self.top_k
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,48 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
# from concurrent.futures import ProcessPoolExecutor
|
from typing import Set
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Optional, List
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||||
|
|
||||||
|
# constant, declared only once
|
||||||
|
EXCLUDED_DIRS: Set[str] = {
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
".env",
|
||||||
|
"site-packages",
|
||||||
|
"node_modules",
|
||||||
|
"dist",
|
||||||
|
"build",
|
||||||
|
".git",
|
||||||
|
"tests",
|
||||||
|
"test",
|
||||||
|
}
|
||||||
|
|
||||||
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
|
|
||||||
|
async def get_source_code_files(
|
||||||
|
repo_path,
|
||||||
|
language_config: dict[str, list[str]] | None = None,
|
||||||
|
excluded_paths: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve source code files from the specified repository path for multiple languages.
|
Retrieve Python source code files from the specified repository path.
|
||||||
|
|
||||||
|
This function scans the given repository path for files that have the .py extension
|
||||||
|
while excluding test files and files within a virtual environment. It returns a list of
|
||||||
|
absolute paths to the source code files that are not empty.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
- repo_path: The file path to the repository to search for source files.
|
- repo_path: Root path of the repository to search
|
||||||
- language_config: dict mapping language names to file extensions, e.g.,
|
- language_config: dict mapping language names to file extensions, e.g.,
|
||||||
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
||||||
|
- excluded_paths: Optional list of path fragments or glob patterns to exclude
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
|
||||||
lang = _get_language_from_extension(file, language_config)
|
lang = _get_language_from_extension(file, language_config)
|
||||||
if lang is None:
|
if lang is None:
|
||||||
continue
|
continue
|
||||||
# Exclude tests and common build/venv directories
|
# Exclude tests, common build/venv directories and files provided in exclude_paths
|
||||||
excluded_dirs = {
|
excluded_dirs = EXCLUDED_DIRS
|
||||||
".venv",
|
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
|
||||||
"venv",
|
|
||||||
"env",
|
root_path = Path(root).resolve()
|
||||||
".env",
|
root_parts = set(root_path.parts) # same as before
|
||||||
"site-packages",
|
|
||||||
"node_modules",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".git",
|
|
||||||
"tests",
|
|
||||||
"test",
|
|
||||||
}
|
|
||||||
root_parts = set(os.path.normpath(root).split(os.sep))
|
|
||||||
base_name, _ext = os.path.splitext(file)
|
base_name, _ext = os.path.splitext(file)
|
||||||
if (
|
if (
|
||||||
base_name.startswith("test_")
|
base_name.startswith("test_")
|
||||||
or base_name.endswith("_test") # catches Go's *_test.go and similar
|
or base_name.endswith("_test")
|
||||||
or ".test." in file
|
or ".test." in file
|
||||||
or ".spec." in file
|
or ".spec." in file
|
||||||
or (excluded_dirs & root_parts)
|
or (excluded_dirs & root_parts) # name match
|
||||||
|
or any(
|
||||||
|
root_path.is_relative_to(p) # full-path match
|
||||||
|
for p in excluded_paths
|
||||||
|
)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
file_path = os.path.abspath(os.path.join(root, file))
|
file_path = os.path.abspath(os.path.join(root, file))
|
||||||
|
|
@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
async def get_repo_file_dependencies(
|
async def get_repo_file_dependencies(
|
||||||
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
|
repo_path: str,
|
||||||
|
detailed_extraction: bool = False,
|
||||||
|
supported_languages: list = None,
|
||||||
|
excluded_paths: Optional[List[str]] = None,
|
||||||
) -> AsyncGenerator[DataPoint, None]:
|
) -> AsyncGenerator[DataPoint, None]:
|
||||||
"""
|
"""
|
||||||
Generate a dependency graph for source files (multi-language) in the given repository path.
|
Generate a dependency graph for source files (multi-language) in the given repository path.
|
||||||
|
|
@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
|
||||||
"go": [".go"],
|
"go": [".go"],
|
||||||
"rust": [".rs"],
|
"rust": [".rs"],
|
||||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||||
|
"c": [".c", ".h"],
|
||||||
}
|
}
|
||||||
if supported_languages is not None:
|
if supported_languages is not None:
|
||||||
language_config = {
|
language_config = {
|
||||||
|
|
@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
|
||||||
else:
|
else:
|
||||||
language_config = default_language_config
|
language_config = default_language_config
|
||||||
|
|
||||||
source_code_files = await get_source_code_files(repo_path, language_config=language_config)
|
source_code_files = await get_source_code_files(
|
||||||
|
repo_path, language_config=language_config, excluded_paths=excluded_paths
|
||||||
|
)
|
||||||
|
|
||||||
repo = Repository(
|
repo = Repository(
|
||||||
id=uuid5(NAMESPACE_OID, repo_path),
|
id=uuid5(NAMESPACE_OID, repo_path),
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue