add test for linter
This commit is contained in:
parent
9569441c5e
commit
a3e218e5a4
3 changed files with 76 additions and 128 deletions
|
|
@ -1,11 +1,16 @@
|
||||||
""" FastAPI server for the Cognee API. """
|
""" FastAPI server for the Cognee API. """
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict, Any, List, Union, Optional
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
||||||
|
from fastapi.responses import JSONResponse, FileResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
level=logging.INFO, # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
|
@ -14,15 +19,10 @@ logging.basicConfig(
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
||||||
from typing import Dict, Any, List, Union, Annotated, Literal, Optional
|
|
||||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
|
||||||
from fastapi.responses import JSONResponse, FileResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
app = FastAPI(debug=True)
|
app = FastAPI(debug=True)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
|
|
@ -33,19 +33,12 @@ origins = [
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins = origins,
|
allow_origins=origins,
|
||||||
allow_credentials = True,
|
allow_credentials=True,
|
||||||
allow_methods = ["OPTIONS", "GET", "POST", "DELETE"],
|
allow_methods=["OPTIONS", "GET", "POST", "DELETE"],
|
||||||
allow_headers = ["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
|
||||||
# from auth.cognito.JWTBearer import JWTBearer
|
|
||||||
# from auth.auth import jwks
|
|
||||||
#
|
|
||||||
# auth = JWTBearer(jwks)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""
|
"""
|
||||||
|
|
@ -53,7 +46,6 @@ async def root():
|
||||||
"""
|
"""
|
||||||
return {"message": "Hello, World, I am alive!"}
|
return {"message": "Hello, World, I am alive!"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health_check():
|
def health_check():
|
||||||
"""
|
"""
|
||||||
|
|
@ -61,11 +53,9 @@ def health_check():
|
||||||
"""
|
"""
|
||||||
return {"status": "OK"}
|
return {"status": "OK"}
|
||||||
|
|
||||||
|
|
||||||
class Payload(BaseModel):
|
class Payload(BaseModel):
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@app.get("/datasets", response_model=list)
|
@app.get("/datasets", response_model=list)
|
||||||
async def get_datasets():
|
async def get_datasets():
|
||||||
from cognee import datasets
|
from cognee import datasets
|
||||||
|
|
@ -74,77 +64,67 @@ async def get_datasets():
|
||||||
@app.delete("/datasets/{dataset_id}", response_model=dict)
|
@app.delete("/datasets/{dataset_id}", response_model=dict)
|
||||||
async def delete_dataset(dataset_id: str):
|
async def delete_dataset(dataset_id: str):
|
||||||
from cognee import datasets
|
from cognee import datasets
|
||||||
|
|
||||||
datasets.delete_dataset(dataset_id)
|
datasets.delete_dataset(dataset_id)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = "OK",
|
content="OK",
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/graph", response_model=list)
|
@app.get("/datasets/{dataset_id}/graph", response_model=list)
|
||||||
async def get_dataset_graph(dataset_id: str):
|
async def get_dataset_graph(dataset_id: str):
|
||||||
from cognee import utils
|
from cognee import utils
|
||||||
|
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
|
||||||
graph_engine = infrastructure_config.get_config("graph_engine")
|
graph_engine = infrastructure_config.get_config()["graph_engine"]
|
||||||
graph_client = await get_graph_client(graph_engine)
|
graph_client = await get_graph_client(graph_engine)
|
||||||
|
|
||||||
graph_url = await utils.render_graph(graph_client.graph)
|
graph_url = await utils.render_graph(graph_client.graph)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = str(graph_url),
|
content=str(graph_url),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
||||||
async def get_dataset_data(dataset_id: str):
|
async def get_dataset_data(dataset_id: str):
|
||||||
from cognee import datasets
|
from cognee import datasets
|
||||||
dataset_data = datasets.list_data(dataset_id)
|
dataset_data = datasets.list_data(dataset_id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
|
||||||
|
return [
|
||||||
return [dict(
|
dict(
|
||||||
id = data["id"],
|
id=data["id"],
|
||||||
name = f"{data['name']}.{data['extension']}",
|
name=f"{data['name']}.{data['extension']}",
|
||||||
keywords = data["keywords"].split("|"),
|
keywords=data["keywords"].split("|"),
|
||||||
filePath = data["file_path"],
|
filePath=data["file_path"],
|
||||||
mimeType = data["mime_type"],
|
mimeType=data["mime_type"],
|
||||||
) for data in dataset_data]
|
)
|
||||||
|
for data in dataset_data
|
||||||
|
]
|
||||||
|
|
||||||
@app.get("/datasets/status", response_model=dict)
|
@app.get("/datasets/status", response_model=dict)
|
||||||
async def get_dataset_status(datasets: Annotated[list, Query(alias = "dataset")] = None):
|
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
|
||||||
from cognee import datasets as cognee_datasets
|
from cognee import datasets as cognee_datasets
|
||||||
datasets_statuses = cognee_datasets.get_status(datasets)
|
datasets_statuses = cognee_datasets.get_status(datasets)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = datasets_statuses
|
content=datasets_statuses
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||||
async def get_raw_data(dataset_id: str, data_id: str):
|
async def get_raw_data(dataset_id: str, data_id: str):
|
||||||
from cognee import datasets
|
from cognee import datasets
|
||||||
dataset_data = datasets.list_data(dataset_id)
|
dataset_data = datasets.list_data(dataset_id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
|
||||||
|
|
||||||
data = [data for data in dataset_data if data["id"] == data_id][0]
|
data = [data for data in dataset_data if data["id"] == data_id][0]
|
||||||
|
|
||||||
return data["file_path"]
|
return data["file_path"]
|
||||||
|
|
||||||
|
|
||||||
class AddPayload(BaseModel):
|
class AddPayload(BaseModel):
|
||||||
data: Union[str, UploadFile, List[Union[str, UploadFile]]]
|
data: Union[str, UploadFile, List[Union[str, UploadFile]]]
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True # This is required to allow the use of Union
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@app.post("/add", response_model=dict)
|
@app.post("/add", response_model=dict)
|
||||||
async def add(
|
async def add(
|
||||||
|
|
@ -153,7 +133,6 @@ async def add(
|
||||||
):
|
):
|
||||||
""" This endpoint is responsible for adding data to the graph."""
|
""" This endpoint is responsible for adding data to the graph."""
|
||||||
from cognee import add as cognee_add
|
from cognee import add as cognee_add
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(data, str) and data.startswith("http"):
|
if isinstance(data, str) and data.startswith("http"):
|
||||||
if "github" in data:
|
if "github" in data:
|
||||||
|
|
@ -182,69 +161,62 @@ async def add(
|
||||||
datasetId,
|
datasetId,
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = "OK"
|
content="OK"
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 409,
|
status_code=409,
|
||||||
content = { "error": str(error) }
|
content={"error": str(error)}
|
||||||
)
|
)
|
||||||
|
|
||||||
class CognifyPayload(BaseModel):
|
class CognifyPayload(BaseModel):
|
||||||
datasets: list[str]
|
datasets: List[str]
|
||||||
|
|
||||||
@app.post("/cognify", response_model=dict)
|
@app.post("/cognify", response_model=dict)
|
||||||
async def cognify(payload: CognifyPayload):
|
async def cognify(payload: CognifyPayload):
|
||||||
""" This endpoint is responsible for the cognitive processing of the content."""
|
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||||
from cognee import cognify as cognee_cognify
|
from cognee import cognify as cognee_cognify
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await cognee_cognify(payload.datasets)
|
await cognee_cognify(payload.datasets)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = "OK"
|
content="OK"
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 409,
|
status_code=409,
|
||||||
content = { "error": error }
|
content={"error": str(error)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchPayload(BaseModel):
|
class SearchPayload(BaseModel):
|
||||||
query_params: Dict[str, Any]
|
query_params: Dict[str, Any]
|
||||||
|
|
||||||
@app.post("/search", response_model=dict)
|
@app.post("/search", response_model=dict)
|
||||||
async def search(payload: SearchPayload):
|
async def search(payload: SearchPayload):
|
||||||
""" This endpoint is responsible for searching for nodes in the graph."""
|
""" This endpoint is responsible for searching for nodes in the graph."""
|
||||||
from cognee import search as cognee_search
|
from cognee import search as cognee_search
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_type = payload.query_params["searchType"]
|
search_type = payload.query_params["searchType"]
|
||||||
params = {
|
params = {
|
||||||
"query": payload.query_params["query"],
|
"query": payload.query_params["query"],
|
||||||
}
|
}
|
||||||
results = await cognee_search(search_type, params)
|
results = await cognee_search(search_type, params)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = json.dumps(results)
|
content=json.dumps(results)
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 409,
|
status_code=409,
|
||||||
content = { "error": error }
|
content={"error": str(error)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/settings", response_model=dict)
|
@app.get("/settings", response_model=dict)
|
||||||
async def get_settings():
|
async def get_settings():
|
||||||
from cognee.modules.settings import get_settings
|
from cognee.modules.settings import get_settings
|
||||||
return get_settings()
|
return get_settings()
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
|
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
|
||||||
model: str
|
model: str
|
||||||
|
|
@ -264,15 +236,14 @@ async def save_config(new_settings: SettingsPayload):
|
||||||
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
||||||
if new_settings.llm is not None:
|
if new_settings.llm is not None:
|
||||||
await save_llm_config(new_settings.llm)
|
await save_llm_config(new_settings.llm)
|
||||||
|
|
||||||
if new_settings.vectorDB is not None:
|
if new_settings.vectorDB is not None:
|
||||||
await save_vector_db_config(new_settings.vectorDB)
|
await save_vector_db_config(new_settings.vectorDB)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
content = "OK",
|
content="OK",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||||
"""
|
"""
|
||||||
Start the API server using uvicorn.
|
Start the API server using uvicorn.
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ import pandas as pd
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
USER_ID = "default_user"
|
USER_ID = "default_user"
|
||||||
async def add_topology(directory="example", model=GitHubRepositoryModel):
|
|
||||||
|
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
|
||||||
graph_db_type = infrastructure_config.get_config()["graph_engine"]
|
graph_db_type = infrastructure_config.get_config()["graph_engine"]
|
||||||
|
|
||||||
graph_client = await get_graph_client(graph_db_type)
|
graph_client = await get_graph_client(graph_db_type)
|
||||||
|
|
@ -16,7 +17,7 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
|
||||||
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
||||||
|
|
||||||
engine = TopologyEngine()
|
engine = TopologyEngine()
|
||||||
topology = await engine.infer_from_directory_structure(node_id =USER_ID , repository = directory, model=model)
|
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
|
||||||
|
|
||||||
def flatten_model(model: BaseModel, parent_id: Optional[str] = None) -> Dict[str, Any]:
|
def flatten_model(model: BaseModel, parent_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
"""Flatten a single Pydantic model to a dictionary handling nested structures."""
|
"""Flatten a single Pydantic model to a dictionary handling nested structures."""
|
||||||
|
|
@ -42,17 +43,16 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def flatten_repository(repo_model):
|
def flatten_repository(repo_model: BaseModel) -> List[Dict[str, Any]]:
|
||||||
""" Flatten the entire repository model, starting with the top-level model """
|
""" Flatten the entire repository model, starting with the top-level model """
|
||||||
return recursive_flatten(repo_model)
|
return recursive_flatten(repo_model)
|
||||||
|
|
||||||
flt_topology = flatten_repository(topology)
|
flt_topology = flatten_repository(topology)
|
||||||
|
|
||||||
df =pd.DataFrame(flt_topology)
|
df = pd.DataFrame(flt_topology)
|
||||||
|
|
||||||
print(df.head(10))
|
print(df.head(10))
|
||||||
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
node_data = row.to_dict()
|
node_data = row.to_dict()
|
||||||
node_id = node_data.pop('node_id')
|
node_id = node_data.pop('node_id')
|
||||||
|
|
@ -65,9 +65,10 @@ async def add_topology(directory="example", model=GitHubRepositoryModel):
|
||||||
return graph_client.graph
|
return graph_client.graph
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
async def test():
|
async def test() -> None:
|
||||||
|
# Uncomment and modify the following lines as needed
|
||||||
# await prune.prune_system()
|
# await prune.prune_system()
|
||||||
# #
|
#
|
||||||
# from cognee.api.v1.add import add
|
# from cognee.api.v1.add import add
|
||||||
# data_directory_path = os.path.abspath("../../../.data")
|
# data_directory_path = os.path.abspath("../../../.data")
|
||||||
# # print(data_directory_path)
|
# # print(data_directory_path)
|
||||||
|
|
@ -75,7 +76,7 @@ if __name__ == "__main__":
|
||||||
# # cognee_directory_path = os.path.abspath("../.cognee_system")
|
# # cognee_directory_path = os.path.abspath("../.cognee_system")
|
||||||
# # config.system_root_directory(cognee_directory_path)
|
# # config.system_root_directory(cognee_directory_path)
|
||||||
#
|
#
|
||||||
# await add("data://" +data_directory_path, "example")
|
# await add("data://" + data_directory_path, "example")
|
||||||
|
|
||||||
# graph = await add_topology()
|
# graph = await add_topology()
|
||||||
|
|
||||||
|
|
@ -88,4 +89,4 @@ if __name__ == "__main__":
|
||||||
await render_graph(graph_client.graph, include_color=True, include_nodes=False, include_size=False)
|
await render_graph(graph_client.graph, include_color=True, include_nodes=False, include_size=False)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from pydantic import BaseModel, create_model
|
|
||||||
from typing import Dict, Type, Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Type, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from cognee import config
|
from cognee import config
|
||||||
|
|
@ -13,23 +9,6 @@ from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.modules.topology.infer_data_topology import infer_data_topology
|
from cognee.modules.topology.infer_data_topology import infer_data_topology
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# class UserLocation(BaseModel):
|
|
||||||
# location_id: str
|
|
||||||
# description: str
|
|
||||||
# default_relationship: Relationship = Relationship(type = "located_in")
|
|
||||||
#
|
|
||||||
# class UserProperties(BaseModel):
|
|
||||||
# custom_properties: Optional[Dict[str, Any]] = None
|
|
||||||
# location: Optional[UserLocation] = None
|
|
||||||
#
|
|
||||||
# class DefaultGraphModel(BaseModel):
|
|
||||||
# node_id: str
|
|
||||||
# user_properties: UserProperties = UserProperties()
|
|
||||||
# documents: List[Document] = []
|
|
||||||
# default_fields: Optional[Dict[str, Any]] = {}
|
|
||||||
# default_relationship: Relationship = Relationship(type = "has_properties")
|
|
||||||
#
|
|
||||||
class Relationship(BaseModel):
|
class Relationship(BaseModel):
|
||||||
type: str = Field(..., description="The type of relationship, e.g., 'belongs_to'.")
|
type: str = Field(..., description="The type of relationship, e.g., 'belongs_to'.")
|
||||||
source: Optional[str] = Field(None, description="The identifier of the source id of in the relationship being a directory or subdirectory")
|
source: Optional[str] = Field(None, description="The identifier of the source id of in the relationship being a directory or subdirectory")
|
||||||
|
|
@ -37,7 +16,6 @@ class Relationship(BaseModel):
|
||||||
properties: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional properties and values related to the relationship.")
|
properties: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional properties and values related to the relationship.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
title: str
|
title: str
|
||||||
|
|
@ -53,8 +31,10 @@ class DirectoryModel(BaseModel):
|
||||||
subdirectories: List['DirectoryModel'] = []
|
subdirectories: List['DirectoryModel'] = []
|
||||||
default_relationship: Relationship
|
default_relationship: Relationship
|
||||||
|
|
||||||
|
|
||||||
DirectoryModel.update_forward_refs()
|
DirectoryModel.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class DirMetadata(BaseModel):
|
class DirMetadata(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
summary: str
|
summary: str
|
||||||
|
|
@ -64,6 +44,7 @@ class DirMetadata(BaseModel):
|
||||||
documents: List[Document] = []
|
documents: List[Document] = []
|
||||||
default_relationship: Relationship
|
default_relationship: Relationship
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepositoryModel(BaseModel):
|
class GitHubRepositoryModel(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
metadata: DirMetadata
|
metadata: DirMetadata
|
||||||
|
|
@ -71,10 +52,10 @@ class GitHubRepositoryModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TopologyEngine:
|
class TopologyEngine:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.models: Dict[str, Type[BaseModel]] = {}
|
self.models: Dict[str, Type[BaseModel]] = {}
|
||||||
|
|
||||||
async def populate_model(self, directory_path, file_structure, parent_id=None):
|
async def populate_model(self, directory_path: str, file_structure: Dict[str, Union[Dict, Tuple[str, ...]]], parent_id: Optional[str] = None) -> DirectoryModel:
|
||||||
directory_id = os.path.basename(directory_path) or "root"
|
directory_id = os.path.basename(directory_path) or "root"
|
||||||
directory = DirectoryModel(
|
directory = DirectoryModel(
|
||||||
node_id=directory_id,
|
node_id=directory_id,
|
||||||
|
|
@ -100,18 +81,17 @@ class TopologyEngine:
|
||||||
|
|
||||||
return directory
|
return directory
|
||||||
|
|
||||||
async def infer_from_directory_structure(self, node_id:str, repository: str, model):
|
async def infer_from_directory_structure(self, node_id: str, repository: str, model: Type[BaseModel]) -> GitHubRepositoryModel:
|
||||||
""" Infer the topology of a repository from its file structure """
|
""" Infer the topology of a repository from its file structure """
|
||||||
|
|
||||||
path = infrastructure_config.get_config()["data_root_directory"]
|
path = infrastructure_config.get_config()["data_root_directory"]
|
||||||
|
path = path + "/" + str(repository)
|
||||||
path = path +"/"+ str(repository)
|
|
||||||
print(path)
|
print(path)
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise FileNotFoundError(f"No such directory: {path}")
|
raise FileNotFoundError(f"No such directory: {path}")
|
||||||
|
|
||||||
root = {}
|
root: Dict[str, Union[Dict, Tuple[str, ...]]] = {}
|
||||||
for filename in glob.glob(f"{path}/**", recursive=True):
|
for filename in glob.glob(f"{path}/**", recursive=True):
|
||||||
parts = os.path.relpath(filename, start=path).split(os.path.sep)
|
parts = os.path.relpath(filename, start=path).split(os.path.sep)
|
||||||
current = root
|
current = root
|
||||||
|
|
@ -128,8 +108,6 @@ class TopologyEngine:
|
||||||
|
|
||||||
root_directory = await self.populate_model('/', root)
|
root_directory = await self.populate_model('/', root)
|
||||||
|
|
||||||
# repository_metadata = await infer_data_topology(str(root), DirMetadata)
|
|
||||||
|
|
||||||
repository_metadata = DirMetadata(
|
repository_metadata = DirMetadata(
|
||||||
node_id="repo1",
|
node_id="repo1",
|
||||||
summary="Example repository",
|
summary="Example repository",
|
||||||
|
|
@ -147,13 +125,10 @@ class TopologyEngine:
|
||||||
|
|
||||||
return active_model
|
return active_model
|
||||||
|
|
||||||
# print(github_repo_model)
|
def load(self, model_name: str) -> Optional[Type[BaseModel]]:
|
||||||
|
|
||||||
|
|
||||||
def load(self, model_name: str):
|
|
||||||
return self.models.get(model_name)
|
return self.models.get(model_name)
|
||||||
|
|
||||||
def extrapolate(self, model_name: str):
|
def extrapolate(self, model_name: str) -> None:
|
||||||
# This method would be implementation-specific depending on what "extrapolate" means
|
# This method would be implementation-specific depending on what "extrapolate" means
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -164,15 +139,16 @@ if __name__ == "__main__":
|
||||||
config.data_root_directory(data_directory_path)
|
config.data_root_directory(data_directory_path)
|
||||||
cognee_directory_path = os.path.abspath("../.cognee_system")
|
cognee_directory_path = os.path.abspath("../.cognee_system")
|
||||||
config.system_root_directory(cognee_directory_path)
|
config.system_root_directory(cognee_directory_path)
|
||||||
async def main():
|
|
||||||
|
async def main() -> None:
|
||||||
engine = TopologyEngine()
|
engine = TopologyEngine()
|
||||||
# model = engine.load("GitHubRepositoryModel")
|
# model = engine.load("GitHubRepositoryModel")
|
||||||
# if model is None:
|
# if model is None:
|
||||||
# raise ValueError("Model not found")
|
# raise ValueError("Model not found")
|
||||||
result = await engine.infer("example")
|
result = await engine.infer_from_directory_structure("example_node_id", "example_repo", GitHubRepositoryModel)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
# result = engine.extrapolate("GitHubRepositoryModel")
|
# result = engine.extrapolate("GitHubRepositoryModel")
|
||||||
# print(result)
|
# print(result)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue