feat: Configurable path exclusion code graph (#1218)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-08-29 17:32:40 +02:00 committed by GitHub
commit 00646f108c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 68 additions and 30 deletions

View file

@ -1,6 +1,7 @@
import os import os
import pathlib import pathlib
import asyncio import asyncio
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline")
@observe @observe
async def run_code_graph_pipeline(repo_path, include_docs=False): async def run_code_graph_pipeline(
repo_path,
include_docs=False,
excluded_paths: Optional[list[str]] = None,
supported_languages: Optional[list[str]] = None,
):
import cognee import cognee
from cognee.low_level import setup from cognee.low_level import setup
@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
user = await get_default_user() user = await get_default_user()
detailed_extraction = True detailed_extraction = True
# Multi-language support: allow passing supported_languages
supported_languages = None # defer to task defaults
tasks = [ tasks = [
Task( Task(
get_repo_file_dependencies, get_repo_file_dependencies,
detailed_extraction=detailed_extraction, detailed_extraction=detailed_extraction,
supported_languages=supported_languages, supported_languages=supported_languages,
excluded_paths=excluded_paths,
), ),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete # Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 30}), Task(add_data_points, task_config={"batch_size": 30}),
@ -95,7 +100,7 @@ if __name__ == "__main__":
async def main(): async def main():
async for run_status in run_code_graph_pipeline("REPO_PATH"): async for run_status in run_code_graph_pipeline("REPO_PATH"):
print(f"{run_status.pipeline_name}: {run_status.status}") print(f"{run_status.pipeline_run_id}: {run_status.status}")
file_path = os.path.join( file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html" pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"

View file

@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
existing_collection = []
for collection in self.classes_and_functions_collections: for collection in self.classes_and_functions_collections:
if await vector_engine.has_collection(collection):
existing_collection.append(collection)
if not existing_collection:
raise RuntimeError("No collection found for code retriever")
for collection in existing_collection:
logger.debug(f"Searching {collection} collection with general query") logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search( search_results_code = await vector_engine.search(
collection, query, limit=self.top_k collection, query, limit=self.top_k

View file

@ -1,24 +1,48 @@
import asyncio import asyncio
import math import math
import os import os
from pathlib import Path
# from concurrent.futures import ProcessPoolExecutor from typing import Set
from typing import AsyncGenerator from typing import AsyncGenerator, Optional, List
from uuid import NAMESPACE_OID, uuid5 from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository from cognee.shared.CodeGraphEntities import CodeFile, Repository
# constant, declared only once
EXCLUDED_DIRS: Set[str] = {
".venv",
"venv",
"env",
".env",
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
async def get_source_code_files(
repo_path,
language_config: dict[str, list[str]] | None = None,
excluded_paths: Optional[List[str]] = None,
):
""" """
Retrieve source code files from the specified repository path for multiple languages. Retrieve Python source code files from the specified repository path.
This function scans the given repository path for files that have the .py extension
while excluding test files and files within a virtual environment. It returns a list of
absolute paths to the source code files that are not empty.
Parameters: Parameters:
----------- -----------
- repo_path: The file path to the repository to search for source files. - repo_path: Root path of the repository to search
- language_config: dict mapping language names to file extensions, e.g., - language_config: dict mapping language names to file extensions, e.g.,
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...} {'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
- excluded_paths: Optional list of path fragments or glob patterns to exclude
Returns: Returns:
-------- --------
@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
lang = _get_language_from_extension(file, language_config) lang = _get_language_from_extension(file, language_config)
if lang is None: if lang is None:
continue continue
# Exclude tests and common build/venv directories # Exclude tests, common build/venv directories and files provided in exclude_paths
excluded_dirs = { excluded_dirs = EXCLUDED_DIRS
".venv", excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
"venv",
"env", root_path = Path(root).resolve()
".env", root_parts = set(root_path.parts) # same as before
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
root_parts = set(os.path.normpath(root).split(os.sep))
base_name, _ext = os.path.splitext(file) base_name, _ext = os.path.splitext(file)
if ( if (
base_name.startswith("test_") base_name.startswith("test_")
or base_name.endswith("_test") # catches Go's *_test.go and similar or base_name.endswith("_test")
or ".test." in file or ".test." in file
or ".spec." in file or ".spec." in file
or (excluded_dirs & root_parts) or (excluded_dirs & root_parts) # name match
or any(
root_path.is_relative_to(p) # full-path match
for p in excluded_paths
)
): ):
continue continue
file_path = os.path.abspath(os.path.join(root, file)) file_path = os.path.abspath(os.path.join(root, file))
@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
async def get_repo_file_dependencies( async def get_repo_file_dependencies(
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None repo_path: str,
detailed_extraction: bool = False,
supported_languages: list = None,
excluded_paths: Optional[List[str]] = None,
) -> AsyncGenerator[DataPoint, None]: ) -> AsyncGenerator[DataPoint, None]:
""" """
Generate a dependency graph for source files (multi-language) in the given repository path. Generate a dependency graph for source files (multi-language) in the given repository path.
@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
"go": [".go"], "go": [".go"],
"rust": [".rs"], "rust": [".rs"],
"cpp": [".cpp", ".c", ".h", ".hpp"], "cpp": [".cpp", ".c", ".h", ".hpp"],
"c": [".c", ".h"],
} }
if supported_languages is not None: if supported_languages is not None:
language_config = { language_config = {
@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
else: else:
language_config = default_language_config language_config = default_language_config
source_code_files = await get_source_code_files(repo_path, language_config=language_config) source_code_files = await get_source_code_files(
repo_path, language_config=language_config, excluded_paths=excluded_paths
)
repo = Repository( repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path), id=uuid5(NAMESPACE_OID, repo_path),