add test for linter

This commit is contained in:
Vasilije 2024-05-25 22:33:12 +02:00
parent 9569441c5e
commit a3e218e5a4
3 changed files with 76 additions and 128 deletions

View file

@ -1,11 +1,16 @@
""" FastAPI server for the Cognee API. """ """ FastAPI server for the Cognee API. """
import os import os
import aiohttp import aiohttp
import uvicorn import uvicorn
import json import json
import logging import logging
from typing import Dict, Any, List, Union, Optional
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# Set up logging # Set up logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL) level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
@ -14,15 +19,10 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from cognee.config import Config from cognee.config import Config
config = Config() config = Config()
config.load() config.load()
from typing import Dict, Any, List, Union, Annotated, Literal, Optional
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
app = FastAPI(debug=True) app = FastAPI(debug=True)
origins = [ origins = [
@ -33,19 +33,12 @@ origins = [
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins = origins, allow_origins=origins,
allow_credentials = True, allow_credentials=True,
allow_methods = ["OPTIONS", "GET", "POST", "DELETE"], allow_methods=["OPTIONS", "GET", "POST", "DELETE"],
allow_headers = ["*"], allow_headers=["*"],
) )
#
# from auth.cognito.JWTBearer import JWTBearer
# from auth.auth import jwks
#
# auth = JWTBearer(jwks)
@app.get("/") @app.get("/")
async def root(): async def root():
""" """
@ -53,7 +46,6 @@ async def root():
""" """
return {"message": "Hello, World, I am alive!"} return {"message": "Hello, World, I am alive!"}
@app.get("/health") @app.get("/health")
def health_check(): def health_check():
""" """
@ -61,11 +53,9 @@ def health_check():
""" """
return {"status": "OK"} return {"status": "OK"}
class Payload(BaseModel): class Payload(BaseModel):
payload: Dict[str, Any] payload: Dict[str, Any]
@app.get("/datasets", response_model=list) @app.get("/datasets", response_model=list)
async def get_datasets(): async def get_datasets():
from cognee import datasets from cognee import datasets
@ -74,77 +64,67 @@ async def get_datasets():
@app.delete("/datasets/{dataset_id}", response_model=dict) @app.delete("/datasets/{dataset_id}", response_model=dict)
async def delete_dataset(dataset_id: str): async def delete_dataset(dataset_id: str):
from cognee import datasets from cognee import datasets
datasets.delete_dataset(dataset_id) datasets.delete_dataset(dataset_id)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = "OK", content="OK",
) )
@app.get("/datasets/{dataset_id}/graph", response_model=list) @app.get("/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str): async def get_dataset_graph(dataset_id: str):
from cognee import utils from cognee import utils
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
graph_engine = infrastructure_config.get_config("graph_engine") graph_engine = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_engine) graph_client = await get_graph_client(graph_engine)
graph_url = await utils.render_graph(graph_client.graph) graph_url = await utils.render_graph(graph_client.graph)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = str(graph_url), content=str(graph_url),
) )
@app.get("/datasets/{dataset_id}/data", response_model=list) @app.get("/datasets/{dataset_id}/data", response_model=list)
async def get_dataset_data(dataset_id: str): async def get_dataset_data(dataset_id: str):
from cognee import datasets from cognee import datasets
dataset_data = datasets.list_data(dataset_id) dataset_data = datasets.list_data(dataset_id)
if dataset_data is None: if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.") raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
return [
return [dict( dict(
id = data["id"], id=data["id"],
name = f"{data['name']}.{data['extension']}", name=f"{data['name']}.{data['extension']}",
keywords = data["keywords"].split("|"), keywords=data["keywords"].split("|"),
filePath = data["file_path"], filePath=data["file_path"],
mimeType = data["mime_type"], mimeType=data["mime_type"],
) for data in dataset_data] )
for data in dataset_data
]
@app.get("/datasets/status", response_model=dict) @app.get("/datasets/status", response_model=dict)
async def get_dataset_status(datasets: Annotated[list, Query(alias = "dataset")] = None): async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
from cognee import datasets as cognee_datasets from cognee import datasets as cognee_datasets
datasets_statuses = cognee_datasets.get_status(datasets) datasets_statuses = cognee_datasets.get_status(datasets)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = datasets_statuses content=datasets_statuses
) )
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse) @app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str): async def get_raw_data(dataset_id: str, data_id: str):
from cognee import datasets from cognee import datasets
dataset_data = datasets.list_data(dataset_id) dataset_data = datasets.list_data(dataset_id)
if dataset_data is None: if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.") raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
data = [data for data in dataset_data if data["id"] == data_id][0] data = [data for data in dataset_data if data["id"] == data_id][0]
return data["file_path"] return data["file_path"]
class AddPayload(BaseModel): class AddPayload(BaseModel):
data: Union[str, UploadFile, List[Union[str, UploadFile]]] data: Union[str, UploadFile, List[Union[str, UploadFile]]]
dataset_id: str dataset_id: str
class Config: class Config:
arbitrary_types_allowed = True # This is required to allow the use of Union arbitrary_types_allowed = True
@app.post("/add", response_model=dict) @app.post("/add", response_model=dict)
async def add( async def add(
@ -153,7 +133,6 @@ async def add(
): ):
""" This endpoint is responsible for adding data to the graph.""" """ This endpoint is responsible for adding data to the graph."""
from cognee import add as cognee_add from cognee import add as cognee_add
try: try:
if isinstance(data, str) and data.startswith("http"): if isinstance(data, str) and data.startswith("http"):
if "github" in data: if "github" in data:
@ -182,69 +161,62 @@ async def add(
datasetId, datasetId,
) )
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = "OK" content="OK"
) )
except Exception as error: except Exception as error:
return JSONResponse( return JSONResponse(
status_code = 409, status_code=409,
content = { "error": str(error) } content={"error": str(error)}
) )
class CognifyPayload(BaseModel): class CognifyPayload(BaseModel):
datasets: list[str] datasets: List[str]
@app.post("/cognify", response_model=dict) @app.post("/cognify", response_model=dict)
async def cognify(payload: CognifyPayload): async def cognify(payload: CognifyPayload):
""" This endpoint is responsible for the cognitive processing of the content.""" """ This endpoint is responsible for the cognitive processing of the content."""
from cognee import cognify as cognee_cognify from cognee import cognify as cognee_cognify
try: try:
await cognee_cognify(payload.datasets) await cognee_cognify(payload.datasets)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = "OK" content="OK"
) )
except Exception as error: except Exception as error:
return JSONResponse( return JSONResponse(
status_code = 409, status_code=409,
content = { "error": error } content={"error": str(error)}
) )
class SearchPayload(BaseModel): class SearchPayload(BaseModel):
query_params: Dict[str, Any] query_params: Dict[str, Any]
@app.post("/search", response_model=dict) @app.post("/search", response_model=dict)
async def search(payload: SearchPayload): async def search(payload: SearchPayload):
""" This endpoint is responsible for searching for nodes in the graph.""" """ This endpoint is responsible for searching for nodes in the graph."""
from cognee import search as cognee_search from cognee import search as cognee_search
try: try:
search_type = payload.query_params["searchType"] search_type = payload.query_params["searchType"]
params = { params = {
"query": payload.query_params["query"], "query": payload.query_params["query"],
} }
results = await cognee_search(search_type, params) results = await cognee_search(search_type, params)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = json.dumps(results) content=json.dumps(results)
) )
except Exception as error: except Exception as error:
return JSONResponse( return JSONResponse(
status_code = 409, status_code=409,
content = { "error": error } content={"error": str(error)}
) )
@app.get("/settings", response_model=dict) @app.get("/settings", response_model=dict)
async def get_settings(): async def get_settings():
from cognee.modules.settings import get_settings from cognee.modules.settings import get_settings
return get_settings() return get_settings()
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]] provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str model: str
@ -264,15 +236,14 @@ async def save_config(new_settings: SettingsPayload):
from cognee.modules.settings import save_llm_config, save_vector_db_config from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None: if new_settings.llm is not None:
await save_llm_config(new_settings.llm) await save_llm_config(new_settings.llm)
if new_settings.vectorDB is not None: if new_settings.vectorDB is not None:
await save_vector_db_config(new_settings.vectorDB) await save_vector_db_config(new_settings.vectorDB)
return JSONResponse( return JSONResponse(
status_code = 200, status_code=200,
content = "OK", content="OK",
) )
def start_api_server(host: str = "0.0.0.0", port: int = 8000): def start_api_server(host: str = "0.0.0.0", port: int = 8000):
""" """
Start the API server using uvicorn. Start the API server using uvicorn.

View file

@ -8,7 +8,8 @@ import pandas as pd
from pydantic import BaseModel from pydantic import BaseModel
USER_ID = "default_user" USER_ID = "default_user"
async def add_topology(directory="example", model=GitHubRepositoryModel):
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
graph_db_type = infrastructure_config.get_config()["graph_engine"] graph_db_type = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_db_type) graph_client = await get_graph_client(graph_db_type)
@ -16,7 +17,7 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
graph_topology = infrastructure_config.get_config()["graph_topology"] graph_topology = infrastructure_config.get_config()["graph_topology"]
engine = TopologyEngine() engine = TopologyEngine()
topology = await engine.infer_from_directory_structure(node_id =USER_ID , repository = directory, model=model) topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
def flatten_model(model: BaseModel, parent_id: Optional[str] = None) -> Dict[str, Any]: def flatten_model(model: BaseModel, parent_id: Optional[str] = None) -> Dict[str, Any]:
"""Flatten a single Pydantic model to a dictionary handling nested structures.""" """Flatten a single Pydantic model to a dictionary handling nested structures."""
@ -42,17 +43,16 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
else: else:
return [] return []
def flatten_repository(repo_model): def flatten_repository(repo_model: BaseModel) -> List[Dict[str, Any]]:
""" Flatten the entire repository model, starting with the top-level model """ """ Flatten the entire repository model, starting with the top-level model """
return recursive_flatten(repo_model) return recursive_flatten(repo_model)
flt_topology = flatten_repository(topology) flt_topology = flatten_repository(topology)
df =pd.DataFrame(flt_topology) df = pd.DataFrame(flt_topology)
print(df.head(10)) print(df.head(10))
for _, row in df.iterrows(): for _, row in df.iterrows():
node_data = row.to_dict() node_data = row.to_dict()
node_id = node_data.pop('node_id') node_id = node_data.pop('node_id')
@ -65,9 +65,10 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
return graph_client.graph return graph_client.graph
if __name__ == "__main__": if __name__ == "__main__":
async def test(): async def test() -> None:
# Uncomment and modify the following lines as needed
# await prune.prune_system() # await prune.prune_system()
# # #
# from cognee.api.v1.add import add # from cognee.api.v1.add import add
# data_directory_path = os.path.abspath("../../../.data") # data_directory_path = os.path.abspath("../../../.data")
# # print(data_directory_path) # # print(data_directory_path)
@ -75,7 +76,7 @@ if __name__ == "__main__":
# # cognee_directory_path = os.path.abspath("../.cognee_system") # # cognee_directory_path = os.path.abspath("../.cognee_system")
# # config.system_root_directory(cognee_directory_path) # # config.system_root_directory(cognee_directory_path)
# #
# await add("data://" +data_directory_path, "example") # await add("data://" + data_directory_path, "example")
# graph = await add_topology() # graph = await add_topology()
@ -88,4 +89,4 @@ if __name__ == "__main__":
await render_graph(graph_client.graph, include_color=True, include_nodes=False, include_size=False) await render_graph(graph_client.graph, include_color=True, include_nodes=False, include_size=False)
import asyncio import asyncio
asyncio.run(test()) asyncio.run(test())

View file

@ -1,11 +1,7 @@
import os import os
import glob import glob
from pydantic import BaseModel, create_model
from typing import Dict, Type, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Type, Any
from datetime import datetime from datetime import datetime
from cognee import config from cognee import config
@ -13,23 +9,6 @@ from cognee.infrastructure import infrastructure_config
from cognee.modules.topology.infer_data_topology import infer_data_topology from cognee.modules.topology.infer_data_topology import infer_data_topology
# class UserLocation(BaseModel):
# location_id: str
# description: str
# default_relationship: Relationship = Relationship(type = "located_in")
#
# class UserProperties(BaseModel):
# custom_properties: Optional[Dict[str, Any]] = None
# location: Optional[UserLocation] = None
#
# class DefaultGraphModel(BaseModel):
# node_id: str
# user_properties: UserProperties = UserProperties()
# documents: List[Document] = []
# default_fields: Optional[Dict[str, Any]] = {}
# default_relationship: Relationship = Relationship(type = "has_properties")
#
class Relationship(BaseModel): class Relationship(BaseModel):
type: str = Field(..., description="The type of relationship, e.g., 'belongs_to'.") type: str = Field(..., description="The type of relationship, e.g., 'belongs_to'.")
source: Optional[str] = Field(None, description="The identifier of the source id of in the relationship being a directory or subdirectory") source: Optional[str] = Field(None, description="The identifier of the source id of in the relationship being a directory or subdirectory")
@ -37,7 +16,6 @@ class Relationship(BaseModel):
properties: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional properties and values related to the relationship.") properties: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional properties and values related to the relationship.")
class Document(BaseModel): class Document(BaseModel):
node_id: str node_id: str
title: str title: str
@ -53,8 +31,10 @@ class DirectoryModel(BaseModel):
subdirectories: List['DirectoryModel'] = [] subdirectories: List['DirectoryModel'] = []
default_relationship: Relationship default_relationship: Relationship
DirectoryModel.update_forward_refs() DirectoryModel.update_forward_refs()
class DirMetadata(BaseModel): class DirMetadata(BaseModel):
node_id: str node_id: str
summary: str summary: str
@ -64,6 +44,7 @@ class DirMetadata(BaseModel):
documents: List[Document] = [] documents: List[Document] = []
default_relationship: Relationship default_relationship: Relationship
class GitHubRepositoryModel(BaseModel): class GitHubRepositoryModel(BaseModel):
node_id: str node_id: str
metadata: DirMetadata metadata: DirMetadata
@ -71,10 +52,10 @@ class GitHubRepositoryModel(BaseModel):
class TopologyEngine: class TopologyEngine:
def __init__(self): def __init__(self) -> None:
self.models: Dict[str, Type[BaseModel]] = {} self.models: Dict[str, Type[BaseModel]] = {}
async def populate_model(self, directory_path, file_structure, parent_id=None): async def populate_model(self, directory_path: str, file_structure: Dict[str, Union[Dict, Tuple[str, ...]]], parent_id: Optional[str] = None) -> DirectoryModel:
directory_id = os.path.basename(directory_path) or "root" directory_id = os.path.basename(directory_path) or "root"
directory = DirectoryModel( directory = DirectoryModel(
node_id=directory_id, node_id=directory_id,
@ -100,18 +81,17 @@ class TopologyEngine:
return directory return directory
async def infer_from_directory_structure(self, node_id:str, repository: str, model): async def infer_from_directory_structure(self, node_id: str, repository: str, model: Type[BaseModel]) -> GitHubRepositoryModel:
""" Infer the topology of a repository from its file structure """ """ Infer the topology of a repository from its file structure """
path = infrastructure_config.get_config()["data_root_directory"] path = infrastructure_config.get_config()["data_root_directory"]
path = path + "/" + str(repository)
path = path +"/"+ str(repository)
print(path) print(path)
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError(f"No such directory: {path}") raise FileNotFoundError(f"No such directory: {path}")
root = {} root: Dict[str, Union[Dict, Tuple[str, ...]]] = {}
for filename in glob.glob(f"{path}/**", recursive=True): for filename in glob.glob(f"{path}/**", recursive=True):
parts = os.path.relpath(filename, start=path).split(os.path.sep) parts = os.path.relpath(filename, start=path).split(os.path.sep)
current = root current = root
@ -128,8 +108,6 @@ class TopologyEngine:
root_directory = await self.populate_model('/', root) root_directory = await self.populate_model('/', root)
# repository_metadata = await infer_data_topology(str(root), DirMetadata)
repository_metadata = DirMetadata( repository_metadata = DirMetadata(
node_id="repo1", node_id="repo1",
summary="Example repository", summary="Example repository",
@ -147,13 +125,10 @@ class TopologyEngine:
return active_model return active_model
# print(github_repo_model) def load(self, model_name: str) -> Optional[Type[BaseModel]]:
def load(self, model_name: str):
return self.models.get(model_name) return self.models.get(model_name)
def extrapolate(self, model_name: str): def extrapolate(self, model_name: str) -> None:
# This method would be implementation-specific depending on what "extrapolate" means # This method would be implementation-specific depending on what "extrapolate" means
pass pass
@ -164,15 +139,16 @@ if __name__ == "__main__":
config.data_root_directory(data_directory_path) config.data_root_directory(data_directory_path)
cognee_directory_path = os.path.abspath("../.cognee_system") cognee_directory_path = os.path.abspath("../.cognee_system")
config.system_root_directory(cognee_directory_path) config.system_root_directory(cognee_directory_path)
async def main():
async def main() -> None:
engine = TopologyEngine() engine = TopologyEngine()
# model = engine.load("GitHubRepositoryModel") # model = engine.load("GitHubRepositoryModel")
# if model is None: # if model is None:
# raise ValueError("Model not found") # raise ValueError("Model not found")
result = await engine.infer("example") result = await engine.infer_from_directory_structure("example_node_id", "example_repo", GitHubRepositoryModel)
print(result) print(result)
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())
# result = engine.extrapolate("GitHubRepositoryModel") # result = engine.extrapolate("GitHubRepositoryModel")
# print(result) # print(result)