feat: user authentication in routes (#133)
* feat: require logged in user in routes
This commit is contained in:
parent
22c0dd5b2d
commit
e1a0b55a21
40 changed files with 526 additions and 170 deletions
16
cognee-frontend/src/app/auth/AuthPage.module.css
Normal file
16
cognee-frontend/src/app/auth/AuthPage.module.css
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
.main {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-direction: column;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
.authContainer {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
padding: 24px 0;
|
||||
margin: 0 auto;
|
||||
max-width: 440px;
|
||||
width: 100%;
|
||||
}
|
||||
29
cognee-frontend/src/app/auth/AuthPage.tsx
Normal file
29
cognee-frontend/src/app/auth/AuthPage.tsx
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import { Spacer, Stack, Text } from 'ohmy-ui';
|
||||
import { TextLogo } from '@/ui/App';
|
||||
import Footer from '@/ui/Partials/Footer/Footer';
|
||||
|
||||
import styles from './AuthPage.module.css';
|
||||
import { Divider } from '@/ui/Layout';
|
||||
import SignInForm from '@/ui/Partials/SignInForm/SignInForm';
|
||||
|
||||
export default function AuthPage() {
|
||||
return (
|
||||
<main className={styles.main}>
|
||||
<Spacer inset vertical="1" horizontal="2">
|
||||
<Stack orientation="horizontal" gap="between" align="center">
|
||||
<TextLogo width={225} height={64} />
|
||||
</Stack>
|
||||
</Spacer>
|
||||
<Divider />
|
||||
<div className={styles.authContainer}>
|
||||
<Stack gap="4" style={{ width: '100%' }}>
|
||||
<h1><Text size="large">Sign in</Text></h1>
|
||||
<SignInForm />
|
||||
</Stack>
|
||||
</div>
|
||||
<Spacer inset horizontal="3" wrap>
|
||||
<Footer />
|
||||
</Spacer>
|
||||
</main>
|
||||
)
|
||||
}
|
||||
1
cognee-frontend/src/app/auth/page.tsx
Normal file
1
cognee-frontend/src/app/auth/page.tsx
Normal file
|
|
@ -0,0 +1 @@
|
|||
export { default } from './AuthPage';
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function cognifyDataset(dataset: { id: string, name: string }) {
|
||||
return fetch('http://127.0.0.1:8000/cognify', {
|
||||
return fetch('/v1/cognify', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function deleteDataset(dataset: { id: string }) {
|
||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}`, {
|
||||
return fetch(`/v1/datasets/${dataset.id}`, {
|
||||
method: 'DELETE',
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function getDatasetData(dataset: { id: string }) {
|
||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/data`)
|
||||
return fetch(`/v1/datasets/${dataset.id}/data`)
|
||||
.then((response) => response.json());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function getExplorationGraphUrl(dataset: { id: string }) {
|
||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/graph`)
|
||||
return fetch(`/v1/datasets/${dataset.id}/graph`)
|
||||
.then(async (response) => {
|
||||
if (response.status !== 200) {
|
||||
throw new Error((await response.text()).replaceAll("\"", ""));
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ import {
|
|||
UploadInput,
|
||||
CloseIcon,
|
||||
} from "ohmy-ui";
|
||||
import styles from "./DataView.module.css";
|
||||
import { fetch } from '@/utils';
|
||||
import RawDataPreview from './RawDataPreview';
|
||||
import styles from "./DataView.module.css";
|
||||
|
||||
export interface Data {
|
||||
id: string;
|
||||
|
|
@ -37,7 +38,7 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
const showRawData = useCallback((dataItem: Data) => {
|
||||
setSelectedData(dataItem);
|
||||
|
||||
fetch(`http://127.0.0.1:8000/datasets/${datasetId}/data/${dataItem.id}/raw`)
|
||||
fetch(`/v1/datasets/${datasetId}/data/${dataItem.id}/raw`)
|
||||
.then((response) => response.arrayBuffer())
|
||||
.then(setRawData);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function addData(dataset: { id: string }, files: File[]) {
|
||||
const formData = new FormData();
|
||||
files.forEach((file) => {
|
||||
|
|
@ -5,7 +7,7 @@ export default function addData(dataset: { id: string }, files: File[]) {
|
|||
})
|
||||
formData.append('datasetId', dataset.id);
|
||||
|
||||
return fetch('http://127.0.0.1:8000/add', {
|
||||
return fetch('/v1/add', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
}).then((response) => response.json());
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { v4 } from 'uuid';
|
||||
import { DataFile } from './useData';
|
||||
import { fetch } from '@/utils';
|
||||
|
||||
export interface Dataset {
|
||||
id: string;
|
||||
|
|
@ -14,7 +15,14 @@ function useDatasets() {
|
|||
const statusTimeout = useRef<any>(null);
|
||||
|
||||
const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
||||
fetch(`http://127.0.0.1:8000/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`)
|
||||
fetch(
|
||||
`/v1/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((response) => response.json())
|
||||
.then((statuses) => setDatasets(
|
||||
(datasets) => (
|
||||
|
|
@ -65,7 +73,11 @@ function useDatasets() {
|
|||
}, []);
|
||||
|
||||
const fetchDatasets = useCallback(() => {
|
||||
fetch('http://127.0.0.1:8000/datasets')
|
||||
fetch('/v1/datasets', {
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
|
||||
},
|
||||
})
|
||||
.then((response) => response.json())
|
||||
.then((datasets) => {
|
||||
setDatasets(datasets);
|
||||
|
|
@ -75,6 +87,9 @@ function useDatasets() {
|
|||
} else {
|
||||
window.location.href = '/wizard';
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
});
|
||||
}, [checkDatasetStatuses]);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import { v4 } from 'uuid';
|
|||
import classNames from 'classnames';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { CTAButton, Stack, Text, DropdownSelect, TextArea, useBoolean } from 'ohmy-ui';
|
||||
import { fetch } from '@/utils';
|
||||
import styles from './SearchView.module.css';
|
||||
|
||||
interface Message {
|
||||
|
|
@ -50,7 +51,7 @@ export default function SearchView() {
|
|||
},
|
||||
]);
|
||||
|
||||
fetch('http://localhost:8000/search', {
|
||||
fetch('/v1/search', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ import {
|
|||
FormGroup,
|
||||
FormInput,
|
||||
FormLabel,
|
||||
H3,
|
||||
Input,
|
||||
Spacer,
|
||||
Stack,
|
||||
useBoolean,
|
||||
} from 'ohmy-ui';
|
||||
import { LoadingIndicator } from '@/ui/App';
|
||||
import { fetch } from '@/utils';
|
||||
|
||||
interface SelectOption {
|
||||
label: string;
|
||||
|
|
@ -75,7 +75,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
|
|||
|
||||
startSaving();
|
||||
|
||||
fetch('http://127.0.0.1:8000/settings', {
|
||||
fetch('/v1/settings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
@ -138,7 +138,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
|
|||
|
||||
useEffect(() => {
|
||||
const fetchConfig = async () => {
|
||||
const response = await fetch('http://127.0.0.1:8000/settings');
|
||||
const response = await fetch('/v1/settings');
|
||||
const settings = await response.json();
|
||||
|
||||
if (!settings.llm.model) {
|
||||
|
|
|
|||
96
cognee-frontend/src/ui/Partials/SignInForm/SignInForm.tsx
Normal file
96
cognee-frontend/src/ui/Partials/SignInForm/SignInForm.tsx
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
CTAButton,
|
||||
FormGroup,
|
||||
FormInput,
|
||||
FormLabel,
|
||||
Input,
|
||||
Spacer,
|
||||
Stack,
|
||||
Text,
|
||||
useBoolean,
|
||||
} from 'ohmy-ui';
|
||||
import { LoadingIndicator } from '@/ui/App';
|
||||
import { fetch, handleServerErrors } from '@/utils';
|
||||
import { useState } from 'react';
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
vectorDBUrl: HTMLInputElement;
|
||||
vectorDBApiKey: HTMLInputElement;
|
||||
llmApiKey: HTMLInputElement;
|
||||
}
|
||||
|
||||
const errorsMap = {
|
||||
LOGIN_BAD_CREDENTIALS: 'Invalid username or password',
|
||||
};
|
||||
|
||||
export default function SignInForm({ onSignInSuccess = () => window.location.href = '/', submitButtonText = 'Sign in' }) {
|
||||
const {
|
||||
value: isSigningIn,
|
||||
setTrue: disableSignIn,
|
||||
setFalse: enableSignIn,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [signInError, setSignInError] = useState<string | null>(null);
|
||||
|
||||
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
const authCredentials = new FormData();
|
||||
// Backend expects username and password fields
|
||||
authCredentials.append("username", formElements.email.value);
|
||||
authCredentials.append("password", formElements.password.value);
|
||||
|
||||
setSignInError(null);
|
||||
disableSignIn();
|
||||
|
||||
fetch('/v1/auth/login', {
|
||||
method: 'POST',
|
||||
body: authCredentials,
|
||||
})
|
||||
.then(handleServerErrors)
|
||||
.then(response => response.json())
|
||||
.then((bearer) => {
|
||||
window.localStorage.setItem('access_token', bearer.access_token);
|
||||
onSignInSuccess();
|
||||
})
|
||||
.catch(error => setSignInError(errorsMap[error.detail as keyof typeof errorsMap]))
|
||||
.finally(() => enableSignIn());
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} style={{ width: '100%' }}>
|
||||
<Stack gap="4" orientation="vertical">
|
||||
<Stack gap="4" orientation="vertical">
|
||||
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||
<FormLabel>Email:</FormLabel>
|
||||
<FormInput>
|
||||
<Input name="email" type="email" placeholder="Your email address" />
|
||||
</FormInput>
|
||||
</FormGroup>
|
||||
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||
<FormLabel>Password:</FormLabel>
|
||||
<FormInput>
|
||||
<Input name="password" type="password" placeholder="Your password" />
|
||||
</FormInput>
|
||||
</FormGroup>
|
||||
</Stack>
|
||||
|
||||
<Spacer top="2">
|
||||
<CTAButton type="submit">
|
||||
<Stack gap="2" orientation="horizontal" align="/center">
|
||||
{submitButtonText}
|
||||
{isSigningIn && <LoadingIndicator />}
|
||||
</Stack>
|
||||
</CTAButton>
|
||||
</Spacer>
|
||||
|
||||
{signInError && (
|
||||
<Text>{signInError}</Text>
|
||||
)}
|
||||
</Stack>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
12
cognee-frontend/src/utils/fetch.ts
Normal file
12
cognee-frontend/src/utils/fetch.ts
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import handleServerErrors from './handleServerErrors';
|
||||
|
||||
export default function fetch(url: string, options: RequestInit = {}): Promise<Response> {
|
||||
return global.fetch('http://127.0.0.1:8000/api' + url, {
|
||||
...options,
|
||||
headers: {
|
||||
...options.headers,
|
||||
'Authorization': `Bearer ${localStorage.getItem('access_token')}`,
|
||||
},
|
||||
})
|
||||
.then(handleServerErrors);
|
||||
}
|
||||
13
cognee-frontend/src/utils/handleServerErrors.ts
Normal file
13
cognee-frontend/src/utils/handleServerErrors.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
export default function handleServerErrors(response: Response): Promise<Response> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (response.status === 401) {
|
||||
window.location.href = '/auth';
|
||||
return;
|
||||
}
|
||||
if (!response.ok) {
|
||||
return response.json().then(error => reject(error));
|
||||
}
|
||||
|
||||
return resolve(response);
|
||||
});
|
||||
}
|
||||
2
cognee-frontend/src/utils/index.ts
Normal file
2
cognee-frontend/src/utils/index.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
export { default as fetch } from './fetch';
|
||||
export { default as handleServerErrors } from './handleServerErrors';
|
||||
|
|
@ -2,15 +2,17 @@
|
|||
import os
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
import json
|
||||
import logging
|
||||
import sentry_sdk
|
||||
from typing import Dict, Any, List, Union, Optional, Literal
|
||||
from typing_extensions import Annotated
|
||||
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
|
||||
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query, Depends
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
|
@ -42,7 +44,6 @@ origins = [
|
|||
"http://127.0.0.1:3000",
|
||||
"http://frontend:3000",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
|
|
@ -58,39 +59,57 @@ from cognee.api.v1.users.routers import get_auth_router, get_register_router,\
|
|||
|
||||
from cognee.api.v1.permissions.get_permissions_router import get_permissions_router
|
||||
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
if request.url.path == "/api/v1/auth/login":
|
||||
return JSONResponse(
|
||||
status_code = 400,
|
||||
content = {"detail": "LOGIN_BAD_CREDENTIALS"},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code = 400,
|
||||
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_auth_router(),
|
||||
prefix = "/auth/jwt",
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_register_router(),
|
||||
prefix = "/auth",
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_reset_password_router(),
|
||||
prefix = "/auth",
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_verify_router(),
|
||||
prefix = "/auth",
|
||||
prefix = "/api/v1/auth",
|
||||
tags = ["auth"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix = "/users",
|
||||
prefix = "/api/v1/users",
|
||||
tags = ["users"],
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
get_permissions_router(),
|
||||
prefix = "/permissions",
|
||||
prefix = "/api/v1/permissions",
|
||||
tags = ["permissions"],
|
||||
)
|
||||
|
||||
|
|
@ -108,31 +127,42 @@ def health_check():
|
|||
"""
|
||||
return {"status": "OK"}
|
||||
|
||||
@app.get("/datasets", response_model = list)
|
||||
async def get_datasets():
|
||||
@app.get("/api/v1/datasets", response_model = list)
|
||||
async def get_datasets(user: User = Depends(get_authenticated_user)):
|
||||
try:
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
datasets = await datasets.list_datasets()
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
datasets = await get_datasets(user.id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code = 200,
|
||||
content = [dataset.to_json() for dataset in datasets],
|
||||
)
|
||||
except Exception as error:
|
||||
raise HTTPException(status_code = 500, detail=f"Error retrieving datasets: {str(error)}") from error
|
||||
raise HTTPException(status_code = 500, detail = f"Error retrieving datasets: {str(error)}") from error
|
||||
|
||||
@app.delete("/datasets/{dataset_id}", response_model = dict)
|
||||
async def delete_dataset(dataset_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
await datasets.delete_dataset(dataset_id)
|
||||
@app.delete("/api/v1/datasets/{dataset_id}", response_model = dict)
|
||||
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
|
||||
dataset = get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
return JSONResponse(
|
||||
status_code = 404,
|
||||
content = {
|
||||
"detail": f"Dataset ({dataset_id}) not found."
|
||||
}
|
||||
)
|
||||
|
||||
await delete_dataset(dataset)
|
||||
|
||||
return JSONResponse(
|
||||
status_code = 200,
|
||||
content = "OK",
|
||||
)
|
||||
|
||||
@app.get("/datasets/{dataset_id}/graph", response_model=list)
|
||||
async def get_dataset_graph(dataset_id: str):
|
||||
@app.get("/api/v1/datasets/{dataset_id}/graph", response_model=list)
|
||||
async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
|
@ -150,21 +180,31 @@ async def get_dataset_graph(dataset_id: str):
|
|||
content = "Graphistry credentials are not set. Please set them in your .env file.",
|
||||
)
|
||||
|
||||
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
||||
async def get_dataset_data(dataset_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
@app.get("/api/v1/datasets/{dataset_id}/data", response_model=list)
|
||||
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.data.methods import get_dataset_data, get_dataset
|
||||
|
||||
dataset_data = await datasets.list_data(dataset_id = dataset_id)
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
return JSONResponse(
|
||||
status_code = 404,
|
||||
content = {
|
||||
"detail": f"Dataset ({dataset_id}) not found."
|
||||
}
|
||||
)
|
||||
|
||||
dataset_data = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
if dataset_data is None:
|
||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset.id}) not found.")
|
||||
|
||||
return [
|
||||
data.to_json() for data in dataset_data
|
||||
]
|
||||
|
||||
@app.get("/datasets/status", response_model=dict)
|
||||
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
|
||||
@app.get("/api/v1/datasets/status", response_model=dict)
|
||||
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
||||
|
||||
try:
|
||||
|
|
@ -180,15 +220,35 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
|
|||
content = {"error": str(error)}
|
||||
)
|
||||
|
||||
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||
async def get_raw_data(dataset_id: str, data_id: str):
|
||||
from cognee.api.v1.datasets.datasets import datasets
|
||||
dataset_data = await datasets.list_data(dataset_id)
|
||||
@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
return JSONResponse(
|
||||
status_code = 404,
|
||||
content = {
|
||||
"detail": f"Dataset ({dataset_id}) not found."
|
||||
}
|
||||
)
|
||||
|
||||
dataset_data = await get_dataset_data(dataset.id)
|
||||
|
||||
if dataset_data is None:
|
||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
data = [data for data in dataset_data if str(data.id) == data_id][0]
|
||||
|
||||
if data is None:
|
||||
return JSONResponse(
|
||||
status_code = 404,
|
||||
content = {
|
||||
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||
}
|
||||
)
|
||||
|
||||
return data.raw_data_location
|
||||
|
||||
class AddPayload(BaseModel):
|
||||
|
|
@ -197,10 +257,11 @@ class AddPayload(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@app.post("/add", response_model=dict)
|
||||
@app.post("/api/v1/add", response_model=dict)
|
||||
async def add(
|
||||
data: List[UploadFile],
|
||||
datasetId: str = Form(...),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
""" This endpoint is responsible for adding data to the graph."""
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
|
@ -230,6 +291,7 @@ async def add(
|
|||
await cognee_add(
|
||||
data,
|
||||
datasetId,
|
||||
user = user,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code = 200,
|
||||
|
|
@ -246,12 +308,12 @@ async def add(
|
|||
class CognifyPayload(BaseModel):
|
||||
datasets: List[str]
|
||||
|
||||
@app.post("/cognify", response_model=dict)
|
||||
async def cognify(payload: CognifyPayload):
|
||||
@app.post("/api/v1/cognify", response_model=dict)
|
||||
async def cognify(payload: CognifyPayload, user: User = Depends(get_authenticated_user)):
|
||||
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
||||
try:
|
||||
await cognee_cognify(payload.datasets)
|
||||
await cognee_cognify(payload.datasets, user)
|
||||
return JSONResponse(
|
||||
status_code = 200,
|
||||
content = {
|
||||
|
|
@ -267,8 +329,8 @@ async def cognify(payload: CognifyPayload):
|
|||
class SearchPayload(BaseModel):
|
||||
query_params: Dict[str, Any]
|
||||
|
||||
@app.post("/search", response_model=dict)
|
||||
async def search(payload: SearchPayload):
|
||||
@app.post("/api/v1/search", response_model=dict)
|
||||
async def search(payload: SearchPayload, user: User = Depends(get_authenticated_user)):
|
||||
""" This endpoint is responsible for searching for nodes in the graph."""
|
||||
from cognee.api.v1.search import search as cognee_search
|
||||
try:
|
||||
|
|
@ -290,8 +352,8 @@ async def search(payload: SearchPayload):
|
|||
content = {"error": str(error)}
|
||||
)
|
||||
|
||||
@app.get("/settings", response_model=dict)
|
||||
async def get_settings():
|
||||
@app.get("/api/v1/settings", response_model=dict)
|
||||
async def get_settings(user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.settings import get_settings as get_cognee_settings
|
||||
return get_cognee_settings()
|
||||
|
||||
|
|
@ -309,8 +371,8 @@ class SettingsPayload(BaseModel):
|
|||
llm: Optional[LLMConfig] = None
|
||||
vectorDB: Optional[VectorDBConfig] = None
|
||||
|
||||
@app.post("/settings", response_model=dict)
|
||||
async def save_config(new_settings: SettingsPayload):
|
||||
@app.post("/api/v1/settings", response_model=dict)
|
||||
async def save_config(new_settings: SettingsPayload, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
||||
if new_settings.llm is not None:
|
||||
await save_llm_config(new_settings.llm)
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from cognee.modules.ingestion import get_matched_datasets, save_data_to_file
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.relational import get_relational_config, get_relational_engine, create_db_and_tables
|
||||
from cognee.modules.users.methods import create_default_user, get_default_user
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.operations.ensure_dataset_exists import ensure_dataset_exists
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
|
||||
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||
await create_db_and_tables()
|
||||
|
|
@ -55,7 +55,10 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
|||
|
||||
return []
|
||||
|
||||
async def add_files(file_paths: List[str], dataset_name: str, user):
|
||||
async def add_files(file_paths: List[str], dataset_name: str, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
base_config = get_base_config()
|
||||
data_directory_path = base_config.data_root_directory
|
||||
|
||||
|
|
@ -101,7 +104,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
)
|
||||
|
||||
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
|
||||
dataset = await ensure_dataset_exists(dataset_name)
|
||||
|
||||
@dlt.resource(standalone = True, merge_key = "id")
|
||||
async def data_resources(file_paths: str, user: User):
|
||||
|
|
@ -115,8 +117,12 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
data = (await session.execute(
|
||||
select(Data).filter(Data.id == data_id)
|
||||
)).scalar_one_or_none()
|
||||
|
|
@ -137,10 +143,8 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
extension = file_metadata["extension"],
|
||||
mime_type = file_metadata["mime_type"],
|
||||
)
|
||||
|
||||
dataset.data.append(data)
|
||||
|
||||
await session.merge(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
yield {
|
||||
|
|
@ -155,12 +159,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
|||
await give_permission_on_document(user, data_id, "write")
|
||||
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
user = await create_default_user()
|
||||
|
||||
run_info = pipeline.run(
|
||||
data_resources(processed_file_paths, user),
|
||||
table_name = "file_metadata",
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@ import logging
|
|||
from typing import Union
|
||||
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.models import Dataset, Data
|
||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.operations.retrieve_datasets import retrieve_datasets
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -35,17 +34,18 @@ class PermissionDeniedException(Exception):
|
|||
super().__init__(self.message)
|
||||
|
||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
return await cognify(await db_engine.get_datasets())
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
datasets = await retrieve_datasets(datasets)
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
existing_datasets = await get_datasets(user.id)
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
# If no datasets are provided, cognify all existing datasets.
|
||||
datasets = existing_datasets
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
|
||||
async def run_cognify_pipeline(dataset: Dataset):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
|
|
@ -112,13 +112,16 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
|||
raise error
|
||||
|
||||
|
||||
existing_datasets = [dataset.name for dataset in list(await db_engine.get_datasets())]
|
||||
existing_datasets_map = {
|
||||
generate_dataset_name(dataset.name): True for dataset in existing_datasets
|
||||
}
|
||||
|
||||
awaitables = []
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in existing_datasets:
|
||||
if dataset_name in existing_datasets_map:
|
||||
awaitables.append(run_cognify_pipeline(dataset))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
|
|
|||
|
|
@ -1,37 +1,34 @@
|
|||
from duckdb import CatalogException
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.ingestion import discover_directory_datasets
|
||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
class datasets():
|
||||
@staticmethod
|
||||
async def list_datasets():
|
||||
db = get_relational_engine()
|
||||
return await db.get_datasets()
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
user = await get_default_user()
|
||||
return await get_datasets(user.id)
|
||||
|
||||
@staticmethod
|
||||
def discover_datasets(directory_path: str):
|
||||
return list(discover_directory_datasets(directory_path).keys())
|
||||
|
||||
@staticmethod
|
||||
async def list_data(dataset_id: str, dataset_name: str = None):
|
||||
try:
|
||||
return await get_dataset_data(dataset_id = dataset_id, dataset_name = dataset_name)
|
||||
except CatalogException:
|
||||
return None
|
||||
async def list_data(dataset_id: str):
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
user = await get_default_user()
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
return await get_dataset_data(dataset.id)
|
||||
|
||||
@staticmethod
|
||||
async def get_status(dataset_ids: list[str]) -> dict:
|
||||
try:
|
||||
return await get_pipeline_status(dataset_ids)
|
||||
except CatalogException:
|
||||
return {}
|
||||
return await get_pipeline_status(dataset_ids)
|
||||
|
||||
@staticmethod
|
||||
async def delete_dataset(dataset_id: str):
|
||||
db = get_relational_engine()
|
||||
try:
|
||||
return await db.delete_table(dataset_id)
|
||||
except CatalogException:
|
||||
return {}
|
||||
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||
user = await get_default_user()
|
||||
dataset = await get_dataset(user.id, dataset_id)
|
||||
|
||||
return await delete_dataset(dataset)
|
||||
|
|
|
|||
|
|
@ -228,11 +228,21 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
# Log that the file does not exist and an empty graph is initialized
|
||||
logger.warning("File %s not found. Initializing an empty graph.", file_path)
|
||||
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
|
||||
|
||||
file_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(file_dir):
|
||||
os.makedirs(file_dir, exist_ok = True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
except Exception as error:
|
||||
except Exception:
|
||||
logger.error("Failed to load graph from file: %s", file_path)
|
||||
# Initialize an empty graph in case of error
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
file_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(file_dir):
|
||||
os.makedirs(file_dir, exist_ok = True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
||||
async def delete_graph(self, file_path: str = None):
|
||||
|
|
|
|||
11
cognee/modules/data/methods/__init__.py
Normal file
11
cognee/modules/data/methods/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# Create
|
||||
from .create_dataset import create_dataset
|
||||
|
||||
# Get
|
||||
from .get_dataset import get_dataset
|
||||
from .get_datasets import get_datasets
|
||||
from .get_datasets_by_name import get_datasets_by_name
|
||||
from .get_dataset_data import get_dataset_data
|
||||
|
||||
# Delete
|
||||
from .delete_dataset import delete_dataset
|
||||
27
cognee/modules/data/methods/create_dataset.py
Normal file
27
cognee/modules/data/methods/create_dataset.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSession) -> Dataset:
|
||||
dataset = (await session.scalars(
|
||||
select(Dataset)\
|
||||
.options(joinedload(Dataset.data))\
|
||||
.filter(Dataset.name == dataset_name)
|
||||
.filter(Dataset.owner_id == owner_id)
|
||||
)).first()
|
||||
|
||||
if dataset is None:
|
||||
dataset = Dataset(
|
||||
id = uuid5(NAMESPACE_OID, dataset_name),
|
||||
name = dataset_name,
|
||||
data = []
|
||||
)
|
||||
dataset.owner_id = owner_id
|
||||
|
||||
session.add(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return dataset
|
||||
7
cognee/modules/data/methods/delete_dataset.py
Normal file
7
cognee/modules/data/methods/delete_dataset.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from cognee.modules.data.models import Dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
async def delete_dataset(dataset: Dataset):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
return await db_engine.delete_table(dataset.id)
|
||||
14
cognee/modules/data/methods/get_dataset.py
Normal file
14
cognee/modules/data/methods/get_dataset.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from uuid import UUID
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import Dataset
|
||||
|
||||
async def get_dataset(user_id: UUID, dataset_id: UUID) -> Dataset:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
if dataset and dataset.owner_id != user_id:
|
||||
return None
|
||||
|
||||
return dataset
|
||||
|
|
@ -3,16 +3,14 @@ from sqlalchemy import select
|
|||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
async def get_dataset_data(dataset_id: UUID = None, dataset_name: str = None):
|
||||
if dataset_id is None and dataset_name is None:
|
||||
raise ValueError("get_dataset_data: Either dataset_id or dataset_name must be provided.")
|
||||
|
||||
async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id) | (Dataset.name == dataset_name))
|
||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
||||
)
|
||||
|
||||
data = result.scalars().all()
|
||||
|
||||
return data
|
||||
|
|
@ -1,13 +1,14 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import Dataset
|
||||
|
||||
async def retrieve_datasets(dataset_names: list[str]) -> list[Dataset]:
|
||||
async def get_datasets(user_id: UUID) -> list[Dataset]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
datasets = (await session.scalars(
|
||||
select(Dataset).filter(Dataset.name.in_(dataset_names))
|
||||
select(Dataset).filter(Dataset.owner_id == user_id)
|
||||
)).all()
|
||||
|
||||
return datasets
|
||||
16
cognee/modules/data/methods/get_datasets_by_name.py
Normal file
16
cognee/modules/data/methods/get_datasets_by_name.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import Dataset
|
||||
|
||||
async def get_datasets_by_name(dataset_names: list[str], user_id: UUID) -> list[Dataset]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
datasets = (await session.scalars(
|
||||
select(Dataset)
|
||||
.filter(Dataset.owner_id == user_id)
|
||||
.filter(Dataset.name.in_(dataset_names))
|
||||
)).all()
|
||||
|
||||
return datasets
|
||||
|
|
@ -21,7 +21,8 @@ class Data(Base):
|
|||
|
||||
datasets: Mapped[List["Dataset"]] = relationship(
|
||||
secondary = DatasetData.__tablename__,
|
||||
back_populates = "data"
|
||||
back_populates = "data",
|
||||
lazy = "noload",
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -16,9 +16,12 @@ class Dataset(Base):
|
|||
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc))
|
||||
|
||||
owner_id = Column(UUID, index = True)
|
||||
|
||||
data: Mapped[List["Data"]] = relationship(
|
||||
secondary = DatasetData.__tablename__,
|
||||
back_populates = "datasets"
|
||||
back_populates = "datasets",
|
||||
lazy = "noload",
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
|
|
@ -27,5 +30,6 @@ class Dataset(Base):
|
|||
"name": self.name,
|
||||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"ownerId": str(self.owner_id),
|
||||
"data": [data.to_json() for data in self.data]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
async def ensure_dataset_exists(dataset_name: str) -> Dataset:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = (await session.scalars(
|
||||
select(Dataset)\
|
||||
.options(joinedload(Dataset.data))\
|
||||
.filter(Dataset.name == dataset_name)
|
||||
)).first()
|
||||
|
||||
if dataset is None:
|
||||
dataset = Dataset(
|
||||
name = dataset_name,
|
||||
data = []
|
||||
)
|
||||
|
||||
session.add(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return dataset
|
||||
|
|
@ -1,9 +1,16 @@
|
|||
# from fastapi import Depends
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
# from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from .models.User import User
|
||||
|
||||
async def get_user_db(session):
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
yield session
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
|
|||
|
|
@ -2,16 +2,33 @@ import os
|
|||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Request
|
||||
from fastapi_users import BaseUserManager, UUIDIDMixin
|
||||
from fastapi_users.exceptions import UserNotExists
|
||||
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
|
||||
from .get_user_db import get_user_db
|
||||
from .models import User
|
||||
from .methods import get_user
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
||||
verification_token_secret = os.getenv("FASTAPI_USERS_VERIFICATION_TOKEN_SECRET", "super_secret")
|
||||
|
||||
async def get(self, id: models.ID) -> models.UP:
|
||||
"""
|
||||
Get a user by id.
|
||||
|
||||
:param id: Id. of the user to retrieve.
|
||||
:raises UserNotExists: The user does not exist.
|
||||
:return: A user.
|
||||
"""
|
||||
user = await get_user(id)
|
||||
|
||||
if user is None:
|
||||
raise UserNotExists()
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from .get_user import get_user
|
||||
from .create_user import create_user
|
||||
from .get_default_user import get_default_user
|
||||
from .create_default_user import create_default_user
|
||||
from .get_authenticated_user import get_authenticated_user
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import hashlib
|
||||
from .create_user import create_user
|
||||
|
||||
async def create_default_user():
|
||||
|
|
@ -7,14 +6,11 @@ async def create_default_user():
|
|||
|
||||
user = await create_user(
|
||||
email = default_user_email,
|
||||
password = await hash_password(default_user_password),
|
||||
is_superuser = True,
|
||||
password = default_user_password,
|
||||
is_superuser = False,
|
||||
is_active = True,
|
||||
is_verified = True,
|
||||
auto_login = True,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def hash_password(password: str) -> str:
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
|
|
|
|||
5
cognee/modules/users/methods/get_authenticated_user.py
Normal file
5
cognee/modules/users/methods/get_authenticated_user.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from ..get_fastapi_users import get_fastapi_users
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
||||
get_authenticated_user = fastapi_users.current_user(active = True, verified = True)
|
||||
|
|
@ -2,6 +2,7 @@ from sqlalchemy.orm import joinedload
|
|||
from sqlalchemy.future import select
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from .create_default_user import create_default_user
|
||||
|
||||
async def get_default_user():
|
||||
db_engine = get_relational_engine()
|
||||
|
|
@ -13,4 +14,7 @@ async def get_default_user():
|
|||
result = await session.execute(query)
|
||||
user = result.scalars().first()
|
||||
|
||||
if user is None:
|
||||
return await create_default_user()
|
||||
|
||||
return user
|
||||
|
|
|
|||
15
cognee/modules/users/methods/get_user.py
Normal file
15
cognee/modules/users/methods/get_user.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..models import User
|
||||
|
||||
async def get_user(user_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
user = (await session.execute(
|
||||
select(User).options(joinedload(User.groups)).where(User.id == user_id)
|
||||
)).scalar()
|
||||
|
||||
return user
|
||||
|
|
@ -16,25 +16,21 @@ class PermissionDeniedException(Exception):
|
|||
|
||||
|
||||
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||
try:
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(ACL)
|
||||
.join(ACL.permission)
|
||||
.options(joinedload(ACL.resources))
|
||||
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
|
||||
.where(ACL.permission.has(name = permission_type))
|
||||
)
|
||||
acls = result.unique().scalars().all()
|
||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(ACL)
|
||||
.join(ACL.permission)
|
||||
.options(joinedload(ACL.resources))
|
||||
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
|
||||
.where(ACL.permission.has(name = permission_type))
|
||||
)
|
||||
acls = result.unique().scalars().all()
|
||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
|
||||
if not has_permissions:
|
||||
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
|
||||
except Exception as error:
|
||||
logger.error("Error checking permissions on documents: %s", str(error))
|
||||
raise
|
||||
if not has_permissions:
|
||||
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
# from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue