feat: user authentication in routes (#133)

* feat: require logged in user in routes
This commit is contained in:
Boris 2024-09-08 21:12:49 +02:00 committed by GitHub
parent 22c0dd5b2d
commit e1a0b55a21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 526 additions and 170 deletions

View file

@ -0,0 +1,16 @@
.main {
display: flex;
flex-direction: row;
flex-direction: column;
padding: 0;
min-height: 100vh;
}
.authContainer {
flex: 1;
display: flex;
padding: 24px 0;
margin: 0 auto;
max-width: 440px;
width: 100%;
}

View file

@ -0,0 +1,29 @@
import { Spacer, Stack, Text } from 'ohmy-ui';
import { TextLogo } from '@/ui/App';
import Footer from '@/ui/Partials/Footer/Footer';
import styles from './AuthPage.module.css';
import { Divider } from '@/ui/Layout';
import SignInForm from '@/ui/Partials/SignInForm/SignInForm';
export default function AuthPage() {
return (
<main className={styles.main}>
<Spacer inset vertical="1" horizontal="2">
<Stack orientation="horizontal" gap="between" align="center">
<TextLogo width={225} height={64} />
</Stack>
</Spacer>
<Divider />
<div className={styles.authContainer}>
<Stack gap="4" style={{ width: '100%' }}>
<h1><Text size="large">Sign in</Text></h1>
<SignInForm />
</Stack>
</div>
<Spacer inset horizontal="3" wrap>
<Footer />
</Spacer>
</main>
)
}

View file

@ -0,0 +1 @@
export { default } from './AuthPage';

View file

@ -1,5 +1,7 @@
import { fetch } from '@/utils';
export default function cognifyDataset(dataset: { id: string, name: string }) {
return fetch('http://127.0.0.1:8000/cognify', {
return fetch('/v1/cognify', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -1,5 +1,7 @@
import { fetch } from '@/utils';
export default function deleteDataset(dataset: { id: string }) {
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}`, {
return fetch(`/v1/datasets/${dataset.id}`, {
method: 'DELETE',
})
}

View file

@ -1,4 +1,6 @@
import { fetch } from '@/utils';
export default function getDatasetData(dataset: { id: string }) {
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/data`)
return fetch(`/v1/datasets/${dataset.id}/data`)
.then((response) => response.json());
}

View file

@ -1,5 +1,7 @@
import { fetch } from '@/utils';
export default function getExplorationGraphUrl(dataset: { id: string }) {
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/graph`)
return fetch(`/v1/datasets/${dataset.id}/graph`)
.then(async (response) => {
if (response.status !== 200) {
throw new Error((await response.text()).replaceAll("\"", ""));

View file

@ -7,8 +7,9 @@ import {
UploadInput,
CloseIcon,
} from "ohmy-ui";
import styles from "./DataView.module.css";
import { fetch } from '@/utils';
import RawDataPreview from './RawDataPreview';
import styles from "./DataView.module.css";
export interface Data {
id: string;
@ -37,7 +38,7 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
const showRawData = useCallback((dataItem: Data) => {
setSelectedData(dataItem);
fetch(`http://127.0.0.1:8000/datasets/${datasetId}/data/${dataItem.id}/raw`)
fetch(`/v1/datasets/${datasetId}/data/${dataItem.id}/raw`)
.then((response) => response.arrayBuffer())
.then(setRawData);

View file

@ -1,3 +1,5 @@
import { fetch } from '@/utils';
export default function addData(dataset: { id: string }, files: File[]) {
const formData = new FormData();
files.forEach((file) => {
@ -5,7 +7,7 @@ export default function addData(dataset: { id: string }, files: File[]) {
})
formData.append('datasetId', dataset.id);
return fetch('http://127.0.0.1:8000/add', {
return fetch('/v1/add', {
method: 'POST',
body: formData,
}).then((response) => response.json());

View file

@ -1,6 +1,7 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { v4 } from 'uuid';
import { DataFile } from './useData';
import { fetch } from '@/utils';
export interface Dataset {
id: string;
@ -14,7 +15,14 @@ function useDatasets() {
const statusTimeout = useRef<any>(null);
const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
fetch(`http://127.0.0.1:8000/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`)
fetch(
`/v1/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`,
{
headers: {
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
},
},
)
.then((response) => response.json())
.then((statuses) => setDatasets(
(datasets) => (
@ -65,7 +73,11 @@ function useDatasets() {
}, []);
const fetchDatasets = useCallback(() => {
fetch('http://127.0.0.1:8000/datasets')
fetch('/v1/datasets', {
headers: {
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
},
})
.then((response) => response.json())
.then((datasets) => {
setDatasets(datasets);
@ -75,6 +87,9 @@ function useDatasets() {
} else {
window.location.href = '/wizard';
}
})
.catch((error) => {
console.error('Error fetching datasets:', error);
});
}, [checkDatasetStatuses]);

View file

@ -2,6 +2,7 @@ import { v4 } from 'uuid';
import classNames from 'classnames';
import { useCallback, useState } from 'react';
import { CTAButton, Stack, Text, DropdownSelect, TextArea, useBoolean } from 'ohmy-ui';
import { fetch } from '@/utils';
import styles from './SearchView.module.css';
interface Message {
@ -50,7 +51,7 @@ export default function SearchView() {
},
]);
fetch('http://localhost:8000/search', {
fetch('/v1/search', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -5,13 +5,13 @@ import {
FormGroup,
FormInput,
FormLabel,
H3,
Input,
Spacer,
Stack,
useBoolean,
} from 'ohmy-ui';
import { LoadingIndicator } from '@/ui/App';
import { fetch } from '@/utils';
interface SelectOption {
label: string;
@ -75,7 +75,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
startSaving();
fetch('http://127.0.0.1:8000/settings', {
fetch('/v1/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@ -138,7 +138,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
useEffect(() => {
const fetchConfig = async () => {
const response = await fetch('http://127.0.0.1:8000/settings');
const response = await fetch('/v1/settings');
const settings = await response.json();
if (!settings.llm.model) {

View file

@ -0,0 +1,96 @@
"use client";
import {
CTAButton,
FormGroup,
FormInput,
FormLabel,
Input,
Spacer,
Stack,
Text,
useBoolean,
} from 'ohmy-ui';
import { LoadingIndicator } from '@/ui/App';
import { fetch, handleServerErrors } from '@/utils';
import { useState } from 'react';
interface SignInFormPayload extends HTMLFormElement {
vectorDBUrl: HTMLInputElement;
vectorDBApiKey: HTMLInputElement;
llmApiKey: HTMLInputElement;
}
const errorsMap = {
LOGIN_BAD_CREDENTIALS: 'Invalid username or password',
};
export default function SignInForm({ onSignInSuccess = () => window.location.href = '/', submitButtonText = 'Sign in' }) {
const {
value: isSigningIn,
setTrue: disableSignIn,
setFalse: enableSignIn,
} = useBoolean(false);
const [signInError, setSignInError] = useState<string | null>(null);
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
event.preventDefault();
const formElements = event.currentTarget;
const authCredentials = new FormData();
// Backend expects username and password fields
authCredentials.append("username", formElements.email.value);
authCredentials.append("password", formElements.password.value);
setSignInError(null);
disableSignIn();
fetch('/v1/auth/login', {
method: 'POST',
body: authCredentials,
})
.then(handleServerErrors)
.then(response => response.json())
.then((bearer) => {
window.localStorage.setItem('access_token', bearer.access_token);
onSignInSuccess();
})
.catch(error => setSignInError(errorsMap[error.detail as keyof typeof errorsMap]))
.finally(() => enableSignIn());
};
return (
<form onSubmit={signIn} style={{ width: '100%' }}>
<Stack gap="4" orientation="vertical">
<Stack gap="4" orientation="vertical">
<FormGroup orientation="vertical" align="center/" gap="2">
<FormLabel>Email:</FormLabel>
<FormInput>
<Input name="email" type="email" placeholder="Your email address" />
</FormInput>
</FormGroup>
<FormGroup orientation="vertical" align="center/" gap="2">
<FormLabel>Password:</FormLabel>
<FormInput>
<Input name="password" type="password" placeholder="Your password" />
</FormInput>
</FormGroup>
</Stack>
<Spacer top="2">
<CTAButton type="submit">
<Stack gap="2" orientation="horizontal" align="/center">
{submitButtonText}
{isSigningIn && <LoadingIndicator />}
</Stack>
</CTAButton>
</Spacer>
{signInError && (
<Text>{signInError}</Text>
)}
</Stack>
</form>
)
}

View file

@ -0,0 +1,12 @@
import handleServerErrors from './handleServerErrors';
export default function fetch(url: string, options: RequestInit = {}): Promise<Response> {
return global.fetch('http://127.0.0.1:8000/api' + url, {
...options,
headers: {
...options.headers,
'Authorization': `Bearer ${localStorage.getItem('access_token')}`,
},
})
.then(handleServerErrors);
}

View file

@ -0,0 +1,13 @@
export default function handleServerErrors(response: Response): Promise<Response> {
return new Promise((resolve, reject) => {
if (response.status === 401) {
window.location.href = '/auth';
return;
}
if (!response.ok) {
return response.json().then(error => reject(error));
}
return resolve(response);
});
}

View file

@ -0,0 +1,2 @@
export { default as fetch } from './fetch';
export { default as handleServerErrors } from './handleServerErrors';

View file

@ -2,15 +2,17 @@
import os
import aiohttp
import uvicorn
import json
import logging
import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query, Depends
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.infrastructure.databases.relational import create_db_and_tables
@ -42,7 +44,6 @@ origins = [
"http://127.0.0.1:3000",
"http://frontend:3000",
"http://localhost:3000",
"http://localhost:3001",
]
app.add_middleware(
@ -58,39 +59,57 @@ from cognee.api.v1.users.routers import get_auth_router, get_register_router,\
from cognee.api.v1.permissions.get_permissions_router import get_permissions_router
from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
@app.exception_handler(RequestValidationError)
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
if request.url.path == "/api/v1/auth/login":
return JSONResponse(
status_code = 400,
content = {"detail": "LOGIN_BAD_CREDENTIALS"},
)
return JSONResponse(
status_code = 400,
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
app.include_router(
get_auth_router(),
prefix = "/auth/jwt",
prefix = "/api/v1/auth",
tags = ["auth"]
)
app.include_router(
get_register_router(),
prefix = "/auth",
prefix = "/api/v1/auth",
tags = ["auth"],
)
app.include_router(
get_reset_password_router(),
prefix = "/auth",
prefix = "/api/v1/auth",
tags = ["auth"],
)
app.include_router(
get_verify_router(),
prefix = "/auth",
prefix = "/api/v1/auth",
tags = ["auth"],
)
app.include_router(
get_users_router(),
prefix = "/users",
prefix = "/api/v1/users",
tags = ["users"],
)
app.include_router(
get_permissions_router(),
prefix = "/permissions",
prefix = "/api/v1/permissions",
tags = ["permissions"],
)
@ -108,31 +127,42 @@ def health_check():
"""
return {"status": "OK"}
@app.get("/datasets", response_model = list)
async def get_datasets():
@app.get("/api/v1/datasets", response_model = list)
async def get_datasets(user: User = Depends(get_authenticated_user)):
try:
from cognee.api.v1.datasets.datasets import datasets
datasets = await datasets.list_datasets()
from cognee.modules.data.methods import get_datasets
datasets = await get_datasets(user.id)
return JSONResponse(
status_code = 200,
content = [dataset.to_json() for dataset in datasets],
)
except Exception as error:
raise HTTPException(status_code = 500, detail=f"Error retrieving datasets: {str(error)}") from error
raise HTTPException(status_code = 500, detail = f"Error retrieving datasets: {str(error)}") from error
@app.delete("/datasets/{dataset_id}", response_model = dict)
async def delete_dataset(dataset_id: str):
from cognee.api.v1.datasets.datasets import datasets
await datasets.delete_dataset(dataset_id)
@app.delete("/api/v1/datasets/{dataset_id}", response_model = dict)
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset, delete_dataset
dataset = get_dataset(user.id, dataset_id)
if dataset is None:
return JSONResponse(
status_code = 404,
content = {
"detail": f"Dataset ({dataset_id}) not found."
}
)
await delete_dataset(dataset)
return JSONResponse(
status_code = 200,
content = "OK",
)
@app.get("/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str):
@app.get("/api/v1/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.shared.utils import render_graph
from cognee.infrastructure.databases.graph import get_graph_engine
@ -150,21 +180,31 @@ async def get_dataset_graph(dataset_id: str):
content = "Graphistry credentials are not set. Please set them in your .env file.",
)
@app.get("/datasets/{dataset_id}/data", response_model=list)
async def get_dataset_data(dataset_id: str):
from cognee.api.v1.datasets.datasets import datasets
@app.get("/api/v1/datasets/{dataset_id}/data", response_model=list)
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset_data, get_dataset
dataset_data = await datasets.list_data(dataset_id = dataset_id)
dataset = await get_dataset(user.id, dataset_id)
if dataset is None:
return JSONResponse(
status_code = 404,
content = {
"detail": f"Dataset ({dataset_id}) not found."
}
)
dataset_data = await get_dataset_data(dataset_id = dataset.id)
if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset.id}) not found.")
return [
data.to_json() for data in dataset_data
]
@app.get("/datasets/status", response_model=dict)
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
@app.get("/api/v1/datasets/status", response_model=dict)
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, user: User = Depends(get_authenticated_user)):
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
try:
@ -180,15 +220,35 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
content = {"error": str(error)}
)
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str):
from cognee.api.v1.datasets.datasets import datasets
dataset_data = await datasets.list_data(dataset_id)
@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
from cognee.modules.data.methods import get_dataset, get_dataset_data
dataset = await get_dataset(user.id, dataset_id)
if dataset is None:
return JSONResponse(
status_code = 404,
content = {
"detail": f"Dataset ({dataset_id}) not found."
}
)
dataset_data = await get_dataset_data(dataset.id)
if dataset_data is None:
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
data = [data for data in dataset_data if str(data.id) == data_id][0]
if data is None:
return JSONResponse(
status_code = 404,
content = {
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
}
)
return data.raw_data_location
class AddPayload(BaseModel):
@ -197,10 +257,11 @@ class AddPayload(BaseModel):
class Config:
arbitrary_types_allowed = True
@app.post("/add", response_model=dict)
@app.post("/api/v1/add", response_model=dict)
async def add(
data: List[UploadFile],
datasetId: str = Form(...),
user: User = Depends(get_authenticated_user),
):
""" This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.add import add as cognee_add
@ -230,6 +291,7 @@ async def add(
await cognee_add(
data,
datasetId,
user = user,
)
return JSONResponse(
status_code = 200,
@ -246,12 +308,12 @@ async def add(
class CognifyPayload(BaseModel):
datasets: List[str]
@app.post("/cognify", response_model=dict)
async def cognify(payload: CognifyPayload):
@app.post("/api/v1/cognify", response_model=dict)
async def cognify(payload: CognifyPayload, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets)
await cognee_cognify(payload.datasets, user)
return JSONResponse(
status_code = 200,
content = {
@ -267,8 +329,8 @@ async def cognify(payload: CognifyPayload):
class SearchPayload(BaseModel):
query_params: Dict[str, Any]
@app.post("/search", response_model=dict)
async def search(payload: SearchPayload):
@app.post("/api/v1/search", response_model=dict)
async def search(payload: SearchPayload, user: User = Depends(get_authenticated_user)):
""" This endpoint is responsible for searching for nodes in the graph."""
from cognee.api.v1.search import search as cognee_search
try:
@ -290,8 +352,8 @@ async def search(payload: SearchPayload):
content = {"error": str(error)}
)
@app.get("/settings", response_model=dict)
async def get_settings():
@app.get("/api/v1/settings", response_model=dict)
async def get_settings(user: User = Depends(get_authenticated_user)):
from cognee.modules.settings import get_settings as get_cognee_settings
return get_cognee_settings()
@ -309,8 +371,8 @@ class SettingsPayload(BaseModel):
llm: Optional[LLMConfig] = None
vectorDB: Optional[VectorDBConfig] = None
@app.post("/settings", response_model=dict)
async def save_config(new_settings: SettingsPayload):
@app.post("/api/v1/settings", response_model=dict)
async def save_config(new_settings: SettingsPayload, user: User = Depends(get_authenticated_user)):
from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None:
await save_llm_config(new_settings.llm)

View file

@ -10,10 +10,10 @@ from cognee.modules.ingestion import get_matched_datasets, save_data_to_file
from cognee.shared.utils import send_telemetry
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.relational import get_relational_config, get_relational_engine, create_db_and_tables
from cognee.modules.users.methods import create_default_user, get_default_user
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.modules.users.models import User
from cognee.modules.data.operations.ensure_dataset_exists import ensure_dataset_exists
from cognee.modules.data.methods import create_dataset
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
await create_db_and_tables()
@ -55,7 +55,10 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
return []
async def add_files(file_paths: List[str], dataset_name: str, user):
async def add_files(file_paths: List[str], dataset_name: str, user: User = None):
if user is None:
user = await get_default_user()
base_config = get_base_config()
data_directory_path = base_config.data_root_directory
@ -101,7 +104,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
)
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
dataset = await ensure_dataset_exists(dataset_name)
@dlt.resource(standalone = True, merge_key = "id")
async def data_resources(file_paths: str, user: User):
@ -115,8 +117,12 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
from sqlalchemy import select
from cognee.modules.data.models import Data
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user.id, session)
data = (await session.execute(
select(Data).filter(Data.id == data_id)
)).scalar_one_or_none()
@ -137,10 +143,8 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
extension = file_metadata["extension"],
mime_type = file_metadata["mime_type"],
)
dataset.data.append(data)
await session.merge(dataset)
await session.commit()
yield {
@ -155,12 +159,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
await give_permission_on_document(user, data_id, "write")
if user is None:
user = await get_default_user()
if user is None:
user = await create_default_user()
run_info = pipeline.run(
data_resources(processed_file_paths, user),
table_name = "file_metadata",

View file

@ -3,11 +3,10 @@ import logging
from typing import Union
from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
from cognee.modules.data.operations.retrieve_datasets import retrieve_datasets
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
from cognee.modules.users.models import User
@ -35,17 +34,18 @@ class PermissionDeniedException(Exception):
super().__init__(self.message)
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
db_engine = get_relational_engine()
if datasets is None or len(datasets) == 0:
return await cognify(await db_engine.get_datasets())
if type(datasets[0]) == str:
datasets = await retrieve_datasets(datasets)
if user is None:
user = await get_default_user()
existing_datasets = await get_datasets(user.id)
if datasets is None or len(datasets) == 0:
# If no datasets are provided, cognify all existing datasets.
datasets = existing_datasets
if type(datasets[0]) == str:
datasets = await get_datasets_by_name(datasets, user.id)
async def run_cognify_pipeline(dataset: Dataset):
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
@ -112,13 +112,16 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
raise error
existing_datasets = [dataset.name for dataset in list(await db_engine.get_datasets())]
existing_datasets_map = {
generate_dataset_name(dataset.name): True for dataset in existing_datasets
}
awaitables = []
for dataset in datasets:
dataset_name = generate_dataset_name(dataset.name)
if dataset_name in existing_datasets:
if dataset_name in existing_datasets_map:
awaitables.append(run_cognify_pipeline(dataset))
return await asyncio.gather(*awaitables)

View file

@ -1,37 +1,34 @@
from duckdb import CatalogException
from cognee.modules.users.methods import get_default_user
from cognee.modules.ingestion import discover_directory_datasets
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.infrastructure.databases.relational import get_relational_engine
class datasets():
@staticmethod
async def list_datasets():
db = get_relational_engine()
return await db.get_datasets()
from cognee.modules.data.methods import get_datasets
user = await get_default_user()
return await get_datasets(user.id)
@staticmethod
def discover_datasets(directory_path: str):
return list(discover_directory_datasets(directory_path).keys())
@staticmethod
async def list_data(dataset_id: str, dataset_name: str = None):
try:
return await get_dataset_data(dataset_id = dataset_id, dataset_name = dataset_name)
except CatalogException:
return None
async def list_data(dataset_id: str):
from cognee.modules.data.methods import get_dataset, get_dataset_data
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_id)
return await get_dataset_data(dataset.id)
@staticmethod
async def get_status(dataset_ids: list[str]) -> dict:
try:
return await get_pipeline_status(dataset_ids)
except CatalogException:
return {}
return await get_pipeline_status(dataset_ids)
@staticmethod
async def delete_dataset(dataset_id: str):
db = get_relational_engine()
try:
return await db.delete_table(dataset_id)
except CatalogException:
return {}
from cognee.modules.data.methods import get_dataset, delete_dataset
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_id)
return await delete_dataset(dataset)

View file

@ -228,11 +228,21 @@ class NetworkXAdapter(GraphDBInterface):
# Log that the file does not exist and an empty graph is initialized
logger.warning("File %s not found. Initializing an empty graph.", file_path)
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
file_dir = os.path.dirname(file_path)
if not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok = True)
await self.save_graph_to_file(file_path)
except Exception as error:
except Exception:
logger.error("Failed to load graph from file: %s", file_path)
# Initialize an empty graph in case of error
self.graph = nx.MultiDiGraph()
file_dir = os.path.dirname(file_path)
if not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok = True)
await self.save_graph_to_file(file_path)
async def delete_graph(self, file_path: str = None):

View file

@ -0,0 +1,11 @@
# Create
from .create_dataset import create_dataset
# Get
from .get_dataset import get_dataset
from .get_datasets import get_datasets
from .get_datasets_by_name import get_datasets_by_name
from .get_dataset_data import get_dataset_data
# Delete
from .delete_dataset import delete_dataset

View file

@ -0,0 +1,27 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from cognee.modules.data.models import Dataset
async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSession) -> Dataset:
dataset = (await session.scalars(
select(Dataset)\
.options(joinedload(Dataset.data))\
.filter(Dataset.name == dataset_name)
.filter(Dataset.owner_id == owner_id)
)).first()
if dataset is None:
dataset = Dataset(
id = uuid5(NAMESPACE_OID, dataset_name),
name = dataset_name,
data = []
)
dataset.owner_id = owner_id
session.add(dataset)
await session.commit()
return dataset

View file

@ -0,0 +1,7 @@
from cognee.modules.data.models import Dataset
from cognee.infrastructure.databases.relational import get_relational_engine
async def delete_dataset(dataset: Dataset):
db_engine = get_relational_engine()
return await db_engine.delete_table(dataset.id)

View file

@ -0,0 +1,14 @@
from uuid import UUID
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models import Dataset
async def get_dataset(user_id: UUID, dataset_id: UUID) -> Dataset:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
dataset = await session.get(Dataset, dataset_id)
if dataset and dataset.owner_id != user_id:
return None
return dataset

View file

@ -3,16 +3,14 @@ from sqlalchemy import select
from cognee.modules.data.models import Data, Dataset
from cognee.infrastructure.databases.relational import get_relational_engine
async def get_dataset_data(dataset_id: UUID = None, dataset_name: str = None):
if dataset_id is None and dataset_name is None:
raise ValueError("get_dataset_data: Either dataset_id or dataset_name must be provided.")
async def get_dataset_data(dataset_id: UUID) -> list[Data]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
result = await session.execute(
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id) | (Dataset.name == dataset_name))
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
)
data = result.scalars().all()
return data

View file

@ -1,13 +1,14 @@
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models import Dataset
async def retrieve_datasets(dataset_names: list[str]) -> list[Dataset]:
async def get_datasets(user_id: UUID) -> list[Dataset]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
datasets = (await session.scalars(
select(Dataset).filter(Dataset.name.in_(dataset_names))
select(Dataset).filter(Dataset.owner_id == user_id)
)).all()
return datasets

View file

@ -0,0 +1,16 @@
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models import Dataset
async def get_datasets_by_name(dataset_names: list[str], user_id: UUID) -> list[Dataset]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
datasets = (await session.scalars(
select(Dataset)
.filter(Dataset.owner_id == user_id)
.filter(Dataset.name.in_(dataset_names))
)).all()
return datasets

View file

@ -21,7 +21,8 @@ class Data(Base):
datasets: Mapped[List["Dataset"]] = relationship(
secondary = DatasetData.__tablename__,
back_populates = "data"
back_populates = "data",
lazy = "noload",
)
def to_json(self) -> dict:

View file

@ -16,9 +16,12 @@ class Dataset(Base):
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc))
owner_id = Column(UUID, index = True)
data: Mapped[List["Data"]] = relationship(
secondary = DatasetData.__tablename__,
back_populates = "datasets"
back_populates = "datasets",
lazy = "noload",
)
def to_json(self) -> dict:
@ -27,5 +30,6 @@ class Dataset(Base):
"name": self.name,
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"ownerId": str(self.owner_id),
"data": [data.to_json() for data in self.data]
}

View file

@ -1,26 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from cognee.modules.data.models import Dataset
from cognee.infrastructure.databases.relational import get_relational_engine
async def ensure_dataset_exists(dataset_name: str) -> Dataset:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
dataset = (await session.scalars(
select(Dataset)\
.options(joinedload(Dataset.data))\
.filter(Dataset.name == dataset_name)
)).first()
if dataset is None:
dataset = Dataset(
name = dataset_name,
data = []
)
session.add(dataset)
await session.commit()
return dataset

View file

@ -1,9 +1,16 @@
# from fastapi import Depends
from typing import AsyncGenerator
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_users.db import SQLAlchemyUserDatabase
# from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from .models.User import User
async def get_user_db(session):
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
yield session
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
from contextlib import asynccontextmanager

View file

@ -2,16 +2,33 @@ import os
import uuid
from typing import Optional
from fastapi import Depends, Request
from fastapi_users import BaseUserManager, UUIDIDMixin
from fastapi_users.exceptions import UserNotExists
from fastapi_users import BaseUserManager, UUIDIDMixin, models
from fastapi_users.db import SQLAlchemyUserDatabase
from .get_user_db import get_user_db
from .models import User
from .methods import get_user
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
verification_token_secret = os.getenv("FASTAPI_USERS_VERIFICATION_TOKEN_SECRET", "super_secret")
async def get(self, id: models.ID) -> models.UP:
"""
Get a user by id.
:param id: Id. of the user to retrieve.
:raises UserNotExists: The user does not exist.
:return: A user.
"""
user = await get_user(id)
if user is None:
raise UserNotExists()
return user
async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.")

View file

@ -1,3 +1,5 @@
from .get_user import get_user
from .create_user import create_user
from .get_default_user import get_default_user
from .create_default_user import create_default_user
from .get_authenticated_user import get_authenticated_user

View file

@ -1,4 +1,3 @@
import hashlib
from .create_user import create_user
async def create_default_user():
@ -7,14 +6,11 @@ async def create_default_user():
user = await create_user(
email = default_user_email,
password = await hash_password(default_user_password),
is_superuser = True,
password = default_user_password,
is_superuser = False,
is_active = True,
is_verified = True,
auto_login = True,
)
return user
async def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()

View file

@ -0,0 +1,5 @@
from ..get_fastapi_users import get_fastapi_users
fastapi_users = get_fastapi_users()
get_authenticated_user = fastapi_users.current_user(active = True, verified = True)

View file

@ -2,6 +2,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.future import select
from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine
from .create_default_user import create_default_user
async def get_default_user():
db_engine = get_relational_engine()
@ -13,4 +14,7 @@ async def get_default_user():
result = await session.execute(query)
user = result.scalars().first()
if user is None:
return await create_default_user()
return user

View file

@ -0,0 +1,15 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models import User
async def get_user(user_id: UUID):
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
user = (await session.execute(
select(User).options(joinedload(User.groups)).where(User.id == user_id)
)).scalar()
return user

View file

@ -16,25 +16,21 @@ class PermissionDeniedException(Exception):
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
try:
user_group_ids = [group.id for group in user.groups]
user_group_ids = [group.id for group in user.groups]
db_engine = get_relational_engine()
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
result = await session.execute(
select(ACL)
.join(ACL.permission)
.options(joinedload(ACL.resources))
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
.where(ACL.permission.has(name = permission_type))
)
acls = result.unique().scalars().all()
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
has_permissions = all(document_id in resource_ids for document_id in document_ids)
async with db_engine.get_async_session() as session:
result = await session.execute(
select(ACL)
.join(ACL.permission)
.options(joinedload(ACL.resources))
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
.where(ACL.permission.has(name = permission_type))
)
acls = result.unique().scalars().all()
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
has_permissions = all(document_id in resource_ids for document_id in document_ids)
if not has_permissions:
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
except Exception as error:
logger.error("Error checking permissions on documents: %s", str(error))
raise
if not has_permissions:
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")

View file

@ -1,11 +1,6 @@
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
# from cognee.infrastructure.databases.vector import get_vector_engine
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
graph_engine = await get_graph_engine()