feat: add codegraph related API endpoints
This commit is contained in:
parent
0c2c5870df
commit
3320bc8f2c
6 changed files with 84 additions and 25 deletions
|
|
@ -10,7 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||||
from cognee.api.v1.settings.routers import get_settings_router
|
from cognee.api.v1.settings.routers import get_settings_router
|
||||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||||
from cognee.api.v1.search.routers import get_search_router
|
from cognee.api.v1.search.routers import get_search_router
|
||||||
from cognee.api.v1.add.routers import get_add_router
|
from cognee.api.v1.add.routers import get_add_router
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
@ -169,6 +169,10 @@ app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["sett
|
||||||
|
|
||||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||||
|
|
||||||
|
app.include_router(
|
||||||
|
get_code_pipeline_router(), prefix="/api/v1/code-pipeline", tags=["code-pipeline"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
|
|
@ -33,22 +32,9 @@ update_status_lock = asyncio.Lock()
|
||||||
|
|
||||||
@observe
|
@observe
|
||||||
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||||
|
|
||||||
file_path = Path(__file__).parent
|
|
||||||
data_directory_path = str(
|
|
||||||
pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()
|
|
||||||
)
|
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
|
||||||
cognee_directory_path = str(
|
|
||||||
pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()
|
|
||||||
)
|
|
||||||
cognee.config.system_root_directory(cognee_directory_path)
|
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
|
|
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
from .get_cognify_router import get_cognify_router
|
from .get_cognify_router import get_cognify_router
|
||||||
|
from .get_code_pipeline_router import get_code_pipeline_router
|
||||||
|
|
|
||||||
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||||
|
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||||
|
code_description_to_code_part_search,
|
||||||
|
)
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class CodePipelineIndexPayloadDTO(BaseModel):
|
||||||
|
repo_path: str
|
||||||
|
include_docs: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CodePipelineRetrievePayloadDTO(BaseModel):
|
||||||
|
query: str
|
||||||
|
fullInput: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_code_pipeline_router() -> APIRouter:
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/index", response_model=None)
|
||||||
|
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||||
|
"""This endpoint is responsible for running the indexation on code repo."""
|
||||||
|
try:
|
||||||
|
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||||
|
print(result)
|
||||||
|
except Exception as error:
|
||||||
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
|
@router.post("/retrieve", response_model=list[dict])
|
||||||
|
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||||
|
"""This endpoint is responsible for retrieving the context."""
|
||||||
|
try:
|
||||||
|
query = (
|
||||||
|
payload.fullInput.replace("cognee ", "")
|
||||||
|
if payload.fullInput.startswith("cognee ")
|
||||||
|
else payload.fullInput
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved_codeparts, __ = await code_description_to_code_part_search(
|
||||||
|
query, include_docs=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": codepart.attributes["id"],
|
||||||
|
"description": codepart.attributes["id"],
|
||||||
|
"content": codepart.attributes["source_code"],
|
||||||
|
}
|
||||||
|
for codepart in retrieved_codeparts
|
||||||
|
]
|
||||||
|
except Exception as error:
|
||||||
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
@ -62,6 +62,8 @@ async def code_description_to_code_part(
|
||||||
"Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k
|
"Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context_from_documents = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if include_docs:
|
if include_docs:
|
||||||
search_results = await search(SearchType.INSIGHTS, query_text=query)
|
search_results = await search(SearchType.INSIGHTS, query_text=query)
|
||||||
|
|
@ -131,14 +133,7 @@ async def code_description_to_code_part(
|
||||||
len(code_pieces_to_return),
|
len(code_pieces_to_return),
|
||||||
)
|
)
|
||||||
|
|
||||||
context = ""
|
return code_pieces_to_return, context_from_documents
|
||||||
for code_piece in code_pieces_to_return:
|
|
||||||
context = context + code_piece.get_attribute("source_code")
|
|
||||||
|
|
||||||
if include_docs:
|
|
||||||
context = context_from_documents + context
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
except Exception as exec_error:
|
except Exception as exec_error:
|
||||||
logging.error(
|
logging.error(
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,15 @@ def check_install_package(package_name):
|
||||||
|
|
||||||
|
|
||||||
async def generate_patch_with_cognee(instance):
|
async def generate_patch_with_cognee(instance):
|
||||||
|
import os
|
||||||
|
from cognee import config
|
||||||
|
|
||||||
|
file_path = Path(__file__).parent
|
||||||
|
data_directory_path = str(Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
|
||||||
|
|
||||||
|
config.data_root_directory(data_directory_path)
|
||||||
|
config.system_root_directory(data_directory_path)
|
||||||
|
|
||||||
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
|
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
|
||||||
include_docs = True
|
include_docs = True
|
||||||
problem_statement = instance["problem_statement"]
|
problem_statement = instance["problem_statement"]
|
||||||
|
|
@ -42,10 +51,17 @@ async def generate_patch_with_cognee(instance):
|
||||||
async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
retrieved_codeparts = await code_description_to_code_part_search(
|
retrieved_codeparts, context_from_documents = await code_description_to_code_part_search(
|
||||||
problem_statement, include_docs=include_docs
|
problem_statement, include_docs=include_docs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context = ""
|
||||||
|
for code_piece in retrieved_codeparts:
|
||||||
|
context = context + code_piece.get_attribute("source_code")
|
||||||
|
|
||||||
|
if include_docs:
|
||||||
|
context = context_from_documents + context
|
||||||
|
|
||||||
prompt = "\n".join(
|
prompt = "\n".join(
|
||||||
[
|
[
|
||||||
problem_statement,
|
problem_statement,
|
||||||
|
|
@ -53,7 +69,7 @@ async def generate_patch_with_cognee(instance):
|
||||||
PATCH_EXAMPLE,
|
PATCH_EXAMPLE,
|
||||||
"</patch>",
|
"</patch>",
|
||||||
"This is the additional context to solve the problem (description from documentation together with codeparts):",
|
"This is the additional context to solve the problem (description from documentation together with codeparts):",
|
||||||
retrieved_codeparts,
|
context,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue