add test for linter

This commit is contained in:
Vasilije 2024-05-25 22:33:12 +02:00
parent 9569441c5e
commit a3e218e5a4
3 changed files with 76 additions and 128 deletions

View file

@ -1,11 +1,16 @@
""" FastAPI server for the Cognee API. """
import os
import aiohttp
import uvicorn
import json
import logging
from typing import Dict, Any, List, Union, Optional
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# Set up logging
logging.basicConfig(
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
@ -14,15 +19,10 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
from cognee.config import Config
config = Config()
config.load()
from typing import Dict, Any, List, Union, Annotated, Literal, Optional
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
app = FastAPI(debug=True)
origins = [
@ -33,19 +33,12 @@ origins = [
app.add_middleware(
CORSMiddleware,
allow_origins = origins,
allow_credentials = True,
allow_methods = ["OPTIONS", "GET", "POST", "DELETE"],
allow_headers = ["*"],
allow_origins=origins,
allow_credentials=True,
allow_methods=["OPTIONS", "GET", "POST", "DELETE"],
allow_headers=["*"],
)
#
# from auth.cognito.JWTBearer import JWTBearer
# from auth.auth import jwks
#
# auth = JWTBearer(jwks)
@app.get("/")
async def root():
"""
@ -53,7 +46,6 @@ async def root():
"""
return {"message": "Hello, World, I am alive!"}
@app.get("/health")
def health_check():
"""
@ -61,11 +53,9 @@ def health_check():
"""
return {"status": "OK"}
class Payload(BaseModel):
payload: Dict[str, Any]
@app.get("/datasets", response_model=list)
async def get_datasets():
from cognee import datasets
@ -74,77 +64,67 @@ async def get_datasets():
@app.delete("/datasets/{dataset_id}", response_model=dict)
async def delete_dataset(dataset_id: str):
from cognee import datasets
datasets.delete_dataset(dataset_id)
return JSONResponse(
status_code = 200,
content = "OK",
status_code=200,
content="OK",
)
@app.get("/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str):
from cognee import utils
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
graph_engine = infrastructure_config.get_config("graph_engine")
graph_engine = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_engine)
graph_url = await utils.render_graph(graph_client.graph)
return JSONResponse(
status_code = 200,
content = str(graph_url),
status_code=200,
content=str(graph_url),
)
@app.get("/datasets/{dataset_id}/data", response_model=list)
async def get_dataset_data(dataset_id: str):
from cognee import datasets
dataset_data = datasets.list_data(dataset_id)
if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
return [dict(
id = data["id"],
name = f"{data['name']}.{data['extension']}",
keywords = data["keywords"].split("|"),
filePath = data["file_path"],
mimeType = data["mime_type"],
) for data in dataset_data]
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
return [
dict(
id=data["id"],
name=f"{data['name']}.{data['extension']}",
keywords=data["keywords"].split("|"),
filePath=data["file_path"],
mimeType=data["mime_type"],
)
for data in dataset_data
]
@app.get("/datasets/status", response_model=dict)
async def get_dataset_status(datasets: Annotated[list, Query(alias = "dataset")] = None):
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
from cognee import datasets as cognee_datasets
datasets_statuses = cognee_datasets.get_status(datasets)
return JSONResponse(
status_code = 200,
content = datasets_statuses
status_code=200,
content=datasets_statuses
)
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str):
from cognee import datasets
dataset_data = datasets.list_data(dataset_id)
if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
data = [data for data in dataset_data if data["id"] == data_id][0]
return data["file_path"]
class AddPayload(BaseModel):
data: Union[str, UploadFile, List[Union[str, UploadFile]]]
dataset_id: str
class Config:
arbitrary_types_allowed = True # This is required to allow the use of Union
arbitrary_types_allowed = True
@app.post("/add", response_model=dict)
async def add(
@ -153,7 +133,6 @@ async def add(
):
""" This endpoint is responsible for adding data to the graph."""
from cognee import add as cognee_add
try:
if isinstance(data, str) and data.startswith("http"):
if "github" in data:
@ -182,69 +161,62 @@ async def add(
datasetId,
)
return JSONResponse(
status_code = 200,
content = "OK"
status_code=200,
content="OK"
)
except Exception as error:
return JSONResponse(
status_code = 409,
content = { "error": str(error) }
status_code=409,
content={"error": str(error)}
)
class CognifyPayload(BaseModel):
datasets: list[str]
datasets: List[str]
@app.post("/cognify", response_model=dict)
async def cognify(payload: CognifyPayload):
""" This endpoint is responsible for the cognitive processing of the content."""
from cognee import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets)
return JSONResponse(
status_code = 200,
content = "OK"
status_code=200,
content="OK"
)
except Exception as error:
return JSONResponse(
status_code = 409,
content = { "error": error }
status_code=409,
content={"error": str(error)}
)
class SearchPayload(BaseModel):
query_params: Dict[str, Any]
query_params: Dict[str, Any]
@app.post("/search", response_model=dict)
async def search(payload: SearchPayload):
""" This endpoint is responsible for searching for nodes in the graph."""
from cognee import search as cognee_search
try:
search_type = payload.query_params["searchType"]
params = {
"query": payload.query_params["query"],
"query": payload.query_params["query"],
}
results = await cognee_search(search_type, params)
return JSONResponse(
status_code = 200,
content = json.dumps(results)
status_code=200,
content=json.dumps(results)
)
except Exception as error:
return JSONResponse(
status_code = 409,
content = { "error": error }
status_code=409,
content={"error": str(error)}
)
@app.get("/settings", response_model=dict)
async def get_settings():
from cognee.modules.settings import get_settings
return get_settings()
class LLMConfig(BaseModel):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str
@ -264,15 +236,14 @@ async def save_config(new_settings: SettingsPayload):
from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None:
await save_llm_config(new_settings.llm)
if new_settings.vectorDB is not None:
await save_vector_db_config(new_settings.vectorDB)
return JSONResponse(
status_code = 200,
content = "OK",
status_code=200,
content="OK",
)
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
"""
Start the API server using uvicorn.

View file

@ -8,7 +8,8 @@ import pandas as pd
from pydantic import BaseModel
USER_ID = "default_user"
async def add_topology(directory="example", model=GitHubRepositoryModel):
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
graph_db_type = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_db_type)
@ -16,7 +17,7 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
graph_topology = infrastructure_config.get_config()["graph_topology"]
engine = TopologyEngine()
topology = await engine.infer_from_directory_structure(node_id =USER_ID , repository = directory, model=model)
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
def flatten_model(model: BaseModel, parent_id: Optional[str] = None) -> Dict[str, Any]:
"""Flatten a single Pydantic model to a dictionary handling nested structures."""
@ -42,17 +43,16 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
else:
return []
def flatten_repository(repo_model):
def flatten_repository(repo_model: BaseModel) -> List[Dict[str, Any]]:
""" Flatten the entire repository model, starting with the top-level model """
return recursive_flatten(repo_model)
flt_topology = flatten_repository(topology)
df =pd.DataFrame(flt_topology)
df = pd.DataFrame(flt_topology)
print(df.head(10))
for _, row in df.iterrows():
node_data = row.to_dict()
node_id = node_data.pop('node_id')
@ -65,9 +65,10 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
return graph_client.graph
if __name__ == "__main__":
async def test():
async def test() -> None:
# Uncomment and modify the following lines as needed
# await prune.prune_system()
# #
#
# from cognee.api.v1.add import add
# data_directory_path = os.path.abspath("../../../.data")
# # print(data_directory_path)
@ -75,7 +76,7 @@ if __name__ == "__main__":
# # cognee_directory_path = os.path.abspath("../.cognee_system")
# # config.system_root_directory(cognee_directory_path)
#
# await add("data://" +data_directory_path, "example")
# await add("data://" + data_directory_path, "example")
# graph = await add_topology()
@ -88,4 +89,4 @@ if __name__ == "__main__":
await render_graph(graph_client.graph, include_color=True, include_nodes=False, include_size=False)
import asyncio
asyncio.run(test())
asyncio.run(test())

View file

@ -1,11 +1,7 @@
import os
import glob
from pydantic import BaseModel, create_model
from typing import Dict, Type, Any
from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Type, Any
from datetime import datetime
from cognee import config
@ -13,23 +9,6 @@ from cognee.infrastructure import infrastructure_config
from cognee.modules.topology.infer_data_topology import infer_data_topology
# class UserLocation(BaseModel):
# location_id: str
# description: str
# default_relationship: Relationship = Relationship(type = "located_in")
#
# class UserProperties(BaseModel):
# custom_properties: Optional[Dict[str, Any]] = None
# location: Optional[UserLocation] = None
#
# class DefaultGraphModel(BaseModel):
# node_id: str
# user_properties: UserProperties = UserProperties()
# documents: List[Document] = []
# default_fields: Optional[Dict[str, Any]] = {}
# default_relationship: Relationship = Relationship(type = "has_properties")
#
class Relationship(BaseModel):
type: str = Field(..., description="The type of relationship, e.g., 'belongs_to'.")
source: Optional[str] = Field(None, description="The identifier of the source id of in the relationship being a directory or subdirectory")
@ -37,7 +16,6 @@ class Relationship(BaseModel):
properties: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional properties and values related to the relationship.")
class Document(BaseModel):
node_id: str
title: str
@ -53,8 +31,10 @@ class DirectoryModel(BaseModel):
subdirectories: List['DirectoryModel'] = []
default_relationship: Relationship
DirectoryModel.update_forward_refs()
class DirMetadata(BaseModel):
node_id: str
summary: str
@ -64,6 +44,7 @@ class DirMetadata(BaseModel):
documents: List[Document] = []
default_relationship: Relationship
class GitHubRepositoryModel(BaseModel):
node_id: str
metadata: DirMetadata
@ -71,10 +52,10 @@ class GitHubRepositoryModel(BaseModel):
class TopologyEngine:
def __init__(self):
def __init__(self) -> None:
self.models: Dict[str, Type[BaseModel]] = {}
async def populate_model(self, directory_path, file_structure, parent_id=None):
async def populate_model(self, directory_path: str, file_structure: Dict[str, Union[Dict, Tuple[str, ...]]], parent_id: Optional[str] = None) -> DirectoryModel:
directory_id = os.path.basename(directory_path) or "root"
directory = DirectoryModel(
node_id=directory_id,
@ -100,18 +81,17 @@ class TopologyEngine:
return directory
async def infer_from_directory_structure(self, node_id:str, repository: str, model):
async def infer_from_directory_structure(self, node_id: str, repository: str, model: Type[BaseModel]) -> GitHubRepositoryModel:
""" Infer the topology of a repository from its file structure """
path = infrastructure_config.get_config()["data_root_directory"]
path = path +"/"+ str(repository)
path = path + "/" + str(repository)
print(path)
if not os.path.exists(path):
raise FileNotFoundError(f"No such directory: {path}")
root = {}
root: Dict[str, Union[Dict, Tuple[str, ...]]] = {}
for filename in glob.glob(f"{path}/**", recursive=True):
parts = os.path.relpath(filename, start=path).split(os.path.sep)
current = root
@ -128,8 +108,6 @@ class TopologyEngine:
root_directory = await self.populate_model('/', root)
# repository_metadata = await infer_data_topology(str(root), DirMetadata)
repository_metadata = DirMetadata(
node_id="repo1",
summary="Example repository",
@ -147,13 +125,10 @@ class TopologyEngine:
return active_model
# print(github_repo_model)
def load(self, model_name: str):
def load(self, model_name: str) -> Optional[Type[BaseModel]]:
return self.models.get(model_name)
def extrapolate(self, model_name: str):
def extrapolate(self, model_name: str) -> None:
# This method would be implementation-specific depending on what "extrapolate" means
pass
@ -164,15 +139,16 @@ if __name__ == "__main__":
config.data_root_directory(data_directory_path)
cognee_directory_path = os.path.abspath("../.cognee_system")
config.system_root_directory(cognee_directory_path)
async def main():
async def main() -> None:
engine = TopologyEngine()
# model = engine.load("GitHubRepositoryModel")
# if model is None:
# raise ValueError("Model not found")
result = await engine.infer("example")
result = await engine.infer_from_directory_structure("example_node_id", "example_repo", GitHubRepositoryModel)
print(result)
import asyncio
asyncio.run(main())
# result = engine.extrapolate("GitHubRepositoryModel")
# print(result)
# print(result)