feat: improve API request and response models and docs (#154)

* feat: improve API request and response models and docs
This commit is contained in:
Boris 2024-10-14 13:38:36 +02:00 committed by GitHub
parent a29acb6ab7
commit 1eb4429c5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 236 additions and 154 deletions

View file

@ -57,10 +57,10 @@ To use different LLM providers, for more info check out our <a href="https://top
If you are using Networkx, create an account on Graphistry to visualize results:
```
cognee.config.set_graphistry_config({
"username": "YOUR_USERNAME",
"password": "YOUR_PASSWORD"
})
cognee.config.set_graphistry_config({
"username": "YOUR_USERNAME",
"password": "YOUR_PASSWORD"
})
```
(Optional) To run the UI, go to cognee-frontend directory and run:

View file

@ -82,7 +82,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
},
body: JSON.stringify({
llm: newLLMConfig,
vectorDB: newVectorConfig,
vectorDb: newVectorConfig,
}),
})
.then(() => {
@ -145,7 +145,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
settings.llm.model = settings.llm.models[settings.llm.provider.value][0];
}
setLLMConfig(settings.llm);
setVectorDBConfig(settings.vectorDB);
setVectorDBConfig(settings.vectorDb);
};
fetchConfig();
}, []);

15
cognee/api/DTO.py Normal file
View file

@ -0,0 +1,15 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel, to_snake
class OutDTO(BaseModel):
model_config = ConfigDict(
alias_generator = to_camel,
populate_by_name = True,
)
class InDTO(BaseModel):
model_config = ConfigDict(
alias_generator = to_camel,
populate_by_name = True,
)

View file

@ -1,18 +1,23 @@
""" FastAPI server for the Cognee API. """
from datetime import datetime
import os
from uuid import UUID
import aiohttp
import uvicorn
import logging
import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal
from typing import List, Union, Optional, Literal
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query, Depends
from fastapi.responses import JSONResponse, FileResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from cognee.api.DTO import InDTO, OutDTO
from cognee.api.v1.search import SearchType
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.modules.pipelines.models import PipelineRunStatus
# Set up logging
@ -124,6 +129,7 @@ async def root():
"""
return {"message": "Hello, World, I am alive!"}
@app.get("/health")
def health_check():
"""
@ -131,41 +137,46 @@ def health_check():
"""
return Response(status_code = 200)
@app.get("/api/v1/datasets", response_model = list)
class ErrorResponseDTO(BaseModel):
message: str
class DatasetDTO(OutDTO):
id: UUID
name: str
created_at: datetime
updated_at: Optional[datetime]
owner_id: UUID
@app.get("/api/v1/datasets", response_model = list[DatasetDTO])
async def get_datasets(user: User = Depends(get_authenticated_user)):
try:
from cognee.modules.data.methods import get_datasets
datasets = await get_datasets(user.id)
return JSONResponse(
status_code = 200,
content = [dataset.to_json() for dataset in datasets],
)
return datasets
except Exception as error:
logger.error(f"Error retrieving datasets: {str(error)}")
raise HTTPException(status_code = 500, detail = f"Error retrieving datasets: {str(error)}") from error
@app.delete("/api/v1/datasets/{dataset_id}", response_model = dict)
@app.delete("/api/v1/datasets/{dataset_id}", response_model = None, responses = { 404: { "model": ErrorResponseDTO }})
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset, delete_dataset
dataset = get_dataset(user.id, dataset_id)
dataset = await get_dataset(user.id, dataset_id)
if dataset is None:
return JSONResponse(
raise HTTPException(
status_code = 404,
content = {
"detail": f"Dataset ({dataset_id}) not found."
}
detail = f"Dataset ({dataset_id}) not found."
)
await delete_dataset(dataset)
return JSONResponse(
status_code = 200,
content = "OK",
)
@app.get("/api/v1/datasets/{dataset_id}/graph", response_model=list)
@app.get("/api/v1/datasets/{dataset_id}/graph", response_model = str)
async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.shared.utils import render_graph
from cognee.infrastructure.databases.graph import get_graph_engine
@ -184,7 +195,17 @@ async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authentica
content = "Graphistry credentials are not set. Please set them in your .env file.",
)
@app.get("/api/v1/datasets/{dataset_id}/data", response_model=list)
class DataDTO(OutDTO):
id: UUID
name: str
created_at: datetime
updated_at: Optional[datetime]
extension: str
mime_type: str
raw_data_location: str
@app.get("/api/v1/datasets/{dataset_id}/data", response_model = list[DataDTO], responses = { 404: { "model": ErrorResponseDTO }})
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset_data, get_dataset
@ -193,38 +214,33 @@ async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticat
if dataset is None:
return JSONResponse(
status_code = 404,
content = {
"detail": f"Dataset ({dataset_id}) not found."
}
content = ErrorResponseDTO(f"Dataset ({dataset_id}) not found."),
)
dataset_data = await get_dataset_data(dataset_id = dataset.id)
if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset.id}) not found.")
return []
return [
data.to_json() for data in dataset_data
]
return dataset_data
@app.get("/api/v1/datasets/status", response_model=dict)
@app.get("/api/v1/datasets/status", response_model = dict[str, PipelineRunStatus])
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, user: User = Depends(get_authenticated_user)):
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
try:
datasets_statuses = await cognee_datasets.get_status(datasets)
return JSONResponse(
status_code = 200,
content = datasets_statuses,
)
return datasets_statuses
except Exception as error:
return JSONResponse(
status_code = 409,
content = {"error": str(error)}
)
@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class = FileResponse)
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset, get_dataset_data
@ -255,13 +271,8 @@ async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_a
return data.raw_data_location
class AddPayload(BaseModel):
data: Union[str, UploadFile, List[Union[str, UploadFile]]]
dataset_id: str
class Config:
arbitrary_types_allowed = True
@app.post("/api/v1/add", response_model=dict)
@app.post("/api/v1/add", response_model = None)
async def add(
data: List[UploadFile],
datasetId: str = Form(...),
@ -297,90 +308,89 @@ async def add(
datasetId,
user = user,
)
return JSONResponse(
status_code = 200,
content = {
"message": "OK"
}
)
except Exception as error:
return JSONResponse(
status_code = 409,
content = {"error": str(error)}
)
class CognifyPayload(BaseModel):
class CognifyPayloadDTO(BaseModel):
datasets: List[str]
@app.post("/api/v1/cognify", response_model=dict)
async def cognify(payload: CognifyPayload, user: User = Depends(get_authenticated_user)):
@app.post("/api/v1/cognify", response_model = None)
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets, user)
return JSONResponse(
status_code = 200,
content = {
"message": "OK"
}
)
except Exception as error:
return JSONResponse(
status_code = 409,
content = {"error": str(error)}
)
class SearchPayload(BaseModel):
searchType: SearchType
class SearchPayloadDTO(InDTO):
search_type: SearchType
query: str
@app.post("/api/v1/search", response_model=list)
async def search(payload: SearchPayload, user: User = Depends(get_authenticated_user)):
@app.post("/api/v1/search", response_model = list)
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for searching for nodes in the graph."""
from cognee.api.v1.search import search as cognee_search
try:
results = await cognee_search(payload.searchType, payload.query, user)
return JSONResponse(
status_code = 200,
content = results,
)
try:
results = await cognee_search(payload.search_type, payload.query, user)
return results
except Exception as error:
return JSONResponse(
status_code = 409,
content = {"error": str(error)}
)
@app.get("/api/v1/settings", response_model=dict)
from cognee.modules.settings.get_settings import LLMConfig, VectorDBConfig
class LLMConfigDTO(OutDTO, LLMConfig):
pass
class VectorDBConfigDTO(OutDTO, VectorDBConfig):
pass
class SettingsDTO(OutDTO):
llm: LLMConfigDTO
vector_db: VectorDBConfigDTO
@app.get("/api/v1/settings", response_model = SettingsDTO)
async def get_settings(user: User = Depends(get_authenticated_user)):
from cognee.modules.settings import get_settings as get_cognee_settings
return get_cognee_settings()
class LLMConfig(BaseModel):
class LLMConfigDTO(InDTO):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str
apiKey: str
api_key: str
class VectorDBConfig(BaseModel):
class VectorDBConfigDTO(InDTO):
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
url: str
apiKey: str
api_key: str
class SettingsPayload(BaseModel):
llm: Optional[LLMConfig] = None
vectorDB: Optional[VectorDBConfig] = None
class SettingsPayloadDTO(InDTO):
llm: Optional[LLMConfigDTO] = None
vector_db: Optional[VectorDBConfigDTO] = None
@app.post("/api/v1/settings", response_model=dict)
async def save_config(new_settings: SettingsPayload, user: User = Depends(get_authenticated_user)):
@app.post("/api/v1/settings", response_model = None)
async def save_settings(new_settings: SettingsPayloadDTO, user: User = Depends(get_authenticated_user)):
from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None:
await save_llm_config(new_settings.llm)
if new_settings.vectorDB is not None:
await save_vector_db_config(new_settings.vectorDB)
return JSONResponse(
status_code=200,
content="OK",
)
if new_settings.vector_db is not None:
await save_vector_db_config(new_settings.vector_db)
def start_api_server(host: str = "0.0.0.0", port: int = 8000):

View file

@ -12,6 +12,7 @@ from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks import chunk_naive_llm_classifier, \
@ -75,11 +76,11 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
async with update_status_lock:
task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == "DATASET_PROCESSING_STARTED":
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
logger.info("Dataset %s is already being processed.", dataset_name)
return
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_STARTED", {
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
@ -120,14 +121,14 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_COMPLETED", {
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
except Exception as error:
send_telemetry("cognee.cognify EXECUTION ERRORED", user.id)
await log_pipeline_status(dataset_id, "DATASET_PROCESSING_ERRORED", {
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})

View file

@ -18,10 +18,10 @@ class LLMConfig(BaseSettings):
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"apiKey": self.llm_api_key,
"api_key": self.llm_api_key,
"temperature": self.llm_temperature,
"streaming": self.llm_stream,
"transcriptionModel": self.transcription_model
"streaming": self.llm_streaming,
"transcription_model": self.transcription_model
}
@lru_cache

View file

@ -1,8 +1,14 @@
import enum
from uuid import uuid4
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime, String, JSON
from sqlalchemy import Column, DateTime, JSON, Enum
from cognee.infrastructure.databases.relational import Base, UUID
class PipelineRunStatus(enum.Enum):
DATASET_PROCESSING_STARTED = "DATASET_PROCESSING_STARTED"
DATASET_PROCESSING_COMPLETED = "DATASET_PROCESSING_COMPLETED"
DATASET_PROCESSING_ERRORED = "DATASET_PROCESSING_ERRORED"
class PipelineRun(Base):
__tablename__ = "pipeline_runs"
@ -10,7 +16,7 @@ class PipelineRun(Base):
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
status = Column(String)
status = Column(Enum(PipelineRunStatus))
run_id = Column(UUID, index = True)
run_info = Column(JSON)

View file

@ -1 +1 @@
from .PipelineRun import PipelineRun
from .PipelineRun import PipelineRun, PipelineRunStatus

View file

@ -1,3 +1,3 @@
from .get_settings import get_settings
from .get_settings import get_settings, SettingsDict
from .save_llm_config import save_llm_config
from .save_vector_db_config import save_vector_db_config

View file

@ -1,7 +1,35 @@
from enum import Enum
from pydantic import BaseModel
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.llm import get_llm_config
def get_settings():
class ConfigChoice(BaseModel):
value: str
label: str
class ModelName(Enum):
openai = "openai"
ollama = "ollama"
anthropic = "anthropic"
class LLMConfig(BaseModel):
api_key: str
model: ConfigChoice
provider: ConfigChoice
models: dict[str, list[ConfigChoice]]
providers: list[ConfigChoice]
class VectorDBConfig(BaseModel):
api_key: str
url: str
provider: ConfigChoice
providers: list[ConfigChoice]
class SettingsDict(BaseModel):
llm: LLMConfig
vector_db: VectorDBConfig
def get_settings() -> SettingsDict:
llm_config = get_llm_config()
vector_dbs = [{
@ -28,9 +56,7 @@ def get_settings():
"label": "Anthropic",
}]
llm_config = get_llm_config()
return dict(
return SettingsDict.model_validate(dict(
llm = {
"provider": {
"label": llm_config.llm_provider,
@ -40,7 +66,7 @@ def get_settings():
"value": llm_config.llm_model,
"label": llm_config.llm_model,
} if llm_config.llm_model else None,
"apiKey": (llm_config.llm_api_key[:-10] + "**********") if llm_config.llm_api_key else None,
"api_key": (llm_config.llm_api_key[:-10] + "**********") if llm_config.llm_api_key else None,
"providers": llm_providers,
"models": {
"openai": [{
@ -72,13 +98,13 @@ def get_settings():
}]
},
},
vectorDB = {
vector_db = {
"provider": {
"label": vector_config.vector_engine_provider,
"value": vector_config.vector_engine_provider.lower(),
},
"url": vector_config.vector_db_url,
"apiKey": vector_config.vector_db_key,
"options": vector_dbs,
"api_key": vector_config.vector_db_key,
"providers": vector_dbs,
},
)
))

View file

@ -2,7 +2,7 @@ from pydantic import BaseModel
from cognee.infrastructure.llm import get_llm_config
class LLMConfig(BaseModel):
apiKey: str
api_key: str
model: str
provider: str
@ -12,5 +12,5 @@ async def save_llm_config(new_llm_config: LLMConfig):
llm_config.llm_provider = new_llm_config.provider
llm_config.llm_model = new_llm_config.model
if "*****" not in new_llm_config.apiKey and len(new_llm_config.apiKey.strip()) > 0:
llm_config.llm_api_key = new_llm_config.apiKey
if "*****" not in new_llm_config.api_key and len(new_llm_config.api_key.strip()) > 0:
llm_config.llm_api_key = new_llm_config.api_key

View file

@ -4,12 +4,12 @@ from cognee.infrastructure.databases.vector import get_vectordb_config
class VectorDBConfig(BaseModel):
url: str
apiKey: str
api_key: str
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
async def save_vector_db_config(vector_db_config: VectorDBConfig):
vector_config = get_vectordb_config()
vector_config.vector_db_url = vector_db_config.url
vector_config.vector_db_key = vector_db_config.apiKey
vector_config.vector_db_key = vector_db_config.api_key
vector_config.vector_engine_provider = vector_db_config.provider

View file

@ -23,7 +23,10 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"message": "Hello, World, I am alive!"
"status": 200,
"body": {
"message": "Hello, World, I am alive!"
}
}
```
@ -37,7 +40,7 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"status": "OK"
"status": 200
}
```
@ -50,15 +53,18 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
[
{
"id": "dataset_id_1",
"name": "Dataset Name 1",
"description": "Description of Dataset 1",
{
"status": 200,
"body": [
{
"id": "dataset_id_1",
"name": "Dataset Name 1",
"description": "Description of Dataset 1",
...
},
...
},
...
]
]
}
```
### 4. Delete Dataset
@ -74,7 +80,7 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"status": "OK"
"status": 200
}
```
@ -105,14 +111,17 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
[
{
"data_id": "data_id_1",
"content": "Data content here",
{
"status": 200,
"body": [
{
"data_id": "data_id_1",
"content": "Data content here",
...
},
...
},
...
]
]
}
```
### 7. Get Dataset Status
@ -128,9 +137,12 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"dataset_id_1": "Status 1",
"dataset_id_2": "Status 2",
...
"status": 200,
"body": {
"dataset_id_1": "Status 1",
"dataset_id_2": "Status 2",
...
}
}
```
@ -169,7 +181,7 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"message": "OK"
"status": 200
}
```
@ -190,7 +202,7 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"message": "OK"
"status": 200
}
```
@ -204,7 +216,7 @@ The base URL for all API requests is determined by the server's deployment envir
**Request Body**:
```json
{
"searchType": "INSIGHTS", # Or "SUMMARIES" or "CHUNKS"
"searchType": "INSIGHTS", // Or "SUMMARIES" or "CHUNKS"
"query": "QUERY_TO_MATCH_DATA"
}
```
@ -213,31 +225,40 @@ The base URL for all API requests is determined by the server's deployment envir
For "INSIGHTS" search type:
```json
[[
{ "name" "source_node_name" },
{ "relationship_name" "between_nodes_relationship_name" },
{ "name" "target_node_name" },
]]
{
"status": 200,
"body": [[
{ "name" "source_node_name" },
{ "relationship_name" "between_nodes_relationship_name" },
{ "name" "target_node_name" },
]]
}
```
For "SUMMARIES" search type:
```json
[
{ "text" "summary_text" },
{ "text" "summary_text" },
{ "text" "summary_text" },
...
]
{
"status": 200,
"body": [
{ "text" "summary_text" },
{ "text" "summary_text" },
{ "text" "summary_text" },
...
]
}
```
For "CHUNKS" search type:
```json
[
{ "text" "chunk_text" },
{ "text" "chunk_text" },
{ "text" "chunk_text" },
...
]
{
"status": 200,
"body": [
{ "text" "chunk_text" },
{ "text" "chunk_text" },
{ "text" "chunk_text" },
...
]
}
```
### 12. Get Settings
@ -250,9 +271,12 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"llm": {...},
"vectorDB": {...},
...
"status": 200,
"body": {
"llm": {...},
"vectorDB": {...},
...
}
}
```
@ -270,6 +294,6 @@ The base URL for all API requests is determined by the server's deployment envir
**Response**:
```json
{
"status": "OK"
"status": 200
}
```