refactor: move codify pipeline out of main repo (#1738)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> This PR removes codify, and the code graph pipeline, out of the repository. It also introduces a Custom Pipeline interface, which can be used in the future to define custom pipelines. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
9571641199
20 changed files with 26 additions and 1617 deletions
30
.github/workflows/basic_tests.yml
vendored
30
.github/workflows/basic_tests.yml
vendored
|
|
@ -197,33 +197,3 @@ jobs:
|
|||
|
||||
- name: Run Simple Examples
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
graph-tests:
|
||||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run Graph Tests
|
||||
run: uv run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph
|
||||
|
|
|
|||
|
|
@ -90,97 +90,6 @@ async def health_check(request):
|
|||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognee_add_developer_rules(
|
||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Ingest core developer rule files into Cognee's memory layer.
|
||||
|
||||
This function loads a predefined set of developer-related configuration,
|
||||
rule, and documentation files from the base repository and assigns them
|
||||
to the special 'developer_rules' node set in Cognee. It ensures these
|
||||
foundational files are always part of the structured memory graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_path : str
|
||||
Root path to resolve relative file paths. Defaults to current directory.
|
||||
|
||||
graph_model_file : str, optional
|
||||
Optional path to a custom schema file for knowledge graph generation.
|
||||
|
||||
graph_model_name : str, optional
|
||||
Optional class name to use from the graph_model_file schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A message indicating how many rule files were scheduled for ingestion,
|
||||
and how to check their processing status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Each file is processed asynchronously in the background.
|
||||
- Files are attached to the 'developer_rules' node set.
|
||||
- Missing files are skipped with a logged warning.
|
||||
"""
|
||||
|
||||
developer_rule_paths = [
|
||||
".cursorrules",
|
||||
".cursor/rules",
|
||||
".same/todos.md",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
"CLAUDE.md",
|
||||
".sourcegraph/memory.md",
|
||||
"AGENT.md",
|
||||
"AGENTS.md",
|
||||
]
|
||||
|
||||
async def cognify_task(file_path: str) -> None:
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
tasks = []
|
||||
for rel_path in developer_rule_paths:
|
||||
abs_path = os.path.join(base_path, rel_path)
|
||||
if os.path.isfile(abs_path):
|
||||
tasks.append(asyncio.create_task(cognify_task(abs_path)))
|
||||
else:
|
||||
logger.warning(f"Skipped missing developer rule file: {abs_path}")
|
||||
log_file = get_log_file_location()
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Started cognify for {len(tasks)} developer rule files in background.\n"
|
||||
f"All are added to the `developer_rules` node set.\n"
|
||||
f"Use `cognify_status` or check logs at {log_file} to monitor progress."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
|
|
@ -406,75 +315,6 @@ async def save_interaction(data: str) -> list:
|
|||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify(repo_path: str) -> list:
|
||||
"""
|
||||
Analyze and generate a code-specific knowledge graph from a software repository.
|
||||
|
||||
This function launches a background task that processes the provided repository
|
||||
and builds a code knowledge graph. The function returns immediately while
|
||||
the processing continues in the background due to MCP timeout constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_path : str
|
||||
Path to the code repository to analyze. This can be a local file path or a
|
||||
relative path to a repository. The path should point to the root of the
|
||||
repository or a specific directory within it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with information about the
|
||||
background task launch and how to check its status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function launches a background task and returns immediately
|
||||
- The code graph generation may take significant time for larger repositories
|
||||
- Use the codify_status tool to check the progress of the operation
|
||||
- Process results are logged to the standard Cognee log file
|
||||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
asyncio.create_task(codify_task(repo_path))
|
||||
|
||||
log_file = get_log_file_location()
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current codify status use the codify_status tool\n"
|
||||
f"or you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
"""
|
||||
|
|
@ -629,45 +469,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_developer_rules() -> list:
|
||||
"""
|
||||
Retrieve all developer rules that were generated based on previous interactions.
|
||||
|
||||
This tool queries the Cognee knowledge graph and returns a list of developer
|
||||
rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the retrieved developer rules.
|
||||
The format is plain text containing the developer rules in bulletpoints.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The specific logic for fetching rules is handled internally.
|
||||
- This tool does not accept any parameters and is intended for simple rule inspection use cases.
|
||||
"""
|
||||
|
||||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
rules_text = await fetch_rules_from_cognee()
|
||||
|
||||
return [types.TextContent(type="text", text=rules_text)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_data(dataset_id: str = None) -> list:
|
||||
"""
|
||||
|
|
@ -953,48 +754,6 @@ async def cognify_status():
|
|||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify_status():
|
||||
"""
|
||||
Get the current status of the codify pipeline.
|
||||
|
||||
This function retrieves information about current and recently completed codify operations
|
||||
in the codebase dataset. It provides details on progress, success/failure status, and statistics
|
||||
about the processed code repositories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the status information as a string.
|
||||
The status includes information about active and completed jobs for the cognify_code_pipeline.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from cognee.api.v1.notebooks.routers import get_notebooks_router
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
|
|
@ -278,10 +278,6 @@ app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["re
|
|||
|
||||
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix="/api/v1/users",
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
|
||||
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
logger = get_logger("code_graph_pipeline")
|
||||
|
||||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(
|
||||
repo_path,
|
||||
include_docs=False,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
supported_languages: Optional[list[str]] = None,
|
||||
):
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
cognee_config = get_cognify_config()
|
||||
user = await get_default_user()
|
||||
detailed_extraction = True
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
get_repo_file_dependencies,
|
||||
detailed_extraction=detailed_extraction,
|
||||
supported_languages=supported_languages,
|
||||
excluded_paths=excluded_paths,
|
||||
),
|
||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config={"batch_size": 30}),
|
||||
]
|
||||
|
||||
if include_docs:
|
||||
# This tasks take a long time to complete
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
]
|
||||
|
||||
dataset_name = "codebase"
|
||||
|
||||
# Save dataset to database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
||||
)
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
search_results = await search(
|
||||
query_type=SearchType.CODE,
|
||||
query_text="How is Relationship weight calculated?",
|
||||
)
|
||||
|
||||
for file in search_results:
|
||||
print(file["name"])
|
||||
|
||||
logger = setup_logging(name="code_graph_pipeline")
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,2 +1 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
|
|
@ -1,90 +0,0 @@
|
|||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(InDTO):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(InDTO):
|
||||
query: str
|
||||
full_input: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
try:
|
||||
import cognee.api.v1.cognify.code_graph_pipeline
|
||||
except ModuleNotFoundError:
|
||||
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
|
||||
return None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""
|
||||
Run indexation on a code repository.
|
||||
|
||||
This endpoint processes a code repository to create a knowledge graph
|
||||
of the codebase structure, dependencies, and relationships.
|
||||
|
||||
## Request Parameters
|
||||
- **repo_path** (str): Path to the code repository
|
||||
- **include_docs** (bool): Whether to include documentation files (default: false)
|
||||
|
||||
## Response
|
||||
No content returned. Processing results are logged.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during indexation process
|
||||
"""
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
logger.info(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""
|
||||
Retrieve context from the code knowledge graph.
|
||||
|
||||
This endpoint searches the indexed code repository to find relevant
|
||||
context based on the provided query.
|
||||
|
||||
## Request Parameters
|
||||
- **query** (str): Search query for code context
|
||||
- **full_input** (str): Full input text for processing
|
||||
|
||||
## Response
|
||||
Returns a list of relevant code files and context as JSON.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during retrieval process
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
payload.full_input.replace("cognee ", "")
|
||||
if payload.full_input.startswith("cognee ")
|
||||
else payload.full_input
|
||||
)
|
||||
|
||||
retriever = CodeRetriever()
|
||||
retrieved_files = await retriever.get_context(query)
|
||||
|
||||
return json.dumps(retrieved_files, cls=JSONEncoder)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -1 +1 @@
|
|||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
|
||||
|
|
|
|||
|
|
@ -1,232 +0,0 @@
|
|||
from typing import Any, Optional, List
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
logger = get_logger("CodeRetriever")
|
||||
|
||||
|
||||
class CodeRetriever(BaseRetriever):
|
||||
"""Retriever for handling code-based searches."""
|
||||
|
||||
class CodeQueryInfo(BaseModel):
|
||||
"""
|
||||
Model for representing the result of a query related to code files.
|
||||
|
||||
This class holds a list of filenames and the corresponding source code extracted from a
|
||||
query. It is used to encapsulate response data in a structured format.
|
||||
"""
|
||||
|
||||
filenames: List[str] = []
|
||||
sourcecode: str
|
||||
|
||||
def __init__(self, top_k: int = 3):
|
||||
"""Initialize retriever with search parameters."""
|
||||
self.top_k = top_k
|
||||
self.file_name_collections = ["CodeFile_name"]
|
||||
self.classes_and_functions_collections = [
|
||||
"ClassDefinition_source_code",
|
||||
"FunctionDefinition_source_code",
|
||||
]
|
||||
|
||||
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
||||
"""Process the query using LLM to extract file names and source code parts."""
|
||||
logger.debug(
|
||||
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
|
||||
try:
|
||||
result = await LLMGateway.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=self.CodeQueryInfo,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
||||
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find relevant code files based on the query."""
|
||||
logger.info(
|
||||
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if not query or not isinstance(query, str):
|
||||
logger.error("Invalid query: must be a non-empty string")
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
logger.debug("Successfully initialized vector and graph engines")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {str(e)}")
|
||||
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
||||
|
||||
files_and_codeparts = await self._process_query(query)
|
||||
|
||||
similar_filenames = []
|
||||
similar_codepieces = []
|
||||
|
||||
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
||||
logger.info("No specific files/code extracted from query, performing general search")
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
existing_collection = []
|
||||
for collection in self.classes_and_functions_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
existing_collection.append(collection)
|
||||
|
||||
if not existing_collection:
|
||||
raise RuntimeError("No collection found for code retriever")
|
||||
|
||||
for collection in existing_collection:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
||||
)
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
for file_from_query in files_and_codeparts.filenames:
|
||||
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, file_from_query, limit=self.top_k
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(search_results_file)} results for file {file_from_query}"
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} with extracted source code")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
total_items = len(similar_filenames) + len(similar_codepieces)
|
||||
logger.info(
|
||||
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
||||
)
|
||||
|
||||
if total_items == 0:
|
||||
logger.warning("No search results found, returning empty list")
|
||||
return []
|
||||
|
||||
logger.debug("Getting graph connections for all search results")
|
||||
relevant_triplets = await asyncio.gather(
|
||||
*[
|
||||
graph_engine.get_connections(similar_piece["id"])
|
||||
for similar_piece in similar_filenames + similar_codepieces
|
||||
]
|
||||
)
|
||||
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
||||
|
||||
paths = set()
|
||||
for i, sublist in enumerate(relevant_triplets):
|
||||
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
||||
for tpl in sublist:
|
||||
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
||||
if "file_path" in tpl[0]:
|
||||
paths.add(tpl[0]["file_path"])
|
||||
if "file_path" in tpl[2]:
|
||||
paths.add(tpl[2]["file_path"])
|
||||
|
||||
logger.info(f"Found {len(paths)} unique file paths to read")
|
||||
|
||||
retrieved_files = {}
|
||||
read_tasks = []
|
||||
for file_path in paths:
|
||||
|
||||
async def read_file(fp):
|
||||
try:
|
||||
logger.debug(f"Reading file: {fp}")
|
||||
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
retrieved_files[fp] = content
|
||||
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {fp}: {e}")
|
||||
retrieved_files[fp] = ""
|
||||
|
||||
read_tasks.append(read_file(file_path))
|
||||
|
||||
await asyncio.gather(*read_tasks)
|
||||
logger.info(
|
||||
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
||||
)
|
||||
|
||||
result = [
|
||||
{
|
||||
"name": file_path,
|
||||
"description": file_path,
|
||||
"content": retrieved_files[file_path],
|
||||
}
|
||||
for file_path in paths
|
||||
]
|
||||
|
||||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the code files context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve code context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The code files context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
10
cognee/modules/retrieval/register_retriever.py
Normal file
10
cognee/modules/retrieval/register_retriever.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Type
|
||||
|
||||
from .base_retriever import BaseRetriever
|
||||
from .registered_community_retrievers import registered_community_retrievers
|
||||
from ..search.types import SearchType
|
||||
|
||||
|
||||
def use_retriever(search_type: SearchType, retriever: Type[BaseRetriever]):
|
||||
"""Register a retriever class for a given search type."""
|
||||
registered_community_retrievers[search_type] = retriever
|
||||
|
|
@ -0,0 +1 @@
|
|||
registered_community_retrievers = {}
|
||||
|
|
@ -23,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
|
||||
|
|
@ -162,10 +161,6 @@ async def get_search_type_tools(
|
|||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
CodeRetriever(top_k=top_k).get_completion,
|
||||
CodeRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CYPHER: [
|
||||
CypherSearchRetriever().get_completion,
|
||||
CypherSearchRetriever().get_context,
|
||||
|
|
@ -208,7 +203,19 @@ async def get_search_type_tools(
|
|||
):
|
||||
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
||||
|
||||
search_type_tools = search_tasks.get(query_type)
|
||||
from cognee.modules.retrieval.registered_community_retrievers import (
|
||||
registered_community_retrievers,
|
||||
)
|
||||
|
||||
if query_type in registered_community_retrievers:
|
||||
retriever = registered_community_retrievers[query_type]
|
||||
retriever_instance = retriever(top_k=top_k)
|
||||
search_type_tools = [
|
||||
retriever_instance.get_completion,
|
||||
retriever_instance.get_context,
|
||||
]
|
||||
else:
|
||||
search_type_tools = search_tasks.get(query_type)
|
||||
|
||||
if not search_type_tools:
|
||||
raise UnsupportedSearchTypeError(str(query_type))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ class SearchType(Enum):
|
|||
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
||||
CODE = "CODE"
|
||||
CYPHER = "CYPHER"
|
||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
|
||||
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Execute the main logic of the dependency graph processor.
|
||||
|
||||
This function sets up argument parsing to retrieve the repository path, checks the
|
||||
existence of the specified path, and processes the repository to produce a dependency
|
||||
graph. If the repository path does not exist, it logs an error message and terminates
|
||||
without further execution.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_file_dependencies(repo_path))
|
||||
graph = asyncio.run(enrich_dependency_graph(graph))
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
for _, target, data in graph.out_edges(node, data=True):
|
||||
print(f" Edge to {target}, data: {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Get local script dependencies.")
|
||||
|
||||
# Suggested path: .../cognee/examples/python/simple_example.py
|
||||
parser.add_argument("script_path", type=str, help="Absolute path to the Python script file")
|
||||
|
||||
# Suggested path: .../cognee
|
||||
parser.add_argument("repo_path", type=str, help="Absolute path to the repository root")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dependencies = asyncio.run(get_local_script_dependencies(args.script_path, args.repo_path))
|
||||
|
||||
print("Dependencies:")
|
||||
for dependency in dependencies:
|
||||
print(dependency)
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse the command line arguments and print the repository file dependencies.
|
||||
|
||||
This function sets up an argument parser to retrieve the path of a repository. It checks
|
||||
if the provided path exists and if it doesn’t, it prints an error message and exits. If
|
||||
the path is valid, it calls an asynchronous function to get the dependencies and prints
|
||||
the nodes and their relations in the dependency graph.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_file_dependencies(repo_path))
|
||||
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
edges = graph.edges(node, data=True)
|
||||
for _, target, data in edges:
|
||||
print(f" Edge to {target}, Relation: {data.get('relation')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
from .get_non_code_files import get_non_py_files
|
||||
from .get_repo_file_dependencies import get_repo_file_dependencies
|
||||
|
|
@ -1,335 +0,0 @@
|
|||
import os
|
||||
import aiofiles
|
||||
import importlib
|
||||
from typing import AsyncGenerator, Optional
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import tree_sitter_python as tspython
|
||||
from tree_sitter import Language, Node, Parser, Tree
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import (
|
||||
CodeFile,
|
||||
ImportStatement,
|
||||
FunctionDefinition,
|
||||
ClassDefinition,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
Handles the parsing of files into source code and an abstract syntax tree
|
||||
representation. Public methods include:
|
||||
|
||||
- parse_file: Parses a file and returns its source code and syntax tree representation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parsed_files = {}
|
||||
|
||||
async def parse_file(self, file_path: str) -> tuple[str, Tree]:
|
||||
"""
|
||||
Parse a file and return its source code along with its syntax tree representation.
|
||||
|
||||
If the file has already been parsed, retrieve the result from memory instead of reading
|
||||
the file again.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to parse.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- tuple[str, Tree]: A tuple containing the source code of the file and its
|
||||
corresponding syntax tree representation.
|
||||
"""
|
||||
PY_LANGUAGE = Language(tspython.language())
|
||||
source_code_parser = Parser(PY_LANGUAGE)
|
||||
|
||||
if file_path not in self.parsed_files:
|
||||
source_code = await get_source_code(file_path)
|
||||
source_code_tree = source_code_parser.parse(bytes(source_code, "utf-8"))
|
||||
self.parsed_files[file_path] = (source_code, source_code_tree)
|
||||
|
||||
return self.parsed_files[file_path]
|
||||
|
||||
|
||||
async def get_source_code(file_path: str):
|
||||
"""
|
||||
Read source code from a file asynchronously.
|
||||
|
||||
This function attempts to open a file specified by the given file path, read its
|
||||
contents, and return the source code. In case of any errors during the file reading
|
||||
process, it logs an error message and returns None.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path to the file from which to read the source code.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the contents of the file as a string if successful, or None if an error
|
||||
occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
source_code = await f.read()
|
||||
return source_code
|
||||
except Exception as error:
|
||||
logger.error(f"Error reading file {file_path}: {str(error)}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_module_path(module_name):
|
||||
"""
|
||||
Find the file path of a module.
|
||||
|
||||
Return the file path of the specified module if found, or return None if the module does
|
||||
not exist or cannot be located.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- module_name: The name of the module whose file path is to be resolved.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The file path of the module as a string or None if the module is not found.
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec and spec.origin:
|
||||
return spec.origin
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def find_function_location(
|
||||
module_path: str, function_name: str, parser: FileParser
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Find the location of a function definition in a specified module.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- module_path (str): The path to the module where the function is defined.
|
||||
- function_name (str): The name of the function whose location is to be found.
|
||||
- parser (FileParser): An instance of FileParser used to parse the module's source
|
||||
code.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Optional[tuple[str, str]]: Returns a tuple containing the module path and the
|
||||
start point of the function if found; otherwise, returns None.
|
||||
"""
|
||||
if not module_path or not os.path.exists(module_path):
|
||||
return None
|
||||
|
||||
source_code, tree = parser.parse_file(module_path)
|
||||
root_node: Node = tree.root_node
|
||||
|
||||
for node in root_node.children:
|
||||
if node.type == "function_definition":
|
||||
func_name_node = node.child_by_field_name("name")
|
||||
|
||||
if func_name_node and func_name_node.text.decode() == function_name:
|
||||
return (module_path, node.start_point) # (line, column)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_local_script_dependencies(
|
||||
repo_path: str, script_path: str, detailed_extraction: bool = False
|
||||
) -> CodeFile:
|
||||
"""
|
||||
Retrieve local script dependencies and create a CodeFile object.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path (str): The path to the repository that contains the script.
|
||||
- script_path (str): The path of the script for which dependencies are being
|
||||
extracted.
|
||||
- detailed_extraction (bool): A flag indicating whether to perform a detailed
|
||||
extraction of code components.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- CodeFile: Returns a CodeFile object containing information about the script,
|
||||
including its dependencies and definitions.
|
||||
"""
|
||||
code_file_parser = FileParser()
|
||||
source_code, source_code_tree = await code_file_parser.parse_file(script_path)
|
||||
|
||||
file_path_relative_to_repo = script_path[len(repo_path) + 1 :]
|
||||
|
||||
if not detailed_extraction:
|
||||
code_file_node = CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, script_path),
|
||||
name=file_path_relative_to_repo,
|
||||
source_code=source_code,
|
||||
file_path=script_path,
|
||||
language="python",
|
||||
)
|
||||
return code_file_node
|
||||
|
||||
code_file_node = CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, script_path),
|
||||
name=file_path_relative_to_repo,
|
||||
source_code=None,
|
||||
file_path=script_path,
|
||||
language="python",
|
||||
)
|
||||
|
||||
async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
|
||||
part.file_path = script_path
|
||||
|
||||
if isinstance(part, FunctionDefinition):
|
||||
code_file_node.provides_function_definition.append(part)
|
||||
if isinstance(part, ClassDefinition):
|
||||
code_file_node.provides_class_definition.append(part)
|
||||
if isinstance(part, ImportStatement):
|
||||
code_file_node.depends_on.append(part)
|
||||
|
||||
return code_file_node
|
||||
|
||||
|
||||
def find_node(nodes: list[Node], condition: callable) -> Node:
|
||||
"""
|
||||
Find and return the first node that satisfies the given condition.
|
||||
|
||||
Iterate through the provided list of nodes and return the first node for which the
|
||||
condition callable returns True. If no such node is found, return None.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- nodes (list[Node]): A list of Node objects to search through.
|
||||
- condition (callable): A callable that takes a Node and returns a boolean
|
||||
indicating if the node meets specified criteria.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Node: The first Node that matches the condition, or None if no such node exists.
|
||||
"""
|
||||
for node in nodes:
|
||||
if condition(node):
|
||||
return node
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def extract_code_parts(
|
||||
tree_root: Node, script_path: str, existing_nodes: list[DataPoint] = {}
|
||||
) -> AsyncGenerator[DataPoint, None]:
|
||||
"""
|
||||
Extract code parts from a given AST node tree asynchronously.
|
||||
|
||||
Iteratively yields DataPoint nodes representing import statements, function definitions,
|
||||
and class definitions found in the children of the specified tree root. The function
|
||||
checks
|
||||
if nodes are already present in the existing_nodes dictionary to prevent duplicates.
|
||||
This function has to be used in an asynchronous context, and it requires a valid
|
||||
tree_root
|
||||
and proper initialization of existing_nodes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_root (Node): The root node of the AST tree containing code parts to extract.
|
||||
- script_path (str): The file path of the script from which the AST was generated.
|
||||
- existing_nodes (list[DataPoint]): A dictionary that holds already extracted
|
||||
DataPoint nodes to avoid duplicates. (default {})
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Yields DataPoint nodes representing imported modules, functions, and classes.
|
||||
"""
|
||||
for child_node in tree_root.children:
|
||||
if child_node.type == "import_statement" or child_node.type == "import_from_statement":
|
||||
parts = child_node.text.decode("utf-8").split()
|
||||
|
||||
if parts[0] == "import":
|
||||
module_name = parts[1]
|
||||
function_name = None
|
||||
elif parts[0] == "from":
|
||||
module_name = parts[1]
|
||||
function_name = parts[3]
|
||||
|
||||
if " as " in function_name:
|
||||
function_name = function_name.split(" as ")[0]
|
||||
|
||||
if " as " in module_name:
|
||||
module_name = module_name.split(" as ")[0]
|
||||
|
||||
if function_name and "import " + function_name not in existing_nodes:
|
||||
import_statement_node = ImportStatement(
|
||||
name=function_name,
|
||||
module=module_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes["import " + function_name] = import_statement_node
|
||||
|
||||
if function_name:
|
||||
yield existing_nodes["import " + function_name]
|
||||
|
||||
if module_name not in existing_nodes:
|
||||
import_statement_node = ImportStatement(
|
||||
name=module_name,
|
||||
module=module_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[module_name] = import_statement_node
|
||||
|
||||
yield existing_nodes[module_name]
|
||||
|
||||
if child_node.type == "function_definition":
|
||||
function_node = find_node(child_node.children, lambda node: node.type == "identifier")
|
||||
function_node_name = function_node.text
|
||||
|
||||
if function_node_name not in existing_nodes:
|
||||
function_definition_node = FunctionDefinition(
|
||||
name=function_node_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[function_node_name] = function_definition_node
|
||||
|
||||
yield existing_nodes[function_node_name]
|
||||
|
||||
if child_node.type == "class_definition":
|
||||
class_name_node = find_node(child_node.children, lambda node: node.type == "identifier")
|
||||
class_name_node_name = class_name_node.text
|
||||
|
||||
if class_name_node_name not in existing_nodes:
|
||||
class_definition_node = ClassDefinition(
|
||||
name=class_name_node_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[class_name_node_name] = class_definition_node
|
||||
|
||||
yield existing_nodes[class_name_node_name]
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
import os
|
||||
|
||||
|
||||
async def get_non_py_files(repo_path):
|
||||
"""
|
||||
Get files that are not .py files and their contents.
|
||||
|
||||
Check if the specified repository path exists and if so, traverse the directory,
|
||||
collecting the paths of files that do not have a .py extension and meet the
|
||||
criteria set in the allowed and ignored patterns. Return a list of paths to
|
||||
those files.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path: The file system path to the repository to scan for non-Python files.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A list of file paths that are not Python files and meet the specified criteria.
|
||||
"""
|
||||
if not os.path.exists(repo_path):
|
||||
return {}
|
||||
|
||||
IGNORED_PATTERNS = {
|
||||
".git",
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
"*.pyd",
|
||||
"node_modules",
|
||||
"*.egg-info",
|
||||
}
|
||||
|
||||
ALLOWED_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".html",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".sql",
|
||||
".log",
|
||||
".ini",
|
||||
".toml",
|
||||
".properties",
|
||||
".sh",
|
||||
".bash",
|
||||
".dockerfile",
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
".makefile",
|
||||
".pyproject",
|
||||
".requirements",
|
||||
".env",
|
||||
".pdf",
|
||||
".doc",
|
||||
".docx",
|
||||
".dot",
|
||||
".dotx",
|
||||
".rtf",
|
||||
".wps",
|
||||
".wpd",
|
||||
".odt",
|
||||
".ott",
|
||||
".ottx",
|
||||
".txt",
|
||||
".wp",
|
||||
".sdw",
|
||||
".sdx",
|
||||
".docm",
|
||||
".dotm",
|
||||
# Additional extensions for other programming languages
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".cs",
|
||||
".go",
|
||||
".php",
|
||||
".rb",
|
||||
".swift",
|
||||
".pl",
|
||||
".lua",
|
||||
".rs",
|
||||
".scala",
|
||||
".kt",
|
||||
".sh",
|
||||
".sql",
|
||||
".v",
|
||||
".asm",
|
||||
".pas",
|
||||
".d",
|
||||
".ml",
|
||||
".clj",
|
||||
".cljs",
|
||||
".erl",
|
||||
".ex",
|
||||
".exs",
|
||||
".f",
|
||||
".fs",
|
||||
".r",
|
||||
".pyi",
|
||||
".pdb",
|
||||
".ipynb",
|
||||
".rmd",
|
||||
".cabal",
|
||||
".hs",
|
||||
".nim",
|
||||
".vhdl",
|
||||
".verilog",
|
||||
".svelte",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".json5",
|
||||
".yaml",
|
||||
".yml",
|
||||
}
|
||||
|
||||
def should_process(path):
|
||||
"""
|
||||
Determine if a file should be processed based on its extension and path patterns.
|
||||
|
||||
This function checks if the file extension is in the allowed list and ensures that none
|
||||
of the ignored patterns are present in the provided file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- path: The file path to check for processing eligibility.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns True if the file should be processed; otherwise, False.
|
||||
"""
|
||||
_, ext = os.path.splitext(path)
|
||||
return ext in ALLOWED_EXTENSIONS and not any(
|
||||
pattern in path for pattern in IGNORED_PATTERNS
|
||||
)
|
||||
|
||||
non_py_files_paths = [
|
||||
os.path.join(root, file)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for file in files
|
||||
if not file.endswith(".py") and should_process(os.path.join(root, file))
|
||||
]
|
||||
return non_py_files_paths
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||
|
||||
# constant, declared only once
|
||||
EXCLUDED_DIRS: Set[str] = {
|
||||
".venv",
|
||||
"venv",
|
||||
"env",
|
||||
".env",
|
||||
"site-packages",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
".git",
|
||||
"tests",
|
||||
"test",
|
||||
}
|
||||
|
||||
|
||||
async def get_source_code_files(
|
||||
repo_path,
|
||||
language_config: dict[str, list[str]] | None = None,
|
||||
excluded_paths: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Retrieve Python source code files from the specified repository path.
|
||||
|
||||
This function scans the given repository path for files that have the .py extension
|
||||
while excluding test files and files within a virtual environment. It returns a list of
|
||||
absolute paths to the source code files that are not empty.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- repo_path: Root path of the repository to search
|
||||
- language_config: dict mapping language names to file extensions, e.g.,
|
||||
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
||||
- excluded_paths: Optional list of path fragments or glob patterns to exclude
|
||||
|
||||
Returns:
|
||||
--------
|
||||
A list of (absolute_path, language) tuples for source code files.
|
||||
"""
|
||||
|
||||
def _get_language_from_extension(file, language_config):
|
||||
for lang, exts in language_config.items():
|
||||
for ext in exts:
|
||||
if file.endswith(ext):
|
||||
return lang
|
||||
return None
|
||||
|
||||
# Default config if not provided
|
||||
if language_config is None:
|
||||
language_config = {
|
||||
"python": [".py"],
|
||||
"javascript": [".js", ".jsx"],
|
||||
"typescript": [".ts", ".tsx"],
|
||||
"java": [".java"],
|
||||
"csharp": [".cs"],
|
||||
"go": [".go"],
|
||||
"rust": [".rs"],
|
||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||
}
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
return []
|
||||
|
||||
source_code_files = set()
|
||||
for root, _, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
lang = _get_language_from_extension(file, language_config)
|
||||
if lang is None:
|
||||
continue
|
||||
# Exclude tests, common build/venv directories and files provided in exclude_paths
|
||||
excluded_dirs = EXCLUDED_DIRS
|
||||
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
|
||||
|
||||
root_path = Path(root).resolve()
|
||||
root_parts = set(root_path.parts) # same as before
|
||||
base_name, _ext = os.path.splitext(file)
|
||||
if (
|
||||
base_name.startswith("test_")
|
||||
or base_name.endswith("_test")
|
||||
or ".test." in file
|
||||
or ".spec." in file
|
||||
or (excluded_dirs & root_parts) # name match
|
||||
or any(
|
||||
root_path.is_relative_to(p) # full-path match
|
||||
for p in excluded_paths
|
||||
)
|
||||
):
|
||||
continue
|
||||
file_path = os.path.abspath(os.path.join(root, file))
|
||||
if os.path.getsize(file_path) == 0:
|
||||
continue
|
||||
source_code_files.add((file_path, lang))
|
||||
|
||||
return sorted(list(source_code_files))
|
||||
|
||||
|
||||
def run_coroutine(coroutine_func, *args, **kwargs):
|
||||
"""
|
||||
Run a coroutine function until it completes.
|
||||
|
||||
This function creates a new asyncio event loop, sets it as the current loop, and
|
||||
executes the given coroutine function with the provided arguments. Once the coroutine
|
||||
completes, the loop is closed. Intended for use in environments where an existing event
|
||||
loop is not available or desirable.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- coroutine_func: The coroutine function to be run.
|
||||
- *args: Positional arguments to pass to the coroutine function.
|
||||
- **kwargs: Keyword arguments to pass to the coroutine function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result returned by the coroutine after completion.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
|
||||
async def get_repo_file_dependencies(
|
||||
repo_path: str,
|
||||
detailed_extraction: bool = False,
|
||||
supported_languages: list = None,
|
||||
excluded_paths: Optional[List[str]] = None,
|
||||
) -> AsyncGenerator[DataPoint, None]:
|
||||
"""
|
||||
Generate a dependency graph for source files (multi-language) in the given repository path.
|
||||
|
||||
Check the validity of the repository path and yield a repository object followed by the
|
||||
dependencies of source files within that repository. Raise a FileNotFoundError if the
|
||||
provided path does not exist. The extraction of detailed dependencies can be controlled
|
||||
via the `detailed_extraction` argument. Languages considered can be restricted via
|
||||
the `supported_languages` argument.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path (str): The file path to the repository to process.
|
||||
- detailed_extraction (bool): Whether to perform a detailed extraction of code parts.
|
||||
- supported_languages (list | None): Subset of languages to include; if None, use defaults.
|
||||
"""
|
||||
|
||||
if isinstance(repo_path, list) and len(repo_path) == 1:
|
||||
repo_path = repo_path[0]
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
||||
|
||||
# Build language config from supported_languages
|
||||
default_language_config = {
|
||||
"python": [".py"],
|
||||
"javascript": [".js", ".jsx"],
|
||||
"typescript": [".ts", ".tsx"],
|
||||
"java": [".java"],
|
||||
"csharp": [".cs"],
|
||||
"go": [".go"],
|
||||
"rust": [".rs"],
|
||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||
"c": [".c", ".h"],
|
||||
}
|
||||
if supported_languages is not None:
|
||||
language_config = {
|
||||
k: v for k, v in default_language_config.items() if k in supported_languages
|
||||
}
|
||||
else:
|
||||
language_config = default_language_config
|
||||
|
||||
source_code_files = await get_source_code_files(
|
||||
repo_path, language_config=language_config, excluded_paths=excluded_paths
|
||||
)
|
||||
|
||||
repo = Repository(
|
||||
id=uuid5(NAMESPACE_OID, repo_path),
|
||||
path=repo_path,
|
||||
)
|
||||
|
||||
yield repo
|
||||
|
||||
chunk_size = 100
|
||||
number_of_chunks = math.ceil(len(source_code_files) / chunk_size)
|
||||
chunk_ranges = [
|
||||
(
|
||||
chunk_number * chunk_size,
|
||||
min((chunk_number + 1) * chunk_size, len(source_code_files)) - 1,
|
||||
)
|
||||
for chunk_number in range(number_of_chunks)
|
||||
]
|
||||
|
||||
# Import dependency extractors for each language (Python for now, extend later)
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
import aiofiles
|
||||
# TODO: Add other language extractors here
|
||||
|
||||
for start_range, end_range in chunk_ranges:
|
||||
tasks = []
|
||||
for file_path, lang in source_code_files[start_range : end_range + 1]:
|
||||
# For now, only Python is supported; extend with other languages
|
||||
if lang == "python":
|
||||
tasks.append(
|
||||
get_local_script_dependencies(repo_path, file_path, detailed_extraction)
|
||||
)
|
||||
else:
|
||||
# Placeholder: create a minimal CodeFile for other languages
|
||||
async def make_codefile_stub(file_path=file_path, lang=lang):
|
||||
async with aiofiles.open(
|
||||
file_path, "r", encoding="utf-8", errors="replace"
|
||||
) as f:
|
||||
source = await f.read()
|
||||
return CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, file_path),
|
||||
name=os.path.relpath(file_path, repo_path),
|
||||
file_path=file_path,
|
||||
language=lang,
|
||||
source_code=source,
|
||||
)
|
||||
|
||||
tasks.append(make_codefile_stub())
|
||||
|
||||
results: list[CodeFile] = await asyncio.gather(*tasks)
|
||||
|
||||
for source_code_file in results:
|
||||
source_code_file.part_of = repo
|
||||
if getattr(
|
||||
source_code_file, "language", None
|
||||
) is None and source_code_file.file_path.endswith(".py"):
|
||||
source_code_file.language = "python"
|
||||
yield source_code_file
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
|
||||
async def main(repo_path, include_docs):
|
||||
# Disable permissions feature for this example
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
run_status = False
|
||||
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||
run_status = run_status
|
||||
|
||||
# Test CODE search
|
||||
search_results = await cognee.search(query_type=SearchType.CODE, query_text="test")
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nSearch results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
return run_status
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
||||
parser.add_argument(
|
||||
"--include_docs",
|
||||
type=lambda x: x.lower() in ("true", "1"),
|
||||
default=False,
|
||||
help="Whether or not to process non-code files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time",
|
||||
type=lambda x: x.lower() in ("true", "1"),
|
||||
default=True,
|
||||
help="Whether or not to time the pipeline run",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if args.time:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
asyncio.run(main(args.repo_path, args.include_docs))
|
||||
end_time = time.time()
|
||||
print("\n" + "=" * 50)
|
||||
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
|
||||
print("=" * 50 + "\n")
|
||||
else:
|
||||
asyncio.run(main(args.repo_path, args.include_docs))
|
||||
Loading…
Add table
Reference in a new issue