refactor: move codify pipeline out of main repo (#1738)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
This PR removes codify, and the code graph pipeline, out of the
repository. It also introduces a Custom Pipeline interface, which can be
used in the future to define custom pipelines.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-12-04 23:10:39 -08:00 committed by GitHub
commit 9571641199
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 26 additions and 1617 deletions

View file

@ -197,33 +197,3 @@ jobs:
- name: Run Simple Examples
run: uv run python ./examples/python/simple_example.py
graph-tests:
name: Run Basic Graph Tests
runs-on: ubuntu-22.04
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}
- name: Run Graph Tests
run: uv run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph

View file

@ -90,97 +90,6 @@ async def health_check(request):
return JSONResponse({"status": "ok"})
@mcp.tool()
async def cognee_add_developer_rules(
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
) -> list:
"""
Ingest core developer rule files into Cognee's memory layer.
This function loads a predefined set of developer-related configuration,
rule, and documentation files from the base repository and assigns them
to the special 'developer_rules' node set in Cognee. It ensures these
foundational files are always part of the structured memory graph.
Parameters
----------
base_path : str
Root path to resolve relative file paths. Defaults to current directory.
graph_model_file : str, optional
Optional path to a custom schema file for knowledge graph generation.
graph_model_name : str, optional
Optional class name to use from the graph_model_file schema.
Returns
-------
list
A message indicating how many rule files were scheduled for ingestion,
and how to check their processing status.
Notes
-----
- Each file is processed asynchronously in the background.
- Files are attached to the 'developer_rules' node set.
- Missing files are skipped with a logged warning.
"""
developer_rule_paths = [
".cursorrules",
".cursor/rules",
".same/todos.md",
".windsurfrules",
".clinerules",
"CLAUDE.md",
".sourcegraph/memory.md",
"AGENT.md",
"AGENTS.md",
]
async def cognify_task(file_path: str) -> None:
with redirect_stdout(sys.stderr):
logger.info(f"Starting cognify for: {file_path}")
try:
await cognee_client.add(file_path, node_set=["developer_rules"])
model = None
if graph_model_file and graph_model_name:
if cognee_client.use_api:
logger.warning(
"Custom graph models are not supported in API mode, ignoring."
)
else:
from cognee.shared.data_models import KnowledgeGraph
model = load_class(graph_model_file, graph_model_name)
await cognee_client.cognify(graph_model=model)
logger.info(f"Cognify finished for: {file_path}")
except Exception as e:
logger.error(f"Cognify failed for {file_path}: {str(e)}")
raise ValueError(f"Failed to cognify: {str(e)}")
tasks = []
for rel_path in developer_rule_paths:
abs_path = os.path.join(base_path, rel_path)
if os.path.isfile(abs_path):
tasks.append(asyncio.create_task(cognify_task(abs_path)))
else:
logger.warning(f"Skipped missing developer rule file: {abs_path}")
log_file = get_log_file_location()
return [
types.TextContent(
type="text",
text=(
f"Started cognify for {len(tasks)} developer rule files in background.\n"
f"All are added to the `developer_rules` node set.\n"
f"Use `cognify_status` or check logs at {log_file} to monitor progress."
),
)
]
@mcp.tool()
async def cognify(
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
@ -406,75 +315,6 @@ async def save_interaction(data: str) -> list:
]
@mcp.tool()
async def codify(repo_path: str) -> list:
"""
Analyze and generate a code-specific knowledge graph from a software repository.
This function launches a background task that processes the provided repository
and builds a code knowledge graph. The function returns immediately while
the processing continues in the background due to MCP timeout constraints.
Parameters
----------
repo_path : str
Path to the code repository to analyze. This can be a local file path or a
relative path to a repository. The path should point to the root of the
repository or a specific directory within it.
Returns
-------
list
A list containing a single TextContent object with information about the
background task launch and how to check its status.
Notes
-----
- The function launches a background task and returns immediately
- The code graph generation may take significant time for larger repositories
- Use the codify_status tool to check the progress of the operation
- Process results are logged to the standard Cognee log file
- All stdout is redirected to stderr to maintain MCP communication integrity
"""
if cognee_client.use_api:
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
async def codify_task(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
logger.info("Codify process starting.")
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
results = []
async for result in run_code_graph_pipeline(repo_path, False):
results.append(result)
logger.info(result)
if all(results):
logger.info("Codify process finished succesfully.")
else:
logger.info("Codify process failed.")
asyncio.create_task(codify_task(repo_path))
log_file = get_log_file_location()
text = (
f"Background process launched due to MCP timeout limitations.\n"
f"To check current codify status use the codify_status tool\n"
f"or you can check the log file at: {log_file}"
)
return [
types.TextContent(
type="text",
text=text,
)
]
@mcp.tool()
async def search(search_query: str, search_type: str) -> list:
"""
@ -629,45 +469,6 @@ async def search(search_query: str, search_type: str) -> list:
return [types.TextContent(type="text", text=search_results)]
@mcp.tool()
async def get_developer_rules() -> list:
"""
Retrieve all developer rules that were generated based on previous interactions.
This tool queries the Cognee knowledge graph and returns a list of developer
rules.
Parameters
----------
None
Returns
-------
list
A list containing a single TextContent object with the retrieved developer rules.
The format is plain text containing the developer rules in bulletpoints.
Notes
-----
- The specific logic for fetching rules is handled internally.
- This tool does not accept any parameters and is intended for simple rule inspection use cases.
"""
async def fetch_rules_from_cognee() -> str:
"""Collect all developer rules from Cognee"""
with redirect_stdout(sys.stderr):
if cognee_client.use_api:
logger.warning("Developer rules retrieval is not available in API mode")
return "Developer rules retrieval is not available in API mode"
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
return developer_rules
rules_text = await fetch_rules_from_cognee()
return [types.TextContent(type="text", text=rules_text)]
@mcp.tool()
async def list_data(dataset_id: str = None) -> list:
"""
@ -953,48 +754,6 @@ async def cognify_status():
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool()
async def codify_status():
"""
Get the current status of the codify pipeline.
This function retrieves information about current and recently completed codify operations
in the codebase dataset. It provides details on progress, success/failure status, and statistics
about the processed code repositories.
Returns
-------
list
A list containing a single TextContent object with the status information as a string.
The status includes information about active and completed jobs for the cognify_code_pipeline.
Notes
-----
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
- Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
try:
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
user = await get_default_user()
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get codify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
def node_to_string(node):
node_data = ", ".join(
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]

View file

@ -21,7 +21,7 @@ from cognee.api.v1.notebooks.routers import get_notebooks_router
from cognee.api.v1.permissions.routers import get_permissions_router
from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
from cognee.api.v1.cognify.routers import get_cognify_router
from cognee.api.v1.search.routers import get_search_router
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
from cognee.api.v1.memify.routers import get_memify_router
@ -278,10 +278,6 @@ app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["re
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
codegraph_routes = get_code_pipeline_router()
if codegraph_routes:
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
app.include_router(
get_users_router(),
prefix="/api/v1/users",

View file

@ -1,119 +0,0 @@
import os
import pathlib
import asyncio
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.modules.observability.get_observe import get_observe
from cognee.api.v1.search import SearchType, search
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.data.methods import create_dataset
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.infrastructure.llm import get_max_chunk_tokens
from cognee.infrastructure.databases.relational import get_relational_engine
observe = get_observe()
logger = get_logger("code_graph_pipeline")
@observe
async def run_code_graph_pipeline(
repo_path,
include_docs=False,
excluded_paths: Optional[list[str]] = None,
supported_languages: Optional[list[str]] = None,
):
import cognee
from cognee.low_level import setup
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
cognee_config = get_cognify_config()
user = await get_default_user()
detailed_extraction = True
tasks = [
Task(
get_repo_file_dependencies,
detailed_extraction=detailed_extraction,
supported_languages=supported_languages,
excluded_paths=excluded_paths,
),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 30}),
]
if include_docs:
# This tasks take a long time to complete
non_code_tasks = [
Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
Task(
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config={"batch_size": 50},
),
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 50},
),
]
dataset_name = "codebase"
# Save dataset to database
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user, session)
if include_docs:
non_code_pipeline_run = run_tasks(
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
)
async for run_status in non_code_pipeline_run:
yield run_status
async for run_status in run_tasks(
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
):
yield run_status
if __name__ == "__main__":
async def main():
async for run_status in run_code_graph_pipeline("REPO_PATH"):
print(f"{run_status.pipeline_run_id}: {run_status.status}")
file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
)
await visualize_graph(file_path)
search_results = await search(
query_type=SearchType.CODE,
query_text="How is Relationship weight calculated?",
)
for file in search_results:
print(file["name"])
logger = setup_logging(name="code_graph_pipeline")
asyncio.run(main())

View file

@ -1,2 +1 @@
from .get_cognify_router import get_cognify_router
from .get_code_pipeline_router import get_code_pipeline_router

View file

@ -1,90 +0,0 @@
import json
from cognee.shared.logging_utils import get_logger
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from cognee.api.DTO import InDTO
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.storage.utils import JSONEncoder
logger = get_logger()
class CodePipelineIndexPayloadDTO(InDTO):
repo_path: str
include_docs: bool = False
class CodePipelineRetrievePayloadDTO(InDTO):
query: str
full_input: str
def get_code_pipeline_router() -> APIRouter:
try:
import cognee.api.v1.cognify.code_graph_pipeline
except ModuleNotFoundError:
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
return None
router = APIRouter()
@router.post("/index", response_model=None)
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
"""
Run indexation on a code repository.
This endpoint processes a code repository to create a knowledge graph
of the codebase structure, dependencies, and relationships.
## Request Parameters
- **repo_path** (str): Path to the code repository
- **include_docs** (bool): Whether to include documentation files (default: false)
## Response
No content returned. Processing results are logged.
## Error Codes
- **409 Conflict**: Error during indexation process
"""
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
try:
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
logger.info(result)
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})
@router.post("/retrieve", response_model=list[dict])
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
"""
Retrieve context from the code knowledge graph.
This endpoint searches the indexed code repository to find relevant
context based on the provided query.
## Request Parameters
- **query** (str): Search query for code context
- **full_input** (str): Full input text for processing
## Response
Returns a list of relevant code files and context as JSON.
## Error Codes
- **409 Conflict**: Error during retrieval process
"""
try:
query = (
payload.full_input.replace("cognee ", "")
if payload.full_input.startswith("cognee ")
else payload.full_input
)
retriever = CodeRetriever()
retrieved_files = await retriever.get_context(query)
return json.dumps(retrieved_files, cls=JSONEncoder)
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -1 +1 @@
from cognee.modules.retrieval.code_retriever import CodeRetriever

View file

@ -1,232 +0,0 @@
from typing import Any, Optional, List
import asyncio
import aiofiles
from pydantic import BaseModel
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
logger = get_logger("CodeRetriever")
class CodeRetriever(BaseRetriever):
"""Retriever for handling code-based searches."""
class CodeQueryInfo(BaseModel):
"""
Model for representing the result of a query related to code files.
This class holds a list of filenames and the corresponding source code extracted from a
query. It is used to encapsulate response data in a structured format.
"""
filenames: List[str] = []
sourcecode: str
def __init__(self, top_k: int = 3):
"""Initialize retriever with search parameters."""
self.top_k = top_k
self.file_name_collections = ["CodeFile_name"]
self.classes_and_functions_collections = [
"ClassDefinition_source_code",
"FunctionDefinition_source_code",
]
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
"""Process the query using LLM to extract file names and source code parts."""
logger.debug(
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
try:
result = await LLMGateway.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=self.CodeQueryInfo,
)
logger.info(
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
)
return result
except Exception as e:
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
raise RuntimeError("Failed to retrieve structured output from LLM") from e
async def get_context(self, query: str) -> Any:
"""Find relevant code files based on the query."""
logger.info(
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if not query or not isinstance(query, str):
logger.error("Invalid query: must be a non-empty string")
raise ValueError("The query must be a non-empty string.")
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
logger.debug("Successfully initialized vector and graph engines")
except Exception as e:
logger.error(f"Database initialization error: {str(e)}")
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
files_and_codeparts = await self._process_query(query)
similar_filenames = []
similar_codepieces = []
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
logger.info("No specific files/code extracted from query, performing general search")
for collection in self.file_name_collections:
logger.debug(f"Searching {collection} collection with general query")
search_results_file = await vector_engine.search(
collection, query, limit=self.top_k
)
logger.debug(f"Found {len(search_results_file)} results in {collection}")
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
existing_collection = []
for collection in self.classes_and_functions_collections:
if await vector_engine.has_collection(collection):
existing_collection.append(collection)
if not existing_collection:
raise RuntimeError("No collection found for code retriever")
for collection in existing_collection:
logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search(
collection, query, limit=self.top_k
)
logger.debug(f"Found {len(search_results_code)} results in {collection}")
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
logger.info(
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
)
for collection in self.file_name_collections:
for file_from_query in files_and_codeparts.filenames:
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.top_k
)
logger.debug(
f"Found {len(search_results_file)} results for file {file_from_query}"
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in self.classes_and_functions_collections:
logger.debug(f"Searching {collection} with extracted source code")
search_results_code = await vector_engine.search(
collection, files_and_codeparts.sourcecode, limit=self.top_k
)
logger.debug(f"Found {len(search_results_code)} results for source code search")
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
total_items = len(similar_filenames) + len(similar_codepieces)
logger.info(
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
)
if total_items == 0:
logger.warning("No search results found, returning empty list")
return []
logger.debug("Getting graph connections for all search results")
relevant_triplets = await asyncio.gather(
*[
graph_engine.get_connections(similar_piece["id"])
for similar_piece in similar_filenames + similar_codepieces
]
)
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
paths = set()
for i, sublist in enumerate(relevant_triplets):
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]:
paths.add(tpl[0]["file_path"])
if "file_path" in tpl[2]:
paths.add(tpl[2]["file_path"])
logger.info(f"Found {len(paths)} unique file paths to read")
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
logger.debug(f"Reading file: {fp}")
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
content = await f.read()
retrieved_files[fp] = content
logger.debug(f"Successfully read {len(content)} characters from {fp}")
except Exception as e:
logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)
logger.info(
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
)
result = [
{
"name": file_path,
"description": file_path,
"content": retrieved_files[file_path],
}
for file_path in paths
]
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
) -> Any:
"""
Returns the code files context.
Parameters:
-----------
- query (str): The query string to retrieve code context for.
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
Returns:
--------
- Any: The code files context, either provided or retrieved.
"""
if context is None:
context = await self.get_context(query)
return context

View file

@ -0,0 +1,10 @@
from typing import Type
from .base_retriever import BaseRetriever
from .registered_community_retrievers import registered_community_retrievers
from ..search.types import SearchType
def use_retriever(search_type: SearchType, retriever: Type[BaseRetriever]):
"""Register a retriever class for a given search type."""
registered_community_retrievers[search_type] = retriever

View file

@ -0,0 +1 @@
registered_community_retrievers = {}

View file

@ -23,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
@ -162,10 +161,6 @@ async def get_search_type_tools(
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
CodeRetriever(top_k=top_k).get_completion,
CodeRetriever(top_k=top_k).get_context,
],
SearchType.CYPHER: [
CypherSearchRetriever().get_completion,
CypherSearchRetriever().get_context,
@ -208,6 +203,18 @@ async def get_search_type_tools(
):
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
from cognee.modules.retrieval.registered_community_retrievers import (
registered_community_retrievers,
)
if query_type in registered_community_retrievers:
retriever = registered_community_retrievers[query_type]
retriever_instance = retriever(top_k=top_k)
search_type_tools = [
retriever_instance.get_completion,
retriever_instance.get_context,
]
else:
search_type_tools = search_tasks.get(query_type)
if not search_type_tools:

View file

@ -8,7 +8,6 @@ class SearchType(Enum):
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
CODE = "CODE"
CYPHER = "CYPHER"
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"

View file

@ -1,35 +0,0 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
def main():
"""
Execute the main logic of the dependency graph processor.
This function sets up argument parsing to retrieve the repository path, checks the
existence of the specified path, and processes the repository to produce a dependency
graph. If the repository path does not exist, it logs an error message and terminates
without further execution.
"""
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -1,20 +0,0 @@
import argparse
import asyncio
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get local script dependencies.")
# Suggested path: .../cognee/examples/python/simple_example.py
parser.add_argument("script_path", type=str, help="Absolute path to the Python script file")
# Suggested path: .../cognee
parser.add_argument("repo_path", type=str, help="Absolute path to the repository root")
args = parser.parse_args()
dependencies = asyncio.run(get_local_script_dependencies(args.script_path, args.repo_path))
print("Dependencies:")
for dependency in dependencies:
print(dependency)

View file

@ -1,35 +0,0 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
def main():
"""
Parse the command line arguments and print the repository file dependencies.
This function sets up an argument parser to retrieve the path of a repository. It checks
if the provided path exists and if it doesnt, it prints an error message and exits. If
the path is valid, it calls an asynchronous function to get the dependencies and prints
the nodes and their relations in the dependency graph.
"""
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
for node in graph.nodes:
print(f"Node: {node}")
edges = graph.edges(node, data=True)
for _, target, data in edges:
print(f" Edge to {target}, Relation: {data.get('relation')}")
if __name__ == "__main__":
main()

View file

@ -1,2 +0,0 @@
from .get_non_code_files import get_non_py_files
from .get_repo_file_dependencies import get_repo_file_dependencies

View file

@ -1,335 +0,0 @@
import os
import aiofiles
import importlib
from typing import AsyncGenerator, Optional
from uuid import NAMESPACE_OID, uuid5
import tree_sitter_python as tspython
from tree_sitter import Language, Node, Parser, Tree
from cognee.shared.logging_utils import get_logger
from cognee.low_level import DataPoint
from cognee.shared.CodeGraphEntities import (
CodeFile,
ImportStatement,
FunctionDefinition,
ClassDefinition,
)
logger = get_logger()
class FileParser:
"""
Handles the parsing of files into source code and an abstract syntax tree
representation. Public methods include:
- parse_file: Parses a file and returns its source code and syntax tree representation.
"""
def __init__(self):
self.parsed_files = {}
async def parse_file(self, file_path: str) -> tuple[str, Tree]:
"""
Parse a file and return its source code along with its syntax tree representation.
If the file has already been parsed, retrieve the result from memory instead of reading
the file again.
Parameters:
-----------
- file_path (str): The path of the file to parse.
Returns:
--------
- tuple[str, Tree]: A tuple containing the source code of the file and its
corresponding syntax tree representation.
"""
PY_LANGUAGE = Language(tspython.language())
source_code_parser = Parser(PY_LANGUAGE)
if file_path not in self.parsed_files:
source_code = await get_source_code(file_path)
source_code_tree = source_code_parser.parse(bytes(source_code, "utf-8"))
self.parsed_files[file_path] = (source_code, source_code_tree)
return self.parsed_files[file_path]
async def get_source_code(file_path: str):
"""
Read source code from a file asynchronously.
This function attempts to open a file specified by the given file path, read its
contents, and return the source code. In case of any errors during the file reading
process, it logs an error message and returns None.
Parameters:
-----------
- file_path (str): The path to the file from which to read the source code.
Returns:
--------
Returns the contents of the file as a string if successful, or None if an error
occurs.
"""
try:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source_code = await f.read()
return source_code
except Exception as error:
logger.error(f"Error reading file {file_path}: {str(error)}")
return None
def resolve_module_path(module_name):
"""
Find the file path of a module.
Return the file path of the specified module if found, or return None if the module does
not exist or cannot be located.
Parameters:
-----------
- module_name: The name of the module whose file path is to be resolved.
Returns:
--------
The file path of the module as a string or None if the module is not found.
"""
try:
spec = importlib.util.find_spec(module_name)
if spec and spec.origin:
return spec.origin
except ModuleNotFoundError:
return None
return None
def find_function_location(
module_path: str, function_name: str, parser: FileParser
) -> Optional[tuple[str, str]]:
"""
Find the location of a function definition in a specified module.
Parameters:
-----------
- module_path (str): The path to the module where the function is defined.
- function_name (str): The name of the function whose location is to be found.
- parser (FileParser): An instance of FileParser used to parse the module's source
code.
Returns:
--------
- Optional[tuple[str, str]]: Returns a tuple containing the module path and the
start point of the function if found; otherwise, returns None.
"""
if not module_path or not os.path.exists(module_path):
return None
source_code, tree = parser.parse_file(module_path)
root_node: Node = tree.root_node
for node in root_node.children:
if node.type == "function_definition":
func_name_node = node.child_by_field_name("name")
if func_name_node and func_name_node.text.decode() == function_name:
return (module_path, node.start_point) # (line, column)
return None
async def get_local_script_dependencies(
repo_path: str, script_path: str, detailed_extraction: bool = False
) -> CodeFile:
"""
Retrieve local script dependencies and create a CodeFile object.
Parameters:
-----------
- repo_path (str): The path to the repository that contains the script.
- script_path (str): The path of the script for which dependencies are being
extracted.
- detailed_extraction (bool): A flag indicating whether to perform a detailed
extraction of code components.
Returns:
--------
- CodeFile: Returns a CodeFile object containing information about the script,
including its dependencies and definitions.
"""
code_file_parser = FileParser()
source_code, source_code_tree = await code_file_parser.parse_file(script_path)
file_path_relative_to_repo = script_path[len(repo_path) + 1 :]
if not detailed_extraction:
code_file_node = CodeFile(
id=uuid5(NAMESPACE_OID, script_path),
name=file_path_relative_to_repo,
source_code=source_code,
file_path=script_path,
language="python",
)
return code_file_node
code_file_node = CodeFile(
id=uuid5(NAMESPACE_OID, script_path),
name=file_path_relative_to_repo,
source_code=None,
file_path=script_path,
language="python",
)
async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
part.file_path = script_path
if isinstance(part, FunctionDefinition):
code_file_node.provides_function_definition.append(part)
if isinstance(part, ClassDefinition):
code_file_node.provides_class_definition.append(part)
if isinstance(part, ImportStatement):
code_file_node.depends_on.append(part)
return code_file_node
def find_node(nodes: list[Node], condition: callable) -> Node:
"""
Find and return the first node that satisfies the given condition.
Iterate through the provided list of nodes and return the first node for which the
condition callable returns True. If no such node is found, return None.
Parameters:
-----------
- nodes (list[Node]): A list of Node objects to search through.
- condition (callable): A callable that takes a Node and returns a boolean
indicating if the node meets specified criteria.
Returns:
--------
- Node: The first Node that matches the condition, or None if no such node exists.
"""
for node in nodes:
if condition(node):
return node
return None
async def extract_code_parts(
tree_root: Node, script_path: str, existing_nodes: list[DataPoint] = {}
) -> AsyncGenerator[DataPoint, None]:
"""
Extract code parts from a given AST node tree asynchronously.
Iteratively yields DataPoint nodes representing import statements, function definitions,
and class definitions found in the children of the specified tree root. The function
checks
if nodes are already present in the existing_nodes dictionary to prevent duplicates.
This function has to be used in an asynchronous context, and it requires a valid
tree_root
and proper initialization of existing_nodes.
Parameters:
-----------
- tree_root (Node): The root node of the AST tree containing code parts to extract.
- script_path (str): The file path of the script from which the AST was generated.
- existing_nodes (list[DataPoint]): A dictionary that holds already extracted
DataPoint nodes to avoid duplicates. (default {})
Returns:
--------
Yields DataPoint nodes representing imported modules, functions, and classes.
"""
for child_node in tree_root.children:
if child_node.type == "import_statement" or child_node.type == "import_from_statement":
parts = child_node.text.decode("utf-8").split()
if parts[0] == "import":
module_name = parts[1]
function_name = None
elif parts[0] == "from":
module_name = parts[1]
function_name = parts[3]
if " as " in function_name:
function_name = function_name.split(" as ")[0]
if " as " in module_name:
module_name = module_name.split(" as ")[0]
if function_name and "import " + function_name not in existing_nodes:
import_statement_node = ImportStatement(
name=function_name,
module=module_name,
start_point=child_node.start_point,
end_point=child_node.end_point,
file_path=script_path,
source_code=child_node.text,
)
existing_nodes["import " + function_name] = import_statement_node
if function_name:
yield existing_nodes["import " + function_name]
if module_name not in existing_nodes:
import_statement_node = ImportStatement(
name=module_name,
module=module_name,
start_point=child_node.start_point,
end_point=child_node.end_point,
file_path=script_path,
source_code=child_node.text,
)
existing_nodes[module_name] = import_statement_node
yield existing_nodes[module_name]
if child_node.type == "function_definition":
function_node = find_node(child_node.children, lambda node: node.type == "identifier")
function_node_name = function_node.text
if function_node_name not in existing_nodes:
function_definition_node = FunctionDefinition(
name=function_node_name,
start_point=child_node.start_point,
end_point=child_node.end_point,
file_path=script_path,
source_code=child_node.text,
)
existing_nodes[function_node_name] = function_definition_node
yield existing_nodes[function_node_name]
if child_node.type == "class_definition":
class_name_node = find_node(child_node.children, lambda node: node.type == "identifier")
class_name_node_name = class_name_node.text
if class_name_node_name not in existing_nodes:
class_definition_node = ClassDefinition(
name=class_name_node_name,
start_point=child_node.start_point,
end_point=child_node.end_point,
file_path=script_path,
source_code=child_node.text,
)
existing_nodes[class_name_node_name] = class_definition_node
yield existing_nodes[class_name_node_name]

View file

@ -1,158 +0,0 @@
import os
async def get_non_py_files(repo_path):
"""
Get files that are not .py files and their contents.
Check if the specified repository path exists and if so, traverse the directory,
collecting the paths of files that do not have a .py extension and meet the
criteria set in the allowed and ignored patterns. Return a list of paths to
those files.
Parameters:
-----------
- repo_path: The file system path to the repository to scan for non-Python files.
Returns:
--------
A list of file paths that are not Python files and meet the specified criteria.
"""
if not os.path.exists(repo_path):
return {}
IGNORED_PATTERNS = {
".git",
"__pycache__",
"*.pyc",
"*.pyo",
"*.pyd",
"node_modules",
"*.egg-info",
}
ALLOWED_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".css",
".js",
".ts",
".jsx",
".tsx",
".sql",
".log",
".ini",
".toml",
".properties",
".sh",
".bash",
".dockerfile",
".gitignore",
".gitattributes",
".makefile",
".pyproject",
".requirements",
".env",
".pdf",
".doc",
".docx",
".dot",
".dotx",
".rtf",
".wps",
".wpd",
".odt",
".ott",
".ottx",
".txt",
".wp",
".sdw",
".sdx",
".docm",
".dotm",
# Additional extensions for other programming languages
".java",
".c",
".cpp",
".h",
".cs",
".go",
".php",
".rb",
".swift",
".pl",
".lua",
".rs",
".scala",
".kt",
".sh",
".sql",
".v",
".asm",
".pas",
".d",
".ml",
".clj",
".cljs",
".erl",
".ex",
".exs",
".f",
".fs",
".r",
".pyi",
".pdb",
".ipynb",
".rmd",
".cabal",
".hs",
".nim",
".vhdl",
".verilog",
".svelte",
".html",
".css",
".scss",
".less",
".json5",
".yaml",
".yml",
}
def should_process(path):
"""
Determine if a file should be processed based on its extension and path patterns.
This function checks if the file extension is in the allowed list and ensures that none
of the ignored patterns are present in the provided file path.
Parameters:
-----------
- path: The file path to check for processing eligibility.
Returns:
--------
Returns True if the file should be processed; otherwise, False.
"""
_, ext = os.path.splitext(path)
return ext in ALLOWED_EXTENSIONS and not any(
pattern in path for pattern in IGNORED_PATTERNS
)
non_py_files_paths = [
os.path.join(root, file)
for root, _, files in os.walk(repo_path)
for file in files
if not file.endswith(".py") and should_process(os.path.join(root, file))
]
return non_py_files_paths

View file

@ -1,243 +0,0 @@
import asyncio
import math
import os
from pathlib import Path
from typing import Set
from typing import AsyncGenerator, Optional, List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository
# constant, declared only once
EXCLUDED_DIRS: Set[str] = {
".venv",
"venv",
"env",
".env",
"site-packages",
"node_modules",
"dist",
"build",
".git",
"tests",
"test",
}
async def get_source_code_files(
repo_path,
language_config: dict[str, list[str]] | None = None,
excluded_paths: Optional[List[str]] = None,
):
"""
Retrieve Python source code files from the specified repository path.
This function scans the given repository path for files that have the .py extension
while excluding test files and files within a virtual environment. It returns a list of
absolute paths to the source code files that are not empty.
Parameters:
-----------
- repo_path: Root path of the repository to search
- language_config: dict mapping language names to file extensions, e.g.,
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
- excluded_paths: Optional list of path fragments or glob patterns to exclude
Returns:
--------
A list of (absolute_path, language) tuples for source code files.
"""
def _get_language_from_extension(file, language_config):
for lang, exts in language_config.items():
for ext in exts:
if file.endswith(ext):
return lang
return None
# Default config if not provided
if language_config is None:
language_config = {
"python": [".py"],
"javascript": [".js", ".jsx"],
"typescript": [".ts", ".tsx"],
"java": [".java"],
"csharp": [".cs"],
"go": [".go"],
"rust": [".rs"],
"cpp": [".cpp", ".c", ".h", ".hpp"],
}
if not os.path.exists(repo_path):
return []
source_code_files = set()
for root, _, files in os.walk(repo_path):
for file in files:
lang = _get_language_from_extension(file, language_config)
if lang is None:
continue
# Exclude tests, common build/venv directories and files provided in exclude_paths
excluded_dirs = EXCLUDED_DIRS
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
root_path = Path(root).resolve()
root_parts = set(root_path.parts) # same as before
base_name, _ext = os.path.splitext(file)
if (
base_name.startswith("test_")
or base_name.endswith("_test")
or ".test." in file
or ".spec." in file
or (excluded_dirs & root_parts) # name match
or any(
root_path.is_relative_to(p) # full-path match
for p in excluded_paths
)
):
continue
file_path = os.path.abspath(os.path.join(root, file))
if os.path.getsize(file_path) == 0:
continue
source_code_files.add((file_path, lang))
return sorted(list(source_code_files))
def run_coroutine(coroutine_func, *args, **kwargs):
"""
Run a coroutine function until it completes.
This function creates a new asyncio event loop, sets it as the current loop, and
executes the given coroutine function with the provided arguments. Once the coroutine
completes, the loop is closed. Intended for use in environments where an existing event
loop is not available or desirable.
Parameters:
-----------
- coroutine_func: The coroutine function to be run.
- *args: Positional arguments to pass to the coroutine function.
- **kwargs: Keyword arguments to pass to the coroutine function.
Returns:
--------
The result returned by the coroutine after completion.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(
repo_path: str,
detailed_extraction: bool = False,
supported_languages: list = None,
excluded_paths: Optional[List[str]] = None,
) -> AsyncGenerator[DataPoint, None]:
"""
Generate a dependency graph for source files (multi-language) in the given repository path.
Check the validity of the repository path and yield a repository object followed by the
dependencies of source files within that repository. Raise a FileNotFoundError if the
provided path does not exist. The extraction of detailed dependencies can be controlled
via the `detailed_extraction` argument. Languages considered can be restricted via
the `supported_languages` argument.
Parameters:
-----------
- repo_path (str): The file path to the repository to process.
- detailed_extraction (bool): Whether to perform a detailed extraction of code parts.
- supported_languages (list | None): Subset of languages to include; if None, use defaults.
"""
if isinstance(repo_path, list) and len(repo_path) == 1:
repo_path = repo_path[0]
if not os.path.exists(repo_path):
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
# Build language config from supported_languages
default_language_config = {
"python": [".py"],
"javascript": [".js", ".jsx"],
"typescript": [".ts", ".tsx"],
"java": [".java"],
"csharp": [".cs"],
"go": [".go"],
"rust": [".rs"],
"cpp": [".cpp", ".c", ".h", ".hpp"],
"c": [".c", ".h"],
}
if supported_languages is not None:
language_config = {
k: v for k, v in default_language_config.items() if k in supported_languages
}
else:
language_config = default_language_config
source_code_files = await get_source_code_files(
repo_path, language_config=language_config, excluded_paths=excluded_paths
)
repo = Repository(
id=uuid5(NAMESPACE_OID, repo_path),
path=repo_path,
)
yield repo
chunk_size = 100
number_of_chunks = math.ceil(len(source_code_files) / chunk_size)
chunk_ranges = [
(
chunk_number * chunk_size,
min((chunk_number + 1) * chunk_size, len(source_code_files)) - 1,
)
for chunk_number in range(number_of_chunks)
]
# Import dependency extractors for each language (Python for now, extend later)
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
import aiofiles
# TODO: Add other language extractors here
for start_range, end_range in chunk_ranges:
tasks = []
for file_path, lang in source_code_files[start_range : end_range + 1]:
# For now, only Python is supported; extend with other languages
if lang == "python":
tasks.append(
get_local_script_dependencies(repo_path, file_path, detailed_extraction)
)
else:
# Placeholder: create a minimal CodeFile for other languages
async def make_codefile_stub(file_path=file_path, lang=lang):
async with aiofiles.open(
file_path, "r", encoding="utf-8", errors="replace"
) as f:
source = await f.read()
return CodeFile(
id=uuid5(NAMESPACE_OID, file_path),
name=os.path.relpath(file_path, repo_path),
file_path=file_path,
language=lang,
source_code=source,
)
tasks.append(make_codefile_stub())
results: list[CodeFile] = await asyncio.gather(*tasks)
for source_code_file in results:
source_code_file.part_of = repo
if getattr(
source_code_file, "language", None
) is None and source_code_file.file_path.endswith(".py"):
source_code_file.language = "python"
yield source_code_file

View file

@ -1,63 +0,0 @@
import argparse
import asyncio
import os
import cognee
from cognee import SearchType
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path, include_docs):
# Disable permissions feature for this example
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
run_status = False
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
run_status = run_status
# Test CODE search
search_results = await cognee.search(query_type=SearchType.CODE, query_text="test")
assert len(search_results) != 0, "The search results list is empty."
print("\n\nSearch results are:\n")
for result in search_results:
print(f"{result}\n")
return run_status
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument(
"--include_docs",
type=lambda x: x.lower() in ("true", "1"),
default=False,
help="Whether or not to process non-code files",
)
parser.add_argument(
"--time",
type=lambda x: x.lower() in ("true", "1"),
default=True,
help="Whether or not to time the pipeline run",
)
return parser.parse_args()
if __name__ == "__main__":
logger = setup_logging(log_level=ERROR)
args = parse_args()
if args.time:
import time
start_time = time.time()
asyncio.run(main(args.repo_path, args.include_docs))
end_time = time.time()
print("\n" + "=" * 50)
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
print("=" * 50 + "\n")
else:
asyncio.run(main(args.repo_path, args.include_docs))