version 0.3.4 (#1433)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
b7cec1b77d
64 changed files with 2397 additions and 1288 deletions
|
|
@ -47,6 +47,28 @@ BAML_LLM_API_VERSION=""
|
||||||
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
|
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
|
||||||
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
|
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# ☁️ Storage Backend Settings
|
||||||
|
################################################################################
|
||||||
|
# Configure storage backend (local filesystem or S3)
|
||||||
|
# STORAGE_BACKEND="local" # Default: uses local filesystem
|
||||||
|
#
|
||||||
|
# -- To switch to S3 storage, uncomment and fill these: ---------------------
|
||||||
|
# STORAGE_BACKEND="s3"
|
||||||
|
# STORAGE_BUCKET_NAME="your-bucket-name"
|
||||||
|
# AWS_REGION="us-east-1"
|
||||||
|
# AWS_ACCESS_KEY_ID="your-access-key"
|
||||||
|
# AWS_SECRET_ACCESS_KEY="your-secret-key"
|
||||||
|
#
|
||||||
|
# -- S3 Root Directories (optional) -----------------------------------------
|
||||||
|
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
|
||||||
|
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
|
||||||
|
#
|
||||||
|
# -- Cache Directory (auto-configured for S3) -------------------------------
|
||||||
|
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
|
||||||
|
# To override the automatic S3 cache location, uncomment:
|
||||||
|
# CACHE_ROOT_DIRECTORY="s3://your-bucket/cognee/cache"
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🗄️ Relational database settings
|
# 🗄️ Relational database settings
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
@ -94,7 +116,15 @@ VECTOR_DB_PROVIDER="lancedb"
|
||||||
VECTOR_DB_URL=
|
VECTOR_DB_URL=
|
||||||
VECTOR_DB_KEY=
|
VECTOR_DB_KEY=
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# 🧩 Ontology resolver settings
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# -- Ontology resolver params --------------------------------------
|
||||||
|
# ONTOLOGY_RESOLVER=rdflib # Default: uses rdflib and owl file to read ontology structures
|
||||||
|
# MATCHING_STRATEGY=fuzzy # Default: uses fuzzy matching with 80% similarity threshold
|
||||||
|
# ONTOLOGY_FILE_PATH=YOUR_FULL_FULE_PATH # Default: empty
|
||||||
|
# To add ontology resolvers, either set them as it is set in ontology_example or add full_path and settings as envs.
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🔄 MIGRATION (RELATIONAL → GRAPH) SETTINGS
|
# 🔄 MIGRATION (RELATIONAL → GRAPH) SETTINGS
|
||||||
|
|
@ -121,6 +151,9 @@ ACCEPT_LOCAL_FILE_PATH=True
|
||||||
# This protects against Server Side Request Forgery when proper infrastructure is not in place.
|
# This protects against Server Side Request Forgery when proper infrastructure is not in place.
|
||||||
ALLOW_HTTP_REQUESTS=True
|
ALLOW_HTTP_REQUESTS=True
|
||||||
|
|
||||||
|
# When set to false don't allow cypher search to be used in Cognee.
|
||||||
|
ALLOW_CYPHER_QUERY=True
|
||||||
|
|
||||||
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
||||||
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
||||||
|
|
||||||
|
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -186,6 +186,7 @@ cognee/cache/
|
||||||
# Default cognee system directory, used in development
|
# Default cognee system directory, used in development
|
||||||
.cognee_system/
|
.cognee_system/
|
||||||
.data_storage/
|
.data_storage/
|
||||||
|
.cognee_cache/
|
||||||
.artifacts/
|
.artifacts/
|
||||||
.anon_id
|
.anon_id
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ export default function Account() {
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="bg-gray-200 h-full max-w-[1920px] mx-auto">
|
<div className="h-full max-w-[1920px] mx-auto">
|
||||||
<video
|
{/* <video
|
||||||
autoPlay
|
autoPlay
|
||||||
loop
|
loop
|
||||||
muted
|
muted
|
||||||
|
|
@ -23,9 +23,9 @@ export default function Account() {
|
||||||
>
|
>
|
||||||
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
||||||
Your browser does not support the video tag.
|
Your browser does not support the video tag.
|
||||||
</video>
|
</video> */}
|
||||||
|
|
||||||
<Header />
|
<Header user={user} />
|
||||||
|
|
||||||
<div className="relative flex flex-row items-start gap-2.5">
|
<div className="relative flex flex-row items-start gap-2.5">
|
||||||
<Link href="/dashboard" className="flex-1/5 py-4 px-5 flex flex-row items-center gap-5">
|
<Link href="/dashboard" className="flex-1/5 py-4 px-5 flex flex-row items-center gap-5">
|
||||||
|
|
@ -42,7 +42,7 @@ export default function Account() {
|
||||||
<div>Plan</div>
|
<div>Plan</div>
|
||||||
<div className="text-sm text-gray-400 mb-8">You are using open-source version. Subscribe to get access to hosted cognee with your data!</div>
|
<div className="text-sm text-gray-400 mb-8">You are using open-source version. Subscribe to get access to hosted cognee with your data!</div>
|
||||||
<Link href="/plan">
|
<Link href="/plan">
|
||||||
<CTAButton><span className="">Select a plan</span></CTAButton>
|
<CTAButton className="w-full"><span className="">Select a plan</span></CTAButton>
|
||||||
</Link>
|
</Link>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import { FormEvent, useCallback, useState } from "react";
|
||||||
|
|
||||||
import { LoadingIndicator } from "@/ui/App";
|
import { LoadingIndicator } from "@/ui/App";
|
||||||
import { useModal } from "@/ui/elements/Modal";
|
import { useModal } from "@/ui/elements/Modal";
|
||||||
import { CloseIcon, PlusIcon } from "@/ui/Icons";
|
import { CloseIcon, MinusIcon, PlusIcon } from "@/ui/Icons";
|
||||||
import { CTAButton, GhostButton, IconButton, Modal, NeutralButton, Select } from "@/ui/elements";
|
import { CTAButton, GhostButton, IconButton, Modal, NeutralButton, Select } from "@/ui/elements";
|
||||||
|
|
||||||
import addData from "@/modules/ingestion/addData";
|
import addData from "@/modules/ingestion/addData";
|
||||||
|
|
@ -16,16 +16,22 @@ interface AddDataToCogneeProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function AddDataToCognee({ datasets, refreshDatasets, useCloud = false }: AddDataToCogneeProps) {
|
export default function AddDataToCognee({ datasets, refreshDatasets, useCloud = false }: AddDataToCogneeProps) {
|
||||||
const [filesForUpload, setFilesForUpload] = useState<FileList | null>(null);
|
const [filesForUpload, setFilesForUpload] = useState<File[]>([]);
|
||||||
|
|
||||||
const prepareFiles = useCallback((event: FormEvent<HTMLInputElement>) => {
|
const addFiles = useCallback((event: FormEvent<HTMLInputElement>) => {
|
||||||
const formElements = event.currentTarget;
|
const formElements = event.currentTarget;
|
||||||
const files = formElements.files;
|
const newFiles = formElements.files;
|
||||||
|
|
||||||
setFilesForUpload(files);
|
if (newFiles?.length) {
|
||||||
|
setFilesForUpload((oldFiles) => [...oldFiles, ...Array.from(newFiles)]);
|
||||||
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const processDataWithCognee = useCallback((state: object, event?: FormEvent<HTMLFormElement>) => {
|
const removeFile = useCallback((file: File) => {
|
||||||
|
setFilesForUpload((oldFiles) => oldFiles.filter((f) => f !== file));
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const processDataWithCognee = useCallback((state?: object, event?: FormEvent<HTMLFormElement>) => {
|
||||||
event!.preventDefault();
|
event!.preventDefault();
|
||||||
|
|
||||||
if (!filesForUpload) {
|
if (!filesForUpload) {
|
||||||
|
|
@ -41,7 +47,7 @@ export default function AddDataToCognee({ datasets, refreshDatasets, useCloud =
|
||||||
} : {
|
} : {
|
||||||
name: "main_dataset",
|
name: "main_dataset",
|
||||||
},
|
},
|
||||||
Array.from(filesForUpload),
|
filesForUpload,
|
||||||
useCloud
|
useCloud
|
||||||
)
|
)
|
||||||
.then(({ dataset_id, dataset_name }) => {
|
.then(({ dataset_id, dataset_name }) => {
|
||||||
|
|
@ -57,7 +63,7 @@ export default function AddDataToCognee({ datasets, refreshDatasets, useCloud =
|
||||||
useCloud,
|
useCloud,
|
||||||
)
|
)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
setFilesForUpload(null);
|
setFilesForUpload([]);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [filesForUpload, refreshDatasets, useCloud]);
|
}, [filesForUpload, refreshDatasets, useCloud]);
|
||||||
|
|
@ -86,24 +92,25 @@ export default function AddDataToCognee({ datasets, refreshDatasets, useCloud =
|
||||||
<div className="mt-8 mb-6">Please select a {useCloud ? "cloud" : "local"} dataset to add data in.<br/> If you don't have any, don't worry, we will create one for you.</div>
|
<div className="mt-8 mb-6">Please select a {useCloud ? "cloud" : "local"} dataset to add data in.<br/> If you don't have any, don't worry, we will create one for you.</div>
|
||||||
<form onSubmit={submitDataToCognee}>
|
<form onSubmit={submitDataToCognee}>
|
||||||
<div className="max-w-md flex flex-col gap-4">
|
<div className="max-w-md flex flex-col gap-4">
|
||||||
<Select name="datasetName">
|
<Select defaultValue={datasets.length ? datasets[0].id : ""} name="datasetName">
|
||||||
{!datasets.length && <option value="">main_dataset</option>}
|
{!datasets.length && <option value="">main_dataset</option>}
|
||||||
{datasets.map((dataset: Dataset, index) => (
|
{datasets.map((dataset: Dataset) => (
|
||||||
<option selected={index===0} key={dataset.id} value={dataset.id}>{dataset.name}</option>
|
<option key={dataset.id} value={dataset.id}>{dataset.name}</option>
|
||||||
))}
|
))}
|
||||||
</Select>
|
</Select>
|
||||||
|
|
||||||
<NeutralButton className="w-full relative justify-start pl-4">
|
<NeutralButton className="w-full relative justify-start pl-4">
|
||||||
<input onChange={prepareFiles} required name="files" tabIndex={-1} type="file" multiple className="absolute w-full h-full cursor-pointer opacity-0" />
|
<input onChange={addFiles} required name="files" tabIndex={-1} type="file" multiple className="absolute w-full h-full cursor-pointer opacity-0" />
|
||||||
<span>select files</span>
|
<span>select files</span>
|
||||||
</NeutralButton>
|
</NeutralButton>
|
||||||
|
|
||||||
{filesForUpload?.length && (
|
{!!filesForUpload.length && (
|
||||||
<div className="pt-4 mt-4 border-t-1 border-t-gray-100">
|
<div className="pt-4 mt-4 border-t-1 border-t-gray-100">
|
||||||
<div className="mb-1.5">selected files:</div>
|
<div className="mb-1.5">selected files:</div>
|
||||||
{Array.from(filesForUpload || []).map((file) => (
|
{filesForUpload.map((file) => (
|
||||||
<div key={file.name} className="py-1.5 pl-2">
|
<div key={file.name} className="py-1.5 pl-2 flex flex-row items-center justify-between w-full">
|
||||||
<span className="text-sm">{file.name}</span>
|
<span className="text-sm">{file.name}</span>
|
||||||
|
<IconButton onClick={removeFile.bind(null, file)}><MinusIcon /></IconButton>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import { Header } from "@/ui/Layout";
|
import { Header } from "@/ui/Layout";
|
||||||
import { SearchIcon } from "@/ui/Icons";
|
import { SearchIcon } from "@/ui/Icons";
|
||||||
import { Notebook } from "@/ui/elements";
|
import { CTAButton, Notebook } from "@/ui/elements";
|
||||||
import { fetch, isCloudEnvironment } from "@/utils";
|
import { fetch, isCloudEnvironment } from "@/utils";
|
||||||
import { Notebook as NotebookType } from "@/ui/elements/Notebook/types";
|
import { Notebook as NotebookType } from "@/ui/elements/Notebook/types";
|
||||||
import { useAuthenticatedUser } from "@/modules/auth";
|
import { useAuthenticatedUser } from "@/modules/auth";
|
||||||
|
|
@ -111,8 +111,8 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
const isCloudEnv = isCloudEnvironment();
|
const isCloudEnv = isCloudEnvironment();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="h-full flex flex-col bg-gray-200">
|
<div className="h-full flex flex-col">
|
||||||
<video
|
{/* <video
|
||||||
autoPlay
|
autoPlay
|
||||||
loop
|
loop
|
||||||
muted
|
muted
|
||||||
|
|
@ -121,12 +121,12 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
>
|
>
|
||||||
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
||||||
Your browser does not support the video tag.
|
Your browser does not support the video tag.
|
||||||
</video>
|
</video> */}
|
||||||
|
|
||||||
<Header user={user} />
|
<Header user={user} />
|
||||||
|
|
||||||
<div className="relative flex-1 flex flex-row gap-2.5 items-start w-full max-w-[1920px] max-h-[calc(100% - 3.5rem)] overflow-hidden mx-auto px-2.5 py-2.5">
|
<div className="relative flex-1 flex flex-row gap-2.5 items-start w-full max-w-[1920px] max-h-[calc(100% - 3.5rem)] overflow-hidden mx-auto px-2.5 pb-2.5">
|
||||||
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl min-h-full">
|
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl h-[calc(100%-2.75rem)]">
|
||||||
<div className="relative mb-2">
|
<div className="relative mb-2">
|
||||||
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
|
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
|
||||||
<input id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
<input id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
||||||
|
|
@ -152,6 +152,12 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
/>
|
/>
|
||||||
</CogneeInstancesAccordion>
|
</CogneeInstancesAccordion>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div className="fixed bottom-2.5 w-[calc(min(1920px,100%)/5)] lg:w-96 ml-[-1.25rem] mx-auto">
|
||||||
|
<a href="/plan">
|
||||||
|
<CTAButton className="w-full">Select a plan</CTAButton>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex-1 flex flex-col justify-between h-full overflow-y-auto">
|
<div className="flex-1 flex flex-col justify-between h-full overflow-y-auto">
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,8 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
||||||
};
|
};
|
||||||
|
|
||||||
checkConnectionToLocalCognee();
|
checkConnectionToLocalCognee();
|
||||||
}, [setCloudCogneeConnected, setLocalCogneeConnected]);
|
checkConnectionToCloudCognee();
|
||||||
|
}, [checkConnectionToCloudCognee, setCloudCogneeConnected, setLocalCogneeConnected]);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
value: isCloudConnectedModalOpen,
|
value: isCloudConnectedModalOpen,
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import { useBoolean } from "@/utils";
|
||||||
import { Accordion, CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
|
import { Accordion, CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
|
||||||
import { CloseIcon, MinusIcon, NotebookIcon, PlusIcon } from "@/ui/Icons";
|
import { CloseIcon, MinusIcon, NotebookIcon, PlusIcon } from "@/ui/Icons";
|
||||||
import { Notebook } from "@/ui/elements/Notebook/types";
|
import { Notebook } from "@/ui/elements/Notebook/types";
|
||||||
import { LoadingIndicator } from "@/ui/App";
|
|
||||||
import { useModal } from "@/ui/elements/Modal";
|
import { useModal } from "@/ui/elements/Modal";
|
||||||
|
import { LoadingIndicator } from "@/ui/App";
|
||||||
|
|
||||||
interface NotebooksAccordionProps {
|
interface NotebooksAccordionProps {
|
||||||
notebooks: Notebook[];
|
notebooks: Notebook[];
|
||||||
|
|
@ -60,7 +60,7 @@ export default function NotebooksAccordion({
|
||||||
.finally(() => setNotebookToRemove(null));
|
.finally(() => setNotebookToRemove(null));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleNotebookAdd = useCallback((_: object, formEvent?: FormEvent<HTMLFormElement>) => {
|
const handleNotebookAdd = useCallback((_: Notebook, formEvent?: FormEvent<HTMLFormElement>) => {
|
||||||
if (!formEvent) {
|
if (!formEvent) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -71,6 +71,7 @@ export default function NotebooksAccordion({
|
||||||
const notebookName = formElements.notebookName.value.trim();
|
const notebookName = formElements.notebookName.value.trim();
|
||||||
|
|
||||||
return addNotebook(notebookName)
|
return addNotebook(notebookName)
|
||||||
|
.then(() => {});
|
||||||
}, [addNotebook]);
|
}, [addNotebook]);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
|
@ -79,7 +80,7 @@ export default function NotebooksAccordion({
|
||||||
closeModal: closeNewNotebookModal,
|
closeModal: closeNewNotebookModal,
|
||||||
confirmAction: handleNewNotebookSubmit,
|
confirmAction: handleNewNotebookSubmit,
|
||||||
isActionLoading: isNewDatasetLoading,
|
isActionLoading: isNewDatasetLoading,
|
||||||
} = useModal<Notebook | void>(false, handleNotebookAdd);
|
} = useModal<Notebook>(false, handleNotebookAdd);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
|
@ -91,7 +92,7 @@ export default function NotebooksAccordion({
|
||||||
tools={isNewDatasetLoading ? (
|
tools={isNewDatasetLoading ? (
|
||||||
<LoadingIndicator />
|
<LoadingIndicator />
|
||||||
) : (
|
) : (
|
||||||
<IconButton onClick={openNewNotebookModal}><PlusIcon /></IconButton>
|
<IconButton onClick={() => openNewNotebookModal()}><PlusIcon /></IconButton>
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{notebooks.length === 0 && (
|
{notebooks.length === 0 && (
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
--global-color-primary-active: #500cc5 !important;
|
--global-color-primary-active: #500cc5 !important;
|
||||||
--global-color-primary-text: white !important;
|
--global-color-primary-text: white !important;
|
||||||
--global-color-secondary: #0DFF00 !important;
|
--global-color-secondary: #0DFF00 !important;
|
||||||
--global-background-default: #0D051C;
|
--global-background-default: #F4F4F4;
|
||||||
--textarea-default-color: #0D051C !important;
|
--textarea-default-color: #0D051C !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -20,6 +20,7 @@ body {
|
||||||
height: 100%;
|
height: 100%;
|
||||||
max-width: 100vw;
|
max-width: 100vw;
|
||||||
overflow-x: hidden;
|
overflow-x: hidden;
|
||||||
|
background-color: var(--global-background-default);
|
||||||
}
|
}
|
||||||
|
|
||||||
a {
|
a {
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { BackIcon, CheckIcon } from "@/ui/Icons";
|
import { BackIcon, CheckIcon } from "@/ui/Icons";
|
||||||
import { CTAButton, NeutralButton } from "@/ui/elements";
|
import { CTAButton, NeutralButton } from "@/ui/elements";
|
||||||
import Header from "@/ui/Layout/Header";
|
import Header from "@/ui/Layout/Header";
|
||||||
|
import { useAuthenticatedUser } from "@/modules/auth";
|
||||||
|
|
||||||
export default function Plan() {
|
export default function Plan() {
|
||||||
|
const { user } = useAuthenticatedUser();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="bg-gray-200 h-full max-w-[1920px] mx-auto">
|
<div className="h-full max-w-[1920px] mx-auto">
|
||||||
<video
|
{/* <video
|
||||||
autoPlay
|
autoPlay
|
||||||
loop
|
loop
|
||||||
muted
|
muted
|
||||||
|
|
@ -15,88 +20,104 @@ export default function Plan() {
|
||||||
>
|
>
|
||||||
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
<source src="/videos/background-video-blur.mp4" type="video/mp4" />
|
||||||
Your browser does not support the video tag.
|
Your browser does not support the video tag.
|
||||||
</video>
|
</video> */}
|
||||||
|
|
||||||
<Header />
|
<Header user={user} />
|
||||||
|
|
||||||
<div className="relative flex flex-row items-start justify-stretch gap-2.5">
|
<div className="relative flex flex-row items-start justify-stretch gap-2.5">
|
||||||
<div className="flex-1/5 h-full">
|
<div className="flex-1/5 h-full">
|
||||||
<Link href="/dashboard" className="py-4 px-5 flex flex-row items-center gap-5">
|
<div className="flex flex-col justify-between">
|
||||||
<BackIcon />
|
<Link href="/dashboard" className="py-4 px-5 flex flex-row items-center gap-5">
|
||||||
<span>back</span>
|
<BackIcon />
|
||||||
</Link>
|
<span>back</span>
|
||||||
|
</Link>
|
||||||
|
|
||||||
|
{/* <div className="fixed bottom-6 w-[calc(min(1920px,100%)/5)] mx-auto">
|
||||||
|
<div className="text-sm mb-2"></div>
|
||||||
|
<a href="/plan">
|
||||||
|
<CTAButton className="w-full">Select a plan</CTAButton>
|
||||||
|
</a>
|
||||||
|
</div> */}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex-3/5">
|
<div className="flex-3/5">
|
||||||
<div className="bg-[rgba(255,255,255,0.7)] rounded-xl px-5 py-4 mb-2">
|
<div className="bg-white rounded-xl px-5 py-5 mb-2">
|
||||||
Affordable and transparent pricing
|
Affordable and transparent pricing
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="grid grid-cols-3 gap-x-2.5">
|
<div className="grid grid-cols-3 gap-x-2.5">
|
||||||
<div className="pt-13 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-[rgba(255,255,255,0.7)] h-full">
|
<div className="pt-13 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-white h-full">
|
||||||
<div>Basic</div>
|
<div>Basic</div>
|
||||||
<div className="text-3xl mb-4 font-bold">Free</div>
|
<div className="text-[1.75rem] mb-4 font-bold">Free</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pt-13 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-[rgba(255,255,255,0.7)] h-full">
|
<div className="pt-5 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-white h-full border-indigo-600 border-1 border-b-0">
|
||||||
|
<div className="text-indigo-600 mb-5 text-xs font-black">Most Popular</div>
|
||||||
<div>On-prem Subscription</div>
|
<div>On-prem Subscription</div>
|
||||||
<div className="mb-4"><span className="text-3xl font-bold">$2470</span><span className="text-gray-400"> /per month</span></div>
|
<div className="mb-2"><span className="text-[1.75rem] font-bold">$2470</span><span className="text-gray-400"> /per month</span></div>
|
||||||
<div className="mb-9"><span className="font-bold">Save 20% </span>yearly</div>
|
<div className=""><span className="font-black">Save 20% </span>yearly</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pt-13 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-[rgba(255,255,255,0.7)] h-full">
|
<div className="pt-13 py-4 px-5 mb-2.5 rounded-tl-xl rounded-tr-xl bg-white h-full">
|
||||||
<div>Cloud Subscription</div>
|
<div>Cloud Subscription</div>
|
||||||
<div className="mb-4"><span className="text-3xl font-bold">$25</span><span className="text-gray-400"> /per month</span></div>
|
<div className="mb-2"><span className="text-[1.75rem] font-bold">$25</span><span className="text-gray-400"> /per month</span></div>
|
||||||
<div className="mb-9 text-gray-400">(beta pricing)</div>
|
<div className=" text-gray-400">(beta pricing)</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="bg-[rgba(255,255,255,0.7)] rounded-bl-xl rounded-br-xl h-full py-4 px-5">
|
<div className="bg-white rounded-bl-xl rounded-br-xl h-full py-4 px-5">
|
||||||
<div className="mb-1 invisible">Everything in the free plan, plus...</div>
|
<div className="mb-2 invisible">Everything in the free plan, plus...</div>
|
||||||
<div className="flex flex-col gap-3 mb-28">
|
<div className="flex flex-col gap-3 mb-28">
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />License to use Cognee open source</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />License to use Cognee open source</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Cognee tasks and pipelines</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Cognee tasks and pipelines</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Custom schema and ontology generation</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Custom schema and ontology generation</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Integrated evaluations</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Integrated evaluations</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />More than 28 data sources supported</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />More than 28 data sources supported</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="bg-[rgba(255,255,255,0.7)] rounded-bl-xl rounded-br-xl h-full py-4 px-5">
|
<div className="bg-white rounded-bl-xl rounded-br-xl border-indigo-600 border-1 border-t-0 h-full py-4 px-5">
|
||||||
<div className="mb-1 text-gray-400">Everything in the free plan, plus...</div>
|
<div className="mb-2 text-gray-400">Everything in the free plan, plus...</div>
|
||||||
<div className="flex flex-col gap-3 mb-10">
|
<div className="flex flex-col gap-3 mb-4">
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />License to use Cognee open source and Cognee Platform</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />License to use Cognee open source and Cognee Platform</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />1 day SLA</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />1 day SLA</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />On-prem deployment</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />On-prem deployment</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Hands-on support</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Hands-on support</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Architecture review</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Architecture review</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Roadmap prioritization</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Roadmap prioritization</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Knowledge transfer</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Knowledge transfer</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="bg-[rgba(255,255,255,0.7)] rounded-bl-xl rounded-br-xl h-full py-4 px-5">
|
<div className="bg-white rounded-bl-xl rounded-br-xl h-full py-4 px-5">
|
||||||
<div className="mb-1 text-gray-400">Everything in the free plan, plus...</div>
|
<div className="mb-2 text-gray-400">Everything in the free plan, plus...</div>
|
||||||
<div className="flex flex-col gap-3 mb-10">
|
<div className="flex flex-col gap-3 mb-4">
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Fully hosted cloud platform</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Fully hosted cloud platform</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Multi-tenant architecture</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Multi-tenant architecture</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Comprehensive API endpoints</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Comprehensive API endpoints</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Automated scaling and parallel processing</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Automated scaling and parallel processing</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Ability to group memories per user and domain</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Ability to group memories per user and domain</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />Automatic updates and priority support</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />Automatic updates and priority support</div>
|
||||||
<div className="flex flex-row gap-2"><CheckIcon className="mt-1 shrink-0" />1 GB ingestion + 10,000 API calls</div>
|
<div className="flex flex-row gap-2 leading-5"><CheckIcon className="mt-1 shrink-0" />1 GB ingestion + 10,000 API calls</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pt-4 pb-14 mb-2.5">
|
<div className="pt-4 pb-14 mb-2.5">
|
||||||
<NeutralButton className="w-full">Try for free</NeutralButton>
|
<a href="https://www.github.com/topoteretes/cognee" target="_blank">
|
||||||
|
<NeutralButton className="w-full">Try for free</NeutralButton>
|
||||||
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pt-4 pb-14 mb-2.5">
|
<div className="pt-4 pb-14 mb-2.5">
|
||||||
<CTAButton className="w-full">Talk to us</CTAButton>
|
<a href="https://www.cognee.ai/contact-us" target="_blank">
|
||||||
|
<CTAButton className="w-full">Talk to us</CTAButton>
|
||||||
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pt-4 pb-14 mb-2.5">
|
<div className="pt-4 pb-14 mb-2.5">
|
||||||
<NeutralButton className="w-full">Sign up for Cogwit Beta</NeutralButton>
|
<a href="https://platform.cognee.ai" target="_blank">
|
||||||
|
<NeutralButton className="w-full">Sign up for Cogwit Beta</NeutralButton>
|
||||||
|
</a>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
@ -106,7 +127,7 @@ export default function Plan() {
|
||||||
<div className="text-center">On-prem</div>
|
<div className="text-center">On-prem</div>
|
||||||
<div className="text-center">Cloud</div>
|
<div className="text-center">Cloud</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="grid grid-cols-4 py-1 px-5 mb-12 bg-[rgba(255,255,255,0.7)] rounded-xl">
|
<div className="grid grid-cols-4 py-1 px-5 mb-12 bg-white rounded-xl leading-[1]">
|
||||||
<div className="border-b-[1px] border-b-gray-100 py-3">Data Sources</div>
|
<div className="border-b-[1px] border-b-gray-100 py-3">Data Sources</div>
|
||||||
<div className="text-center border-b-[1px] border-b-gray-100 py-3">28+</div>
|
<div className="text-center border-b-[1px] border-b-gray-100 py-3">28+</div>
|
||||||
<div className="text-center border-b-[1px] border-b-gray-100 py-3">28+</div>
|
<div className="text-center border-b-[1px] border-b-gray-100 py-3">28+</div>
|
||||||
|
|
@ -134,19 +155,19 @@ export default function Plan() {
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="grid grid-cols-2 gap-x-2.5 gap-y-2.5 mb-12">
|
<div className="grid grid-cols-2 gap-x-2.5 gap-y-2.5 mb-12">
|
||||||
<div className="bg-[rgba(255,255,255,0.5)] py-4 px-5 rounded-xl">
|
<div className="bg-white py-4 px-5 rounded-xl">
|
||||||
<div>Can I change my plan anytime?</div>
|
<div>Can I change my plan anytime?</div>
|
||||||
<div className="text-gray-500 mt-6">Yes, you can upgrade or downgrade your plan at any time. Changes take effect immediately.</div>
|
<div className="text-gray-500 mt-6">Yes, you can upgrade or downgrade your plan at any time. Changes take effect immediately.</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="bg-[rgba(255,255,255,0.5)] py-4 px-5 rounded-xl">
|
<div className="bg-white py-4 px-5 rounded-xl">
|
||||||
<div>What happens to my data if I downgrade?</div>
|
<div>What happens to my data if I downgrade?</div>
|
||||||
<div className="text-gray-500 mt-6">Your data is preserved, but features may be limited based on your new plan constraints.</div>
|
<div className="text-gray-500 mt-6">Your data is preserved, but features may be limited based on your new plan constraints.</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="bg-[rgba(255,255,255,0.5)] py-4 px-5 rounded-xl">
|
<div className="bg-white py-4 px-5 rounded-xl">
|
||||||
<div>Do you offer educational discounts?</div>
|
<div>Do you offer educational discounts?</div>
|
||||||
<div className="text-gray-500 mt-6">Yes, we offer special pricing for educational institutions and students. Contact us for details.</div>
|
<div className="text-gray-500 mt-6">Yes, we offer special pricing for educational institutions and students. Contact us for details.</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="bg-[rgba(255,255,255,0.5)] py-4 px-5 rounded-xl">
|
<div className="bg-white py-4 px-5 rounded-xl">
|
||||||
<div>Is there a free trial for paid plans?</div>
|
<div>Is there a free trial for paid plans?</div>
|
||||||
<div className="text-gray-500 mt-6">All new accounts start with a 14-day free trial of our Pro plan features.</div>
|
<div className="text-gray-500 mt-6">All new accounts start with a 14-day free trial of our Pro plan features.</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -32,20 +32,21 @@ export default function Header({ user }: HeaderProps) {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<header className="relative bg-[rgba(244,244,244,0.3)] flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
|
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
|
||||||
<div className="flex flex-row gap-4 items-center">
|
<div className="flex flex-row gap-4 items-center">
|
||||||
<CogneeIcon />
|
<CogneeIcon />
|
||||||
<div className="text-lg">Cognee Local</div>
|
<div className="text-lg">Cognee Local</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex flex-row items-center gap-2.5">
|
<div className="flex flex-row items-center gap-2.5">
|
||||||
<GhostButton onClick={openSyncModal} className="text-indigo-700 gap-3 pl-4 pr-4">
|
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
|
||||||
<CloudIcon />
|
<CloudIcon />
|
||||||
<div>Sync</div>
|
<div>Sync</div>
|
||||||
</GhostButton>
|
</GhostButton>
|
||||||
<a href="/plan">
|
<a href="/plan" className="!text-indigo-600 pl-4 pr-4">
|
||||||
<GhostButton className="text-indigo-700 pl-4 pr-4">Premium</GhostButton>
|
Premium
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://platform.cognee.ai" className="!text-indigo-600 pl-4 pr-4">API keys</a>
|
||||||
{/* <div className="px-2 py-2 mr-3">
|
{/* <div className="px-2 py-2 mr-3">
|
||||||
<SettingsIcon />
|
<SettingsIcon />
|
||||||
</div> */}
|
</div> */}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import classNames from "classnames";
|
import classNames from "classnames";
|
||||||
import { ButtonHTMLAttributes } from "react";
|
import { ButtonHTMLAttributes } from "react";
|
||||||
|
|
||||||
export default function CTAButton({ children, className, ...props }: ButtonHTMLAttributes<HTMLButtonElement>) {
|
export default function GhostButton({ children, className, ...props }: ButtonHTMLAttributes<HTMLButtonElement>) {
|
||||||
return (
|
return (
|
||||||
<button className={classNames("flex flex-row justify-center items-center gap-2 cursor-pointer rounded-3xl bg-transparent px-10 h-8 text-black hover:bg-gray-200 focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600", className)} {...props}>{children}</button>
|
<button className={classNames("flex flex-row justify-center items-center gap-2 cursor-pointer rounded-3xl bg-transparent px-10 h-8 text-black hover:bg-gray-200 focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600", className)} {...props}>{children}</button>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import { FormEvent, useCallback, useState } from "react";
|
import { FormEvent, useCallback, useState } from "react";
|
||||||
import { useBoolean } from "@/utils";
|
import { useBoolean } from "@/utils";
|
||||||
|
|
||||||
export default function useModal<ConfirmActionReturnType = void>(initiallyOpen?: boolean, confirmCallback?: (state: object, event?: FormEvent<HTMLFormElement>) => Promise<ConfirmActionReturnType> | ConfirmActionReturnType) {
|
export default function useModal<ModalState extends object, ConfirmActionEvent = FormEvent<HTMLFormElement>>(initiallyOpen?: boolean, confirmCallback?: (state: ModalState, event?: ConfirmActionEvent) => Promise<void> | void) {
|
||||||
const [modalState, setModalState] = useState<object>({});
|
const [modalState, setModalState] = useState<ModalState>();
|
||||||
const [isActionLoading, setLoading] = useState(false);
|
const [isActionLoading, setLoading] = useState(false);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
|
@ -11,7 +11,7 @@ export default function useModal<ConfirmActionReturnType = void>(initiallyOpen?:
|
||||||
setFalse: closeModalInternal,
|
setFalse: closeModalInternal,
|
||||||
} = useBoolean(initiallyOpen || false);
|
} = useBoolean(initiallyOpen || false);
|
||||||
|
|
||||||
const openModal = useCallback((state?: object) => {
|
const openModal = useCallback((state?: ModalState) => {
|
||||||
if (state) {
|
if (state) {
|
||||||
setModalState(state);
|
setModalState(state);
|
||||||
}
|
}
|
||||||
|
|
@ -20,20 +20,21 @@ export default function useModal<ConfirmActionReturnType = void>(initiallyOpen?:
|
||||||
|
|
||||||
const closeModal = useCallback(() => {
|
const closeModal = useCallback(() => {
|
||||||
closeModalInternal();
|
closeModalInternal();
|
||||||
setModalState({});
|
setModalState({} as ModalState);
|
||||||
}, [closeModalInternal]);
|
}, [closeModalInternal]);
|
||||||
|
|
||||||
const confirmAction = useCallback((event?: FormEvent<HTMLFormElement>) => {
|
const confirmAction = useCallback((event?: ConfirmActionEvent) => {
|
||||||
if (confirmCallback) {
|
if (confirmCallback) {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
|
|
||||||
const maybePromise = confirmCallback(modalState, event);
|
const maybePromise = confirmCallback(modalState as ModalState, event);
|
||||||
|
|
||||||
if (maybePromise instanceof Promise) {
|
if (maybePromise instanceof Promise) {
|
||||||
return maybePromise
|
return maybePromise
|
||||||
.finally(closeModal)
|
.finally(closeModal)
|
||||||
.finally(() => setLoading(false));
|
.finally(() => setLoading(false));
|
||||||
} else {
|
} else {
|
||||||
|
closeModal();
|
||||||
return maybePromise; // Not a promise.
|
return maybePromise; // Not a promise.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@
|
||||||
|
|
||||||
import { v4 as uuid4 } from "uuid";
|
import { v4 as uuid4 } from "uuid";
|
||||||
import classNames from "classnames";
|
import classNames from "classnames";
|
||||||
import { Fragment, MutableRefObject, useCallback, useEffect, useRef, useState } from "react";
|
import { Fragment, MouseEvent, MutableRefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import { CaretIcon, PlusIcon } from "@/ui/Icons";
|
import { useModal } from "@/ui/elements/Modal";
|
||||||
import { IconButton, PopupMenu, TextArea } from "@/ui/elements";
|
import { CaretIcon, CloseIcon, PlusIcon } from "@/ui/Icons";
|
||||||
|
import { IconButton, PopupMenu, TextArea, Modal, GhostButton, CTAButton } from "@/ui/elements";
|
||||||
import { GraphControlsAPI } from "@/app/(graph)/GraphControls";
|
import { GraphControlsAPI } from "@/app/(graph)/GraphControls";
|
||||||
import GraphVisualization, { GraphVisualizationAPI } from "@/app/(graph)/GraphVisualization";
|
import GraphVisualization, { GraphVisualizationAPI } from "@/app/(graph)/GraphVisualization";
|
||||||
|
|
||||||
|
|
@ -60,13 +61,26 @@ export default function Notebook({ notebook, updateNotebook, runCell }: Notebook
|
||||||
updateNotebook(newNotebook);
|
updateNotebook(newNotebook);
|
||||||
}, [notebook, updateNotebook]);
|
}, [notebook, updateNotebook]);
|
||||||
|
|
||||||
const handleCellRemove = useCallback((cell: Cell) => {
|
const removeCell = useCallback((cell: Cell, event?: MouseEvent) => {
|
||||||
|
event?.preventDefault();
|
||||||
|
|
||||||
updateNotebook({
|
updateNotebook({
|
||||||
...notebook,
|
...notebook,
|
||||||
cells: notebook.cells.filter((c: Cell) => c.id !== cell.id),
|
cells: notebook.cells.filter((c: Cell) => c.id !== cell.id),
|
||||||
});
|
});
|
||||||
}, [notebook, updateNotebook]);
|
}, [notebook, updateNotebook]);
|
||||||
|
|
||||||
|
const {
|
||||||
|
isModalOpen: isRemoveCellConfirmModalOpen,
|
||||||
|
openModal: openCellRemoveConfirmModal,
|
||||||
|
closeModal: closeCellRemoveConfirmModal,
|
||||||
|
confirmAction: handleCellRemoveConfirm,
|
||||||
|
} = useModal<Cell, MouseEvent>(false, removeCell);
|
||||||
|
|
||||||
|
const handleCellRemove = useCallback((cell: Cell) => {
|
||||||
|
openCellRemoveConfirmModal(cell);
|
||||||
|
}, [openCellRemoveConfirmModal]);
|
||||||
|
|
||||||
const handleCellInputChange = useCallback((notebook: NotebookType, cell: Cell, value: string) => {
|
const handleCellInputChange = useCallback((notebook: NotebookType, cell: Cell, value: string) => {
|
||||||
const newCell = {...cell, content: value };
|
const newCell = {...cell, content: value };
|
||||||
|
|
||||||
|
|
@ -134,100 +148,133 @@ export default function Notebook({ notebook, updateNotebook, runCell }: Notebook
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="bg-white rounded-xl flex flex-col gap-0.5 px-7 py-5 flex-1">
|
<>
|
||||||
<div className="mb-5">{notebook.name}</div>
|
<div className="bg-white rounded-xl flex flex-col gap-0.5 px-7 py-5 flex-1">
|
||||||
|
<div className="mb-5">{notebook.name}</div>
|
||||||
|
|
||||||
{notebook.cells.map((cell: Cell, index) => (
|
{notebook.cells.map((cell: Cell, index) => (
|
||||||
<Fragment key={cell.id}>
|
<Fragment key={cell.id}>
|
||||||
<div key={cell.id} className="flex flex-row rounded-xl border-1 border-gray-100">
|
<div key={cell.id} className="flex flex-row rounded-xl border-1 border-gray-100">
|
||||||
<div className="flex flex-col flex-1 relative">
|
<div className="flex flex-col flex-1 relative">
|
||||||
{cell.type === "code" ? (
|
{cell.type === "code" ? (
|
||||||
<>
|
<>
|
||||||
<div className="absolute left-[-1.35rem] top-2.5">
|
<div className="absolute left-[-1.35rem] top-2.5">
|
||||||
<IconButton className="p-[0.25rem] m-[-0.25rem]" onClick={toggleCellOpen.bind(null, cell.id)}>
|
<IconButton className="p-[0.25rem] m-[-0.25rem]" onClick={toggleCellOpen.bind(null, cell.id)}>
|
||||||
<CaretIcon className={classNames("transition-transform", openCells.has(cell.id) ? "rotate-0" : "rotate-180")} />
|
<CaretIcon className={classNames("transition-transform", openCells.has(cell.id) ? "rotate-0" : "rotate-180")} />
|
||||||
</IconButton>
|
</IconButton>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<NotebookCellHeader
|
<NotebookCellHeader
|
||||||
cell={cell}
|
cell={cell}
|
||||||
runCell={handleCellRun}
|
runCell={handleCellRun}
|
||||||
renameCell={handleCellRename}
|
renameCell={handleCellRename}
|
||||||
removeCell={handleCellRemove}
|
removeCell={handleCellRemove}
|
||||||
moveCellUp={handleCellUp}
|
moveCellUp={handleCellUp}
|
||||||
moveCellDown={handleCellDown}
|
moveCellDown={handleCellDown}
|
||||||
className="rounded-tl-xl rounded-tr-xl"
|
className="rounded-tl-xl rounded-tr-xl"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{openCells.has(cell.id) && (
|
{openCells.has(cell.id) && (
|
||||||
<>
|
<>
|
||||||
|
<TextArea
|
||||||
|
value={cell.content}
|
||||||
|
onChange={handleCellInputChange.bind(null, notebook, cell)}
|
||||||
|
// onKeyUp={handleCellRunOnEnter}
|
||||||
|
isAutoExpanding
|
||||||
|
name="cellInput"
|
||||||
|
placeholder="Type your code here..."
|
||||||
|
contentEditable={true}
|
||||||
|
className="resize-none min-h-36 max-h-96 overflow-y-auto rounded-tl-none rounded-tr-none rounded-bl-xl rounded-br-xl border-0 !outline-0"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex flex-col bg-gray-100 overflow-x-auto max-w-full">
|
||||||
|
{cell.result && (
|
||||||
|
<div className="px-2 py-2">
|
||||||
|
output: <CellResult content={cell.result} />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{!!cell.error?.length && (
|
||||||
|
<div className="px-2 py-2">
|
||||||
|
error: {cell.error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<div className="absolute left-[-1.35rem] top-2.5">
|
||||||
|
<IconButton className="p-[0.25rem] m-[-0.25rem]" onClick={toggleCellOpen.bind(null, cell.id)}>
|
||||||
|
<CaretIcon className={classNames("transition-transform", openCells.has(cell.id) ? "rotate-0" : "rotate-180")} />
|
||||||
|
</IconButton>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<NotebookCellHeader
|
||||||
|
cell={cell}
|
||||||
|
renameCell={handleCellRename}
|
||||||
|
removeCell={handleCellRemove}
|
||||||
|
moveCellUp={handleCellUp}
|
||||||
|
moveCellDown={handleCellDown}
|
||||||
|
className="rounded-tl-xl rounded-tr-xl"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{openCells.has(cell.id) && (
|
||||||
<TextArea
|
<TextArea
|
||||||
value={cell.content}
|
value={cell.content}
|
||||||
onChange={handleCellInputChange.bind(null, notebook, cell)}
|
onChange={handleCellInputChange.bind(null, notebook, cell)}
|
||||||
// onKeyUp={handleCellRunOnEnter}
|
// onKeyUp={handleCellRunOnEnter}
|
||||||
isAutoExpanding
|
isAutoExpanding
|
||||||
name="cellInput"
|
name="cellInput"
|
||||||
placeholder="Type your code here..."
|
placeholder="Type your text here..."
|
||||||
contentEditable={true}
|
contentEditable={true}
|
||||||
className="resize-none min-h-36 max-h-96 overflow-y-auto rounded-tl-none rounded-tr-none rounded-bl-xl rounded-br-xl border-0 !outline-0"
|
className="resize-none min-h-24 max-h-96 overflow-y-auto rounded-tl-none rounded-tr-none rounded-bl-xl rounded-br-xl border-0 !outline-0"
|
||||||
/>
|
/>
|
||||||
|
)}
|
||||||
<div className="flex flex-col bg-gray-100 overflow-x-auto max-w-full">
|
</>
|
||||||
{cell.result && (
|
)}
|
||||||
<div className="px-2 py-2">
|
</div>
|
||||||
output: <CellResult content={cell.result} />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{cell.error && (
|
|
||||||
<div className="px-2 py-2">
|
|
||||||
error: {cell.error}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
openCells.has(cell.id) && (
|
|
||||||
<TextArea
|
|
||||||
value={cell.content}
|
|
||||||
onChange={handleCellInputChange.bind(null, notebook, cell)}
|
|
||||||
// onKeyUp={handleCellRunOnEnter}
|
|
||||||
isAutoExpanding
|
|
||||||
name="cellInput"
|
|
||||||
placeholder="Type your text here..."
|
|
||||||
contentEditable={true}
|
|
||||||
className="resize-none min-h-24 max-h-96 overflow-y-auto rounded-tl-none rounded-tr-none rounded-bl-xl rounded-br-xl border-0 !outline-0"
|
|
||||||
/>
|
|
||||||
)
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
<div className="ml-[-1.35rem]">
|
||||||
<div className="ml-[-1.35rem]">
|
<PopupMenu
|
||||||
<PopupMenu
|
openToRight={true}
|
||||||
openToRight={true}
|
triggerElement={<PlusIcon />}
|
||||||
triggerElement={<PlusIcon />}
|
triggerClassName="p-[0.25rem] m-[-0.25rem]"
|
||||||
triggerClassName="p-[0.25rem] m-[-0.25rem]"
|
>
|
||||||
>
|
<div className="flex flex-col gap-0.5">
|
||||||
<div className="flex flex-col gap-0.5">
|
<button
|
||||||
<button
|
onClick={() => handleCellAdd(index, "markdown")}
|
||||||
onClick={() => handleCellAdd(index, "markdown")}
|
className="hover:bg-gray-100 w-full text-left px-2 cursor-pointer"
|
||||||
|
>
|
||||||
|
<span>text</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
onClick={() => handleCellAdd(index, "code")}
|
||||||
className="hover:bg-gray-100 w-full text-left px-2 cursor-pointer"
|
className="hover:bg-gray-100 w-full text-left px-2 cursor-pointer"
|
||||||
>
|
>
|
||||||
<span>text</span>
|
<span>code</span>
|
||||||
</button>
|
</div>
|
||||||
</div>
|
</PopupMenu>
|
||||||
<div
|
</div>
|
||||||
onClick={() => handleCellAdd(index, "code")}
|
</Fragment>
|
||||||
className="hover:bg-gray-100 w-full text-left px-2 cursor-pointer"
|
))}
|
||||||
>
|
</div>
|
||||||
<span>code</span>
|
|
||||||
</div>
|
<Modal isOpen={isRemoveCellConfirmModalOpen}>
|
||||||
</PopupMenu>
|
<div className="w-full max-w-2xl">
|
||||||
|
<div className="flex flex-row items-center justify-between">
|
||||||
|
<span className="text-2xl">Delete notebook cell?</span>
|
||||||
|
<IconButton onClick={closeCellRemoveConfirmModal}><CloseIcon /></IconButton>
|
||||||
</div>
|
</div>
|
||||||
</Fragment>
|
<div className="mt-8 mb-6">Are you sure you want to delete a notebook cell? This action cannot be undone.</div>
|
||||||
))}
|
<div className="flex flex-row gap-4 mt-4 justify-end">
|
||||||
</div>
|
<GhostButton type="button" onClick={closeCellRemoveConfirmModal}>cancel</GhostButton>
|
||||||
|
<CTAButton onClick={handleCellRemoveConfirm} type="submit">delete</CTAButton>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -253,7 +300,7 @@ function CellResult({ content }: { content: [] }) {
|
||||||
data={transformInsightsGraphData(line)}
|
data={transformInsightsGraphData(line)}
|
||||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||||
graphControls={graphControls}
|
graphControls={graphControls}
|
||||||
className="min-h-48"
|
className="min-h-80"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import { Cell } from "./types";
|
||||||
|
|
||||||
interface NotebookCellHeaderProps {
|
interface NotebookCellHeaderProps {
|
||||||
cell: Cell;
|
cell: Cell;
|
||||||
runCell: (cell: Cell, cogneeInstance: string) => Promise<void>;
|
runCell?: (cell: Cell, cogneeInstance: string) => Promise<void>;
|
||||||
renameCell: (cell: Cell) => void;
|
renameCell: (cell: Cell) => void;
|
||||||
removeCell: (cell: Cell) => void;
|
removeCell: (cell: Cell) => void;
|
||||||
moveCellUp: (cell: Cell) => void;
|
moveCellUp: (cell: Cell) => void;
|
||||||
|
|
@ -36,28 +36,36 @@ export default function NotebookCellHeader({
|
||||||
const [runInstance, setRunInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
const [runInstance, setRunInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||||
|
|
||||||
const handleCellRun = () => {
|
const handleCellRun = () => {
|
||||||
setIsRunningCell();
|
if (runCell) {
|
||||||
runCell(cell, runInstance)
|
setIsRunningCell();
|
||||||
.then(() => {
|
runCell(cell, runInstance)
|
||||||
setIsNotRunningCell();
|
.then(() => {
|
||||||
});
|
setIsNotRunningCell();
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={classNames("flex flex-row justify-between items-center h-9 bg-gray-100", className)}>
|
<div className={classNames("flex flex-row justify-between items-center h-9 bg-gray-100", className)}>
|
||||||
<div className="flex flex-row items-center px-3.5">
|
<div className="flex flex-row items-center px-3.5">
|
||||||
{isRunningCell ? <LoadingIndicator /> : <IconButton onClick={handleCellRun}><PlayIcon /></IconButton>}
|
{runCell && (
|
||||||
|
<>
|
||||||
|
{isRunningCell ? <LoadingIndicator /> : <IconButton onClick={handleCellRun}><PlayIcon /></IconButton>}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<span className="ml-4">{cell.name}</span>
|
<span className="ml-4">{cell.name}</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="pr-4 flex flex-row items-center gap-8">
|
<div className="pr-4 flex flex-row items-center gap-8">
|
||||||
{isCloudEnvironment() ? (
|
{runCell && (
|
||||||
<div>
|
isCloudEnvironment() ? (
|
||||||
cloud cognee
|
<div>
|
||||||
</div>
|
cloud cognee
|
||||||
) : (
|
</div>
|
||||||
<div>
|
) : (
|
||||||
local cognee
|
<div>
|
||||||
</div>
|
local cognee
|
||||||
|
</div>
|
||||||
|
)
|
||||||
)}
|
)}
|
||||||
{/* <Select name="cogneeInstance" onChange={(event) => setRunInstance(event.currentTarget.value)} className="!bg-transparent outline-none cursor-pointer !hover:bg-gray-50">
|
{/* <Select name="cogneeInstance" onChange={(event) => setRunInstance(event.currentTarget.value)} className="!bg-transparent outline-none cursor-pointer !hover:bg-gray-50">
|
||||||
<option value="local" className="flex flex-row items-center gap-2">
|
<option value="local" className="flex flex-row items-center gap-2">
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ let numberOfRetries = 0;
|
||||||
|
|
||||||
const isAuth0Enabled = process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "true";
|
const isAuth0Enabled = process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "true";
|
||||||
|
|
||||||
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localhost:8000/api";
|
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localhost:8000";
|
||||||
|
|
||||||
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001/api";
|
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
|
||||||
|
|
||||||
let apiKey: string | null = null;
|
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
|
||||||
let accessToken: string | null = null;
|
let accessToken: string | null = null;
|
||||||
|
|
||||||
export default async function fetch(url: string, options: RequestInit = {}, useCloud = false): Promise<Response> {
|
export default async function fetch(url: string, options: RequestInit = {}, useCloud = false): Promise<Response> {
|
||||||
|
|
@ -30,26 +30,24 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const authHeaders = useCloud && (!isCloudEnvironment() || !accessToken) ? {
|
||||||
|
"X-Api-Key": apiKey,
|
||||||
|
} : {
|
||||||
|
"Authorization": `Bearer ${accessToken}`,
|
||||||
|
}
|
||||||
|
|
||||||
return global.fetch(
|
return global.fetch(
|
||||||
(useCloud ? cloudApiUrl : backendApiUrl) + (useCloud ? url.replace("/v1", "") : url),
|
(useCloud ? cloudApiUrl : backendApiUrl) + "/api" + (useCloud ? url.replace("/v1", "") : url),
|
||||||
{
|
{
|
||||||
...options,
|
...options,
|
||||||
headers: {
|
headers: {
|
||||||
...options.headers,
|
...options.headers,
|
||||||
...(useCloud && !isCloudEnvironment()
|
...authHeaders,
|
||||||
? {"X-Api-Key": apiKey!}
|
} as HeadersInit,
|
||||||
: {"Authorization": `Bearer ${accessToken}`}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.then((response) => handleServerErrors(response, retry))
|
.then((response) => handleServerErrors(response, retry, useCloud))
|
||||||
.then((response) => {
|
|
||||||
numberOfRetries = 0;
|
|
||||||
|
|
||||||
return response;
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
if (error.detail === undefined) {
|
if (error.detail === undefined) {
|
||||||
return Promise.reject(
|
return Promise.reject(
|
||||||
|
|
@ -57,10 +55,10 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error.status === 401) {
|
|
||||||
return retry(error);
|
|
||||||
}
|
|
||||||
return Promise.reject(error);
|
return Promise.reject(error);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
numberOfRetries = 0;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
|
|
||||||
export default function handleServerErrors(response: Response, retry?: (response: Response) => Promise<Response>): Promise<Response> {
|
export default function handleServerErrors(response: Response, retry?: (response: Response) => Promise<Response>, useCloud?: boolean): Promise<Response> {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
if (response.status === 401) {
|
if (response.status === 401 && !useCloud) {
|
||||||
if (retry) {
|
if (retry) {
|
||||||
return retry(response)
|
return retry(response)
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
|
|
@ -13,7 +13,10 @@ export default function handleServerErrors(response: Response, retry?: (response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
return response.json().then(error => reject(error));
|
return response.json().then(error => {
|
||||||
|
error.status = response.status;
|
||||||
|
reject(error);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (response.status >= 200 && response.status < 300) {
|
if (response.status >= 200 && response.status < 300) {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def get_checks_router():
|
||||||
api_token = request.headers.get("X-Api-Key")
|
api_token = request.headers.get("X-Api-Key")
|
||||||
|
|
||||||
if api_token is None:
|
if api_token is None:
|
||||||
return CloudApiKeyMissingError()
|
raise CloudApiKeyMissingError()
|
||||||
|
|
||||||
return await check_api_key(api_token)
|
return await check_api_key(api_token)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||||
|
|
@ -10,7 +11,11 @@ from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||||
from cognee.modules.pipelines import run_pipeline
|
from cognee.modules.pipelines import run_pipeline
|
||||||
from cognee.modules.pipelines.tasks.task import Task
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
get_default_ontology_resolver,
|
||||||
|
get_ontology_resolver_from_env,
|
||||||
|
)
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
|
|
@ -39,7 +44,7 @@ async def cognify(
|
||||||
graph_model: BaseModel = KnowledgeGraph,
|
graph_model: BaseModel = KnowledgeGraph,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
chunk_size: int = None,
|
chunk_size: int = None,
|
||||||
ontology_file_path: Optional[str] = None,
|
config: Config = None,
|
||||||
vector_db_config: dict = None,
|
vector_db_config: dict = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
run_in_background: bool = False,
|
run_in_background: bool = False,
|
||||||
|
|
@ -100,8 +105,6 @@ async def cognify(
|
||||||
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
|
||||||
Default limits: ~512-8192 tokens depending on models.
|
Default limits: ~512-8192 tokens depending on models.
|
||||||
Smaller chunks = more granular but potentially fragmented knowledge.
|
Smaller chunks = more granular but potentially fragmented knowledge.
|
||||||
ontology_file_path: Path to RDF/OWL ontology file for domain-specific entity types.
|
|
||||||
Useful for specialized fields like medical or legal documents.
|
|
||||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||||
graph_db_config: Custom graph database configuration for relationship storage.
|
graph_db_config: Custom graph database configuration for relationship storage.
|
||||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||||
|
|
@ -188,11 +191,28 @@ async def cognify(
|
||||||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
||||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
ontology_config = get_ontology_env_config()
|
||||||
|
if (
|
||||||
|
ontology_config.ontology_file_path
|
||||||
|
and ontology_config.ontology_resolver
|
||||||
|
and ontology_config.matching_strategy
|
||||||
|
):
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||||
|
}
|
||||||
|
|
||||||
if temporal_cognify:
|
if temporal_cognify:
|
||||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||||
else:
|
else:
|
||||||
tasks = await get_default_tasks(
|
tasks = await get_default_tasks(
|
||||||
user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt
|
user, graph_model, chunker, chunk_size, config, custom_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||||
|
|
@ -216,9 +236,26 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
graph_model: BaseModel = KnowledgeGraph,
|
graph_model: BaseModel = KnowledgeGraph,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
chunk_size: int = None,
|
chunk_size: int = None,
|
||||||
ontology_file_path: Optional[str] = None,
|
config: Config = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
|
if config is None:
|
||||||
|
ontology_config = get_ontology_env_config()
|
||||||
|
if (
|
||||||
|
ontology_config.ontology_file_path
|
||||||
|
and ontology_config.ontology_resolver
|
||||||
|
and ontology_config.matching_strategy
|
||||||
|
):
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||||
|
}
|
||||||
|
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||||
|
|
@ -230,7 +267,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data,
|
extract_graph_from_data,
|
||||||
graph_model=graph_model,
|
graph_model=graph_model,
|
||||||
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
|
config=config,
|
||||||
custom_prompt=custom_prompt,
|
custom_prompt=custom_prompt,
|
||||||
task_config={"batch_size": 10},
|
task_config={"batch_size": 10},
|
||||||
), # Generate knowledge graphs from the document chunks.
|
), # Generate knowledge graphs from the document chunks.
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
|
||||||
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
||||||
|
|
@ -119,7 +120,7 @@ def get_cognify_router() -> APIRouter:
|
||||||
|
|
||||||
# If any cognify run errored return JSONResponse with proper error status code
|
# If any cognify run errored return JSONResponse with proper error status code
|
||||||
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
|
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
|
||||||
return JSONResponse(status_code=420, content=cognify_run)
|
return JSONResponse(status_code=420, content=jsonable_encoder(cognify_run))
|
||||||
return cognify_run
|
return cognify_run
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ class prune:
|
||||||
await _prune_data()
|
await _prune_data()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def prune_system(graph=True, vector=True, metadata=False):
|
async def prune_system(graph=True, vector=True, metadata=False, cache=True):
|
||||||
await _prune_system(graph, vector, metadata)
|
await _prune_system(graph, vector, metadata, cache)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
|
||||||
mark_sync_completed,
|
mark_sync_completed,
|
||||||
mark_sync_failed,
|
mark_sync_failed,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger("sync")
|
logger = get_logger("sync")
|
||||||
|
|
||||||
|
|
@ -583,7 +584,9 @@ async def _check_hashes_diff(
|
||||||
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
@ -630,7 +633,9 @@ async def _download_missing_files(
|
||||||
|
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
for file_hash in hashes_missing_on_local:
|
for file_hash in hashes_missing_on_local:
|
||||||
try:
|
try:
|
||||||
# Download file from cloud by hash
|
# Download file from cloud by hash
|
||||||
|
|
@ -749,7 +754,9 @@ async def _upload_missing_files(
|
||||||
|
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
for file_info in files_to_upload:
|
for file_info in files_to_upload:
|
||||||
try:
|
try:
|
||||||
file_dir = os.path.dirname(file_info.raw_data_location)
|
file_dir = os.path.dirname(file_info.raw_data_location)
|
||||||
|
|
@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
|
||||||
logger.info("Pruning cloud dataset to match local state")
|
logger.info("Pruning cloud dataset to match local state")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
|
||||||
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, json=payload, headers=headers) as response:
|
async with session.post(url, json=payload, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,30 @@ import pydantic
|
||||||
class BaseConfig(BaseSettings):
|
class BaseConfig(BaseSettings):
|
||||||
data_root_directory: str = get_absolute_path(".data_storage")
|
data_root_directory: str = get_absolute_path(".data_storage")
|
||||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||||
monitoring_tool: object = Observer.LANGFUSE
|
cache_root_directory: str = get_absolute_path(".cognee_cache")
|
||||||
|
monitoring_tool: object = Observer.NONE
|
||||||
|
|
||||||
@pydantic.model_validator(mode="after")
|
@pydantic.model_validator(mode="after")
|
||||||
def validate_paths(self):
|
def validate_paths(self):
|
||||||
|
# Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
|
||||||
|
# I'll remove this after we update documentation for S3 storage
|
||||||
|
# Auto-configure cache root directory for S3 storage if not explicitly set
|
||||||
|
storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
|
||||||
|
cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
|
||||||
|
|
||||||
|
if storage_backend == "s3" and not cache_root_env:
|
||||||
|
# Auto-generate S3 cache path when using S3 storage
|
||||||
|
bucket_name = os.getenv("STORAGE_BUCKET_NAME")
|
||||||
|
if bucket_name:
|
||||||
|
self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
|
||||||
|
|
||||||
# Require absolute paths for root directories
|
# Require absolute paths for root directories
|
||||||
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
||||||
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
||||||
|
# Set monitoring tool based on available keys
|
||||||
|
if self.langfuse_public_key and self.langfuse_secret_key:
|
||||||
|
self.monitoring_tool = Observer.LANGFUSE
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||||
|
|
@ -31,6 +48,7 @@ class BaseConfig(BaseSettings):
|
||||||
"data_root_directory": self.data_root_directory,
|
"data_root_directory": self.data_root_directory,
|
||||||
"system_root_directory": self.system_root_directory,
|
"system_root_directory": self.system_root_directory,
|
||||||
"monitoring_tool": self.monitoring_tool,
|
"monitoring_tool": self.monitoring_tool,
|
||||||
|
"cache_root_directory": self.cache_root_directory,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
|
||||||
|
|
||||||
async def get_default_tasks_by_indices(
|
async def get_default_tasks_by_indices(
|
||||||
|
|
@ -33,7 +33,7 @@ async def get_no_summary_tasks(
|
||||||
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
|
||||||
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
|
||||||
|
|
||||||
ontology_adapter = OntologyResolver(ontology_file=ontology_file_path)
|
ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path)
|
||||||
|
|
||||||
graph_task = Task(
|
graph_task = Task(
|
||||||
extract_graph_from_data,
|
extract_graph_from_data,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import aiohttp
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
"""Get or create an aiohttp session."""
|
"""Get or create an aiohttp session."""
|
||||||
if self._session is None or self._session.closed:
|
if self._session is None or self._session.closed:
|
||||||
self._session = aiohttp.ClientSession()
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
self._session = aiohttp.ClientSession(connector=connector)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(AsyncAttrs, DeclarativeBase):
|
||||||
"""
|
"""
|
||||||
Represents a base class for declarative models using SQLAlchemy.
|
Represents a base class for declarative models using SQLAlchemy.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
|
||||||
embedding_rate_limit_async,
|
embedding_rate_limit_async,
|
||||||
embedding_sleep_and_retry_async,
|
embedding_sleep_and_retry_async,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger("OllamaEmbeddingEngine")
|
logger = get_logger("OllamaEmbeddingEngine")
|
||||||
|
|
||||||
|
|
@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||||
) as response:
|
) as response:
|
||||||
|
|
|
||||||
|
|
@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
|
||||||
if os.path.exists(full_file_path):
|
if os.path.exists(full_file_path):
|
||||||
os.remove(full_file_path)
|
os.remove(full_file_path)
|
||||||
|
|
||||||
|
def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||||
|
|
||||||
|
if directory_path:
|
||||||
|
full_directory_path = os.path.join(parsed_storage_path, directory_path)
|
||||||
|
else:
|
||||||
|
full_directory_path = parsed_storage_path
|
||||||
|
|
||||||
|
directory_pathlib = Path(full_directory_path)
|
||||||
|
|
||||||
|
if not directory_pathlib.exists() or not directory_pathlib.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Use rglob for recursive search
|
||||||
|
for file_path in directory_pathlib.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
# Get relative path from storage root
|
||||||
|
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||||
|
# Normalize path separators for consistency
|
||||||
|
relative_path = relative_path.replace(os.sep, "/")
|
||||||
|
files.append(relative_path)
|
||||||
|
else:
|
||||||
|
# Use iterdir for just immediate directory
|
||||||
|
for file_path in directory_pathlib.iterdir():
|
||||||
|
if file_path.is_file():
|
||||||
|
# Get relative path from storage root
|
||||||
|
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||||
|
# Normalize path separators for consistency
|
||||||
|
relative_path = relative_path.replace(os.sep, "/")
|
||||||
|
files.append(relative_path)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
def remove_all(self, tree_path: str = None):
|
def remove_all(self, tree_path: str = None):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -155,21 +155,19 @@ class S3FileStorage(Storage):
|
||||||
"""
|
"""
|
||||||
Ensure that the specified directory exists, creating it if necessary.
|
Ensure that the specified directory exists, creating it if necessary.
|
||||||
|
|
||||||
If the directory already exists, no action is taken.
|
For S3 storage, this is a no-op since directories are created implicitly
|
||||||
|
when files are written to paths. S3 doesn't have actual directories,
|
||||||
|
just object keys with prefixes that appear as directories.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- directory_path (str): The path of the directory to check or create.
|
- directory_path (str): The path of the directory to check or create.
|
||||||
"""
|
"""
|
||||||
if not directory_path.strip():
|
# In S3, directories don't exist as separate entities - they're just prefixes
|
||||||
directory_path = self.storage_path.replace("s3://", "")
|
# When you write a file to s3://bucket/path/to/file.txt, the "directories"
|
||||||
|
# path/ and path/to/ are implicitly created. No explicit action needed.
|
||||||
def ensure_directory():
|
pass
|
||||||
if not self.s3.exists(directory_path):
|
|
||||||
self.s3.makedirs(directory_path, exist_ok=True)
|
|
||||||
|
|
||||||
await run_async(ensure_directory)
|
|
||||||
|
|
||||||
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||||
"""
|
"""
|
||||||
|
|
@ -213,6 +211,55 @@ class S3FileStorage(Storage):
|
||||||
|
|
||||||
await run_async(remove_file)
|
await run_async(remove_file)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
|
||||||
|
def list_files_sync():
|
||||||
|
if directory_path:
|
||||||
|
# Combine storage path with directory path
|
||||||
|
full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
|
||||||
|
else:
|
||||||
|
full_path = self.storage_path.replace("s3://", "")
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Use ** for recursive search
|
||||||
|
pattern = f"{full_path}/**"
|
||||||
|
else:
|
||||||
|
# Just files in the immediate directory
|
||||||
|
pattern = f"{full_path}/*"
|
||||||
|
|
||||||
|
# Use s3fs glob to find files
|
||||||
|
try:
|
||||||
|
all_paths = self.s3.glob(pattern)
|
||||||
|
# Filter to only files (not directories)
|
||||||
|
files = [path for path in all_paths if self.s3.isfile(path)]
|
||||||
|
|
||||||
|
# Convert back to relative paths from storage root
|
||||||
|
storage_prefix = self.storage_path.replace("s3://", "")
|
||||||
|
relative_files = []
|
||||||
|
for file_path in files:
|
||||||
|
if file_path.startswith(storage_prefix):
|
||||||
|
relative_path = file_path[len(storage_prefix) :].lstrip("/")
|
||||||
|
relative_files.append(relative_path)
|
||||||
|
|
||||||
|
return relative_files
|
||||||
|
except Exception:
|
||||||
|
# If directory doesn't exist or other error, return empty list
|
||||||
|
return []
|
||||||
|
|
||||||
|
return await run_async(list_files_sync)
|
||||||
|
|
||||||
async def remove_all(self, tree_path: str):
|
async def remove_all(self, tree_path: str):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,24 @@ class StorageManager:
|
||||||
else:
|
else:
|
||||||
return self.storage.remove(file_path)
|
return self.storage.remove(file_path)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
if inspect.iscoroutinefunction(self.storage.list_files):
|
||||||
|
return await self.storage.list_files(directory_path, recursive)
|
||||||
|
else:
|
||||||
|
return self.storage.list_files(directory_path, recursive)
|
||||||
|
|
||||||
async def remove_all(self, tree_path: str = None):
|
async def remove_all(self, tree_path: str = None):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,12 @@ async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
|
|
||||||
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
file_path = getattr(file, "name", None) or getattr(file, "full_name", None)
|
||||||
file_name = Path(file_path).stem if file_path else None
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_name = Path(file_path).stem if file_path else None
|
||||||
|
else:
|
||||||
|
# In case file_path does not exist or is a integer return None
|
||||||
|
file_name = None
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
pos = file.tell() # remember current pointer
|
pos = file.tell() # remember current pointer
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from cognee.modules.cloud.exceptions import CloudConnectionError
|
from cognee.modules.cloud.exceptions import CloudConnectionError
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
|
|
||||||
async def check_api_key(auth_token: str):
|
async def check_api_key(auth_token: str):
|
||||||
|
|
@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, headers=headers) as response:
|
async with session.post(url, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.shared.cache import delete_cache
|
||||||
|
|
||||||
|
|
||||||
async def prune_system(graph=True, vector=True, metadata=True):
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||||
if graph:
|
if graph:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.delete_graph()
|
await graph_engine.delete_graph()
|
||||||
|
|
@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
|
||||||
if metadata:
|
if metadata:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
await db_engine.delete_database()
|
await db_engine.delete_database()
|
||||||
|
|
||||||
|
if cache:
|
||||||
|
await delete_cache()
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,14 @@ from cognee.modules.engine.utils import (
|
||||||
generate_node_id,
|
generate_node_id,
|
||||||
generate_node_name,
|
generate_node_name,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
get_default_ontology_resolver,
|
||||||
|
get_ontology_resolver_from_env,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_node_key(node_id: str, category: str) -> str:
|
def _create_node_key(node_id: str, category: str) -> str:
|
||||||
|
|
@ -83,7 +89,7 @@ def _process_ontology_edges(
|
||||||
|
|
||||||
def _create_type_node(
|
def _create_type_node(
|
||||||
node_type: str,
|
node_type: str,
|
||||||
ontology_resolver: OntologyResolver,
|
ontology_resolver: RDFLibOntologyResolver,
|
||||||
added_nodes_map: dict,
|
added_nodes_map: dict,
|
||||||
added_ontology_nodes_map: dict,
|
added_ontology_nodes_map: dict,
|
||||||
name_mapping: dict,
|
name_mapping: dict,
|
||||||
|
|
@ -141,7 +147,7 @@ def _create_entity_node(
|
||||||
node_name: str,
|
node_name: str,
|
||||||
node_description: str,
|
node_description: str,
|
||||||
type_node: EntityType,
|
type_node: EntityType,
|
||||||
ontology_resolver: OntologyResolver,
|
ontology_resolver: RDFLibOntologyResolver,
|
||||||
added_nodes_map: dict,
|
added_nodes_map: dict,
|
||||||
added_ontology_nodes_map: dict,
|
added_ontology_nodes_map: dict,
|
||||||
name_mapping: dict,
|
name_mapping: dict,
|
||||||
|
|
@ -198,7 +204,7 @@ def _create_entity_node(
|
||||||
def _process_graph_nodes(
|
def _process_graph_nodes(
|
||||||
data_chunk: DocumentChunk,
|
data_chunk: DocumentChunk,
|
||||||
graph: KnowledgeGraph,
|
graph: KnowledgeGraph,
|
||||||
ontology_resolver: OntologyResolver,
|
ontology_resolver: RDFLibOntologyResolver,
|
||||||
added_nodes_map: dict,
|
added_nodes_map: dict,
|
||||||
added_ontology_nodes_map: dict,
|
added_ontology_nodes_map: dict,
|
||||||
name_mapping: dict,
|
name_mapping: dict,
|
||||||
|
|
@ -277,7 +283,7 @@ def _process_graph_edges(
|
||||||
def expand_with_nodes_and_edges(
|
def expand_with_nodes_and_edges(
|
||||||
data_chunks: list[DocumentChunk],
|
data_chunks: list[DocumentChunk],
|
||||||
chunk_graphs: list[KnowledgeGraph],
|
chunk_graphs: list[KnowledgeGraph],
|
||||||
ontology_resolver: OntologyResolver = None,
|
ontology_resolver: BaseOntologyResolver = None,
|
||||||
existing_edges_map: Optional[dict[str, bool]] = None,
|
existing_edges_map: Optional[dict[str, bool]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,8 +302,8 @@ def expand_with_nodes_and_edges(
|
||||||
chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each
|
chunk_graphs (list[KnowledgeGraph]): List of knowledge graphs corresponding to each
|
||||||
data chunk. Each graph contains nodes (entities) and edges (relationships) extracted
|
data chunk. Each graph contains nodes (entities) and edges (relationships) extracted
|
||||||
from the chunk content.
|
from the chunk content.
|
||||||
ontology_resolver (OntologyResolver, optional): Resolver for validating entities and
|
ontology_resolver (BaseOntologyResolver, optional): Resolver for validating entities and
|
||||||
types against an ontology. If None, a default OntologyResolver is created.
|
types against an ontology. If None, a default RDFLibOntologyResolver is created.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
existing_edges_map (dict[str, bool], optional): Mapping of existing edge keys to prevent
|
existing_edges_map (dict[str, bool], optional): Mapping of existing edge keys to prevent
|
||||||
duplicate edge creation. Keys are formatted as "{source_id}_{target_id}_{relation}".
|
duplicate edge creation. Keys are formatted as "{source_id}_{target_id}_{relation}".
|
||||||
|
|
@ -320,7 +326,15 @@ def expand_with_nodes_and_edges(
|
||||||
existing_edges_map = {}
|
existing_edges_map = {}
|
||||||
|
|
||||||
if ontology_resolver is None:
|
if ontology_resolver is None:
|
||||||
ontology_resolver = OntologyResolver()
|
ontology_config = get_ontology_env_config()
|
||||||
|
if (
|
||||||
|
ontology_config.ontology_file_path
|
||||||
|
and ontology_config.ontology_resolver
|
||||||
|
and ontology_config.matching_strategy
|
||||||
|
):
|
||||||
|
ontology_resolver = get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||||
|
else:
|
||||||
|
ontology_resolver = get_default_ontology_resolver()
|
||||||
|
|
||||||
added_nodes_map = {}
|
added_nodes_map = {}
|
||||||
added_ontology_nodes_map = {}
|
added_ontology_nodes_map = {}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,40 @@ from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
|
||||||
from ..models.Notebook import Notebook, NotebookCell
|
from ..models.Notebook import Notebook, NotebookCell
|
||||||
|
|
||||||
|
TUTORIAL_NOTEBOOK_NAME = "Python Development with Cognee Tutorial 🧠"
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_tutorial_notebook(
|
||||||
|
user_id: UUID, session: AsyncSession, force_refresh: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create the default tutorial notebook for new users.
|
||||||
|
Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
|
||||||
|
"""
|
||||||
|
TUTORIAL_ZIP_URL = (
|
||||||
|
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create notebook from remote zip file (includes notebook + data files)
|
||||||
|
notebook = await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url=TUTORIAL_ZIP_URL,
|
||||||
|
owner_id=user_id,
|
||||||
|
notebook_filename="tutorial.ipynb",
|
||||||
|
name=TUTORIAL_NOTEBOOK_NAME,
|
||||||
|
deletable=False,
|
||||||
|
force=force_refresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to session and commit
|
||||||
|
session.add(notebook)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
@with_async_session
|
||||||
async def create_notebook(
|
async def create_notebook(
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,16 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List
|
from typing import List
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, and_
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
|
||||||
from ..models.Notebook import Notebook
|
from ..models.Notebook import Notebook
|
||||||
|
from .create_notebook import _create_tutorial_notebook, TUTORIAL_NOTEBOOK_NAME
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
@with_async_session
|
||||||
|
|
@ -13,6 +18,27 @@ async def get_notebooks(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
) -> List[Notebook]:
|
) -> List[Notebook]:
|
||||||
|
# Check if tutorial notebook already exists for this user
|
||||||
|
tutorial_query = select(Notebook).where(
|
||||||
|
and_(
|
||||||
|
Notebook.owner_id == user_id,
|
||||||
|
Notebook.name == TUTORIAL_NOTEBOOK_NAME,
|
||||||
|
~Notebook.deletable,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tutorial_result = await session.execute(tutorial_query)
|
||||||
|
tutorial_notebook = tutorial_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# If tutorial notebook doesn't exist, create it
|
||||||
|
if tutorial_notebook is None:
|
||||||
|
logger.info(f"Tutorial notebook not found for user {user_id}, creating it")
|
||||||
|
try:
|
||||||
|
await _create_tutorial_notebook(user_id, session, force_refresh=False)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but continue to return existing notebooks
|
||||||
|
logger.error(f"Failed to create tutorial notebook for user {user_id}: {e}")
|
||||||
|
|
||||||
|
# Get all notebooks for the user
|
||||||
result = await session.execute(select(Notebook).where(Notebook.owner_id == user_id))
|
result = await session.execute(select(Notebook).where(Notebook.owner_id == user_id))
|
||||||
|
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,24 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal
|
import nbformat
|
||||||
|
import asyncio
|
||||||
|
from nbformat.notebooknode import NotebookNode
|
||||||
|
from typing import List, Literal, Optional, cast, Tuple
|
||||||
from uuid import uuid4, UUID as UUID_t
|
from uuid import uuid4, UUID as UUID_t
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
||||||
from sqlalchemy.orm import mapped_column, Mapped
|
from sqlalchemy.orm import mapped_column, Mapped
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
from cognee.shared.cache import (
|
||||||
|
download_and_extract_zip,
|
||||||
|
get_tutorial_data_dir,
|
||||||
|
generate_content_hash,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
|
||||||
|
|
||||||
class NotebookCell(BaseModel):
|
class NotebookCell(BaseModel):
|
||||||
|
|
@ -51,3 +62,197 @@ class Notebook(Base):
|
||||||
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_ipynb_zip_url(
|
||||||
|
cls,
|
||||||
|
zip_url: str,
|
||||||
|
owner_id: UUID_t,
|
||||||
|
notebook_filename: str = "tutorial.ipynb",
|
||||||
|
name: Optional[str] = None,
|
||||||
|
deletable: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
) -> "Notebook":
|
||||||
|
"""
|
||||||
|
Create a Notebook instance from a remote zip file containing notebook + data files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_url: Remote URL to fetch the .zip file from
|
||||||
|
owner_id: UUID of the notebook owner
|
||||||
|
notebook_filename: Name of the .ipynb file within the zip
|
||||||
|
name: Optional custom name for the notebook
|
||||||
|
deletable: Whether the notebook can be deleted
|
||||||
|
force: If True, re-download even if already cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Notebook instance
|
||||||
|
"""
|
||||||
|
# Generate a cache key based on the zip URL
|
||||||
|
content_hash = generate_content_hash(zip_url, notebook_filename)
|
||||||
|
|
||||||
|
# Download and extract the zip file to tutorial_data/{content_hash}
|
||||||
|
try:
|
||||||
|
extracted_cache_dir = await download_and_extract_zip(
|
||||||
|
url=zip_url,
|
||||||
|
cache_dir_name=f"tutorial_data/{content_hash}",
|
||||||
|
version_or_hash=content_hash,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
|
||||||
|
|
||||||
|
# Use cache system to access the notebook file
|
||||||
|
from cognee.shared.cache import cache_file_exists, read_cache_file
|
||||||
|
|
||||||
|
notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
|
||||||
|
|
||||||
|
# Check if the notebook file exists in cache
|
||||||
|
if not await cache_file_exists(notebook_file_path):
|
||||||
|
raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
|
||||||
|
|
||||||
|
# Read and parse the notebook using cache system
|
||||||
|
async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
|
||||||
|
notebook_content = await asyncio.to_thread(f.read)
|
||||||
|
notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
|
||||||
|
|
||||||
|
# Update file paths in notebook cells to point to actual cached data files
|
||||||
|
await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
|
||||||
|
|
||||||
|
return notebook
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Update file paths in code cells to use actual cached data files.
|
||||||
|
Works with both local filesystem and S3 storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook: Parsed Notebook instance with cells to update
|
||||||
|
cache_dir: Path to the cached tutorial directory containing data files
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from cognee.shared.cache import list_cache_files, cache_file_exists
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Look for data files in the data subdirectory
|
||||||
|
data_dir = f"{cache_dir}/data"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all data files in the cache directory using cache system
|
||||||
|
data_files = {}
|
||||||
|
if await cache_file_exists(data_dir):
|
||||||
|
file_list = await list_cache_files(data_dir)
|
||||||
|
else:
|
||||||
|
file_list = []
|
||||||
|
|
||||||
|
for file_path in file_list:
|
||||||
|
# Extract just the filename
|
||||||
|
filename = file_path.split("/")[-1]
|
||||||
|
# Use the file path as provided by cache system
|
||||||
|
data_files[filename] = file_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't list files, skip updating paths
|
||||||
|
logger.error(f"Error listing data files in {data_dir}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pattern to match file://data/filename patterns in code cells
|
||||||
|
file_pattern = r'"file://data/([^"]+)"'
|
||||||
|
|
||||||
|
def replace_path(match):
|
||||||
|
filename = match.group(1)
|
||||||
|
if filename in data_files:
|
||||||
|
file_path = data_files[filename]
|
||||||
|
# For local filesystem, preserve file:// prefix
|
||||||
|
if not file_path.startswith("s3://"):
|
||||||
|
return f'"file://{file_path}"'
|
||||||
|
else:
|
||||||
|
# For S3, return the S3 URL as-is
|
||||||
|
return f'"{file_path}"'
|
||||||
|
return match.group(0) # Keep original if file not found
|
||||||
|
|
||||||
|
# Update only code cells
|
||||||
|
updated_cells = 0
|
||||||
|
for cell in notebook.cells:
|
||||||
|
if cell.type == "code":
|
||||||
|
original_content = cell.content
|
||||||
|
# Update file paths in the cell content
|
||||||
|
cell.content = re.sub(file_pattern, replace_path, cell.content)
|
||||||
|
if original_content != cell.content:
|
||||||
|
updated_cells += 1
|
||||||
|
|
||||||
|
# Log summary of updates (useful for monitoring)
|
||||||
|
if updated_cells > 0:
|
||||||
|
logger.info(f"Updated file paths in {updated_cells} notebook cells")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ipynb_string(
|
||||||
|
cls,
|
||||||
|
notebook_content: str,
|
||||||
|
owner_id: UUID_t,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
deletable: bool = True,
|
||||||
|
) -> "Notebook":
|
||||||
|
"""
|
||||||
|
Create a Notebook instance from Jupyter notebook string content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_content: Raw Jupyter notebook content as string
|
||||||
|
owner_id: UUID of the notebook owner
|
||||||
|
name: Optional custom name for the notebook
|
||||||
|
deletable: Whether the notebook can be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Notebook instance ready to be saved to database
|
||||||
|
"""
|
||||||
|
# Parse and validate the Jupyter notebook using nbformat
|
||||||
|
# Note: nbformat.reads() has loose typing, so we cast to NotebookNode
|
||||||
|
jupyter_nb = cast(
|
||||||
|
NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert Jupyter cells to NotebookCell objects
|
||||||
|
cells = []
|
||||||
|
for jupyter_cell in jupyter_nb.cells:
|
||||||
|
# Each cell is also a NotebookNode with dynamic attributes
|
||||||
|
cell = cast(NotebookNode, jupyter_cell)
|
||||||
|
# Skip raw cells as they're not supported in our model
|
||||||
|
if cell.cell_type == "raw":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the source content
|
||||||
|
content = cell.source
|
||||||
|
|
||||||
|
# Generate a name based on content or cell index
|
||||||
|
cell_name = cls._generate_cell_name(cell)
|
||||||
|
|
||||||
|
# Map cell types (jupyter uses "code"/"markdown", we use same)
|
||||||
|
cell_type = "code" if cell.cell_type == "code" else "markdown"
|
||||||
|
|
||||||
|
cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
|
||||||
|
|
||||||
|
# Extract notebook name from metadata if not provided
|
||||||
|
if name is None:
|
||||||
|
kernelspec = jupyter_nb.metadata.get("kernelspec", {})
|
||||||
|
name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
|
||||||
|
|
||||||
|
return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
|
||||||
|
"""Generate a meaningful name for a notebook cell using nbformat cell."""
|
||||||
|
if jupyter_cell.cell_type == "markdown":
|
||||||
|
# Try to extract a title from markdown headers
|
||||||
|
content = jupyter_cell.source
|
||||||
|
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
if lines and lines[0].startswith("#"):
|
||||||
|
# Extract header text, clean it up
|
||||||
|
header = lines[0].lstrip("#").strip()
|
||||||
|
return header[:50] if len(header) > 50 else header
|
||||||
|
else:
|
||||||
|
return "Markdown Cell"
|
||||||
|
else:
|
||||||
|
return "Code Cell"
|
||||||
|
|
|
||||||
|
|
@ -9,3 +9,17 @@ def get_observe():
|
||||||
from langfuse.decorators import observe
|
from langfuse.decorators import observe
|
||||||
|
|
||||||
return observe
|
return observe
|
||||||
|
elif monitoring == Observer.NONE:
|
||||||
|
# Return a no-op decorator that handles keyword arguments
|
||||||
|
def no_op_decorator(*args, **kwargs):
|
||||||
|
if len(args) == 1 and callable(args[0]) and not kwargs:
|
||||||
|
# Direct decoration: @observe
|
||||||
|
return args[0]
|
||||||
|
else:
|
||||||
|
# Parameterized decoration: @observe(as_type="generation")
|
||||||
|
def decorator(func):
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
return no_op_decorator
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from enum import Enum
|
||||||
class Observer(str, Enum):
|
class Observer(str, Enum):
|
||||||
"""Monitoring tools"""
|
"""Monitoring tools"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
LANGFUSE = "langfuse"
|
LANGFUSE = "langfuse"
|
||||||
LLMLITE = "llmlite"
|
LLMLITE = "llmlite"
|
||||||
LANGSMITH = "langsmith"
|
LANGSMITH = "langsmith"
|
||||||
|
|
|
||||||
42
cognee/modules/ontology/base_ontology_resolver.py
Normal file
42
cognee/modules/ontology/base_ontology_resolver.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from cognee.modules.ontology.models import AttachedOntologyNode
|
||||||
|
from cognee.modules.ontology.matching_strategies import MatchingStrategy, FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOntologyResolver(ABC):
|
||||||
|
"""Abstract base class for ontology resolvers."""
|
||||||
|
|
||||||
|
def __init__(self, matching_strategy: Optional[MatchingStrategy] = None):
|
||||||
|
"""Initialize the ontology resolver with a matching strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matching_strategy: The strategy to use for entity matching.
|
||||||
|
Defaults to FuzzyMatchingStrategy if None.
|
||||||
|
"""
|
||||||
|
self.matching_strategy = matching_strategy or FuzzyMatchingStrategy()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_lookup(self) -> None:
|
||||||
|
"""Build the lookup dictionary for ontology entities."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def refresh_lookup(self) -> None:
|
||||||
|
"""Refresh the lookup dictionary."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def find_closest_match(self, name: str, category: str) -> Optional[str]:
|
||||||
|
"""Find the closest match for a given name in the specified category."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_subgraph(
|
||||||
|
self, node_name: str, node_type: str = "individuals", directed: bool = True
|
||||||
|
) -> Tuple[
|
||||||
|
List[AttachedOntologyNode], List[Tuple[str, str, str]], Optional[AttachedOntologyNode]
|
||||||
|
]:
|
||||||
|
"""Get a subgraph for the given node."""
|
||||||
|
pass
|
||||||
41
cognee/modules/ontology/get_default_ontology_resolver.py
Normal file
41
cognee/modules/ontology/get_default_ontology_resolver.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_ontology_resolver() -> BaseOntologyResolver:
|
||||||
|
return RDFLibOntologyResolver(ontology_file=None, matching_strategy=FuzzyMatchingStrategy())
|
||||||
|
|
||||||
|
|
||||||
|
def get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver: str = "", matching_strategy: str = "", ontology_file_path: str = ""
|
||||||
|
) -> BaseOntologyResolver:
|
||||||
|
"""
|
||||||
|
Create and return an ontology resolver instance based on environment parameters.
|
||||||
|
|
||||||
|
Currently, this function supports only the RDFLib-based ontology resolver
|
||||||
|
with a fuzzy matching strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ontology_resolver (str): The ontology resolver type to use.
|
||||||
|
Supported value: "rdflib".
|
||||||
|
matching_strategy (str): The matching strategy to apply.
|
||||||
|
Supported value: "fuzzy".
|
||||||
|
ontology_file_path (str): Path to the ontology file required for the resolver.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseOntologyResolver: An instance of the requested ontology resolver.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EnvironmentError: If the provided resolver or strategy is unsupported,
|
||||||
|
or if required parameters are missing.
|
||||||
|
"""
|
||||||
|
if ontology_resolver == "rdflib" and matching_strategy == "fuzzy" and ontology_file_path:
|
||||||
|
return RDFLibOntologyResolver(
|
||||||
|
matching_strategy=FuzzyMatchingStrategy(), ontology_file=ontology_file_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Unsupported ontology resolver: {ontology_resolver}. "
|
||||||
|
f"Supported resolvers are: RdfLib with FuzzyMatchingStrategy."
|
||||||
|
)
|
||||||
53
cognee/modules/ontology/matching_strategies.py
Normal file
53
cognee/modules/ontology/matching_strategies.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
import difflib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class MatchingStrategy(ABC):
|
||||||
|
"""Abstract base class for ontology entity matching strategies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def find_match(self, name: str, candidates: List[str]) -> Optional[str]:
|
||||||
|
"""Find the best match for a given name from a list of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to match
|
||||||
|
candidates: List of candidate names to match against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The best matching candidate name, or None if no match found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FuzzyMatchingStrategy(MatchingStrategy):
|
||||||
|
"""Fuzzy matching strategy using difflib for approximate string matching."""
|
||||||
|
|
||||||
|
def __init__(self, cutoff: float = 0.8):
|
||||||
|
"""Initialize fuzzy matching strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cutoff: Minimum similarity score (0.0 to 1.0) for a match to be considered valid
|
||||||
|
"""
|
||||||
|
self.cutoff = cutoff
|
||||||
|
|
||||||
|
def find_match(self, name: str, candidates: List[str]) -> Optional[str]:
|
||||||
|
"""Find the closest fuzzy match for a given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The normalized name to match
|
||||||
|
candidates: List of normalized candidate names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The best matching candidate name, or None if no match meets the cutoff
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check for exact match first
|
||||||
|
if name in candidates:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# Find fuzzy match
|
||||||
|
best_match = difflib.get_close_matches(name, candidates, n=1, cutoff=self.cutoff)
|
||||||
|
return best_match[0] if best_match else None
|
||||||
20
cognee/modules/ontology/models.py
Normal file
20
cognee/modules/ontology/models.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class AttachedOntologyNode:
|
||||||
|
"""Lightweight wrapper to be able to parse any ontology solution and generalize cognee interface."""
|
||||||
|
|
||||||
|
def __init__(self, uri: Any, category: str):
|
||||||
|
self.uri = uri
|
||||||
|
self.name = self._extract_name(uri)
|
||||||
|
self.category = category
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_name(uri: Any) -> str:
|
||||||
|
uri_str = str(uri)
|
||||||
|
if "#" in uri_str:
|
||||||
|
return uri_str.split("#")[-1]
|
||||||
|
return uri_str.rstrip("/").split("/")[-1]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AttachedOntologyNode(name={self.name}, category={self.category})"
|
||||||
24
cognee/modules/ontology/ontology_config.py
Normal file
24
cognee/modules/ontology/ontology_config.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import MatchingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyConfig(TypedDict, total=False):
|
||||||
|
"""Configuration containing ontology resolver.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ontology_resolver: The ontology resolver instance to use
|
||||||
|
"""
|
||||||
|
|
||||||
|
ontology_resolver: Optional[BaseOntologyResolver]
|
||||||
|
|
||||||
|
|
||||||
|
class Config(TypedDict, total=False):
|
||||||
|
"""Top-level configuration dictionary.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ontology_config: Configuration containing ontology resolver
|
||||||
|
"""
|
||||||
|
|
||||||
|
ontology_config: Optional[OntologyConfig]
|
||||||
45
cognee/modules/ontology/ontology_env_config.py
Normal file
45
cognee/modules/ontology/ontology_env_config.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""This module contains the configuration for ontology handling."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyEnvConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Represents the configuration for ontology handling, including parameters for
|
||||||
|
ontology file storage and resolution/matching strategies.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
- to_dict
|
||||||
|
|
||||||
|
Instance variables:
|
||||||
|
- ontology_resolver
|
||||||
|
- ontology_matching
|
||||||
|
- ontology_file_path
|
||||||
|
- model_config
|
||||||
|
"""
|
||||||
|
|
||||||
|
ontology_resolver: str = "rdflib"
|
||||||
|
matching_strategy: str = "fuzzy"
|
||||||
|
ontology_file_path: str = ""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Return the configuration as a dictionary.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"ontology_resolver": self.ontology_resolver,
|
||||||
|
"matching_strategy": self.matching_strategy,
|
||||||
|
"ontology_file_path": self.ontology_file_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_ontology_env_config():
|
||||||
|
"""
|
||||||
|
Retrieve the ontology configuration. This function utilizes caching to return a
|
||||||
|
singleton instance of the OntologyConfig class for efficiency.
|
||||||
|
"""
|
||||||
|
return OntologyEnvConfig()
|
||||||
|
|
@ -10,31 +10,26 @@ from cognee.modules.ontology.exceptions import (
|
||||||
FindClosestMatchError,
|
FindClosestMatchError,
|
||||||
GetSubgraphError,
|
GetSubgraphError,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.models import AttachedOntologyNode
|
||||||
|
from cognee.modules.ontology.matching_strategies import MatchingStrategy, FuzzyMatchingStrategy
|
||||||
|
|
||||||
logger = get_logger("OntologyAdapter")
|
logger = get_logger("OntologyAdapter")
|
||||||
|
|
||||||
|
|
||||||
class AttachedOntologyNode:
|
class RDFLibOntologyResolver(BaseOntologyResolver):
|
||||||
"""Lightweight wrapper to be able to parse any ontology solution and generalize cognee interface."""
|
"""RDFLib-based ontology resolver implementation.
|
||||||
|
|
||||||
def __init__(self, uri: URIRef, category: str):
|
This implementation uses RDFLib to parse and work with RDF/OWL ontology files.
|
||||||
self.uri = uri
|
It provides fuzzy matching and subgraph extraction capabilities for ontology entities.
|
||||||
self.name = self._extract_name(uri)
|
"""
|
||||||
self.category = category
|
|
||||||
|
|
||||||
@staticmethod
|
def __init__(
|
||||||
def _extract_name(uri: URIRef) -> str:
|
self,
|
||||||
uri_str = str(uri)
|
ontology_file: Optional[str] = None,
|
||||||
if "#" in uri_str:
|
matching_strategy: Optional[MatchingStrategy] = None,
|
||||||
return uri_str.split("#")[-1]
|
) -> None:
|
||||||
return uri_str.rstrip("/").split("/")[-1]
|
super().__init__(matching_strategy)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"AttachedOntologyNode(name={self.name}, category={self.category})"
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyResolver:
|
|
||||||
def __init__(self, ontology_file: Optional[str] = None):
|
|
||||||
self.ontology_file = ontology_file
|
self.ontology_file = ontology_file
|
||||||
try:
|
try:
|
||||||
if ontology_file and os.path.exists(ontology_file):
|
if ontology_file and os.path.exists(ontology_file):
|
||||||
|
|
@ -60,7 +55,7 @@ class OntologyResolver:
|
||||||
name = uri_str.rstrip("/").split("/")[-1]
|
name = uri_str.rstrip("/").split("/")[-1]
|
||||||
return name.lower().replace(" ", "_").strip()
|
return name.lower().replace(" ", "_").strip()
|
||||||
|
|
||||||
def build_lookup(self):
|
def build_lookup(self) -> None:
|
||||||
try:
|
try:
|
||||||
classes: Dict[str, URIRef] = {}
|
classes: Dict[str, URIRef] = {}
|
||||||
individuals: Dict[str, URIRef] = {}
|
individuals: Dict[str, URIRef] = {}
|
||||||
|
|
@ -97,7 +92,7 @@ class OntologyResolver:
|
||||||
logger.error("Failed to build lookup dictionary: %s", str(e))
|
logger.error("Failed to build lookup dictionary: %s", str(e))
|
||||||
raise RuntimeError("Lookup build failed") from e
|
raise RuntimeError("Lookup build failed") from e
|
||||||
|
|
||||||
def refresh_lookup(self):
|
def refresh_lookup(self) -> None:
|
||||||
self.build_lookup()
|
self.build_lookup()
|
||||||
logger.info("Ontology lookup refreshed.")
|
logger.info("Ontology lookup refreshed.")
|
||||||
|
|
||||||
|
|
@ -105,13 +100,8 @@ class OntologyResolver:
|
||||||
try:
|
try:
|
||||||
normalized_name = name.lower().replace(" ", "_").strip()
|
normalized_name = name.lower().replace(" ", "_").strip()
|
||||||
possible_matches = list(self.lookup.get(category, {}).keys())
|
possible_matches = list(self.lookup.get(category, {}).keys())
|
||||||
if normalized_name in possible_matches:
|
|
||||||
return normalized_name
|
|
||||||
|
|
||||||
best_match = difflib.get_close_matches(
|
return self.matching_strategy.find_match(normalized_name, possible_matches)
|
||||||
normalized_name, possible_matches, n=1, cutoff=0.8
|
|
||||||
)
|
|
||||||
return best_match[0] if best_match else None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in find_closest_match: %s", str(e))
|
logger.error("Error in find_closest_match: %s", str(e))
|
||||||
raise FindClosestMatchError() from e
|
raise FindClosestMatchError() from e
|
||||||
|
|
@ -125,7 +115,9 @@ class OntologyResolver:
|
||||||
|
|
||||||
def get_subgraph(
|
def get_subgraph(
|
||||||
self, node_name: str, node_type: str = "individuals", directed: bool = True
|
self, node_name: str, node_type: str = "individuals", directed: bool = True
|
||||||
) -> Tuple[List[Any], List[Tuple[str, str, str]], Optional[Any]]:
|
) -> Tuple[
|
||||||
|
List[AttachedOntologyNode], List[Tuple[str, str, str]], Optional[AttachedOntologyNode]
|
||||||
|
]:
|
||||||
nodes_set = set()
|
nodes_set = set()
|
||||||
edges: List[Tuple[str, str, str]] = []
|
edges: List[Tuple[str, str, str]] = []
|
||||||
visited = set()
|
visited = set()
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.modules.data.models.Data import Data
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunInfo(BaseModel):
|
class PipelineRunInfo(BaseModel):
|
||||||
|
|
@ -8,11 +9,15 @@ class PipelineRunInfo(BaseModel):
|
||||||
pipeline_run_id: UUID
|
pipeline_run_id: UUID
|
||||||
dataset_id: UUID
|
dataset_id: UUID
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
payload: Optional[Any] = None
|
# Data must be mentioned in typing to allow custom encoders for Data to be activated
|
||||||
|
payload: Optional[Union[Any, List[Data]]] = None
|
||||||
data_ingestion_info: Optional[list] = None
|
data_ingestion_info: Optional[list] = None
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"arbitrary_types_allowed": True,
|
"arbitrary_types_allowed": True,
|
||||||
|
"from_attributes": True,
|
||||||
|
# Add custom encoding handler for Data ORM model
|
||||||
|
"json_encoders": {Data: lambda d: d.to_json()},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
logger.info(
|
logger.info(
|
||||||
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
||||||
)
|
)
|
||||||
triplets = await self.get_context(query)
|
triplets = await self.get_triplets(query)
|
||||||
return await self.resolve_edges_to_text(triplets)
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
|
|
@ -122,7 +122,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
logger.info(
|
logger.info(
|
||||||
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
||||||
)
|
)
|
||||||
triplets = await self.get_context(query)
|
triplets = await self.get_triplets(query)
|
||||||
return await self.resolve_edges_to_text(triplets)
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from typing import Callable, List, Optional, Type
|
from typing import Callable, List, Optional, Type
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
|
|
@ -160,6 +161,12 @@ async def get_search_type_tools(
|
||||||
if query_type is SearchType.FEELING_LUCKY:
|
if query_type is SearchType.FEELING_LUCKY:
|
||||||
query_type = await select_search_type(query_text)
|
query_type = await select_search_type(query_text)
|
||||||
|
|
||||||
|
if (
|
||||||
|
query_type in [SearchType.CYPHER, SearchType.NATURAL_LANGUAGE]
|
||||||
|
and os.getenv("ALLOW_CYPHER_QUERY", "true").lower() == "false"
|
||||||
|
):
|
||||||
|
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
||||||
|
|
||||||
search_type_tools = search_tasks.get(query_type)
|
search_type_tools = search_tasks.get(query_type)
|
||||||
|
|
||||||
if not search_type_tools:
|
if not search_type_tools:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from uuid import uuid4
|
from uuid import UUID, uuid4
|
||||||
from fastapi_users.exceptions import UserAlreadyExists
|
from fastapi_users.exceptions import UserAlreadyExists
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.notebooks.methods import create_notebook
|
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||||
from cognee.modules.notebooks.models.Notebook import NotebookCell
|
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||||
from cognee.modules.users.exceptions import TenantNotFoundError
|
from cognee.modules.users.exceptions import TenantNotFoundError
|
||||||
from cognee.modules.users.get_user_manager import get_user_manager_context
|
from cognee.modules.users.get_user_manager import get_user_manager_context
|
||||||
from cognee.modules.users.get_user_db import get_user_db_context
|
from cognee.modules.users.get_user_db import get_user_db_context
|
||||||
|
|
@ -60,27 +61,6 @@ async def create_user(
|
||||||
if auto_login:
|
if auto_login:
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
|
|
||||||
await create_notebook(
|
|
||||||
user_id=user.id,
|
|
||||||
notebook_name="Welcome to cognee 🧠",
|
|
||||||
cells=[
|
|
||||||
NotebookCell(
|
|
||||||
id=uuid4(),
|
|
||||||
name="Welcome",
|
|
||||||
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
|
|
||||||
type="markdown",
|
|
||||||
),
|
|
||||||
NotebookCell(
|
|
||||||
id=uuid4(),
|
|
||||||
name="Example",
|
|
||||||
content="",
|
|
||||||
type="code",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
deletable=False,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
except UserAlreadyExists as error:
|
except UserAlreadyExists as error:
|
||||||
print(f"User {email} already exists")
|
print(f"User {email} already exists")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from ...models.User import User
|
from ...models.User import User
|
||||||
|
|
@ -27,9 +29,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
|
||||||
# Get all datasets all tenants have access to
|
# Get all datasets all tenants have access to
|
||||||
tenant = await get_tenant(user.tenant_id)
|
tenant = await get_tenant(user.tenant_id)
|
||||||
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
||||||
|
|
||||||
# Get all datasets Users roles have access to
|
# Get all datasets Users roles have access to
|
||||||
for role_name in user.roles:
|
if isinstance(user, SimpleNamespace):
|
||||||
role = await get_role(user.tenant_id, role_name)
|
# If simple namespace use roles defined in user
|
||||||
|
roles = user.roles
|
||||||
|
else:
|
||||||
|
roles = await user.awaitable_attrs.roles
|
||||||
|
for role in roles:
|
||||||
datasets.extend(await get_principal_datasets(role, permission_type))
|
datasets.extend(await get_principal_datasets(role, permission_type))
|
||||||
|
|
||||||
# Deduplicate datasets with same ID
|
# Deduplicate datasets with same ID
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,11 @@ def ensure_absolute_path(path: str) -> str:
|
||||||
"""
|
"""
|
||||||
if path is None:
|
if path is None:
|
||||||
raise ValueError("Path cannot be None")
|
raise ValueError("Path cannot be None")
|
||||||
|
|
||||||
|
# Check if it's an S3 URL - S3 URLs are absolute by definition
|
||||||
|
if path.startswith("s3://"):
|
||||||
|
return path
|
||||||
|
|
||||||
path_obj = Path(path).expanduser()
|
path_obj = Path(path).expanduser()
|
||||||
if path_obj.is_absolute():
|
if path_obj.is_absolute():
|
||||||
return str(path_obj.resolve())
|
return str(path_obj.resolve())
|
||||||
|
|
|
||||||
346
cognee/shared/cache.py
Normal file
346
cognee/shared/cache.py
Normal file
|
|
@ -0,0 +1,346 @@
|
||||||
|
"""
|
||||||
|
Storage-aware cache management utilities for Cognee.
|
||||||
|
|
||||||
|
This module provides cache functionality that works with both local and cloud storage
|
||||||
|
backends (like S3) through the StorageManager abstraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import zipfile
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||||
|
from cognee.infrastructure.files.storage.StorageManager import StorageManager
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageAwareCache:
|
||||||
|
"""
|
||||||
|
A cache manager that works with different storage backends (local, S3, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_subdir: str = "cache"):
|
||||||
|
"""
|
||||||
|
Initialize the cache manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_subdir: Subdirectory name within the system root for caching
|
||||||
|
"""
|
||||||
|
self.base_config = get_base_config()
|
||||||
|
# Since we're using cache_root_directory, don't add extra cache prefix
|
||||||
|
self.cache_base_path = ""
|
||||||
|
self.storage_manager: StorageManager = get_file_storage(
|
||||||
|
self.base_config.cache_root_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print absolute path
|
||||||
|
storage_path = self.storage_manager.storage.storage_path
|
||||||
|
if storage_path.startswith("s3://"):
|
||||||
|
absolute_path = storage_path # S3 paths are already absolute
|
||||||
|
else:
|
||||||
|
import os
|
||||||
|
|
||||||
|
absolute_path = os.path.abspath(storage_path)
|
||||||
|
logger.info(f"Storage manager absolute path: {absolute_path}")
|
||||||
|
|
||||||
|
async def get_cache_dir(self) -> str:
|
||||||
|
"""Get the base cache directory path."""
|
||||||
|
cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
async def get_cache_subdir(self, name: str) -> str:
|
||||||
|
"""Get a specific cache subdirectory."""
|
||||||
|
if self.cache_base_path:
|
||||||
|
cache_path = f"{self.cache_base_path}/{name}"
|
||||||
|
else:
|
||||||
|
cache_path = name
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||||
|
|
||||||
|
# Return the absolute path based on storage system
|
||||||
|
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||||
|
return cache_path
|
||||||
|
elif hasattr(self.storage_manager.storage, "storage_path"):
|
||||||
|
return f"{self.storage_manager.storage.storage_path}/{cache_path}"
|
||||||
|
else:
|
||||||
|
# Fallback for other storage types
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
async def delete_cache(self):
|
||||||
|
"""Delete the entire cache directory."""
|
||||||
|
logger.info("Deleting cache...")
|
||||||
|
try:
|
||||||
|
await self.storage_manager.remove_all(self.cache_base_path)
|
||||||
|
logger.info("✓ Cache deleted successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting cache: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool:
|
||||||
|
"""Check if cached content is valid for the given version/hash."""
|
||||||
|
version_file = f"{cache_dir}/version.txt"
|
||||||
|
|
||||||
|
if not await self.storage_manager.file_exists(version_file):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.storage_manager.open(version_file, "r") as f:
|
||||||
|
cached_version = (await asyncio.to_thread(f.read)).strip()
|
||||||
|
return cached_version == version_or_hash
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error checking cache validity: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _clear_cache(self, cache_dir: str) -> None:
|
||||||
|
"""Clear a cache directory."""
|
||||||
|
try:
|
||||||
|
await self.storage_manager.remove_all(cache_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error clearing cache directory {cache_dir}: {e}")
|
||||||
|
|
||||||
|
async def _check_remote_content_freshness(
|
||||||
|
self, url: str, cache_dir: str
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Check if remote content is fresher than cached version using HTTP headers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_fresh: bool, new_identifier: Optional[str])
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Make a HEAD request to check headers without downloading
|
||||||
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Try ETag first (most reliable)
|
||||||
|
etag = response.headers.get("ETag", "").strip('"')
|
||||||
|
last_modified = response.headers.get("Last-Modified", "")
|
||||||
|
|
||||||
|
# Use ETag if available, otherwise Last-Modified
|
||||||
|
remote_identifier = etag if etag else last_modified
|
||||||
|
|
||||||
|
if not remote_identifier:
|
||||||
|
logger.debug("No freshness headers available, cannot check for updates")
|
||||||
|
return True, None # Assume fresh if no headers
|
||||||
|
|
||||||
|
# Check cached identifier
|
||||||
|
identifier_file = f"{cache_dir}/content_id.txt"
|
||||||
|
if await self.storage_manager.file_exists(identifier_file):
|
||||||
|
async with self.storage_manager.open(identifier_file, "r") as f:
|
||||||
|
cached_identifier = (await asyncio.to_thread(f.read)).strip()
|
||||||
|
if cached_identifier == remote_identifier:
|
||||||
|
logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)")
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)"
|
||||||
|
)
|
||||||
|
return False, remote_identifier
|
||||||
|
else:
|
||||||
|
# No cached identifier, treat as stale
|
||||||
|
return False, remote_identifier
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not check remote freshness: {e}")
|
||||||
|
return True, None # Assume fresh if we can't check
|
||||||
|
|
||||||
|
async def download_and_extract_zip(
|
||||||
|
self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Download a zip file and extract it to cache directory with content freshness checking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to download zip file from
|
||||||
|
cache_subdir_name: Name of the cache subdirectory
|
||||||
|
version_or_hash: Version string or content hash for cache validation
|
||||||
|
force: If True, re-download even if already cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the cached directory
|
||||||
|
"""
|
||||||
|
cache_dir = await self.get_cache_subdir(cache_subdir_name)
|
||||||
|
|
||||||
|
# Check if already cached and valid
|
||||||
|
if not force and await self._is_cache_valid(cache_dir, version_or_hash):
|
||||||
|
# Also check if remote content has changed
|
||||||
|
is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir)
|
||||||
|
if is_fresh:
|
||||||
|
logger.debug(f"Content already cached and fresh for version {version_or_hash}")
|
||||||
|
return cache_dir
|
||||||
|
else:
|
||||||
|
logger.info("Cached content is stale, updating...")
|
||||||
|
|
||||||
|
# Clear old cache if it exists
|
||||||
|
await self._clear_cache(cache_dir)
|
||||||
|
|
||||||
|
logger.info(f"Downloading content from {url}...")
|
||||||
|
|
||||||
|
# Download the zip file
|
||||||
|
zip_content = BytesIO()
|
||||||
|
etag = ""
|
||||||
|
last_modified = ""
|
||||||
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Extract headers before consuming response
|
||||||
|
etag = response.headers.get("ETag", "").strip('"')
|
||||||
|
last_modified = response.headers.get("Last-Modified", "")
|
||||||
|
|
||||||
|
# Read the response content
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
zip_content.write(chunk)
|
||||||
|
zip_content.seek(0)
|
||||||
|
|
||||||
|
# Extract the archive
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_dir)
|
||||||
|
|
||||||
|
# Extract files and store them using StorageManager
|
||||||
|
with zipfile.ZipFile(zip_content, "r") as zip_file:
|
||||||
|
for file_info in zip_file.infolist():
|
||||||
|
if file_info.is_dir():
|
||||||
|
# Create directory
|
||||||
|
dir_path = f"{cache_dir}/{file_info.filename}"
|
||||||
|
await self.storage_manager.ensure_directory_exists(dir_path)
|
||||||
|
else:
|
||||||
|
# Extract and store file
|
||||||
|
file_data = zip_file.read(file_info.filename)
|
||||||
|
file_path = f"{cache_dir}/{file_info.filename}"
|
||||||
|
await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True)
|
||||||
|
|
||||||
|
# Write version info for future cache validation
|
||||||
|
version_file = f"{cache_dir}/version.txt"
|
||||||
|
await self.storage_manager.store(version_file, version_or_hash, overwrite=True)
|
||||||
|
|
||||||
|
# Store content identifier from response headers for freshness checking
|
||||||
|
content_identifier = etag if etag else last_modified
|
||||||
|
|
||||||
|
if content_identifier:
|
||||||
|
identifier_file = f"{cache_dir}/content_id.txt"
|
||||||
|
await self.storage_manager.store(identifier_file, content_identifier, overwrite=True)
|
||||||
|
logger.debug(f"Stored content identifier: {content_identifier[:20]}...")
|
||||||
|
|
||||||
|
logger.info("✓ Content downloaded and cached successfully!")
|
||||||
|
return cache_dir
|
||||||
|
|
||||||
|
async def file_exists(self, file_path: str) -> bool:
|
||||||
|
"""Check if a file exists in cache storage."""
|
||||||
|
return await self.storage_manager.file_exists(file_path)
|
||||||
|
|
||||||
|
async def read_file(self, file_path: str, encoding: str = "utf-8"):
|
||||||
|
"""Read a file from cache storage."""
|
||||||
|
return self.storage_manager.open(file_path, encoding=encoding)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str):
|
||||||
|
"""List files in a cache directory."""
|
||||||
|
try:
|
||||||
|
file_list = await self.storage_manager.list_files(directory_path)
|
||||||
|
|
||||||
|
# For S3 storage, convert relative paths to full S3 URLs
|
||||||
|
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||||
|
full_paths = []
|
||||||
|
for file_path in file_list:
|
||||||
|
full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}"
|
||||||
|
full_paths.append(full_s3_path)
|
||||||
|
return full_paths
|
||||||
|
else:
|
||||||
|
# For local storage, return absolute paths
|
||||||
|
storage_path = self.storage_manager.storage.storage_path
|
||||||
|
if not storage_path.startswith("/"):
|
||||||
|
import os
|
||||||
|
|
||||||
|
storage_path = os.path.abspath(storage_path)
|
||||||
|
|
||||||
|
full_paths = []
|
||||||
|
for file_path in file_list:
|
||||||
|
if file_path.startswith("/"):
|
||||||
|
full_paths.append(file_path) # Already absolute
|
||||||
|
else:
|
||||||
|
full_paths.append(f"{storage_path}/{file_path}")
|
||||||
|
return full_paths
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error listing files in {directory_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions that maintain API compatibility
|
||||||
|
_cache_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_manager() -> StorageAwareCache:
|
||||||
|
"""Get a singleton cache manager instance."""
|
||||||
|
global _cache_manager
|
||||||
|
if _cache_manager is None:
|
||||||
|
_cache_manager = StorageAwareCache()
|
||||||
|
return _cache_manager
|
||||||
|
|
||||||
|
|
||||||
|
def generate_content_hash(url: str, additional_data: str = "") -> str:
|
||||||
|
"""Generate a content hash from URL and optional additional data."""
|
||||||
|
content = f"{url}:{additional_data}"
|
||||||
|
return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability
|
||||||
|
|
||||||
|
|
||||||
|
# Async wrapper functions for backward compatibility
|
||||||
|
async def delete_cache():
|
||||||
|
"""Delete the Cognee cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
await cache_manager.delete_cache()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cognee_cache_dir() -> str:
|
||||||
|
"""Get the base Cognee cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.get_cache_dir()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cache_subdir(name: str) -> str:
|
||||||
|
"""Get a specific cache subdirectory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.get_cache_subdir(name)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_and_extract_zip(
|
||||||
|
url: str, cache_dir_name: str, version_or_hash: str, force: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Download a zip file and extract it to cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tutorial_data_dir() -> str:
|
||||||
|
"""Get the tutorial data cache directory."""
|
||||||
|
return await get_cache_subdir("tutorial_data")
|
||||||
|
|
||||||
|
|
||||||
|
# Cache file operations
|
||||||
|
async def cache_file_exists(file_path: str) -> bool:
|
||||||
|
"""Check if a file exists in cache storage."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.file_exists(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_cache_file(file_path: str, encoding: str = "utf-8"):
|
||||||
|
"""Read a file from cache storage."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.read_file(file_path, encoding)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_cache_files(directory_path: str):
|
||||||
|
"""List files in a cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.list_files(directory_path)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""This module contains utility functions for the cognee."""
|
"""This module contains utility functions for the cognee."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
@ -18,6 +19,17 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
proxy_url = "https://test.prometh.ai"
|
proxy_url = "https://test.prometh.ai"
|
||||||
|
|
||||||
|
|
||||||
|
def create_secure_ssl_context() -> ssl.SSLContext:
|
||||||
|
"""
|
||||||
|
Create a secure SSL context.
|
||||||
|
|
||||||
|
By default, use the system's certificate store.
|
||||||
|
If users report SSL issues, I'm keeping this open in case we need to switch to:
|
||||||
|
ssl.create_default_context(cafile=certifi.where())
|
||||||
|
"""
|
||||||
|
return ssl.create_default_context()
|
||||||
|
|
||||||
|
|
||||||
def get_entities(tagged_tokens):
|
def get_entities(tagged_tokens):
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,14 @@ from typing import Type, List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.tasks.storage.add_data_points import add_data_points
|
from cognee.tasks.storage.add_data_points import add_data_points
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
get_default_ontology_resolver,
|
||||||
|
get_ontology_resolver_from_env,
|
||||||
|
)
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
from cognee.modules.graph.utils import (
|
from cognee.modules.graph.utils import (
|
||||||
expand_with_nodes_and_edges,
|
expand_with_nodes_and_edges,
|
||||||
|
|
@ -24,9 +30,28 @@ async def integrate_chunk_graphs(
|
||||||
data_chunks: list[DocumentChunk],
|
data_chunks: list[DocumentChunk],
|
||||||
chunk_graphs: list,
|
chunk_graphs: list,
|
||||||
graph_model: Type[BaseModel],
|
graph_model: Type[BaseModel],
|
||||||
ontology_adapter: OntologyResolver,
|
ontology_resolver: BaseOntologyResolver,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""Updates DocumentChunk objects, integrates data points and edges into databases."""
|
"""Integrate chunk graphs with ontology validation and store in databases.
|
||||||
|
|
||||||
|
This function processes document chunks and their associated knowledge graphs,
|
||||||
|
validates entities against an ontology resolver, and stores the integrated
|
||||||
|
data points and edges in the configured databases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_chunks: List of document chunks containing source data
|
||||||
|
chunk_graphs: List of knowledge graphs corresponding to each chunk
|
||||||
|
graph_model: Pydantic model class for graph data validation
|
||||||
|
ontology_resolver: Resolver for validating entities against ontology
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of updated DocumentChunk objects with integrated data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidChunkGraphInputError: If input validation fails
|
||||||
|
InvalidGraphModelError: If graph model validation fails
|
||||||
|
InvalidOntologyAdapterError: If ontology resolver validation fails
|
||||||
|
"""
|
||||||
|
|
||||||
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
|
if not isinstance(data_chunks, list) or not isinstance(chunk_graphs, list):
|
||||||
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
|
raise InvalidChunkGraphInputError("data_chunks and chunk_graphs must be lists.")
|
||||||
|
|
@ -36,9 +61,9 @@ async def integrate_chunk_graphs(
|
||||||
)
|
)
|
||||||
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
if not isinstance(graph_model, type) or not issubclass(graph_model, BaseModel):
|
||||||
raise InvalidGraphModelError(graph_model)
|
raise InvalidGraphModelError(graph_model)
|
||||||
if ontology_adapter is None or not hasattr(ontology_adapter, "get_subgraph"):
|
if ontology_resolver is None or not hasattr(ontology_resolver, "get_subgraph"):
|
||||||
raise InvalidOntologyAdapterError(
|
raise InvalidOntologyAdapterError(
|
||||||
type(ontology_adapter).__name__ if ontology_adapter else "None"
|
type(ontology_resolver).__name__ if ontology_resolver else "None"
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
@ -55,7 +80,7 @@ async def integrate_chunk_graphs(
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_nodes, graph_edges = expand_with_nodes_and_edges(
|
graph_nodes, graph_edges = expand_with_nodes_and_edges(
|
||||||
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
|
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(graph_nodes) > 0:
|
if len(graph_nodes) > 0:
|
||||||
|
|
@ -70,7 +95,7 @@ async def integrate_chunk_graphs(
|
||||||
async def extract_graph_from_data(
|
async def extract_graph_from_data(
|
||||||
data_chunks: List[DocumentChunk],
|
data_chunks: List[DocumentChunk],
|
||||||
graph_model: Type[BaseModel],
|
graph_model: Type[BaseModel],
|
||||||
ontology_adapter: OntologyResolver = None,
|
config: Config = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -101,6 +126,24 @@ async def extract_graph_from_data(
|
||||||
if edge.source_node_id in valid_node_ids and edge.target_node_id in valid_node_ids
|
if edge.source_node_id in valid_node_ids and edge.target_node_id in valid_node_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
return await integrate_chunk_graphs(
|
# Extract resolver from config if provided, otherwise get default
|
||||||
data_chunks, chunk_graphs, graph_model, ontology_adapter or OntologyResolver()
|
if config is None:
|
||||||
)
|
ontology_config = get_ontology_env_config()
|
||||||
|
if (
|
||||||
|
ontology_config.ontology_file_path
|
||||||
|
and ontology_config.ontology_resolver
|
||||||
|
and ontology_config.matching_strategy
|
||||||
|
):
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||||
|
}
|
||||||
|
|
||||||
|
ontology_resolver = config["ontology_config"]["ontology_resolver"]
|
||||||
|
|
||||||
|
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import List
|
||||||
|
|
||||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
from cognee.tasks.graph.cascade_extract.utils.extract_nodes import extract_nodes
|
from cognee.tasks.graph.cascade_extract.utils.extract_nodes import extract_nodes
|
||||||
from cognee.tasks.graph.cascade_extract.utils.extract_content_nodes_and_relationship_names import (
|
from cognee.tasks.graph.cascade_extract.utils.extract_content_nodes_and_relationship_names import (
|
||||||
extract_content_nodes_and_relationship_names,
|
extract_content_nodes_and_relationship_names,
|
||||||
|
|
@ -17,9 +17,21 @@ from cognee.tasks.graph.extract_graph_from_data import integrate_chunk_graphs
|
||||||
async def extract_graph_from_data(
|
async def extract_graph_from_data(
|
||||||
data_chunks: List[DocumentChunk],
|
data_chunks: List[DocumentChunk],
|
||||||
n_rounds: int = 2,
|
n_rounds: int = 2,
|
||||||
ontology_adapter: OntologyResolver = None,
|
ontology_adapter: BaseOntologyResolver = None,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""Extract and update graph data from document chunks in multiple steps."""
|
"""Extract and update graph data from document chunks using cascade extraction.
|
||||||
|
|
||||||
|
This function performs multi-step graph extraction from document chunks,
|
||||||
|
using cascade extraction techniques to build comprehensive knowledge graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_chunks: List of document chunks to process
|
||||||
|
n_rounds: Number of extraction rounds to perform (default: 2)
|
||||||
|
ontology_adapter: Resolver for validating entities against ontology
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of updated DocumentChunk objects with extracted graph data
|
||||||
|
"""
|
||||||
chunk_nodes = await asyncio.gather(
|
chunk_nodes = await asyncio.gather(
|
||||||
*[extract_nodes(chunk.text, n_rounds) for chunk in data_chunks]
|
*[extract_nodes(chunk.text, n_rounds) for chunk in data_chunks]
|
||||||
)
|
)
|
||||||
|
|
@ -44,5 +56,5 @@ async def extract_graph_from_data(
|
||||||
data_chunks=data_chunks,
|
data_chunks=data_chunks,
|
||||||
chunk_graphs=chunk_graphs,
|
chunk_graphs=chunk_graphs,
|
||||||
graph_model=KnowledgeGraph,
|
graph_model=KnowledgeGraph,
|
||||||
ontology_adapter=ontology_adapter or OntologyResolver(),
|
ontology_adapter=ontology_adapter,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
||||||
abs_path.is_file()
|
abs_path.is_file()
|
||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
# In case file path is too long it's most likely not a relative path
|
# In case file path is too long it's most likely not a relative path
|
||||||
|
abs_path = data_item
|
||||||
logger.debug(f"Data item was too long to be a possible file path: {abs_path}")
|
logger.debug(f"Data item was too long to be a possible file path: {abs_path}")
|
||||||
abs_path = Path("")
|
abs_path = Path("")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
"""Test that the server is running and can accept connections."""
|
"""Test that the server is running and can accept connections."""
|
||||||
# Test health endpoint
|
# Test health endpoint
|
||||||
health_response = requests.get("http://localhost:8000/health", timeout=15)
|
health_response = requests.get("http://localhost:8000/health", timeout=15)
|
||||||
self.assertIn(health_response.status_code, [200, 503])
|
self.assertIn(health_response.status_code, [200])
|
||||||
|
|
||||||
# Test root endpoint
|
# Test root endpoint
|
||||||
root_response = requests.get("http://localhost:8000/", timeout=15)
|
root_response = requests.get("http://localhost:8000/", timeout=15)
|
||||||
|
|
@ -88,7 +88,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
payload = {"datasets": [dataset_name]}
|
payload = {"datasets": [dataset_name]}
|
||||||
|
|
||||||
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
|
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
|
||||||
if add_response.status_code not in [200, 201, 409]:
|
if add_response.status_code not in [200, 201]:
|
||||||
add_response.raise_for_status()
|
add_response.raise_for_status()
|
||||||
|
|
||||||
# Cognify request
|
# Cognify request
|
||||||
|
|
@ -99,7 +99,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
cognify_response = requests.post(url, headers=headers, json=payload, timeout=150)
|
cognify_response = requests.post(url, headers=headers, json=payload, timeout=150)
|
||||||
if cognify_response.status_code not in [200, 201, 409]:
|
if cognify_response.status_code not in [200, 201]:
|
||||||
cognify_response.raise_for_status()
|
cognify_response.raise_for_status()
|
||||||
|
|
||||||
# TODO: Add test to verify cognify pipeline is complete before testing search
|
# TODO: Add test to verify cognify pipeline is complete before testing search
|
||||||
|
|
@ -115,7 +115,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
payload = {"searchType": "GRAPH_COMPLETION", "query": "What's in the document?"}
|
payload = {"searchType": "GRAPH_COMPLETION", "query": "What's in the document?"}
|
||||||
|
|
||||||
search_response = requests.post(url, headers=headers, json=payload, timeout=50)
|
search_response = requests.post(url, headers=headers, json=payload, timeout=50)
|
||||||
if search_response.status_code not in [200, 201, 409]:
|
if search_response.status_code not in [200, 201]:
|
||||||
search_response.raise_for_status()
|
search_response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
import pytest
|
import pytest
|
||||||
from rdflib import Graph, Namespace, RDF, OWL, RDFS
|
from rdflib import Graph, Namespace, RDF, OWL, RDFS
|
||||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver, AttachedOntologyNode
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.models import AttachedOntologyNode
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_default_ontology_resolver
|
||||||
|
|
||||||
|
|
||||||
def test_ontology_adapter_initialization_success():
|
def test_ontology_adapter_initialization_success():
|
||||||
"""Test successful initialization of OntologyAdapter."""
|
"""Test successful initialization of RDFLibOntologyResolver from get_default_ontology_resolver."""
|
||||||
|
|
||||||
adapter = OntologyResolver()
|
adapter = get_default_ontology_resolver()
|
||||||
adapter.build_lookup()
|
adapter.build_lookup()
|
||||||
|
|
||||||
assert isinstance(adapter.lookup, dict)
|
assert isinstance(adapter.lookup, dict)
|
||||||
|
|
@ -14,7 +16,7 @@ def test_ontology_adapter_initialization_success():
|
||||||
|
|
||||||
def test_ontology_adapter_initialization_file_not_found():
|
def test_ontology_adapter_initialization_file_not_found():
|
||||||
"""Test OntologyAdapter initialization with nonexistent file."""
|
"""Test OntologyAdapter initialization with nonexistent file."""
|
||||||
adapter = OntologyResolver(ontology_file="nonexistent.owl")
|
adapter = RDFLibOntologyResolver(ontology_file="nonexistent.owl")
|
||||||
assert adapter.graph is None
|
assert adapter.graph is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -27,7 +29,7 @@ def test_build_lookup():
|
||||||
|
|
||||||
g.add((ns.Audi, RDF.type, ns.Car))
|
g.add((ns.Audi, RDF.type, ns.Car))
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = RDFLibOntologyResolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -50,7 +52,7 @@ def test_find_closest_match_exact():
|
||||||
g.add((ns.Car, RDF.type, OWL.Class))
|
g.add((ns.Car, RDF.type, OWL.Class))
|
||||||
g.add((ns.Audi, RDF.type, ns.Car))
|
g.add((ns.Audi, RDF.type, ns.Car))
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = RDFLibOntologyResolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -71,7 +73,7 @@ def test_find_closest_match_fuzzy():
|
||||||
g.add((ns.Audi, RDF.type, ns.Car))
|
g.add((ns.Audi, RDF.type, ns.Car))
|
||||||
g.add((ns.BMW, RDF.type, ns.Car))
|
g.add((ns.BMW, RDF.type, ns.Car))
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = RDFLibOntologyResolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -92,7 +94,7 @@ def test_find_closest_match_no_match():
|
||||||
g.add((ns.Audi, RDF.type, ns.Car))
|
g.add((ns.Audi, RDF.type, ns.Car))
|
||||||
g.add((ns.BMW, RDF.type, ns.Car))
|
g.add((ns.BMW, RDF.type, ns.Car))
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = RDFLibOntologyResolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -102,10 +104,10 @@ def test_find_closest_match_no_match():
|
||||||
|
|
||||||
|
|
||||||
def test_get_subgraph_no_match_rdflib():
|
def test_get_subgraph_no_match_rdflib():
|
||||||
"""Test get_subgraph returns empty results for a non-existent node."""
|
"""Test get_subgraph returns empty results for a non-existent node using RDFLibOntologyResolver."""
|
||||||
g = Graph()
|
g = Graph()
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = get_default_ontology_resolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -138,7 +140,7 @@ def test_get_subgraph_success_rdflib():
|
||||||
g.add((ns.VW, owns, ns.Audi))
|
g.add((ns.VW, owns, ns.Audi))
|
||||||
g.add((ns.VW, owns, ns.Porsche))
|
g.add((ns.VW, owns, ns.Porsche))
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = RDFLibOntologyResolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -160,10 +162,10 @@ def test_get_subgraph_success_rdflib():
|
||||||
|
|
||||||
|
|
||||||
def test_refresh_lookup_rdflib():
|
def test_refresh_lookup_rdflib():
|
||||||
"""Test that refresh_lookup rebuilds the lookup dict into a new object."""
|
"""Test that refresh_lookup rebuilds the lookup dict into a new object using RDFLibOntologyResolver."""
|
||||||
g = Graph()
|
g = Graph()
|
||||||
|
|
||||||
resolver = OntologyResolver()
|
resolver = get_default_ontology_resolver()
|
||||||
resolver.graph = g
|
resolver.graph = g
|
||||||
resolver.build_lookup()
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
|
@ -172,3 +174,318 @@ def test_refresh_lookup_rdflib():
|
||||||
resolver.refresh_lookup()
|
resolver.refresh_lookup()
|
||||||
|
|
||||||
assert resolver.lookup is not original_lookup
|
assert resolver.lookup is not original_lookup
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_matching_strategy_exact_match():
|
||||||
|
"""Test FuzzyMatchingStrategy finds exact matches."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
strategy = FuzzyMatchingStrategy()
|
||||||
|
candidates = ["audi", "bmw", "mercedes"]
|
||||||
|
|
||||||
|
result = strategy.find_match("audi", candidates)
|
||||||
|
assert result == "audi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_matching_strategy_fuzzy_match():
|
||||||
|
"""Test FuzzyMatchingStrategy finds fuzzy matches."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
strategy = FuzzyMatchingStrategy(cutoff=0.6)
|
||||||
|
candidates = ["audi", "bmw", "mercedes"]
|
||||||
|
|
||||||
|
result = strategy.find_match("audii", candidates)
|
||||||
|
assert result == "audi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_matching_strategy_no_match():
|
||||||
|
"""Test FuzzyMatchingStrategy returns None when no match meets cutoff."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
strategy = FuzzyMatchingStrategy(cutoff=0.9)
|
||||||
|
candidates = ["audi", "bmw", "mercedes"]
|
||||||
|
|
||||||
|
result = strategy.find_match("completely_different", candidates)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_matching_strategy_empty_candidates():
|
||||||
|
"""Test FuzzyMatchingStrategy handles empty candidates list."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
strategy = FuzzyMatchingStrategy()
|
||||||
|
|
||||||
|
result = strategy.find_match("audi", [])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_ontology_resolver_initialization():
|
||||||
|
"""Test BaseOntologyResolver initialization with default matching strategy."""
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
class TestOntologyResolver(BaseOntologyResolver):
|
||||||
|
def build_lookup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def refresh_lookup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def find_closest_match(self, name, category):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_subgraph(self, node_name, node_type="individuals", directed=True):
|
||||||
|
return [], [], None
|
||||||
|
|
||||||
|
resolver = TestOntologyResolver()
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_ontology_resolver_custom_matching_strategy():
|
||||||
|
"""Test BaseOntologyResolver initialization with custom matching strategy."""
|
||||||
|
from cognee.modules.ontology.base_ontology_resolver import BaseOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import MatchingStrategy
|
||||||
|
|
||||||
|
class CustomMatchingStrategy(MatchingStrategy):
|
||||||
|
def find_match(self, name, candidates):
|
||||||
|
return "custom_match"
|
||||||
|
|
||||||
|
class TestOntologyResolver(BaseOntologyResolver):
|
||||||
|
def build_lookup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def refresh_lookup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def find_closest_match(self, name, category):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_subgraph(self, node_name, node_type="individuals", directed=True):
|
||||||
|
return [], [], None
|
||||||
|
|
||||||
|
custom_strategy = CustomMatchingStrategy()
|
||||||
|
resolver = TestOntologyResolver(matching_strategy=custom_strategy)
|
||||||
|
assert resolver.matching_strategy == custom_strategy
|
||||||
|
|
||||||
|
|
||||||
|
def test_ontology_config_structure():
|
||||||
|
"""Test TypedDict structure for ontology configuration."""
|
||||||
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
matching_strategy = FuzzyMatchingStrategy()
|
||||||
|
resolver = RDFLibOntologyResolver(matching_strategy=matching_strategy)
|
||||||
|
|
||||||
|
config: Config = {"ontology_config": {"ontology_resolver": resolver}}
|
||||||
|
|
||||||
|
assert config["ontology_config"]["ontology_resolver"] == resolver
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_default():
|
||||||
|
"""Test get_default_ontology_resolver returns a properly configured RDFLibOntologyResolver with FuzzyMatchingStrategy."""
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
resolver = get_default_ontology_resolver()
|
||||||
|
|
||||||
|
assert isinstance(resolver, RDFLibOntologyResolver)
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_ontology_resolver():
|
||||||
|
"""Test get_default_ontology_resolver returns a properly configured RDFLibOntologyResolver with FuzzyMatchingStrategy."""
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
resolver = get_default_ontology_resolver()
|
||||||
|
|
||||||
|
assert isinstance(resolver, RDFLibOntologyResolver)
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rdflib_ontology_resolver_uses_matching_strategy():
|
||||||
|
"""Test that RDFLibOntologyResolver uses the provided matching strategy."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import MatchingStrategy
|
||||||
|
|
||||||
|
class TestMatchingStrategy(MatchingStrategy):
|
||||||
|
def find_match(self, name, candidates):
|
||||||
|
return "test_match" if candidates else None
|
||||||
|
|
||||||
|
ns = Namespace("http://example.org/test#")
|
||||||
|
g = Graph()
|
||||||
|
g.add((ns.Car, RDF.type, OWL.Class))
|
||||||
|
g.add((ns.Audi, RDF.type, ns.Car))
|
||||||
|
|
||||||
|
resolver = RDFLibOntologyResolver(matching_strategy=TestMatchingStrategy())
|
||||||
|
resolver.graph = g
|
||||||
|
resolver.build_lookup()
|
||||||
|
|
||||||
|
result = resolver.find_closest_match("Audi", "individuals")
|
||||||
|
assert result == "test_match"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rdflib_ontology_resolver_default_matching_strategy():
|
||||||
|
"""Test that RDFLibOntologyResolver uses FuzzyMatchingStrategy by default."""
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
resolver = RDFLibOntologyResolver()
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_success():
|
||||||
|
"""Test get_ontology_resolver_from_env returns correct resolver with valid parameters."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
resolver = get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib", matching_strategy="fuzzy", ontology_file_path="/test/path.owl"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(resolver, RDFLibOntologyResolver)
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
assert resolver.ontology_file == "/test/path.owl"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_unsupported_resolver():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for unsupported resolver."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="unsupported",
|
||||||
|
matching_strategy="fuzzy",
|
||||||
|
ontology_file_path="/test/path.owl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver: unsupported" in str(exc_info.value)
|
||||||
|
assert "Supported resolvers are: RdfLib with FuzzyMatchingStrategy" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_unsupported_strategy():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for unsupported strategy."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib",
|
||||||
|
matching_strategy="unsupported",
|
||||||
|
ontology_file_path="/test/path.owl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver: rdflib" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_empty_file_path():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for empty file path."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib", matching_strategy="fuzzy", ontology_file_path=""
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver: rdflib" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_none_file_path():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for None file path."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib", matching_strategy="fuzzy", ontology_file_path=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver: rdflib" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_empty_resolver():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for empty resolver."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="", matching_strategy="fuzzy", ontology_file_path="/test/path.owl"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver:" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_empty_strategy():
|
||||||
|
"""Test get_ontology_resolver_from_env raises EnvironmentError for empty strategy."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib", matching_strategy="", ontology_file_path="/test/path.owl"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver: rdflib" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_default_parameters():
|
||||||
|
"""Test get_ontology_resolver_from_env with default empty parameters raises EnvironmentError."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError) as exc_info:
|
||||||
|
get_ontology_resolver_from_env()
|
||||||
|
|
||||||
|
assert "Unsupported ontology resolver:" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_case_sensitivity():
|
||||||
|
"""Test get_ontology_resolver_from_env is case sensitive."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError):
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="RDFLIB",
|
||||||
|
matching_strategy="fuzzy",
|
||||||
|
ontology_file_path="/test/path.owl",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(EnvironmentError):
|
||||||
|
get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="RdfLib",
|
||||||
|
matching_strategy="fuzzy",
|
||||||
|
ontology_file_path="/test/path.owl",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_with_actual_file():
|
||||||
|
"""Test get_ontology_resolver_from_env works with actual file path."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.matching_strategies import FuzzyMatchingStrategy
|
||||||
|
|
||||||
|
resolver = get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib",
|
||||||
|
matching_strategy="fuzzy",
|
||||||
|
ontology_file_path="/path/to/ontology.owl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(resolver, RDFLibOntologyResolver)
|
||||||
|
assert isinstance(resolver.matching_strategy, FuzzyMatchingStrategy)
|
||||||
|
assert resolver.ontology_file == "/path/to/ontology.owl"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ontology_resolver_from_env_resolver_functionality():
|
||||||
|
"""Test that resolver created from env function works correctly."""
|
||||||
|
from cognee.modules.ontology.get_default_ontology_resolver import get_ontology_resolver_from_env
|
||||||
|
|
||||||
|
resolver = get_ontology_resolver_from_env(
|
||||||
|
ontology_resolver="rdflib", matching_strategy="fuzzy", ontology_file_path="/test/path.owl"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolver.build_lookup()
|
||||||
|
assert isinstance(resolver.lookup, dict)
|
||||||
|
|
||||||
|
result = resolver.find_closest_match("test", "individuals")
|
||||||
|
assert result is None # Should return None for non-existent entity
|
||||||
|
|
||||||
|
nodes, relationships, start_node = resolver.get_subgraph("test", "individuals")
|
||||||
|
assert nodes == []
|
||||||
|
assert relationships == []
|
||||||
|
assert start_node is None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,399 @@
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from uuid import uuid4
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from pathlib import Path
|
||||||
|
import zipfile
|
||||||
|
from cognee.shared.cache import get_tutorial_data_dir
|
||||||
|
|
||||||
|
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||||
|
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||||
|
import cognee
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level fixtures available to all test classes
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
"""Mock database session."""
|
||||||
|
session = AsyncMock(spec=AsyncSession)
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_jupyter_notebook():
|
||||||
|
"""Sample Jupyter notebook content for testing."""
|
||||||
|
return {
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["# Tutorial Introduction\n", "\n", "This is a tutorial notebook."],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": None,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": ["import cognee\n", "print('Hello, Cognee!')"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["## Step 1: Data Ingestion\n", "\n", "Let's add some data."],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": None,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": ["# Add your data here\n", "# await cognee.add('data.txt')"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["This is a raw cell that should be skipped"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTutorialNotebookCreation:
|
||||||
|
"""Test cases for tutorial notebook creation functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_success(self, sample_jupyter_notebook):
|
||||||
|
"""Test successful creation of notebook from JSON string."""
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(
|
||||||
|
notebook_content=notebook_json, owner_id=user_id, name="String Test Notebook"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert notebook.owner_id == user_id
|
||||||
|
assert notebook.name == "String Test Notebook"
|
||||||
|
assert len(notebook.cells) == 4 # Should skip the raw cell
|
||||||
|
assert notebook.cells[0].type == "markdown"
|
||||||
|
assert notebook.cells[1].type == "code"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_cell_name_generation(self, sample_jupyter_notebook):
|
||||||
|
"""Test that cell names are generated correctly from markdown headers."""
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
# Check markdown header extraction
|
||||||
|
assert notebook.cells[0].name == "Tutorial Introduction"
|
||||||
|
assert notebook.cells[2].name == "Step 1: Data Ingestion"
|
||||||
|
|
||||||
|
# Check code cell naming
|
||||||
|
assert notebook.cells[1].name == "Code Cell"
|
||||||
|
assert notebook.cells[3].name == "Code Cell"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_with_default_name(self, sample_jupyter_notebook):
|
||||||
|
"""Test notebook creation uses kernelspec display_name when no name provided."""
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
assert notebook.name == "Python 3" # From kernelspec.display_name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_fallback_name(self):
|
||||||
|
"""Test fallback naming when kernelspec is missing."""
|
||||||
|
minimal_notebook = {
|
||||||
|
"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Test"]}],
|
||||||
|
"metadata": {}, # No kernelspec
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(minimal_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
assert notebook.name == "Imported Notebook" # Fallback name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_invalid_json(self):
|
||||||
|
"""Test error handling for invalid JSON."""
|
||||||
|
user_id = uuid4()
|
||||||
|
invalid_json = "{ invalid json content"
|
||||||
|
|
||||||
|
from nbformat.reader import NotJSONError
|
||||||
|
|
||||||
|
with pytest.raises(NotJSONError):
|
||||||
|
Notebook.from_ipynb_string(notebook_content=invalid_json, owner_id=user_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(Notebook, "from_ipynb_zip_url")
|
||||||
|
async def test_create_tutorial_notebook_error_propagated(self, mock_from_zip_url, mock_session):
|
||||||
|
"""Test that errors are propagated when zip fetch fails."""
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_from_zip_url.side_effect = Exception("Network error")
|
||||||
|
|
||||||
|
# Should raise the exception (not catch it)
|
||||||
|
with pytest.raises(Exception, match="Network error"):
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session)
|
||||||
|
|
||||||
|
# Verify error handling path was taken
|
||||||
|
mock_from_zip_url.assert_called_once()
|
||||||
|
mock_session.add.assert_not_called()
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
|
|
||||||
|
def test_generate_cell_name_code_cell(self):
|
||||||
|
"""Test cell name generation for code cells."""
|
||||||
|
from nbformat.notebooknode import NotebookNode
|
||||||
|
|
||||||
|
mock_cell = NotebookNode(
|
||||||
|
{"cell_type": "code", "source": 'import pandas as pd\nprint("Hello world")'}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = Notebook._generate_cell_name(mock_cell)
|
||||||
|
assert result == "Code Cell"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTutorialNotebookZipFunctionality:
|
||||||
|
"""Test cases for zip-based tutorial functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_zip_url_missing_notebook(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""Test error handling when notebook file is missing from zip."""
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
FileNotFoundError,
|
||||||
|
match="Notebook file 'super_random_tutorial_name.ipynb' not found in zip",
|
||||||
|
):
|
||||||
|
await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||||
|
owner_id=user_id,
|
||||||
|
notebook_filename="super_random_tutorial_name.ipynb",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_zip_url_download_failure(self):
|
||||||
|
"""Test error handling when zip download fails."""
|
||||||
|
user_id = uuid4()
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to download tutorial zip"):
|
||||||
|
await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/nonexistent_tutorial_name.zip",
|
||||||
|
owner_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tutorial_notebook_zip_success(self, mock_session):
|
||||||
|
"""Test successful tutorial notebook creation with zip."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
# Check that tutorial data directory is empty using storage-aware method
|
||||||
|
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||||
|
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||||
|
if tutorial_data_dir.exists():
|
||||||
|
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||||
|
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session)
|
||||||
|
|
||||||
|
items = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items) == 1, "Tutorial data directory should contain exactly one item"
|
||||||
|
assert items[0].is_dir(), "Tutorial data directory item should be a directory"
|
||||||
|
|
||||||
|
# Verify the structure inside the tutorial directory
|
||||||
|
tutorial_dir = items[0]
|
||||||
|
|
||||||
|
# Check for tutorial.ipynb file
|
||||||
|
notebook_file = tutorial_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file.exists(), f"tutorial.ipynb should exist in {tutorial_dir}"
|
||||||
|
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||||
|
|
||||||
|
# Check for data subfolder with contents
|
||||||
|
data_folder = tutorial_dir / "data"
|
||||||
|
assert data_folder.exists(), f"data subfolder should exist in {tutorial_dir}"
|
||||||
|
assert data_folder.is_dir(), "data should be a directory"
|
||||||
|
|
||||||
|
data_items = list(data_folder.iterdir())
|
||||||
|
assert len(data_items) > 0, (
|
||||||
|
f"data folder should contain files, but found {len(data_items)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tutorial_notebook_with_force_refresh(self, mock_session):
|
||||||
|
"""Test tutorial notebook creation with force refresh."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
# Check that tutorial data directory is empty using storage-aware method
|
||||||
|
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||||
|
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||||
|
if tutorial_data_dir.exists():
|
||||||
|
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||||
|
|
||||||
|
# First creation (without force refresh)
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session, force_refresh=False)
|
||||||
|
|
||||||
|
items_first = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items_first) == 1, (
|
||||||
|
"Tutorial data directory should contain exactly one item after first creation"
|
||||||
|
)
|
||||||
|
first_dir = items_first[0]
|
||||||
|
assert first_dir.is_dir(), "Tutorial data directory item should be a directory"
|
||||||
|
|
||||||
|
# Verify the structure inside the tutorial directory (first creation)
|
||||||
|
notebook_file = first_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file.exists(), f"tutorial.ipynb should exist in {first_dir}"
|
||||||
|
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||||
|
|
||||||
|
data_folder = first_dir / "data"
|
||||||
|
assert data_folder.exists(), f"data subfolder should exist in {first_dir}"
|
||||||
|
assert data_folder.is_dir(), "data should be a directory"
|
||||||
|
|
||||||
|
data_items = list(data_folder.iterdir())
|
||||||
|
assert len(data_items) > 0, (
|
||||||
|
f"data folder should contain files, but found {len(data_items)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture metadata from first creation
|
||||||
|
|
||||||
|
first_creation_metadata = {}
|
||||||
|
|
||||||
|
for file_path in first_dir.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(first_dir)
|
||||||
|
stat = file_path.stat()
|
||||||
|
|
||||||
|
# Store multiple metadata points
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
first_creation_metadata[str(relative_path)] = {
|
||||||
|
"mtime": stat.st_mtime,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"hash": hashlib.md5(content).hexdigest(),
|
||||||
|
"first_bytes": content[:100]
|
||||||
|
if content
|
||||||
|
else b"", # First 100 bytes as fingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
# Wait a moment to ensure different timestamps
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Force refresh - should create new files with different metadata
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session, force_refresh=True)
|
||||||
|
|
||||||
|
items_second = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items_second) == 1, (
|
||||||
|
"Tutorial data directory should contain exactly one item after force refresh"
|
||||||
|
)
|
||||||
|
second_dir = items_second[0]
|
||||||
|
|
||||||
|
# Verify the structure is maintained after force refresh
|
||||||
|
notebook_file_second = second_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file_second.exists(), (
|
||||||
|
f"tutorial.ipynb should exist in {second_dir} after force refresh"
|
||||||
|
)
|
||||||
|
assert notebook_file_second.is_file(), "tutorial.ipynb should be a file after force refresh"
|
||||||
|
|
||||||
|
data_folder_second = second_dir / "data"
|
||||||
|
assert data_folder_second.exists(), (
|
||||||
|
f"data subfolder should exist in {second_dir} after force refresh"
|
||||||
|
)
|
||||||
|
assert data_folder_second.is_dir(), "data should be a directory after force refresh"
|
||||||
|
|
||||||
|
data_items_second = list(data_folder_second.iterdir())
|
||||||
|
assert len(data_items_second) > 0, (
|
||||||
|
f"data folder should still contain files after force refresh, but found {len(data_items_second)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare metadata to ensure files are actually different
|
||||||
|
files_with_changed_metadata = 0
|
||||||
|
|
||||||
|
for file_path in second_dir.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(second_dir)
|
||||||
|
relative_path_str = str(relative_path)
|
||||||
|
|
||||||
|
# File should exist from first creation
|
||||||
|
assert relative_path_str in first_creation_metadata, (
|
||||||
|
f"File {relative_path_str} missing from first creation"
|
||||||
|
)
|
||||||
|
|
||||||
|
old_metadata = first_creation_metadata[relative_path_str]
|
||||||
|
|
||||||
|
# Get new metadata
|
||||||
|
stat = file_path.stat()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
new_content = f.read()
|
||||||
|
|
||||||
|
new_metadata = {
|
||||||
|
"mtime": stat.st_mtime,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"hash": hashlib.md5(new_content).hexdigest(),
|
||||||
|
"first_bytes": new_content[:100] if new_content else b"",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if any metadata changed (indicating file was refreshed)
|
||||||
|
metadata_changed = (
|
||||||
|
new_metadata["mtime"] > old_metadata["mtime"] # Newer modification time
|
||||||
|
or new_metadata["hash"] != old_metadata["hash"] # Different content hash
|
||||||
|
or new_metadata["size"] != old_metadata["size"] # Different file size
|
||||||
|
or new_metadata["first_bytes"]
|
||||||
|
!= old_metadata["first_bytes"] # Different content
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata_changed:
|
||||||
|
files_with_changed_metadata += 1
|
||||||
|
|
||||||
|
# Assert that force refresh actually updated files
|
||||||
|
assert files_with_changed_metadata > 0, (
|
||||||
|
f"Force refresh should have updated at least some files, but all {len(first_creation_metadata)} "
|
||||||
|
f"files appear to have identical metadata. This suggests force refresh didn't work."
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.commit.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tutorial_zip_url_accessibility(self):
|
||||||
|
"""Test that the actual tutorial zip URL is accessible (integration test)."""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify it's a valid zip file by checking headers
|
||||||
|
assert response.headers.get("content-type") in [
|
||||||
|
"application/zip",
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/x-zip-compressed",
|
||||||
|
] or response.content.startswith(b"PK") # Zip file signature
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pytest.skip("Network request failed or zip not available - skipping integration test")
|
||||||
|
|
@ -5,6 +5,8 @@ import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
from cognee.shared.logging_utils import setup_logging
|
from cognee.shared.logging_utils import setup_logging
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
|
||||||
text_1 = """
|
text_1 = """
|
||||||
1. Audi
|
1. Audi
|
||||||
|
|
@ -60,7 +62,14 @@ async def main():
|
||||||
os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl"
|
os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl"
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.cognify(ontology_file_path=ontology_path)
|
# Create full config structure manually
|
||||||
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await cognee.cognify(config=config)
|
||||||
print("Knowledge with ontology created.")
|
print("Knowledge with ontology created.")
|
||||||
|
|
||||||
# Step 4: Query insights
|
# Step 4: Query insights
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import os
|
||||||
import textwrap
|
import textwrap
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
|
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||||
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
|
|
||||||
|
|
||||||
async def run_pipeline(ontology_path=None):
|
async def run_pipeline(ontology_path=None):
|
||||||
|
|
@ -17,7 +19,13 @@ async def run_pipeline(ontology_path=None):
|
||||||
|
|
||||||
await cognee.add(scientific_papers_dir)
|
await cognee.add(scientific_papers_dir)
|
||||||
|
|
||||||
pipeline_run = await cognee.cognify(ontology_file_path=ontology_path)
|
config: Config = {
|
||||||
|
"ontology_config": {
|
||||||
|
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline_run = await cognee.cognify(config=config)
|
||||||
|
|
||||||
return pipeline_run
|
return pipeline_run
|
||||||
|
|
||||||
|
|
|
||||||
986
notebooks/ontology_demo.ipynb
vendored
986
notebooks/ontology_demo.ipynb
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
|
|
||||||
version = "0.3.4.dev0"
|
version = "0.3.4"
|
||||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Vasilije Markovic" },
|
{ name = "Vasilije Markovic" },
|
||||||
|
|
@ -46,6 +46,7 @@ dependencies = [
|
||||||
"matplotlib>=3.8.3,<4",
|
"matplotlib>=3.8.3,<4",
|
||||||
"networkx>=3.4.2,<4",
|
"networkx>=3.4.2,<4",
|
||||||
"lancedb>=0.24.0,<1.0.0",
|
"lancedb>=0.24.0,<1.0.0",
|
||||||
|
"nbformat>=5.7.0,<6.0.0",
|
||||||
"alembic>=1.13.3,<2",
|
"alembic>=1.13.3,<2",
|
||||||
"pre-commit>=4.0.1,<5",
|
"pre-commit>=4.0.1,<5",
|
||||||
"scikit-learn>=1.6.1,<2",
|
"scikit-learn>=1.6.1,<2",
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -811,7 +811,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
version = "0.3.4.dev0"
|
version = "0.3.4"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|
@ -832,6 +832,7 @@ dependencies = [
|
||||||
{ name = "limits" },
|
{ name = "limits" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
|
{ name = "nbformat" },
|
||||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "nltk" },
|
{ name = "nltk" },
|
||||||
|
|
@ -1010,6 +1011,7 @@ requires-dist = [
|
||||||
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
|
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
|
||||||
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
|
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
|
||||||
|
{ name = "nbformat", specifier = ">=5.7.0,<6.0.0" },
|
||||||
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
|
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
|
||||||
{ name = "networkx", specifier = ">=3.4.2,<4" },
|
{ name = "networkx", specifier = ">=3.4.2,<4" },
|
||||||
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue