fix: Secure api v2 (#1060)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
8198118baa
commit
bcd418151a
5 changed files with 38 additions and 12 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Form, UploadFile, Depends
|
from fastapi import Form, UploadFile, Depends
|
||||||
|
|
@ -31,8 +32,11 @@ def get_add_router() -> APIRouter:
|
||||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Add check if HTTP Requests are enabled before allowing requests and git clone
|
if (
|
||||||
if isinstance(data, str) and data.startswith("http"):
|
isinstance(data, str)
|
||||||
|
and data.startswith("http")
|
||||||
|
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
||||||
|
):
|
||||||
if "github" in data:
|
if "github" in data:
|
||||||
# Perform git clone if the URL is from GitHub
|
# Perform git clone if the URL is from GitHub
|
||||||
repo_name = data.split("/")[-1].replace(".git", "")
|
repo_name = data.split("/")[-1].replace(".git", "")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse, FileResponse
|
||||||
|
|
||||||
from cognee.api.DTO import InDTO, OutDTO
|
from cognee.api.DTO import InDTO, OutDTO
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||||
from cognee.modules.data.methods import create_dataset, get_datasets_by_name
|
from cognee.modules.data.methods import create_dataset, get_datasets_by_name
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.api.v1.delete.exceptions import DataNotFoundError, DatasetNotFoundError
|
from cognee.api.v1.delete.exceptions import DataNotFoundError, DatasetNotFoundError
|
||||||
|
|
@ -177,7 +178,8 @@ def get_datasets_router() -> APIRouter:
|
||||||
async def get_dataset_data(dataset_id: UUID, user: User = Depends(get_authenticated_user)):
|
async def get_dataset_data(dataset_id: UUID, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.modules.data.methods import get_dataset_data, get_dataset
|
from cognee.modules.data.methods import get_dataset_data, get_dataset
|
||||||
|
|
||||||
dataset = await get_dataset(user.id, dataset_id)
|
# Verify user has permission to read dataset
|
||||||
|
dataset = await get_authorized_existing_datasets([dataset_id], "read", user)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -185,7 +187,7 @@ def get_datasets_router() -> APIRouter:
|
||||||
content=ErrorResponseDTO(f"Dataset ({str(dataset_id)}) not found."),
|
content=ErrorResponseDTO(f"Dataset ({str(dataset_id)}) not found."),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_data = await get_dataset_data(dataset_id=dataset.id)
|
dataset_data = await get_dataset_data(dataset_id=dataset[0].id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -200,6 +202,9 @@ def get_datasets_router() -> APIRouter:
|
||||||
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Verify user has permission to read dataset
|
||||||
|
await get_authorized_existing_datasets(datasets, "read", user)
|
||||||
|
|
||||||
datasets_statuses = await cognee_datasets.get_status(datasets)
|
datasets_statuses = await cognee_datasets.get_status(datasets)
|
||||||
|
|
||||||
return datasets_statuses
|
return datasets_statuses
|
||||||
|
|
@ -211,16 +216,17 @@ def get_datasets_router() -> APIRouter:
|
||||||
dataset_id: UUID, data_id: UUID, user: User = Depends(get_authenticated_user)
|
dataset_id: UUID, data_id: UUID, user: User = Depends(get_authenticated_user)
|
||||||
):
|
):
|
||||||
from cognee.modules.data.methods import get_data
|
from cognee.modules.data.methods import get_data
|
||||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
from cognee.modules.data.methods import get_dataset_data
|
||||||
|
|
||||||
dataset = await get_dataset(user.id, dataset_id)
|
# Verify user has permission to read dataset
|
||||||
|
dataset = await get_authorized_existing_datasets([dataset_id], "read", user)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=404, content={"detail": f"Dataset ({dataset_id}) not found."}
|
status_code=404, content={"detail": f"Dataset ({dataset_id}) not found."}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_data = await get_dataset_data(dataset.id)
|
dataset_data = await get_dataset_data(dataset[0].id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise DataNotFoundError(message=f"No data found in dataset ({dataset_id}).")
|
raise DataNotFoundError(message=f"No data found in dataset ({dataset_id}).")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from fastapi import Form, UploadFile, Depends
|
from fastapi import Form, UploadFile, Depends
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
@ -37,8 +38,9 @@ def get_delete_router() -> APIRouter:
|
||||||
# Handle each file in the list
|
# Handle each file in the list
|
||||||
results = []
|
results = []
|
||||||
for file in data:
|
for file in data:
|
||||||
# TODO: Add check if HTTP Requests are enabled before allowing requests and git clone
|
if file.filename.startswith("http") and (
|
||||||
if file.filename.startswith("http"):
|
os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true"
|
||||||
|
):
|
||||||
if "github" in file.filename:
|
if "github" in file.filename:
|
||||||
# For GitHub repos, we need to get the content hash of each file
|
# For GitHub repos, we need to get the content hash of each file
|
||||||
repo_name = file.filename.split("/")[-1].replace(".git", "")
|
repo_name = file.filename.split("/")[-1].replace(".git", "")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
from uuid import UUID
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -9,11 +15,17 @@ def get_visualize_router() -> APIRouter:
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("", response_model=None)
|
@router.get("", response_model=None)
|
||||||
async def visualize():
|
async def visualize(dataset_id: UUID, user: User = Depends(get_authenticated_user)):
|
||||||
"""This endpoint is responsible for adding data to the graph."""
|
"""This endpoint is responsible for adding data to the graph."""
|
||||||
from cognee.api.v1.visualize import visualize_graph
|
from cognee.api.v1.visualize import visualize_graph
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Verify user has permission to read dataset
|
||||||
|
dataset = await get_authorized_existing_datasets([dataset_id], "read", user)
|
||||||
|
|
||||||
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||||
|
await set_database_global_context_variables(dataset[0].id, dataset[0].owner_id)
|
||||||
|
|
||||||
html_visualization = await visualize_graph()
|
html_visualization = await visualize_graph()
|
||||||
return HTMLResponse(html_visualization)
|
return HTMLResponse(html_visualization)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from typing import Union, BinaryIO, Any
|
from typing import Union, BinaryIO, Any
|
||||||
|
|
||||||
from cognee.modules.ingestion.exceptions import IngestionError
|
from cognee.modules.ingestion.exceptions import IngestionError
|
||||||
|
|
@ -20,7 +21,8 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any], datase
|
||||||
file_path = data_item
|
file_path = data_item
|
||||||
# data is a file path
|
# data is a file path
|
||||||
elif data_item.startswith("file://") or data_item.startswith("/"):
|
elif data_item.startswith("file://") or data_item.startswith("/"):
|
||||||
# TODO: Add check if ACCEPT_LOCAL_FILE_PATH is enabled, if it's not raise an error
|
if os.getenv("ACCEPT_LOCAL_FILE_PATH", "true").lower() == "false":
|
||||||
|
raise IngestionError(message="Local files are not accepted.")
|
||||||
file_path = data_item.replace("file://", "")
|
file_path = data_item.replace("file://", "")
|
||||||
# data is text
|
# data is text
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue