Merge branch 'dev' into COG-970-refactor-tokenizing
This commit is contained in:
commit
6f8cbdbf1c
6 changed files with 84 additions and 25 deletions
|
|
@ -10,7 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from fastapi import Request
|
||||
|
|
@ -169,6 +169,10 @@ app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["sett
|
|||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
||||
app.include_router(
|
||||
get_code_pipeline_router(), prefix="/api/v1/code-pipeline", tags=["code-pipeline"]
|
||||
)
|
||||
|
||||
|
||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
|
|
@ -34,22 +33,9 @@ update_status_lock = asyncio.Lock()
|
|||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
57
cognee/api/v1/cognify/routers/get_code_pipeline_router.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
code_description_to_code_part_search,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(BaseModel):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(BaseModel):
|
||||
query: str
|
||||
fullInput: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""This endpoint is responsible for running the indexation on code repo."""
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
print(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""This endpoint is responsible for retrieving the context."""
|
||||
try:
|
||||
query = (
|
||||
payload.fullInput.replace("cognee ", "")
|
||||
if payload.fullInput.startswith("cognee ")
|
||||
else payload.fullInput
|
||||
)
|
||||
|
||||
retrieved_codeparts, __ = await code_description_to_code_part_search(
|
||||
query, include_docs=False
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": codepart.attributes["id"],
|
||||
"description": codepart.attributes["id"],
|
||||
"content": codepart.attributes["source_code"],
|
||||
}
|
||||
for codepart in retrieved_codeparts
|
||||
]
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -62,6 +62,8 @@ async def code_description_to_code_part(
|
|||
"Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k
|
||||
)
|
||||
|
||||
context_from_documents = ""
|
||||
|
||||
try:
|
||||
if include_docs:
|
||||
search_results = await search(SearchType.INSIGHTS, query_text=query)
|
||||
|
|
@ -131,14 +133,7 @@ async def code_description_to_code_part(
|
|||
len(code_pieces_to_return),
|
||||
)
|
||||
|
||||
context = ""
|
||||
for code_piece in code_pieces_to_return:
|
||||
context = context + code_piece.get_attribute("source_code")
|
||||
|
||||
if include_docs:
|
||||
context = context_from_documents + context
|
||||
|
||||
return context
|
||||
return code_pieces_to_return, context_from_documents
|
||||
|
||||
except Exception as exec_error:
|
||||
logging.error(
|
||||
|
|
|
|||
|
|
@ -34,6 +34,15 @@ def check_install_package(package_name):
|
|||
|
||||
|
||||
async def generate_patch_with_cognee(instance):
|
||||
import os
|
||||
from cognee import config
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
|
||||
|
||||
config.data_root_directory(data_directory_path)
|
||||
config.system_root_directory(data_directory_path)
|
||||
|
||||
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
|
||||
include_docs = True
|
||||
problem_statement = instance["problem_statement"]
|
||||
|
|
@ -42,10 +51,17 @@ async def generate_patch_with_cognee(instance):
|
|||
async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||
print(result)
|
||||
|
||||
retrieved_codeparts = await code_description_to_code_part_search(
|
||||
retrieved_codeparts, context_from_documents = await code_description_to_code_part_search(
|
||||
problem_statement, include_docs=include_docs
|
||||
)
|
||||
|
||||
context = ""
|
||||
for code_piece in retrieved_codeparts:
|
||||
context = context + code_piece.get_attribute("source_code")
|
||||
|
||||
if include_docs:
|
||||
context = context_from_documents + context
|
||||
|
||||
prompt = "\n".join(
|
||||
[
|
||||
problem_statement,
|
||||
|
|
@ -53,7 +69,7 @@ async def generate_patch_with_cognee(instance):
|
|||
PATCH_EXAMPLE,
|
||||
"</patch>",
|
||||
"This is the additional context to solve the problem (description from documentation together with codeparts):",
|
||||
retrieved_codeparts,
|
||||
context,
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue