diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 19d194b26..fb3612857 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,6 +1,7 @@ import os import pathlib import asyncio +from typing import Optional from cognee.shared.logging_utils import get_logger, setup_logging from cognee.modules.observability.get_observe import get_observe @@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline") @observe -async def run_code_graph_pipeline(repo_path, include_docs=False): +async def run_code_graph_pipeline( + repo_path, + include_docs=False, + excluded_paths: Optional[list[str]] = None, + supported_languages: Optional[list[str]] = None, +): import cognee from cognee.low_level import setup @@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False): user = await get_default_user() detailed_extraction = True - # Multi-language support: allow passing supported_languages - supported_languages = None # defer to task defaults tasks = [ Task( get_repo_file_dependencies, detailed_extraction=detailed_extraction, supported_languages=supported_languages, + excluded_paths=excluded_paths, ), # Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete Task(add_data_points, task_config={"batch_size": 30}), @@ -95,7 +100,7 @@ if __name__ == "__main__": async def main(): async for run_status in run_code_graph_pipeline("REPO_PATH"): - print(f"{run_status.pipeline_name}: {run_status.status}") + print(f"{run_status.pipeline_run_id}: {run_status.status}") file_path = os.path.join( pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html" diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index 6e819d8a7..76b5e758c 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever): {"id": res.id, "score": res.score, "payload": res.payload} ) + existing_collection = [] for collection in self.classes_and_functions_collections: + if await vector_engine.has_collection(collection): + existing_collection.append(collection) + + if not existing_collection: + raise RuntimeError("No collection found for code retriever") + + for collection in existing_collection: logger.debug(f"Searching {collection} collection with general query") search_results_code = await vector_engine.search( collection, query, limit=self.top_k diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 4ff79523f..06cc3bddb 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -1,24 +1,48 @@ import asyncio import math import os - -# from concurrent.futures import ProcessPoolExecutor -from typing import AsyncGenerator +from pathlib import Path +from typing import Set +from typing import AsyncGenerator, Optional, List from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.shared.CodeGraphEntities import CodeFile, Repository +# constant, declared only once +EXCLUDED_DIRS: Set[str] = { + ".venv", + "venv", + "env", + ".env", + "site-packages", + "node_modules", + "dist", + "build", + ".git", + "tests", + "test", +} -async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None): + +async def get_source_code_files( + repo_path, + language_config: dict[str, list[str]] | None = None, + excluded_paths: Optional[List[str]] = None, +): """ - Retrieve source code files from the specified repository path for multiple languages. + Retrieve Python source code files from the specified repository path. + + This function scans the given repository path for files that have the .py extension + while excluding test files and files within a virtual environment. It returns a list of + absolute paths to the source code files that are not empty. Parameters: ----------- - - repo_path: The file path to the repository to search for source files. - - language_config: dict mapping language names to file extensions, e.g., + - repo_path: Root path of the repository to search + - language_config: dict mapping language names to file extensions, e.g., {'python': ['.py'], 'javascript': ['.js', '.jsx'], ...} + - excluded_paths: Optional list of path fragments or glob patterns to exclude Returns: -------- @@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]] lang = _get_language_from_extension(file, language_config) if lang is None: continue - # Exclude tests and common build/venv directories - excluded_dirs = { - ".venv", - "venv", - "env", - ".env", - "site-packages", - "node_modules", - "dist", - "build", - ".git", - "tests", - "test", - } - root_parts = set(os.path.normpath(root).split(os.sep)) + # Exclude tests, common build/venv directories and files provided in exclude_paths + excluded_dirs = EXCLUDED_DIRS + excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths + + root_path = Path(root).resolve() + root_parts = set(root_path.parts) # same as before base_name, _ext = os.path.splitext(file) if ( base_name.startswith("test_") - or base_name.endswith("_test") # catches Go's *_test.go and similar + or base_name.endswith("_test") or ".test." in file or ".spec." in file - or (excluded_dirs & root_parts) + or (excluded_dirs & root_parts) # name match + or any( + root_path.is_relative_to(p) # full-path match + for p in excluded_paths + ) ): continue file_path = os.path.abspath(os.path.join(root, file)) @@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs): async def get_repo_file_dependencies( - repo_path: str, detailed_extraction: bool = False, supported_languages: list = None + repo_path: str, + detailed_extraction: bool = False, + supported_languages: list = None, + excluded_paths: Optional[List[str]] = None, ) -> AsyncGenerator[DataPoint, None]: """ Generate a dependency graph for source files (multi-language) in the given repository path. @@ -150,6 +172,7 @@ async def get_repo_file_dependencies( "go": [".go"], "rust": [".rs"], "cpp": [".cpp", ".c", ".h", ".hpp"], + "c": [".c", ".h"], } if supported_languages is not None: language_config = { @@ -158,7 +181,9 @@ async def get_repo_file_dependencies( else: language_config = default_language_config - source_code_files = await get_source_code_files(repo_path, language_config=language_config) + source_code_files = await get_source_code_files( + repo_path, language_config=language_config, excluded_paths=excluded_paths + ) repo = Repository( id=uuid5(NAMESPACE_OID, repo_path),