feat: Configurable path exclusion code graph (#1218)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-08-29 17:32:40 +02:00 committed by GitHub
commit 00646f108c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 68 additions and 30 deletions

View file

@ -1,6 +1,7 @@
import os
import pathlib
import asyncio
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.modules.observability.get_observe import get_observe
@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline")
@observe
async def run_code_graph_pipeline(repo_path, include_docs=False):
async def run_code_graph_pipeline(
repo_path,
include_docs=False,
excluded_paths: Optional[list[str]] = None,
supported_languages: Optional[list[str]] = None,
):
import cognee
from cognee.low_level import setup
@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
user = await get_default_user()
detailed_extraction = True
# Multi-language support: allow passing supported_languages
supported_languages = None # defer to task defaults
tasks = [
Task(
get_repo_file_dependencies,
detailed_extraction=detailed_extraction,
supported_languages=supported_languages,
excluded_paths=excluded_paths,
),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 30}),
@ -95,7 +100,7 @@ if __name__ == "__main__":
async def main():
async for run_status in run_code_graph_pipeline("REPO_PATH"):
print(f"{run_status.pipeline_name}: {run_status.status}")
print(f"{run_status.pipeline_run_id}: {run_status.status}")
file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"

View file

@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
{"id": res.id, "score": res.score, "payload": res.payload}
)
existing_collection = []
for collection in self.classes_and_functions_collections:
if await vector_engine.has_collection(collection):
existing_collection.append(collection)
if not existing_collection:
raise RuntimeError("No collection found for code retriever")
for collection in existing_collection:
logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search(
collection, query, limit=self.top_k

View file

@ -1,24 +1,48 @@
import asyncio
import math
import os
# from concurrent.futures import ProcessPoolExecutor
from typing import AsyncGenerator
from pathlib import Path
from typing import Set
from typing import AsyncGenerator, Optional, List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository
# constant, declared only once
EXCLUDED_DIRS: Set[str] = {
".venv",
"venv",
"env",
".env",
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
async def get_source_code_files(
repo_path,
language_config: dict[str, list[str]] | None = None,
excluded_paths: Optional[List[str]] = None,
):
"""
Retrieve source code files from the specified repository path for multiple languages.
Retrieve Python source code files from the specified repository path.
This function scans the given repository path for files that have the .py extension
while excluding test files and files within a virtual environment. It returns a list of
absolute paths to the source code files that are not empty.
Parameters:
-----------
- repo_path: The file path to the repository to search for source files.
- language_config: dict mapping language names to file extensions, e.g.,
- repo_path: Root path of the repository to search
- language_config: dict mapping language names to file extensions, e.g.,
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
- excluded_paths: Optional list of path fragments or glob patterns to exclude
Returns:
--------
@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
lang = _get_language_from_extension(file, language_config)
if lang is None:
continue
# Exclude tests and common build/venv directories
excluded_dirs = {
".venv",
"venv",
"env",
".env",
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
root_parts = set(os.path.normpath(root).split(os.sep))
# Exclude tests, common build/venv directories and files provided in exclude_paths
excluded_dirs = EXCLUDED_DIRS
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
root_path = Path(root).resolve()
root_parts = set(root_path.parts) # same as before
base_name, _ext = os.path.splitext(file)
if (
base_name.startswith("test_")
or base_name.endswith("_test") # catches Go's *_test.go and similar
or base_name.endswith("_test")
or ".test." in file
or ".spec." in file
or (excluded_dirs & root_parts)
or (excluded_dirs & root_parts) # name match
or any(
root_path.is_relative_to(p) # full-path match
for p in excluded_paths
)
):
continue
file_path = os.path.abspath(os.path.join(root, file))
@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
async def get_repo_file_dependencies(
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
repo_path: str,
detailed_extraction: bool = False,
supported_languages: list = None,
excluded_paths: Optional[List[str]] = None,
) -> AsyncGenerator[DataPoint, None]:
"""
Generate a dependency graph for source files (multi-language) in the given repository path.
@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
"go": [".go"],
"rust": [".rs"],
"cpp": [".cpp", ".c", ".h", ".hpp"],
"c": [".c", ".h"],
}
if supported_languages is not None:
language_config = {
@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
else:
language_config = default_language_config
source_code_files = await get_source_code_files(repo_path, language_config=language_config)
source_code_files = await get_source_code_files(
repo_path, language_config=language_config, excluded_paths=excluded_paths
)
repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path),