feat: user authentication in routes (#133)
* feat: require logged in user in routes
This commit is contained in:
parent
22c0dd5b2d
commit
e1a0b55a21
40 changed files with 526 additions and 170 deletions
16
cognee-frontend/src/app/auth/AuthPage.module.css
Normal file
16
cognee-frontend/src/app/auth/AuthPage.module.css
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
.main {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
flex-direction: column;
|
||||||
|
padding: 0;
|
||||||
|
min-height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
.authContainer {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
padding: 24px 0;
|
||||||
|
margin: 0 auto;
|
||||||
|
max-width: 440px;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
29
cognee-frontend/src/app/auth/AuthPage.tsx
Normal file
29
cognee-frontend/src/app/auth/AuthPage.tsx
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
import { Spacer, Stack, Text } from 'ohmy-ui';
|
||||||
|
import { TextLogo } from '@/ui/App';
|
||||||
|
import Footer from '@/ui/Partials/Footer/Footer';
|
||||||
|
|
||||||
|
import styles from './AuthPage.module.css';
|
||||||
|
import { Divider } from '@/ui/Layout';
|
||||||
|
import SignInForm from '@/ui/Partials/SignInForm/SignInForm';
|
||||||
|
|
||||||
|
export default function AuthPage() {
|
||||||
|
return (
|
||||||
|
<main className={styles.main}>
|
||||||
|
<Spacer inset vertical="1" horizontal="2">
|
||||||
|
<Stack orientation="horizontal" gap="between" align="center">
|
||||||
|
<TextLogo width={225} height={64} />
|
||||||
|
</Stack>
|
||||||
|
</Spacer>
|
||||||
|
<Divider />
|
||||||
|
<div className={styles.authContainer}>
|
||||||
|
<Stack gap="4" style={{ width: '100%' }}>
|
||||||
|
<h1><Text size="large">Sign in</Text></h1>
|
||||||
|
<SignInForm />
|
||||||
|
</Stack>
|
||||||
|
</div>
|
||||||
|
<Spacer inset horizontal="3" wrap>
|
||||||
|
<Footer />
|
||||||
|
</Spacer>
|
||||||
|
</main>
|
||||||
|
)
|
||||||
|
}
|
||||||
1
cognee-frontend/src/app/auth/page.tsx
Normal file
1
cognee-frontend/src/app/auth/page.tsx
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
export { default } from './AuthPage';
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export default function cognifyDataset(dataset: { id: string, name: string }) {
|
export default function cognifyDataset(dataset: { id: string, name: string }) {
|
||||||
return fetch('http://127.0.0.1:8000/cognify', {
|
return fetch('/v1/cognify', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export default function deleteDataset(dataset: { id: string }) {
|
export default function deleteDataset(dataset: { id: string }) {
|
||||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}`, {
|
return fetch(`/v1/datasets/${dataset.id}`, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export default function getDatasetData(dataset: { id: string }) {
|
export default function getDatasetData(dataset: { id: string }) {
|
||||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/data`)
|
return fetch(`/v1/datasets/${dataset.id}/data`)
|
||||||
.then((response) => response.json());
|
.then((response) => response.json());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export default function getExplorationGraphUrl(dataset: { id: string }) {
|
export default function getExplorationGraphUrl(dataset: { id: string }) {
|
||||||
return fetch(`http://127.0.0.1:8000/datasets/${dataset.id}/graph`)
|
return fetch(`/v1/datasets/${dataset.id}/graph`)
|
||||||
.then(async (response) => {
|
.then(async (response) => {
|
||||||
if (response.status !== 200) {
|
if (response.status !== 200) {
|
||||||
throw new Error((await response.text()).replaceAll("\"", ""));
|
throw new Error((await response.text()).replaceAll("\"", ""));
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,9 @@ import {
|
||||||
UploadInput,
|
UploadInput,
|
||||||
CloseIcon,
|
CloseIcon,
|
||||||
} from "ohmy-ui";
|
} from "ohmy-ui";
|
||||||
import styles from "./DataView.module.css";
|
import { fetch } from '@/utils';
|
||||||
import RawDataPreview from './RawDataPreview';
|
import RawDataPreview from './RawDataPreview';
|
||||||
|
import styles from "./DataView.module.css";
|
||||||
|
|
||||||
export interface Data {
|
export interface Data {
|
||||||
id: string;
|
id: string;
|
||||||
|
|
@ -37,7 +38,7 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
||||||
const showRawData = useCallback((dataItem: Data) => {
|
const showRawData = useCallback((dataItem: Data) => {
|
||||||
setSelectedData(dataItem);
|
setSelectedData(dataItem);
|
||||||
|
|
||||||
fetch(`http://127.0.0.1:8000/datasets/${datasetId}/data/${dataItem.id}/raw`)
|
fetch(`/v1/datasets/${datasetId}/data/${dataItem.id}/raw`)
|
||||||
.then((response) => response.arrayBuffer())
|
.then((response) => response.arrayBuffer())
|
||||||
.then(setRawData);
|
.then(setRawData);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export default function addData(dataset: { id: string }, files: File[]) {
|
export default function addData(dataset: { id: string }, files: File[]) {
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
files.forEach((file) => {
|
files.forEach((file) => {
|
||||||
|
|
@ -5,7 +7,7 @@ export default function addData(dataset: { id: string }, files: File[]) {
|
||||||
})
|
})
|
||||||
formData.append('datasetId', dataset.id);
|
formData.append('datasetId', dataset.id);
|
||||||
|
|
||||||
return fetch('http://127.0.0.1:8000/add', {
|
return fetch('/v1/add', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: formData,
|
body: formData,
|
||||||
}).then((response) => response.json());
|
}).then((response) => response.json());
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
import { v4 } from 'uuid';
|
import { v4 } from 'uuid';
|
||||||
import { DataFile } from './useData';
|
import { DataFile } from './useData';
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
export interface Dataset {
|
export interface Dataset {
|
||||||
id: string;
|
id: string;
|
||||||
|
|
@ -14,7 +15,14 @@ function useDatasets() {
|
||||||
const statusTimeout = useRef<any>(null);
|
const statusTimeout = useRef<any>(null);
|
||||||
|
|
||||||
const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
||||||
fetch(`http://127.0.0.1:8000/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`)
|
fetch(
|
||||||
|
`/v1/datasets/status?dataset=${datasets.map(d => d.id).join('&dataset=')}`,
|
||||||
|
{
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
.then((statuses) => setDatasets(
|
.then((statuses) => setDatasets(
|
||||||
(datasets) => (
|
(datasets) => (
|
||||||
|
|
@ -65,7 +73,11 @@ function useDatasets() {
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const fetchDatasets = useCallback(() => {
|
const fetchDatasets = useCallback(() => {
|
||||||
fetch('http://127.0.0.1:8000/datasets')
|
fetch('/v1/datasets', {
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
|
||||||
|
},
|
||||||
|
})
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
.then((datasets) => {
|
.then((datasets) => {
|
||||||
setDatasets(datasets);
|
setDatasets(datasets);
|
||||||
|
|
@ -75,6 +87,9 @@ function useDatasets() {
|
||||||
} else {
|
} else {
|
||||||
window.location.href = '/wizard';
|
window.location.href = '/wizard';
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Error fetching datasets:', error);
|
||||||
});
|
});
|
||||||
}, [checkDatasetStatuses]);
|
}, [checkDatasetStatuses]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import { v4 } from 'uuid';
|
||||||
import classNames from 'classnames';
|
import classNames from 'classnames';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { CTAButton, Stack, Text, DropdownSelect, TextArea, useBoolean } from 'ohmy-ui';
|
import { CTAButton, Stack, Text, DropdownSelect, TextArea, useBoolean } from 'ohmy-ui';
|
||||||
|
import { fetch } from '@/utils';
|
||||||
import styles from './SearchView.module.css';
|
import styles from './SearchView.module.css';
|
||||||
|
|
||||||
interface Message {
|
interface Message {
|
||||||
|
|
@ -50,7 +51,7 @@ export default function SearchView() {
|
||||||
},
|
},
|
||||||
]);
|
]);
|
||||||
|
|
||||||
fetch('http://localhost:8000/search', {
|
fetch('/v1/search', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,13 @@ import {
|
||||||
FormGroup,
|
FormGroup,
|
||||||
FormInput,
|
FormInput,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
H3,
|
|
||||||
Input,
|
Input,
|
||||||
Spacer,
|
Spacer,
|
||||||
Stack,
|
Stack,
|
||||||
useBoolean,
|
useBoolean,
|
||||||
} from 'ohmy-ui';
|
} from 'ohmy-ui';
|
||||||
import { LoadingIndicator } from '@/ui/App';
|
import { LoadingIndicator } from '@/ui/App';
|
||||||
|
import { fetch } from '@/utils';
|
||||||
|
|
||||||
interface SelectOption {
|
interface SelectOption {
|
||||||
label: string;
|
label: string;
|
||||||
|
|
@ -75,7 +75,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
|
||||||
|
|
||||||
startSaving();
|
startSaving();
|
||||||
|
|
||||||
fetch('http://127.0.0.1:8000/settings', {
|
fetch('/v1/settings', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
@ -138,7 +138,7 @@ export default function Settings({ onDone = () => {}, submitButtonText = 'Save'
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchConfig = async () => {
|
const fetchConfig = async () => {
|
||||||
const response = await fetch('http://127.0.0.1:8000/settings');
|
const response = await fetch('/v1/settings');
|
||||||
const settings = await response.json();
|
const settings = await response.json();
|
||||||
|
|
||||||
if (!settings.llm.model) {
|
if (!settings.llm.model) {
|
||||||
|
|
|
||||||
96
cognee-frontend/src/ui/Partials/SignInForm/SignInForm.tsx
Normal file
96
cognee-frontend/src/ui/Partials/SignInForm/SignInForm.tsx
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import {
|
||||||
|
CTAButton,
|
||||||
|
FormGroup,
|
||||||
|
FormInput,
|
||||||
|
FormLabel,
|
||||||
|
Input,
|
||||||
|
Spacer,
|
||||||
|
Stack,
|
||||||
|
Text,
|
||||||
|
useBoolean,
|
||||||
|
} from 'ohmy-ui';
|
||||||
|
import { LoadingIndicator } from '@/ui/App';
|
||||||
|
import { fetch, handleServerErrors } from '@/utils';
|
||||||
|
import { useState } from 'react';
|
||||||
|
|
||||||
|
interface SignInFormPayload extends HTMLFormElement {
|
||||||
|
vectorDBUrl: HTMLInputElement;
|
||||||
|
vectorDBApiKey: HTMLInputElement;
|
||||||
|
llmApiKey: HTMLInputElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
const errorsMap = {
|
||||||
|
LOGIN_BAD_CREDENTIALS: 'Invalid username or password',
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function SignInForm({ onSignInSuccess = () => window.location.href = '/', submitButtonText = 'Sign in' }) {
|
||||||
|
const {
|
||||||
|
value: isSigningIn,
|
||||||
|
setTrue: disableSignIn,
|
||||||
|
setFalse: enableSignIn,
|
||||||
|
} = useBoolean(false);
|
||||||
|
|
||||||
|
const [signInError, setSignInError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||||
|
event.preventDefault();
|
||||||
|
const formElements = event.currentTarget;
|
||||||
|
|
||||||
|
const authCredentials = new FormData();
|
||||||
|
// Backend expects username and password fields
|
||||||
|
authCredentials.append("username", formElements.email.value);
|
||||||
|
authCredentials.append("password", formElements.password.value);
|
||||||
|
|
||||||
|
setSignInError(null);
|
||||||
|
disableSignIn();
|
||||||
|
|
||||||
|
fetch('/v1/auth/login', {
|
||||||
|
method: 'POST',
|
||||||
|
body: authCredentials,
|
||||||
|
})
|
||||||
|
.then(handleServerErrors)
|
||||||
|
.then(response => response.json())
|
||||||
|
.then((bearer) => {
|
||||||
|
window.localStorage.setItem('access_token', bearer.access_token);
|
||||||
|
onSignInSuccess();
|
||||||
|
})
|
||||||
|
.catch(error => setSignInError(errorsMap[error.detail as keyof typeof errorsMap]))
|
||||||
|
.finally(() => enableSignIn());
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form onSubmit={signIn} style={{ width: '100%' }}>
|
||||||
|
<Stack gap="4" orientation="vertical">
|
||||||
|
<Stack gap="4" orientation="vertical">
|
||||||
|
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||||
|
<FormLabel>Email:</FormLabel>
|
||||||
|
<FormInput>
|
||||||
|
<Input name="email" type="email" placeholder="Your email address" />
|
||||||
|
</FormInput>
|
||||||
|
</FormGroup>
|
||||||
|
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||||
|
<FormLabel>Password:</FormLabel>
|
||||||
|
<FormInput>
|
||||||
|
<Input name="password" type="password" placeholder="Your password" />
|
||||||
|
</FormInput>
|
||||||
|
</FormGroup>
|
||||||
|
</Stack>
|
||||||
|
|
||||||
|
<Spacer top="2">
|
||||||
|
<CTAButton type="submit">
|
||||||
|
<Stack gap="2" orientation="horizontal" align="/center">
|
||||||
|
{submitButtonText}
|
||||||
|
{isSigningIn && <LoadingIndicator />}
|
||||||
|
</Stack>
|
||||||
|
</CTAButton>
|
||||||
|
</Spacer>
|
||||||
|
|
||||||
|
{signInError && (
|
||||||
|
<Text>{signInError}</Text>
|
||||||
|
)}
|
||||||
|
</Stack>
|
||||||
|
</form>
|
||||||
|
)
|
||||||
|
}
|
||||||
12
cognee-frontend/src/utils/fetch.ts
Normal file
12
cognee-frontend/src/utils/fetch.ts
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import handleServerErrors from './handleServerErrors';
|
||||||
|
|
||||||
|
export default function fetch(url: string, options: RequestInit = {}): Promise<Response> {
|
||||||
|
return global.fetch('http://127.0.0.1:8000/api' + url, {
|
||||||
|
...options,
|
||||||
|
headers: {
|
||||||
|
...options.headers,
|
||||||
|
'Authorization': `Bearer ${localStorage.getItem('access_token')}`,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.then(handleServerErrors);
|
||||||
|
}
|
||||||
13
cognee-frontend/src/utils/handleServerErrors.ts
Normal file
13
cognee-frontend/src/utils/handleServerErrors.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
export default function handleServerErrors(response: Response): Promise<Response> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
if (response.status === 401) {
|
||||||
|
window.location.href = '/auth';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!response.ok) {
|
||||||
|
return response.json().then(error => reject(error));
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolve(response);
|
||||||
|
});
|
||||||
|
}
|
||||||
2
cognee-frontend/src/utils/index.ts
Normal file
2
cognee-frontend/src/utils/index.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
export { default as fetch } from './fetch';
|
||||||
|
export { default as handleServerErrors } from './handleServerErrors';
|
||||||
|
|
@ -2,15 +2,17 @@
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from typing import Dict, Any, List, Union, Optional, Literal
|
from typing import Dict, Any, List, Union, Optional, Literal
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
|
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query, Depends
|
||||||
from fastapi.responses import JSONResponse, FileResponse
|
from fastapi.responses import JSONResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||||
|
|
||||||
|
|
@ -42,7 +44,6 @@ origins = [
|
||||||
"http://127.0.0.1:3000",
|
"http://127.0.0.1:3000",
|
||||||
"http://frontend:3000",
|
"http://frontend:3000",
|
||||||
"http://localhost:3000",
|
"http://localhost:3000",
|
||||||
"http://localhost:3001",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
@ -58,39 +59,57 @@ from cognee.api.v1.users.routers import get_auth_router, get_register_router,\
|
||||||
|
|
||||||
from cognee.api.v1.permissions.get_permissions_router import get_permissions_router
|
from cognee.api.v1.permissions.get_permissions_router import get_permissions_router
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
if request.url.path == "/api/v1/auth/login":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 400,
|
||||||
|
content = {"detail": "LOGIN_BAD_CREDENTIALS"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 400,
|
||||||
|
content = jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||||
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_auth_router(),
|
get_auth_router(),
|
||||||
prefix = "/auth/jwt",
|
prefix = "/api/v1/auth",
|
||||||
tags = ["auth"]
|
tags = ["auth"]
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_register_router(),
|
get_register_router(),
|
||||||
prefix = "/auth",
|
prefix = "/api/v1/auth",
|
||||||
tags = ["auth"],
|
tags = ["auth"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_reset_password_router(),
|
get_reset_password_router(),
|
||||||
prefix = "/auth",
|
prefix = "/api/v1/auth",
|
||||||
tags = ["auth"],
|
tags = ["auth"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_verify_router(),
|
get_verify_router(),
|
||||||
prefix = "/auth",
|
prefix = "/api/v1/auth",
|
||||||
tags = ["auth"],
|
tags = ["auth"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_users_router(),
|
get_users_router(),
|
||||||
prefix = "/users",
|
prefix = "/api/v1/users",
|
||||||
tags = ["users"],
|
tags = ["users"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
get_permissions_router(),
|
get_permissions_router(),
|
||||||
prefix = "/permissions",
|
prefix = "/api/v1/permissions",
|
||||||
tags = ["permissions"],
|
tags = ["permissions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -108,31 +127,42 @@ def health_check():
|
||||||
"""
|
"""
|
||||||
return {"status": "OK"}
|
return {"status": "OK"}
|
||||||
|
|
||||||
@app.get("/datasets", response_model = list)
|
@app.get("/api/v1/datasets", response_model = list)
|
||||||
async def get_datasets():
|
async def get_datasets(user: User = Depends(get_authenticated_user)):
|
||||||
try:
|
try:
|
||||||
from cognee.api.v1.datasets.datasets import datasets
|
from cognee.modules.data.methods import get_datasets
|
||||||
datasets = await datasets.list_datasets()
|
datasets = await get_datasets(user.id)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
content = [dataset.to_json() for dataset in datasets],
|
content = [dataset.to_json() for dataset in datasets],
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise HTTPException(status_code = 500, detail=f"Error retrieving datasets: {str(error)}") from error
|
raise HTTPException(status_code = 500, detail = f"Error retrieving datasets: {str(error)}") from error
|
||||||
|
|
||||||
@app.delete("/datasets/{dataset_id}", response_model = dict)
|
@app.delete("/api/v1/datasets/{dataset_id}", response_model = dict)
|
||||||
async def delete_dataset(dataset_id: str):
|
async def delete_dataset(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.api.v1.datasets.datasets import datasets
|
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||||
await datasets.delete_dataset(dataset_id)
|
|
||||||
|
dataset = get_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 404,
|
||||||
|
content = {
|
||||||
|
"detail": f"Dataset ({dataset_id}) not found."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_dataset(dataset)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
content = "OK",
|
content = "OK",
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/graph", response_model=list)
|
@app.get("/api/v1/datasets/{dataset_id}/graph", response_model=list)
|
||||||
async def get_dataset_graph(dataset_id: str):
|
async def get_dataset_graph(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.shared.utils import render_graph
|
from cognee.shared.utils import render_graph
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
|
@ -150,21 +180,31 @@ async def get_dataset_graph(dataset_id: str):
|
||||||
content = "Graphistry credentials are not set. Please set them in your .env file.",
|
content = "Graphistry credentials are not set. Please set them in your .env file.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
@app.get("/api/v1/datasets/{dataset_id}/data", response_model=list)
|
||||||
async def get_dataset_data(dataset_id: str):
|
async def get_dataset_data(dataset_id: str, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.api.v1.datasets.datasets import datasets
|
from cognee.modules.data.methods import get_dataset_data, get_dataset
|
||||||
|
|
||||||
dataset_data = await datasets.list_data(dataset_id = dataset_id)
|
dataset = await get_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 404,
|
||||||
|
content = {
|
||||||
|
"detail": f"Dataset ({dataset_id}) not found."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_data = await get_dataset_data(dataset_id = dataset.id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset.id}) not found.")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
data.to_json() for data in dataset_data
|
data.to_json() for data in dataset_data
|
||||||
]
|
]
|
||||||
|
|
||||||
@app.get("/datasets/status", response_model=dict)
|
@app.get("/api/v1/datasets/status", response_model=dict)
|
||||||
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
|
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -180,15 +220,35 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
|
||||||
content = {"error": str(error)}
|
content = {"error": str(error)}
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
@app.get("/api/v1/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||||
async def get_raw_data(dataset_id: str, data_id: str):
|
async def get_raw_data(dataset_id: str, data_id: str, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.api.v1.datasets.datasets import datasets
|
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||||
dataset_data = await datasets.list_data(dataset_id)
|
|
||||||
|
dataset = await get_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 404,
|
||||||
|
content = {
|
||||||
|
"detail": f"Dataset ({dataset_id}) not found."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_data = await get_dataset_data(dataset.id)
|
||||||
|
|
||||||
if dataset_data is None:
|
if dataset_data is None:
|
||||||
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
raise HTTPException(status_code = 404, detail = f"Dataset ({dataset_id}) not found.")
|
||||||
|
|
||||||
data = [data for data in dataset_data if str(data.id) == data_id][0]
|
data = [data for data in dataset_data if str(data.id) == data_id][0]
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 404,
|
||||||
|
content = {
|
||||||
|
"detail": f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return data.raw_data_location
|
return data.raw_data_location
|
||||||
|
|
||||||
class AddPayload(BaseModel):
|
class AddPayload(BaseModel):
|
||||||
|
|
@ -197,10 +257,11 @@ class AddPayload(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@app.post("/add", response_model=dict)
|
@app.post("/api/v1/add", response_model=dict)
|
||||||
async def add(
|
async def add(
|
||||||
data: List[UploadFile],
|
data: List[UploadFile],
|
||||||
datasetId: str = Form(...),
|
datasetId: str = Form(...),
|
||||||
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
""" This endpoint is responsible for adding data to the graph."""
|
""" This endpoint is responsible for adding data to the graph."""
|
||||||
from cognee.api.v1.add import add as cognee_add
|
from cognee.api.v1.add import add as cognee_add
|
||||||
|
|
@ -230,6 +291,7 @@ async def add(
|
||||||
await cognee_add(
|
await cognee_add(
|
||||||
data,
|
data,
|
||||||
datasetId,
|
datasetId,
|
||||||
|
user = user,
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
|
|
@ -246,12 +308,12 @@ async def add(
|
||||||
class CognifyPayload(BaseModel):
|
class CognifyPayload(BaseModel):
|
||||||
datasets: List[str]
|
datasets: List[str]
|
||||||
|
|
||||||
@app.post("/cognify", response_model=dict)
|
@app.post("/api/v1/cognify", response_model=dict)
|
||||||
async def cognify(payload: CognifyPayload):
|
async def cognify(payload: CognifyPayload, user: User = Depends(get_authenticated_user)):
|
||||||
""" This endpoint is responsible for the cognitive processing of the content."""
|
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||||
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
||||||
try:
|
try:
|
||||||
await cognee_cognify(payload.datasets)
|
await cognee_cognify(payload.datasets, user)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
content = {
|
content = {
|
||||||
|
|
@ -267,8 +329,8 @@ async def cognify(payload: CognifyPayload):
|
||||||
class SearchPayload(BaseModel):
|
class SearchPayload(BaseModel):
|
||||||
query_params: Dict[str, Any]
|
query_params: Dict[str, Any]
|
||||||
|
|
||||||
@app.post("/search", response_model=dict)
|
@app.post("/api/v1/search", response_model=dict)
|
||||||
async def search(payload: SearchPayload):
|
async def search(payload: SearchPayload, user: User = Depends(get_authenticated_user)):
|
||||||
""" This endpoint is responsible for searching for nodes in the graph."""
|
""" This endpoint is responsible for searching for nodes in the graph."""
|
||||||
from cognee.api.v1.search import search as cognee_search
|
from cognee.api.v1.search import search as cognee_search
|
||||||
try:
|
try:
|
||||||
|
|
@ -290,8 +352,8 @@ async def search(payload: SearchPayload):
|
||||||
content = {"error": str(error)}
|
content = {"error": str(error)}
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/settings", response_model=dict)
|
@app.get("/api/v1/settings", response_model=dict)
|
||||||
async def get_settings():
|
async def get_settings(user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.modules.settings import get_settings as get_cognee_settings
|
from cognee.modules.settings import get_settings as get_cognee_settings
|
||||||
return get_cognee_settings()
|
return get_cognee_settings()
|
||||||
|
|
||||||
|
|
@ -309,8 +371,8 @@ class SettingsPayload(BaseModel):
|
||||||
llm: Optional[LLMConfig] = None
|
llm: Optional[LLMConfig] = None
|
||||||
vectorDB: Optional[VectorDBConfig] = None
|
vectorDB: Optional[VectorDBConfig] = None
|
||||||
|
|
||||||
@app.post("/settings", response_model=dict)
|
@app.post("/api/v1/settings", response_model=dict)
|
||||||
async def save_config(new_settings: SettingsPayload):
|
async def save_config(new_settings: SettingsPayload, user: User = Depends(get_authenticated_user)):
|
||||||
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
||||||
if new_settings.llm is not None:
|
if new_settings.llm is not None:
|
||||||
await save_llm_config(new_settings.llm)
|
await save_llm_config(new_settings.llm)
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ from cognee.modules.ingestion import get_matched_datasets, save_data_to_file
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config, get_relational_engine, create_db_and_tables
|
from cognee.infrastructure.databases.relational import get_relational_config, get_relational_engine, create_db_and_tables
|
||||||
from cognee.modules.users.methods import create_default_user, get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.data.operations.ensure_dataset_exists import ensure_dataset_exists
|
from cognee.modules.data.methods import create_dataset
|
||||||
|
|
||||||
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None):
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
|
|
@ -55,7 +55,10 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def add_files(file_paths: List[str], dataset_name: str, user):
|
async def add_files(file_paths: List[str], dataset_name: str, user: User = None):
|
||||||
|
if user is None:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
data_directory_path = base_config.data_root_directory
|
data_directory_path = base_config.data_root_directory
|
||||||
|
|
||||||
|
|
@ -101,7 +104,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
|
dataset_name = dataset_name.replace(" ", "_").replace(".", "_") if dataset_name is not None else "main_dataset"
|
||||||
dataset = await ensure_dataset_exists(dataset_name)
|
|
||||||
|
|
||||||
@dlt.resource(standalone = True, merge_key = "id")
|
@dlt.resource(standalone = True, merge_key = "id")
|
||||||
async def data_resources(file_paths: str, user: User):
|
async def data_resources(file_paths: str, user: User):
|
||||||
|
|
@ -115,8 +117,12 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
|
dataset = await create_dataset(dataset_name, user.id, session)
|
||||||
|
|
||||||
data = (await session.execute(
|
data = (await session.execute(
|
||||||
select(Data).filter(Data.id == data_id)
|
select(Data).filter(Data.id == data_id)
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
|
|
@ -137,10 +143,8 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
||||||
extension = file_metadata["extension"],
|
extension = file_metadata["extension"],
|
||||||
mime_type = file_metadata["mime_type"],
|
mime_type = file_metadata["mime_type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset.data.append(data)
|
dataset.data.append(data)
|
||||||
|
|
||||||
await session.merge(dataset)
|
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
|
|
@ -155,12 +159,6 @@ async def add_files(file_paths: List[str], dataset_name: str, user):
|
||||||
await give_permission_on_document(user, data_id, "write")
|
await give_permission_on_document(user, data_id, "write")
|
||||||
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
user = await get_default_user()
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
user = await create_default_user()
|
|
||||||
|
|
||||||
run_info = pipeline.run(
|
run_info = pipeline.run(
|
||||||
data_resources(processed_file_paths, user),
|
data_resources(processed_file_paths, user),
|
||||||
table_name = "file_metadata",
|
table_name = "file_metadata",
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,10 @@ import logging
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.data.models import Dataset, Data
|
from cognee.modules.data.models import Dataset, Data
|
||||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
from cognee.modules.data.operations.retrieve_datasets import retrieve_datasets
|
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
@ -35,17 +34,18 @@ class PermissionDeniedException(Exception):
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||||
db_engine = get_relational_engine()
|
|
||||||
|
|
||||||
if datasets is None or len(datasets) == 0:
|
|
||||||
return await cognify(await db_engine.get_datasets())
|
|
||||||
|
|
||||||
if type(datasets[0]) == str:
|
|
||||||
datasets = await retrieve_datasets(datasets)
|
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
existing_datasets = await get_datasets(user.id)
|
||||||
|
|
||||||
|
if datasets is None or len(datasets) == 0:
|
||||||
|
# If no datasets are provided, cognify all existing datasets.
|
||||||
|
datasets = existing_datasets
|
||||||
|
|
||||||
|
if type(datasets[0]) == str:
|
||||||
|
datasets = await get_datasets_by_name(datasets, user.id)
|
||||||
|
|
||||||
async def run_cognify_pipeline(dataset: Dataset):
|
async def run_cognify_pipeline(dataset: Dataset):
|
||||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||||
|
|
||||||
|
|
@ -112,13 +112,16 @@ async def cognify(datasets: Union[str, list[str]] = None, user: User = None):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
|
||||||
existing_datasets = [dataset.name for dataset in list(await db_engine.get_datasets())]
|
existing_datasets_map = {
|
||||||
|
generate_dataset_name(dataset.name): True for dataset in existing_datasets
|
||||||
|
}
|
||||||
|
|
||||||
awaitables = []
|
awaitables = []
|
||||||
|
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
dataset_name = generate_dataset_name(dataset.name)
|
dataset_name = generate_dataset_name(dataset.name)
|
||||||
|
|
||||||
if dataset_name in existing_datasets:
|
if dataset_name in existing_datasets_map:
|
||||||
awaitables.append(run_cognify_pipeline(dataset))
|
awaitables.append(run_cognify_pipeline(dataset))
|
||||||
|
|
||||||
return await asyncio.gather(*awaitables)
|
return await asyncio.gather(*awaitables)
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,34 @@
|
||||||
from duckdb import CatalogException
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.ingestion import discover_directory_datasets
|
from cognee.modules.ingestion import discover_directory_datasets
|
||||||
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
|
|
||||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
||||||
|
|
||||||
class datasets():
|
class datasets():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def list_datasets():
|
async def list_datasets():
|
||||||
db = get_relational_engine()
|
from cognee.modules.data.methods import get_datasets
|
||||||
return await db.get_datasets()
|
user = await get_default_user()
|
||||||
|
return await get_datasets(user.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def discover_datasets(directory_path: str):
|
def discover_datasets(directory_path: str):
|
||||||
return list(discover_directory_datasets(directory_path).keys())
|
return list(discover_directory_datasets(directory_path).keys())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def list_data(dataset_id: str, dataset_name: str = None):
|
async def list_data(dataset_id: str):
|
||||||
try:
|
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||||
return await get_dataset_data(dataset_id = dataset_id, dataset_name = dataset_name)
|
user = await get_default_user()
|
||||||
except CatalogException:
|
|
||||||
return None
|
dataset = await get_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
|
return await get_dataset_data(dataset.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_status(dataset_ids: list[str]) -> dict:
|
async def get_status(dataset_ids: list[str]) -> dict:
|
||||||
try:
|
return await get_pipeline_status(dataset_ids)
|
||||||
return await get_pipeline_status(dataset_ids)
|
|
||||||
except CatalogException:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def delete_dataset(dataset_id: str):
|
async def delete_dataset(dataset_id: str):
|
||||||
db = get_relational_engine()
|
from cognee.modules.data.methods import get_dataset, delete_dataset
|
||||||
try:
|
user = await get_default_user()
|
||||||
return await db.delete_table(dataset_id)
|
dataset = await get_dataset(user.id, dataset_id)
|
||||||
except CatalogException:
|
|
||||||
return {}
|
return await delete_dataset(dataset)
|
||||||
|
|
|
||||||
|
|
@ -228,11 +228,21 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
# Log that the file does not exist and an empty graph is initialized
|
# Log that the file does not exist and an empty graph is initialized
|
||||||
logger.warning("File %s not found. Initializing an empty graph.", file_path)
|
logger.warning("File %s not found. Initializing an empty graph.", file_path)
|
||||||
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
|
self.graph = nx.MultiDiGraph() # Use MultiDiGraph to keep it consistent with __init__
|
||||||
|
|
||||||
|
file_dir = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(file_dir):
|
||||||
|
os.makedirs(file_dir, exist_ok = True)
|
||||||
|
|
||||||
await self.save_graph_to_file(file_path)
|
await self.save_graph_to_file(file_path)
|
||||||
except Exception as error:
|
except Exception:
|
||||||
logger.error("Failed to load graph from file: %s", file_path)
|
logger.error("Failed to load graph from file: %s", file_path)
|
||||||
# Initialize an empty graph in case of error
|
# Initialize an empty graph in case of error
|
||||||
self.graph = nx.MultiDiGraph()
|
self.graph = nx.MultiDiGraph()
|
||||||
|
|
||||||
|
file_dir = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(file_dir):
|
||||||
|
os.makedirs(file_dir, exist_ok = True)
|
||||||
|
|
||||||
await self.save_graph_to_file(file_path)
|
await self.save_graph_to_file(file_path)
|
||||||
|
|
||||||
async def delete_graph(self, file_path: str = None):
|
async def delete_graph(self, file_path: str = None):
|
||||||
|
|
|
||||||
11
cognee/modules/data/methods/__init__.py
Normal file
11
cognee/modules/data/methods/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Create
|
||||||
|
from .create_dataset import create_dataset
|
||||||
|
|
||||||
|
# Get
|
||||||
|
from .get_dataset import get_dataset
|
||||||
|
from .get_datasets import get_datasets
|
||||||
|
from .get_datasets_by_name import get_datasets_by_name
|
||||||
|
from .get_dataset_data import get_dataset_data
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
from .delete_dataset import delete_dataset
|
||||||
27
cognee/modules/data/methods/create_dataset.py
Normal file
27
cognee/modules/data/methods/create_dataset.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
|
|
||||||
|
async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSession) -> Dataset:
|
||||||
|
dataset = (await session.scalars(
|
||||||
|
select(Dataset)\
|
||||||
|
.options(joinedload(Dataset.data))\
|
||||||
|
.filter(Dataset.name == dataset_name)
|
||||||
|
.filter(Dataset.owner_id == owner_id)
|
||||||
|
)).first()
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset = Dataset(
|
||||||
|
id = uuid5(NAMESPACE_OID, dataset_name),
|
||||||
|
name = dataset_name,
|
||||||
|
data = []
|
||||||
|
)
|
||||||
|
dataset.owner_id = owner_id
|
||||||
|
|
||||||
|
session.add(dataset)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return dataset
|
||||||
7
cognee/modules/data/methods/delete_dataset.py
Normal file
7
cognee/modules/data/methods/delete_dataset.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
async def delete_dataset(dataset: Dataset):
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
return await db_engine.delete_table(dataset.id)
|
||||||
14
cognee/modules/data/methods/get_dataset.py
Normal file
14
cognee/modules/data/methods/get_dataset.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from ..models import Dataset
|
||||||
|
|
||||||
|
async def get_dataset(user_id: UUID, dataset_id: UUID) -> Dataset:
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
dataset = await session.get(Dataset, dataset_id)
|
||||||
|
|
||||||
|
if dataset and dataset.owner_id != user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
@ -3,16 +3,14 @@ from sqlalchemy import select
|
||||||
from cognee.modules.data.models import Data, Dataset
|
from cognee.modules.data.models import Data, Dataset
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
async def get_dataset_data(dataset_id: UUID = None, dataset_name: str = None):
|
async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
||||||
if dataset_id is None and dataset_name is None:
|
|
||||||
raise ValueError("get_dataset_data: Either dataset_id or dataset_name must be provided.")
|
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id) | (Dataset.name == dataset_name))
|
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
||||||
)
|
)
|
||||||
|
|
||||||
data = result.scalars().all()
|
data = result.scalars().all()
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
|
from uuid import UUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from ..models import Dataset
|
from ..models import Dataset
|
||||||
|
|
||||||
async def retrieve_datasets(dataset_names: list[str]) -> list[Dataset]:
|
async def get_datasets(user_id: UUID) -> list[Dataset]:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
datasets = (await session.scalars(
|
datasets = (await session.scalars(
|
||||||
select(Dataset).filter(Dataset.name.in_(dataset_names))
|
select(Dataset).filter(Dataset.owner_id == user_id)
|
||||||
)).all()
|
)).all()
|
||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
16
cognee/modules/data/methods/get_datasets_by_name.py
Normal file
16
cognee/modules/data/methods/get_datasets_by_name.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from ..models import Dataset
|
||||||
|
|
||||||
|
async def get_datasets_by_name(dataset_names: list[str], user_id: UUID) -> list[Dataset]:
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
datasets = (await session.scalars(
|
||||||
|
select(Dataset)
|
||||||
|
.filter(Dataset.owner_id == user_id)
|
||||||
|
.filter(Dataset.name.in_(dataset_names))
|
||||||
|
)).all()
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
@ -21,7 +21,8 @@ class Data(Base):
|
||||||
|
|
||||||
datasets: Mapped[List["Dataset"]] = relationship(
|
datasets: Mapped[List["Dataset"]] = relationship(
|
||||||
secondary = DatasetData.__tablename__,
|
secondary = DatasetData.__tablename__,
|
||||||
back_populates = "data"
|
back_populates = "data",
|
||||||
|
lazy = "noload",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_json(self) -> dict:
|
def to_json(self) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,12 @@ class Dataset(Base):
|
||||||
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc))
|
updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
owner_id = Column(UUID, index = True)
|
||||||
|
|
||||||
data: Mapped[List["Data"]] = relationship(
|
data: Mapped[List["Data"]] = relationship(
|
||||||
secondary = DatasetData.__tablename__,
|
secondary = DatasetData.__tablename__,
|
||||||
back_populates = "datasets"
|
back_populates = "datasets",
|
||||||
|
lazy = "noload",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_json(self) -> dict:
|
def to_json(self) -> dict:
|
||||||
|
|
@ -27,5 +30,6 @@ class Dataset(Base):
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"createdAt": self.created_at.isoformat(),
|
"createdAt": self.created_at.isoformat(),
|
||||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"ownerId": str(self.owner_id),
|
||||||
"data": [data.to_json() for data in self.data]
|
"data": [data.to_json() for data in self.data]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
from cognee.modules.data.models import Dataset
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
||||||
|
|
||||||
async def ensure_dataset_exists(dataset_name: str) -> Dataset:
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
dataset = (await session.scalars(
|
|
||||||
select(Dataset)\
|
|
||||||
.options(joinedload(Dataset.data))\
|
|
||||||
.filter(Dataset.name == dataset_name)
|
|
||||||
)).first()
|
|
||||||
|
|
||||||
if dataset is None:
|
|
||||||
dataset = Dataset(
|
|
||||||
name = dataset_name,
|
|
||||||
data = []
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(dataset)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
@ -1,9 +1,16 @@
|
||||||
# from fastapi import Depends
|
from typing import AsyncGenerator
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
# from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from .models.User import User
|
from .models.User import User
|
||||||
|
|
||||||
async def get_user_db(session):
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||||
yield SQLAlchemyUserDatabase(session, User)
|
yield SQLAlchemyUserDatabase(session, User)
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,33 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Depends, Request
|
from fastapi import Depends, Request
|
||||||
from fastapi_users import BaseUserManager, UUIDIDMixin
|
from fastapi_users.exceptions import UserNotExists
|
||||||
|
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
|
||||||
from .get_user_db import get_user_db
|
from .get_user_db import get_user_db
|
||||||
from .models import User
|
from .models import User
|
||||||
|
from .methods import get_user
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
reset_password_token_secret = os.getenv("FASTAPI_USERS_RESET_PASSWORD_TOKEN_SECRET", "super_secret")
|
||||||
verification_token_secret = os.getenv("FASTAPI_USERS_VERIFICATION_TOKEN_SECRET", "super_secret")
|
verification_token_secret = os.getenv("FASTAPI_USERS_VERIFICATION_TOKEN_SECRET", "super_secret")
|
||||||
|
|
||||||
|
async def get(self, id: models.ID) -> models.UP:
|
||||||
|
"""
|
||||||
|
Get a user by id.
|
||||||
|
|
||||||
|
:param id: Id. of the user to retrieve.
|
||||||
|
:raises UserNotExists: The user does not exist.
|
||||||
|
:return: A user.
|
||||||
|
"""
|
||||||
|
user = await get_user(id)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise UserNotExists()
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||||
print(f"User {user.id} has registered.")
|
print(f"User {user.id} has registered.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from .get_user import get_user
|
||||||
from .create_user import create_user
|
from .create_user import create_user
|
||||||
from .get_default_user import get_default_user
|
from .get_default_user import get_default_user
|
||||||
from .create_default_user import create_default_user
|
from .create_default_user import create_default_user
|
||||||
|
from .get_authenticated_user import get_authenticated_user
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import hashlib
|
|
||||||
from .create_user import create_user
|
from .create_user import create_user
|
||||||
|
|
||||||
async def create_default_user():
|
async def create_default_user():
|
||||||
|
|
@ -7,14 +6,11 @@ async def create_default_user():
|
||||||
|
|
||||||
user = await create_user(
|
user = await create_user(
|
||||||
email = default_user_email,
|
email = default_user_email,
|
||||||
password = await hash_password(default_user_password),
|
password = default_user_password,
|
||||||
is_superuser = True,
|
is_superuser = False,
|
||||||
is_active = True,
|
is_active = True,
|
||||||
is_verified = True,
|
is_verified = True,
|
||||||
auto_login = True,
|
auto_login = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def hash_password(password: str) -> str:
|
|
||||||
return hashlib.sha256(password.encode()).hexdigest()
|
|
||||||
|
|
|
||||||
5
cognee/modules/users/methods/get_authenticated_user.py
Normal file
5
cognee/modules/users/methods/get_authenticated_user.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from ..get_fastapi_users import get_fastapi_users
|
||||||
|
|
||||||
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
||||||
|
get_authenticated_user = fastapi_users.current_user(active = True, verified = True)
|
||||||
|
|
@ -2,6 +2,7 @@ from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from .create_default_user import create_default_user
|
||||||
|
|
||||||
async def get_default_user():
|
async def get_default_user():
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
@ -13,4 +14,7 @@ async def get_default_user():
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
user = result.scalars().first()
|
user = result.scalars().first()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
return await create_default_user()
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
|
||||||
15
cognee/modules/users/methods/get_user.py
Normal file
15
cognee/modules/users/methods/get_user.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from ..models import User
|
||||||
|
|
||||||
|
async def get_user(user_id: UUID):
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
user = (await session.execute(
|
||||||
|
select(User).options(joinedload(User.groups)).where(User.id == user_id)
|
||||||
|
)).scalar()
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
@ -16,25 +16,21 @@ class PermissionDeniedException(Exception):
|
||||||
|
|
||||||
|
|
||||||
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||||
try:
|
user_group_ids = [group.id for group in user.groups]
|
||||||
user_group_ids = [group.id for group in user.groups]
|
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(ACL)
|
select(ACL)
|
||||||
.join(ACL.permission)
|
.join(ACL.permission)
|
||||||
.options(joinedload(ACL.resources))
|
.options(joinedload(ACL.resources))
|
||||||
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
|
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
|
||||||
.where(ACL.permission.has(name = permission_type))
|
.where(ACL.permission.has(name = permission_type))
|
||||||
)
|
)
|
||||||
acls = result.unique().scalars().all()
|
acls = result.unique().scalars().all()
|
||||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||||
|
|
||||||
if not has_permissions:
|
if not has_permissions:
|
||||||
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
|
raise PermissionDeniedException(f"User {user.username} does not have {permission_type} permission on documents")
|
||||||
except Exception as error:
|
|
||||||
logger.error("Error checking permissions on documents: %s", str(error))
|
|
||||||
raise
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,6 @@
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||||
|
|
||||||
|
|
||||||
# from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
|
|
||||||
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue