feat: implement full delete feature
This commit is contained in:
parent
a8dab3019e
commit
caf4801a6b
43 changed files with 783 additions and 277 deletions
|
|
@ -108,6 +108,13 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
setDatasets(datasets);
|
setDatasets(datasets);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const [searchValue, setSearchValue] = useState<string>("");
|
||||||
|
|
||||||
|
const handleSearchDatasetInputChange = useCallback((event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const newSearchValue = event.currentTarget.value;
|
||||||
|
setSearchValue(newSearchValue);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const isCloudEnv = isCloudEnvironment();
|
const isCloudEnv = isCloudEnvironment();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -129,7 +136,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl h-[calc(100%-2.75rem)]">
|
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl h-[calc(100%-2.75rem)]">
|
||||||
<div className="relative mb-2">
|
<div className="relative mb-2">
|
||||||
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
|
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
|
||||||
<input id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
<input onChange={handleSearchDatasetInputChange} id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<AddDataToCognee
|
<AddDataToCognee
|
||||||
|
|
@ -148,6 +155,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
||||||
<div className="mt-7 mb-14">
|
<div className="mt-7 mb-14">
|
||||||
<CogneeInstancesAccordion>
|
<CogneeInstancesAccordion>
|
||||||
<InstanceDatasetsAccordion
|
<InstanceDatasetsAccordion
|
||||||
|
searchValue={searchValue}
|
||||||
onDatasetsChange={handleDatasetsChange}
|
onDatasetsChange={handleDatasetsChange}
|
||||||
/>
|
/>
|
||||||
</CogneeInstancesAccordion>
|
</CogneeInstancesAccordion>
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,13 @@ interface DatasetsChangePayload {
|
||||||
export interface DatasetsAccordionProps extends Omit<AccordionProps, "isOpen" | "openAccordion" | "closeAccordion" | "children"> {
|
export interface DatasetsAccordionProps extends Omit<AccordionProps, "isOpen" | "openAccordion" | "closeAccordion" | "children"> {
|
||||||
onDatasetsChange?: (payload: DatasetsChangePayload) => void;
|
onDatasetsChange?: (payload: DatasetsChangePayload) => void;
|
||||||
useCloud?: boolean;
|
useCloud?: boolean;
|
||||||
|
searchValue: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function DatasetsAccordion({
|
export default function DatasetsAccordion({
|
||||||
title,
|
title,
|
||||||
tools,
|
tools,
|
||||||
|
searchValue,
|
||||||
switchCaretPosition = false,
|
switchCaretPosition = false,
|
||||||
className,
|
className,
|
||||||
contentClassName,
|
contentClassName,
|
||||||
|
|
@ -43,13 +45,7 @@ export default function DatasetsAccordion({
|
||||||
removeDataset,
|
removeDataset,
|
||||||
getDatasetData,
|
getDatasetData,
|
||||||
removeDatasetData,
|
removeDatasetData,
|
||||||
} = useDatasets(useCloud);
|
} = useDatasets(useCloud, searchValue);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (datasets.length === 0) {
|
|
||||||
refreshDatasets();
|
|
||||||
}
|
|
||||||
}, [datasets.length, refreshDatasets]);
|
|
||||||
|
|
||||||
const [openDatasets, openDataset] = useState<Set<string>>(new Set());
|
const [openDatasets, openDataset] = useState<Set<string>>(new Set());
|
||||||
|
|
||||||
|
|
@ -237,11 +233,16 @@ export default function DatasetsAccordion({
|
||||||
contentClassName={contentClassName}
|
contentClassName={contentClassName}
|
||||||
>
|
>
|
||||||
<div className="flex flex-col">
|
<div className="flex flex-col">
|
||||||
{datasets.length === 0 && (
|
{datasets.length === 0 && !searchValue && (
|
||||||
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
|
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
|
||||||
<span>No datasets here, add one by clicking +</span>
|
<span>No datasets here, add one by clicking +</span>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
{datasets.length === 0 && searchValue && (
|
||||||
|
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
|
||||||
|
<span>No datasets found, please adjust your search term</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{datasets.map((dataset) => {
|
{datasets.map((dataset) => {
|
||||||
return (
|
return (
|
||||||
<Accordion
|
<Accordion
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
import classNames from "classnames";
|
import classNames from "classnames";
|
||||||
import { useCallback, useEffect } from "react";
|
import { useCallback, useEffect } from "react";
|
||||||
|
|
||||||
import { fetch, isCloudEnvironment, useBoolean } from "@/utils";
|
import { fetch, isCloudApiKeySet, isCloudEnvironment, useBoolean } from "@/utils";
|
||||||
import { checkCloudConnection } from "@/modules/cloud";
|
import { checkCloudConnection } from "@/modules/cloud";
|
||||||
import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons";
|
import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons";
|
||||||
import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
|
import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
|
||||||
|
|
||||||
import DatasetsAccordion, { DatasetsAccordionProps } from "./DatasetsAccordion";
|
import DatasetsAccordion, { DatasetsAccordionProps } from "./DatasetsAccordion";
|
||||||
|
|
||||||
type InstanceDatasetsAccordionProps = Omit<DatasetsAccordionProps, "title">;
|
interface InstanceDatasetsAccordionProps extends Omit<DatasetsAccordionProps, "title"> {
|
||||||
|
searchValue: string;
|
||||||
|
}
|
||||||
|
|
||||||
export default function InstanceDatasetsAccordion({ onDatasetsChange }: InstanceDatasetsAccordionProps) {
|
export default function InstanceDatasetsAccordion({ searchValue, onDatasetsChange }: InstanceDatasetsAccordionProps) {
|
||||||
const {
|
const {
|
||||||
value: isLocalCogneeConnected,
|
value: isLocalCogneeConnected,
|
||||||
setTrue: setLocalCogneeConnected,
|
setTrue: setLocalCogneeConnected,
|
||||||
|
|
@ -19,7 +21,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
||||||
const {
|
const {
|
||||||
value: isCloudCogneeConnected,
|
value: isCloudCogneeConnected,
|
||||||
setTrue: setCloudCogneeConnected,
|
setTrue: setCloudCogneeConnected,
|
||||||
} = useBoolean(isCloudEnvironment());
|
} = useBoolean(isCloudEnvironment() || isCloudApiKeySet());
|
||||||
|
|
||||||
const checkConnectionToCloudCognee = useCallback((apiKey?: string) => {
|
const checkConnectionToCloudCognee = useCallback((apiKey?: string) => {
|
||||||
if (apiKey) {
|
if (apiKey) {
|
||||||
|
|
@ -71,6 +73,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
searchValue={searchValue}
|
||||||
tools={isLocalCogneeConnected ? <span className="text-xs text-indigo-600">Connected</span> : <span className="text-xs text-gray-400">Not connected</span>}
|
tools={isLocalCogneeConnected ? <span className="text-xs text-indigo-600">Connected</span> : <span className="text-xs text-gray-400">Not connected</span>}
|
||||||
switchCaretPosition={true}
|
switchCaretPosition={true}
|
||||||
className="pt-3 pb-1.5"
|
className="pt-3 pb-1.5"
|
||||||
|
|
@ -88,6 +91,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
searchValue={searchValue}
|
||||||
tools={<span className="text-xs text-indigo-600">Connected</span>}
|
tools={<span className="text-xs text-indigo-600">Connected</span>}
|
||||||
switchCaretPosition={true}
|
switchCaretPosition={true}
|
||||||
className="pt-3 pb-1.5"
|
className="pt-3 pb-1.5"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
import { fetch } from '@/utils';
|
import { fetch } from '@/utils';
|
||||||
import { DataFile } from './useData';
|
import { DataFile } from './useData';
|
||||||
|
|
@ -11,7 +11,20 @@ export interface Dataset {
|
||||||
status: string;
|
status: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
function useDatasets(useCloud = false) {
|
function filterDatasets(datasets: Dataset[], searchValue: string) {
|
||||||
|
if (searchValue.trim() === "") {
|
||||||
|
return datasets;
|
||||||
|
}
|
||||||
|
|
||||||
|
const lowercaseSearchValue = searchValue.toLowerCase();
|
||||||
|
|
||||||
|
return datasets.filter((dataset) =>
|
||||||
|
dataset.name.toLowerCase().includes(lowercaseSearchValue)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function useDatasets(useCloud = false, searchValue: string = "") {
|
||||||
|
const allDatasets = useRef<Dataset[]>([]);
|
||||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
// const statusTimeout = useRef<any>(null);
|
// const statusTimeout = useRef<any>(null);
|
||||||
|
|
@ -57,26 +70,32 @@ function useDatasets(useCloud = false) {
|
||||||
// };
|
// };
|
||||||
// }, []);
|
// }, []);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
setDatasets(filterDatasets(allDatasets.current, searchValue));
|
||||||
|
}, [searchValue]);
|
||||||
|
|
||||||
const addDataset = useCallback((datasetName: string) => {
|
const addDataset = useCallback((datasetName: string) => {
|
||||||
return createDataset({ name: datasetName }, useCloud)
|
return createDataset({ name: datasetName }, useCloud)
|
||||||
.then((dataset) => {
|
.then((dataset) => {
|
||||||
setDatasets((datasets) => [
|
const newDatasets = [
|
||||||
...datasets,
|
...allDatasets.current,
|
||||||
dataset,
|
dataset,
|
||||||
]);
|
];
|
||||||
|
allDatasets.current = newDatasets;
|
||||||
|
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||||
});
|
});
|
||||||
}, [useCloud]);
|
}, [searchValue, useCloud]);
|
||||||
|
|
||||||
const removeDataset = useCallback((datasetId: string) => {
|
const removeDataset = useCallback((datasetId: string) => {
|
||||||
return fetch(`/v1/datasets/${datasetId}`, {
|
return fetch(`/v1/datasets/${datasetId}`, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
}, useCloud)
|
}, useCloud)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
setDatasets((datasets) =>
|
const newDatasets = allDatasets.current.filter((dataset) => dataset.id !== datasetId)
|
||||||
datasets.filter((dataset) => dataset.id !== datasetId)
|
allDatasets.current = newDatasets;
|
||||||
);
|
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||||
});
|
});
|
||||||
}, [useCloud]);
|
}, [searchValue, useCloud]);
|
||||||
|
|
||||||
const fetchDatasets = useCallback(() => {
|
const fetchDatasets = useCallback(() => {
|
||||||
return fetch('/v1/datasets', {
|
return fetch('/v1/datasets', {
|
||||||
|
|
@ -86,7 +105,8 @@ function useDatasets(useCloud = false) {
|
||||||
}, useCloud)
|
}, useCloud)
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
.then((datasets) => {
|
.then((datasets) => {
|
||||||
setDatasets(datasets);
|
allDatasets.current = datasets;
|
||||||
|
setDatasets(filterDatasets(datasets, searchValue));
|
||||||
|
|
||||||
// if (datasets.length > 0) {
|
// if (datasets.length > 0) {
|
||||||
// checkDatasetStatuses(datasets);
|
// checkDatasetStatuses(datasets);
|
||||||
|
|
@ -97,28 +117,38 @@ function useDatasets(useCloud = false) {
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Error fetching datasets:', error);
|
console.error('Error fetching datasets:', error);
|
||||||
});
|
});
|
||||||
}, [useCloud]);
|
}, [searchValue, useCloud]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (allDatasets.current.length === 0) {
|
||||||
|
fetchDatasets();
|
||||||
|
}
|
||||||
|
}, [fetchDatasets]);
|
||||||
|
|
||||||
const getDatasetData = useCallback((datasetId: string) => {
|
const getDatasetData = useCallback((datasetId: string) => {
|
||||||
return fetch(`/v1/datasets/${datasetId}/data`, {}, useCloud)
|
return fetch(`/v1/datasets/${datasetId}/data`, {}, useCloud)
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
.then((data) => {
|
.then((data) => {
|
||||||
const datasetIndex = datasets.findIndex((dataset) => dataset.id === datasetId);
|
const datasetIndex = allDatasets.current.findIndex((dataset) => dataset.id === datasetId);
|
||||||
|
|
||||||
if (datasetIndex >= 0) {
|
if (datasetIndex >= 0) {
|
||||||
setDatasets((datasets) => [
|
const newDatasets = [
|
||||||
...datasets.slice(0, datasetIndex),
|
...allDatasets.current.slice(0, datasetIndex),
|
||||||
{
|
{
|
||||||
...datasets[datasetIndex],
|
...allDatasets.current[datasetIndex],
|
||||||
data,
|
data,
|
||||||
},
|
},
|
||||||
...datasets.slice(datasetIndex + 1),
|
...allDatasets.current.slice(datasetIndex + 1),
|
||||||
]);
|
];
|
||||||
|
|
||||||
|
allDatasets.current = newDatasets;
|
||||||
|
|
||||||
|
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||||
}
|
}
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
});
|
});
|
||||||
}, [datasets, useCloud]);
|
}, [searchValue, useCloud]);
|
||||||
|
|
||||||
const removeDatasetData = useCallback((datasetId: string, dataId: string) => {
|
const removeDatasetData = useCallback((datasetId: string, dataId: string) => {
|
||||||
return fetch(`/v1/datasets/${datasetId}/data/${dataId}`, {
|
return fetch(`/v1/datasets/${datasetId}/data/${dataId}`, {
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,19 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, List, Mapping
|
from typing import Any, Dict, Iterable, List, Mapping
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee import config, prune, search, SearchType, visualize_graph
|
from cognee import config, prune, search, SearchType, visualize_graph
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
|
from cognee.modules.users.models import User
|
||||||
from cognee.pipelines import run_tasks, Task
|
from cognee.pipelines import run_tasks, Task
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.data.methods import load_or_create_datasets
|
|
||||||
|
|
||||||
|
|
||||||
class Person(DataPoint):
|
class Person(DataPoint):
|
||||||
|
|
@ -76,18 +80,6 @@ def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
|
||||||
"""Collect people from payloads."""
|
|
||||||
people = [person for payload in payloads for person in payload.get("people", [])]
|
|
||||||
return people
|
|
||||||
|
|
||||||
|
|
||||||
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
|
||||||
"""Collect companies from payloads."""
|
|
||||||
companies = [company for payload in payloads for company in payload.get("companies", [])]
|
|
||||||
return companies
|
|
||||||
|
|
||||||
|
|
||||||
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
|
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
|
||||||
"""Build person nodes keyed by name."""
|
"""Build person nodes keyed by name."""
|
||||||
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
|
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
|
||||||
|
|
@ -176,10 +168,10 @@ def attach_employees_to_departments(
|
||||||
target.employees = employees
|
target.employees = employees
|
||||||
|
|
||||||
|
|
||||||
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
|
def build_companies(data: Data) -> list[Company]:
|
||||||
"""Build company nodes from payloads."""
|
"""Build company nodes from payloads."""
|
||||||
people = collect_people(payloads)
|
people = data.people
|
||||||
companies = collect_companies(payloads)
|
companies = data.companies
|
||||||
people_nodes = build_people_nodes(people)
|
people_nodes = build_people_nodes(people)
|
||||||
groups = group_people_by_department(people)
|
groups = group_people_by_department(people)
|
||||||
dept_names = collect_declared_departments(groups, companies)
|
dept_names = collect_declared_departments(groups, companies)
|
||||||
|
|
@ -192,19 +184,29 @@ def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def load_default_payload() -> list[Mapping[str, Any]]:
|
class Data(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
companies: List[Dict[str, Any]]
|
||||||
|
people: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def load_default_payload() -> Data:
|
||||||
"""Load the default payload from data files."""
|
"""Load the default payload from data files."""
|
||||||
companies = load_json_file(COMPANIES_JSON)
|
companies = load_json_file(COMPANIES_JSON)
|
||||||
people = load_json_file(PEOPLE_JSON)
|
people = load_json_file(PEOPLE_JSON)
|
||||||
payload = [{"companies": companies, "people": people}]
|
|
||||||
return payload
|
data = Data(
|
||||||
|
id=uuid4(),
|
||||||
|
companies=companies,
|
||||||
|
people=people,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def ingest_payloads(data: List[Any] | None) -> list[Company]:
|
def ingest_payloads(data: List[Data]) -> list[Company]:
|
||||||
"""Ingest payloads and build company nodes."""
|
"""Ingest payloads and build company nodes."""
|
||||||
if not data or data == [None]:
|
companies = build_companies(data[0])
|
||||||
data = load_default_payload()
|
|
||||||
companies = build_companies(data)
|
|
||||||
return companies
|
return companies
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -221,18 +223,16 @@ async def execute_pipeline() -> None:
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
# Get user and dataset
|
# Get user and dataset
|
||||||
user = await get_default_user()
|
user: User = await get_default_user() # type: ignore
|
||||||
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
|
dataset = Dataset(id=uuid4(), name="demo_dataset")
|
||||||
dataset_id = datasets[0].id
|
data = load_default_payload()
|
||||||
|
|
||||||
# Build and run pipeline
|
# Build and run pipeline
|
||||||
tasks = [Task(ingest_payloads), Task(add_data_points)]
|
tasks = [Task(ingest_payloads), Task(add_data_points)]
|
||||||
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
|
pipeline = run_tasks(tasks, dataset, [data], user, "demo_pipeline")
|
||||||
async for status in pipeline:
|
async for status in pipeline:
|
||||||
logging.info("Pipeline status: %s", status)
|
logging.info("Pipeline status: %s", status)
|
||||||
|
|
||||||
# Post-process: index graph edges and visualize
|
|
||||||
await index_graph_edges()
|
|
||||||
await visualize_graph(str(GRAPH_HTML))
|
await visualize_graph(str(GRAPH_HTML))
|
||||||
|
|
||||||
# Run query against graph
|
# Run query against graph
|
||||||
|
|
|
||||||
|
|
@ -85,13 +85,13 @@ async def run_code_graph_pipeline(
|
||||||
|
|
||||||
if include_docs:
|
if include_docs:
|
||||||
non_code_pipeline_run = run_tasks(
|
non_code_pipeline_run = run_tasks(
|
||||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
non_code_tasks, dataset, repo_path, user, "cognify_pipeline"
|
||||||
)
|
)
|
||||||
async for run_status in non_code_pipeline_run:
|
async for run_status in non_code_pipeline_run:
|
||||||
yield run_status
|
yield run_status
|
||||||
|
|
||||||
async for run_status in run_tasks(
|
async for run_status in run_tasks(
|
||||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
tasks, dataset, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||||
):
|
):
|
||||||
yield run_status
|
yield run_status
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
get_ontology_resolver_from_env,
|
get_ontology_resolver_from_env,
|
||||||
)
|
)
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
check_permissions_on_dataset,
|
check_permissions_on_dataset,
|
||||||
|
|
@ -208,6 +209,9 @@ async def cognify(
|
||||||
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
if temporal_cognify:
|
if temporal_cognify:
|
||||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,8 @@
|
||||||
import inspect
|
from uuid import UUID
|
||||||
from functools import wraps
|
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
||||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
|
||||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -19,121 +14,6 @@ EdgeData = Tuple[
|
||||||
Node = Tuple[str, NodeData] # (node_id, properties)
|
Node = Tuple[str, NodeData] # (node_id, properties)
|
||||||
|
|
||||||
|
|
||||||
def record_graph_changes(func):
|
|
||||||
"""
|
|
||||||
Decorator to record graph changes in the relationship database.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- func: The asynchronous function to wrap, which likely modifies graph data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the wrapped function that manages database relationships.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Wraps the given asynchronous function to handle database relationships.
|
|
||||||
|
|
||||||
Tracks the caller's function and class name for context. When the wrapped function is
|
|
||||||
called, it manages database relationships for nodes or edges by adding entries to a
|
|
||||||
ledger and committing the changes to the database session. Errors during relationship
|
|
||||||
addition or session commit are logged and will not disrupt the execution of the wrapped
|
|
||||||
function.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- *args: Positional arguments passed to the wrapped function.
|
|
||||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the result of the wrapped function call.
|
|
||||||
"""
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
while frame:
|
|
||||||
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
|
||||||
caller_frame = frame.f_back
|
|
||||||
break
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
caller_name = caller_frame.f_code.co_name
|
|
||||||
caller_class = (
|
|
||||||
caller_frame.f_locals.get("self", None).__class__.__name__
|
|
||||||
if caller_frame.f_locals.get("self", None)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
||||||
|
|
||||||
result = await func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
if func.__name__ == "add_nodes":
|
|
||||||
nodes: List[DataPoint] = args[0]
|
|
||||||
|
|
||||||
relationship_ledgers = []
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
node_id = UUID(str(node.id))
|
|
||||||
relationship_ledgers.append(
|
|
||||||
GraphRelationshipLedger(
|
|
||||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
||||||
source_node_id=node_id,
|
|
||||||
destination_node_id=node_id,
|
|
||||||
creator_function=f"{creator}.node",
|
|
||||||
node_label=getattr(node, "name", None) or str(node.id),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
session.add_all(relationship_ledgers)
|
|
||||||
await session.flush()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error adding relationship: {e}")
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
elif func.__name__ == "add_edges":
|
|
||||||
edges = args[0]
|
|
||||||
|
|
||||||
relationship_ledgers = []
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
source_id = UUID(str(edge[0]))
|
|
||||||
target_id = UUID(str(edge[1]))
|
|
||||||
rel_type = str(edge[2])
|
|
||||||
relationship_ledgers.append(
|
|
||||||
GraphRelationshipLedger(
|
|
||||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
||||||
source_node_id=source_id,
|
|
||||||
destination_node_id=target_id,
|
|
||||||
creator_function=f"{creator}.{rel_type}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
session.add_all(relationship_ledgers)
|
|
||||||
await session.flush()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error adding relationship: {e}")
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
try:
|
|
||||||
await session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error committing session: {e}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class GraphDBInterface(ABC):
|
class GraphDBInterface(ABC):
|
||||||
"""
|
"""
|
||||||
Define an interface for graph database operations to be implemented by concrete classes.
|
Define an interface for graph database operations to be implemented by concrete classes.
|
||||||
|
|
@ -189,7 +69,6 @@ class GraphDBInterface(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@record_graph_changes
|
|
||||||
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a single operation.
|
Add multiple nodes to the graph in a single operation.
|
||||||
|
|
@ -273,7 +152,6 @@ class GraphDBInterface(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@record_graph_changes
|
|
||||||
async def add_edges(
|
async def add_edges(
|
||||||
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from cognee.infrastructure.utils.run_sync import run_sync
|
||||||
from cognee.infrastructure.files.storage import get_file_storage
|
from cognee.infrastructure.files.storage import get_file_storage
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
record_graph_changes,
|
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
@ -378,7 +377,6 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add node: {e}")
|
logger.error(f"Failed to add node: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a batch operation.
|
Add multiple nodes to the graph in a batch operation.
|
||||||
|
|
@ -675,7 +673,6 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add edge: {e}")
|
logger.error(f"Failed to add edge: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges in a batch operation.
|
Add multiple edges in a batch operation.
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ from cognee.tasks.temporal_graph.models import Timestamp
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
record_graph_changes,
|
|
||||||
)
|
)
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
||||||
|
|
@ -175,7 +174,6 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
@override_distributed(queued_add_nodes)
|
@override_distributed(queued_add_nodes)
|
||||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -446,7 +444,6 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return flattened
|
return flattened
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
@override_distributed(queued_add_edges)
|
@override_distributed(queued_add_edges)
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from uuid import UUID
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
record_graph_changes,
|
|
||||||
NodeData,
|
NodeData,
|
||||||
EdgeData,
|
EdgeData,
|
||||||
Node,
|
Node,
|
||||||
|
|
@ -229,7 +228,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
logger.error(f"Failed to add node {node.id}: {error_msg}")
|
logger.error(f"Failed to add node {node.id}: {error_msg}")
|
||||||
raise Exception(f"Failed to add node: {error_msg}") from e
|
raise Exception(f"Failed to add node: {error_msg}") from e
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a single operation.
|
Add multiple nodes to the graph in a single operation.
|
||||||
|
|
@ -534,7 +532,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
||||||
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
|
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
|
||||||
raise Exception(f"Failed to add edge: {error_msg}") from e
|
raise Exception(f"Failed to add edge: {error_msg}") from e
|
||||||
|
|
||||||
@record_graph_changes
|
|
||||||
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges to the graph in a single operation.
|
Add multiple edges to the graph in a single operation.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from os import path
|
from os import path
|
||||||
|
from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
@ -282,7 +283,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
# Delete one at a time to avoid commit conflicts
|
# Delete one at a time to avoid commit conflicts
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional, get_type_hints
|
from typing import List, Optional, get_type_hints
|
||||||
|
from uuid import UUID
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
|
@ -384,7 +385,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Protocol, Optional, Union, Any
|
from typing import List, Protocol, Optional, Any
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from uuid import UUID
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBInterface(Protocol):
|
class VectorDBInterface(Protocol):
|
||||||
|
|
@ -127,9 +127,7 @@ class VectorDBInterface(Protocol):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_data_points(
|
async def delete_data_points(self, collection_name: str, data_point_ids: List[UUID]):
|
||||||
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Delete specified data points from a collection.
|
Delete specified data points from a collection.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import pickle
|
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -28,8 +27,6 @@ class DataPoint(BaseModel):
|
||||||
- update_version
|
- update_version
|
||||||
- to_json
|
- to_json
|
||||||
- from_json
|
- from_json
|
||||||
- to_pickle
|
|
||||||
- from_pickle
|
|
||||||
- to_dict
|
- to_dict
|
||||||
- from_dict
|
- from_dict
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from cognee.modules.data.models import Dataset
|
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||||
|
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset:
|
async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset:
|
||||||
owner_id = user.id
|
owner_id = user.id
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .generate_node_id import generate_node_id
|
from .generate_node_id import generate_node_id
|
||||||
|
from .generate_edge_id import generate_edge_id
|
||||||
from .generate_node_name import generate_node_name
|
from .generate_node_name import generate_node_name
|
||||||
from .generate_edge_name import generate_edge_name
|
from .generate_edge_name import generate_edge_name
|
||||||
from .generate_event_datapoint import generate_event_datapoint
|
from .generate_event_datapoint import generate_event_datapoint
|
||||||
|
|
|
||||||
|
|
@ -1 +1,8 @@
|
||||||
from .get_formatted_graph_data import get_formatted_graph_data
|
from .get_formatted_graph_data import get_formatted_graph_data
|
||||||
|
from .upsert_edges import upsert_edges
|
||||||
|
from .upsert_nodes import upsert_nodes
|
||||||
|
from .get_data_related_nodes import get_data_related_nodes
|
||||||
|
from .delete_data_related_nodes import delete_data_related_nodes
|
||||||
|
from .delete_data_related_edges import delete_data_related_edges
|
||||||
|
from .get_data_related_edges import get_data_related_edges
|
||||||
|
from .delete_data_nodes_and_edges import delete_data_nodes_and_edges
|
||||||
|
|
|
||||||
43
cognee/modules/graph/methods/delete_data_nodes_and_edges.py
Normal file
43
cognee/modules/graph/methods/delete_data_nodes_and_edges.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||||
|
from cognee.modules.engine.utils import generate_edge_id
|
||||||
|
from cognee.modules.graph.methods import (
|
||||||
|
delete_data_related_edges,
|
||||||
|
delete_data_related_nodes,
|
||||||
|
get_data_related_nodes,
|
||||||
|
get_data_related_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None:
|
||||||
|
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
|
||||||
|
|
||||||
|
affected_vector_collections: Dict[str, List] = {}
|
||||||
|
for node in affected_nodes:
|
||||||
|
for indexed_field in node.indexed_fields:
|
||||||
|
collection_name = f"{node.type}_{indexed_field}"
|
||||||
|
if collection_name not in affected_vector_collections:
|
||||||
|
affected_vector_collections[collection_name] = []
|
||||||
|
affected_vector_collections[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
for affected_collection, affected_nodes in affected_vector_collections.items():
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
affected_collection, [node.id for node in affected_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||||
|
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
"EdgeType_relationship_name",
|
||||||
|
[generate_edge_id(edge.relationship_name) for edge in affected_relationships],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_data_related_nodes(data_id)
|
||||||
|
await delete_data_related_edges(data_id)
|
||||||
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models import Edge
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
|
||||||
|
nodes = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()
|
||||||
|
|
||||||
|
await session.execute(delete(Edge).where(Edge.id.in_([node.id for node in nodes])))
|
||||||
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models import Node
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
|
||||||
|
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()
|
||||||
|
|
||||||
|
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
|
||||||
17
cognee/modules/graph/methods/get_data_related_edges.py
Normal file
17
cognee/modules/graph/methods/get_data_related_edges.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models import Edge
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: AsyncSession):
|
||||||
|
return (
|
||||||
|
await session.scalars(
|
||||||
|
select(Edge)
|
||||||
|
.where(Edge.data_id == data_id, Edge.dataset_id == dataset_id)
|
||||||
|
.distinct(Edge.relationship_name)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
26
cognee/modules/graph/methods/get_data_related_nodes.py
Normal file
26
cognee/modules/graph/methods/get_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import and_, exists, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models import Node
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession):
|
||||||
|
NodeAlias = Node.__table__.alias("n2")
|
||||||
|
|
||||||
|
subq = select(NodeAlias.c.id).where(
|
||||||
|
and_(
|
||||||
|
NodeAlias.c.slug == Node.slug,
|
||||||
|
NodeAlias.c.dataset_id == dataset_id,
|
||||||
|
NodeAlias.c.data_id != data_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_statement = select(Node).where(
|
||||||
|
and_(Node.data_id == data_id, Node.dataset_id == dataset_id, ~exists(subq))
|
||||||
|
)
|
||||||
|
|
||||||
|
data_related_nodes = await session.scalars(query_statement)
|
||||||
|
return data_related_nodes.all()
|
||||||
8
cognee/modules/graph/methods/set_current_user.py
Normal file
8
cognee/modules/graph/methods/set_current_user.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
async def set_current_user(session: AsyncSession, user_id: uuid.UUID, local: bool = False):
|
||||||
|
scope = "LOCAL " if local else ""
|
||||||
|
await session.execute(text(f"SET {scope}app.current_user_id = '{user_id}'"))
|
||||||
57
cognee/modules/graph/methods/upsert_edges.py
Normal file
57
cognee/modules/graph/methods/upsert_edges.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models.Edge import Edge
|
||||||
|
from .set_current_user import set_current_user
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def upsert_edges(
|
||||||
|
edges: List[Tuple[UUID, UUID, str, Dict[str, Any]]],
|
||||||
|
user_id: UUID,
|
||||||
|
data_id: UUID,
|
||||||
|
dataset_id: UUID,
|
||||||
|
session: AsyncSession,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds edges to the edges table.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- edges (list): A list of edges to be added to the graph.
|
||||||
|
"""
|
||||||
|
if session.get_bind().dialect.name == "postgresql":
|
||||||
|
# Set the session-level RLS variable
|
||||||
|
await set_current_user(session, user_id)
|
||||||
|
|
||||||
|
upsert_statement = (
|
||||||
|
insert(Edge)
|
||||||
|
.values(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": uuid5(
|
||||||
|
NAMESPACE_OID,
|
||||||
|
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
|
||||||
|
),
|
||||||
|
"user_id": user_id,
|
||||||
|
"data_id": data_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"source_node_id": edge[0],
|
||||||
|
"destination_node_id": edge[1],
|
||||||
|
"relationship_name": edge[2],
|
||||||
|
"label": edge[2],
|
||||||
|
"props": jsonable_encoder(edge[3]),
|
||||||
|
}
|
||||||
|
for edge in edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(index_elements=["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.execute(upsert_statement)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
51
cognee/modules/graph/methods/upsert_nodes.py
Normal file
51
cognee/modules/graph/methods/upsert_nodes.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import List
|
||||||
|
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.modules.graph.models import Node
|
||||||
|
from .set_current_user import set_current_user
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def upsert_nodes(
|
||||||
|
nodes: List[DataPoint], user_id: UUID, dataset_id: UUID, data_id: UUID, session: AsyncSession
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds nodes to the nodes table.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- nodes (list): A list of nodes to be added to the graph.
|
||||||
|
"""
|
||||||
|
if session.get_bind().dialect.name == "postgresql":
|
||||||
|
# Set the session-level RLS variable
|
||||||
|
await set_current_user(session, user_id)
|
||||||
|
|
||||||
|
upsert_statement = (
|
||||||
|
insert(Node)
|
||||||
|
.values(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": uuid5(
|
||||||
|
NAMESPACE_OID, str(user_id) + str(dataset_id) + str(data_id) + str(node.id)
|
||||||
|
),
|
||||||
|
"slug": node.id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"data_id": data_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"type": node.type,
|
||||||
|
"indexed_fields": DataPoint.get_embeddable_property_names(node),
|
||||||
|
"label": getattr(node, "label", getattr(node, "name", str(node.id))),
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(index_elements=["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.execute(upsert_statement)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
50
cognee/modules/graph/models/Edge.py
Normal file
50
cognee/modules/graph/models/Edge.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
from sqlalchemy import (
|
||||||
|
# event,
|
||||||
|
String,
|
||||||
|
JSON,
|
||||||
|
UUID,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from sqlalchemy.schema import DDL
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Edge(Base):
|
||||||
|
__tablename__ = "edges"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||||
|
|
||||||
|
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
|
||||||
|
|
||||||
|
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
relationship_name: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
|
||||||
|
label: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
props: Mapped[dict | None] = mapped_column(JSON)
|
||||||
|
|
||||||
|
# __table_args__ = (
|
||||||
|
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# Enable row-level security (RLS) for edges
|
||||||
|
# enable_edge_rls = DDL("""
|
||||||
|
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
|
||||||
|
# """)
|
||||||
|
# create_user_isolation_policy = DDL("""
|
||||||
|
# CREATE POLICY user_isolation_policy
|
||||||
|
# ON edges
|
||||||
|
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||||
|
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||||
|
# """)
|
||||||
|
|
||||||
|
# event.listen(Edge.__table__, "after_create", enable_edge_rls)
|
||||||
|
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)
|
||||||
53
cognee/modules/graph/models/Node.py
Normal file
53
cognee/modules/graph/models/Node.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
from sqlalchemy import (
|
||||||
|
Index,
|
||||||
|
# event,
|
||||||
|
String,
|
||||||
|
JSON,
|
||||||
|
UUID,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from sqlalchemy.schema import DDL
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Node(Base):
|
||||||
|
__tablename__ = "nodes"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||||
|
|
||||||
|
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
label: Mapped[str] = mapped_column(String(255))
|
||||||
|
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
indexed_fields: Mapped[list] = mapped_column(JSON)
|
||||||
|
|
||||||
|
# props: Mapped[dict] = mapped_column(JSON)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("index_node_dataset_slug", "dataset_id", "slug"),
|
||||||
|
Index("index_node_dataset_data", "dataset_id", "data_id"),
|
||||||
|
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Enable row-level security (RLS) for nodes
|
||||||
|
# enable_node_rls = DDL("""
|
||||||
|
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
|
||||||
|
# """)
|
||||||
|
# create_user_isolation_policy = DDL("""
|
||||||
|
# CREATE POLICY user_isolation_policy
|
||||||
|
# ON nodes
|
||||||
|
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||||
|
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||||
|
# """)
|
||||||
|
|
||||||
|
# event.listen(Node.__table__, "after_create", enable_node_rls)
|
||||||
|
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)
|
||||||
2
cognee/modules/graph/models/__init__.py
Normal file
2
cognee/modules/graph/models/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .Edge import Edge
|
||||||
|
from .Node import Node
|
||||||
|
|
@ -134,8 +134,8 @@ def _targets_generator(
|
||||||
|
|
||||||
async def get_graph_from_model(
|
async def get_graph_from_model(
|
||||||
data_point: DataPoint,
|
data_point: DataPoint,
|
||||||
added_nodes: Dict[str, bool],
|
added_nodes: Optional[Dict[str, bool]] = None,
|
||||||
added_edges: Dict[str, bool],
|
added_edges: Optional[Dict[str, bool]] = None,
|
||||||
visited_properties: Optional[Dict[str, bool]] = None,
|
visited_properties: Optional[Dict[str, bool]] = None,
|
||||||
include_root: bool = True,
|
include_root: bool = True,
|
||||||
) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
||||||
|
|
@ -152,6 +152,12 @@ async def get_graph_from_model(
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (nodes, edges) extracted from the model
|
Tuple of (nodes, edges) extracted from the model
|
||||||
"""
|
"""
|
||||||
|
if not added_nodes:
|
||||||
|
added_nodes = {}
|
||||||
|
|
||||||
|
if not added_edges:
|
||||||
|
added_edges = {}
|
||||||
|
|
||||||
if str(data_point.id) in added_nodes:
|
if str(data_point.id) in added_nodes:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from cognee.modules.pipelines.layers.setup_and_check_environment import (
|
from cognee.modules.pipelines.layers.setup_and_check_environment import (
|
||||||
setup_and_check_environment,
|
setup_and_check_environment,
|
||||||
|
|
@ -29,12 +29,13 @@ update_status_lock = asyncio.Lock()
|
||||||
async def run_pipeline(
|
async def run_pipeline(
|
||||||
tasks: list[Task],
|
tasks: list[Task],
|
||||||
data=None,
|
data=None,
|
||||||
datasets: Union[str, list[str], list[UUID]] = None,
|
datasets: Optional[Union[str, list[str], list[UUID]]] = None,
|
||||||
user: User = None,
|
user: Optional[User] = None,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
vector_db_config: dict = None,
|
vector_db_config: Optional[dict] = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: Optional[dict] = None,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
|
context: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
validate_pipeline_tasks(tasks)
|
validate_pipeline_tasks(tasks)
|
||||||
await setup_and_check_environment(vector_db_config, graph_db_config)
|
await setup_and_check_environment(vector_db_config, graph_db_config)
|
||||||
|
|
@ -48,8 +49,8 @@ async def run_pipeline(
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
data=data,
|
data=data,
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
context={"dataset": dataset},
|
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
|
context=context,
|
||||||
):
|
):
|
||||||
yield run_info
|
yield run_info
|
||||||
|
|
||||||
|
|
@ -58,16 +59,16 @@ async def run_pipeline_per_dataset(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
user: User,
|
user: User,
|
||||||
tasks: list[Task],
|
tasks: list[Task],
|
||||||
data=None,
|
data: Optional[list[Data]] = None,
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
context: dict = None,
|
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
|
context: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
data: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
data = await get_dataset_data(dataset_id=dataset.id)
|
||||||
|
|
||||||
process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name)
|
process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name)
|
||||||
if process_pipeline_status:
|
if process_pipeline_status:
|
||||||
|
|
@ -77,7 +78,7 @@ async def run_pipeline_per_dataset(
|
||||||
return
|
return
|
||||||
|
|
||||||
pipeline_run = run_tasks(
|
pipeline_run = run_tasks(
|
||||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
tasks, dataset, data, user, pipeline_name, context, incremental_loading
|
||||||
)
|
)
|
||||||
|
|
||||||
async for pipeline_run_info in pipeline_run:
|
async for pipeline_run_info in pipeline_run:
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
|
||||||
from typing import Any, List
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
|
||||||
log_pipeline_run_complete,
|
log_pipeline_run_complete,
|
||||||
log_pipeline_run_error,
|
log_pipeline_run_error,
|
||||||
)
|
)
|
||||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
|
||||||
from .run_tasks_data_item import run_tasks_data_item
|
from .run_tasks_data_item import run_tasks_data_item
|
||||||
from ..tasks.task import Task
|
from ..tasks.task import Task
|
||||||
|
|
||||||
|
|
@ -54,25 +53,18 @@ def override_run_tasks(new_gen):
|
||||||
@override_run_tasks(run_tasks_distributed)
|
@override_run_tasks(run_tasks_distributed)
|
||||||
async def run_tasks(
|
async def run_tasks(
|
||||||
tasks: List[Task],
|
tasks: List[Task],
|
||||||
dataset_id: UUID,
|
dataset: Dataset,
|
||||||
data: List[Any] = None,
|
data: Optional[List[Any]] = None,
|
||||||
user: User = None,
|
user: Optional[User] = None,
|
||||||
pipeline_name: str = "unknown_pipeline",
|
pipeline_name: str = "unknown_pipeline",
|
||||||
context: dict = None,
|
context: Optional[Dict] = None,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
):
|
):
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
# Get Dataset object
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
from cognee.modules.data.models import Dataset
|
|
||||||
|
|
||||||
dataset = await session.get(Dataset, dataset_id)
|
|
||||||
|
|
||||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset.id, data)
|
||||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||||
|
|
||||||
yield PipelineRunStarted(
|
yield PipelineRunStarted(
|
||||||
|
|
@ -99,7 +91,12 @@ async def run_tasks(
|
||||||
pipeline_name,
|
pipeline_name,
|
||||||
pipeline_id,
|
pipeline_id,
|
||||||
pipeline_run_id,
|
pipeline_run_id,
|
||||||
context,
|
{
|
||||||
|
**(context or {}),
|
||||||
|
"user": user,
|
||||||
|
"data": data_item,
|
||||||
|
"dataset": dataset,
|
||||||
|
},
|
||||||
user,
|
user,
|
||||||
incremental_loading,
|
incremental_loading,
|
||||||
)
|
)
|
||||||
|
|
@ -121,7 +118,7 @@ async def run_tasks(
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_pipeline_run_complete(
|
await log_pipeline_run_complete(
|
||||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data
|
||||||
)
|
)
|
||||||
|
|
||||||
yield PipelineRunCompleted(
|
yield PipelineRunCompleted(
|
||||||
|
|
@ -141,7 +138,7 @@ async def run_tasks(
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
await log_pipeline_run_error(
|
await log_pipeline_run_error(
|
||||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data, error
|
||||||
)
|
)
|
||||||
|
|
||||||
yield PipelineRunErrored(
|
yield PipelineRunErrored(
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Type, List, Optional
|
from typing import Dict, Type, List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
from cognee.tasks.storage import index_graph_edges
|
|
||||||
from cognee.tasks.storage.add_data_points import add_data_points
|
from cognee.tasks.storage.add_data_points import add_data_points
|
||||||
from cognee.modules.ontology.ontology_config import Config
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
|
@ -32,6 +31,7 @@ async def integrate_chunk_graphs(
|
||||||
chunk_graphs: list,
|
chunk_graphs: list,
|
||||||
graph_model: Type[BaseModel],
|
graph_model: Type[BaseModel],
|
||||||
ontology_resolver: BaseOntologyResolver,
|
ontology_resolver: BaseOntologyResolver,
|
||||||
|
context: Dict,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""Integrate chunk graphs with ontology validation and store in databases.
|
"""Integrate chunk graphs with ontology validation and store in databases.
|
||||||
|
|
||||||
|
|
@ -85,19 +85,19 @@ async def integrate_chunk_graphs(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(graph_nodes) > 0:
|
if len(graph_nodes) > 0:
|
||||||
await add_data_points(graph_nodes)
|
await add_data_points(graph_nodes, context)
|
||||||
|
|
||||||
if len(graph_edges) > 0:
|
if len(graph_edges) > 0:
|
||||||
await graph_engine.add_edges(graph_edges)
|
await graph_engine.add_edges(graph_edges)
|
||||||
await index_graph_edges(graph_edges)
|
|
||||||
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
|
|
||||||
async def extract_graph_from_data(
|
async def extract_graph_from_data(
|
||||||
data_chunks: List[DocumentChunk],
|
data_chunks: List[DocumentChunk],
|
||||||
|
context: Dict,
|
||||||
graph_model: Type[BaseModel],
|
graph_model: Type[BaseModel],
|
||||||
config: Config = None,
|
config: Optional[Config] = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -136,16 +136,16 @@ async def extract_graph_from_data(
|
||||||
and ontology_config.ontology_resolver
|
and ontology_config.ontology_resolver
|
||||||
and ontology_config.matching_strategy
|
and ontology_config.matching_strategy
|
||||||
):
|
):
|
||||||
config: Config = {
|
config = {
|
||||||
"ontology_config": {
|
"ontology_config": {
|
||||||
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config: Config = {
|
config = {"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}}
|
||||||
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
|
||||||
}
|
|
||||||
|
|
||||||
ontology_resolver = config["ontology_config"]["ontology_resolver"]
|
ontology_resolver = config["ontology_config"]["ontology_resolver"]
|
||||||
|
|
||||||
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)
|
return await integrate_chunk_graphs(
|
||||||
|
data_chunks, chunk_graphs, graph_model, ontology_resolver, context
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import Any, Dict, List, Optional
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.modules.graph.methods import upsert_edges, upsert_nodes
|
||||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||||
|
from cognee.modules.users.models import User
|
||||||
from .index_data_points import index_data_points
|
from .index_data_points import index_data_points
|
||||||
from .index_graph_edges import index_graph_edges
|
from .index_graph_edges import index_graph_edges
|
||||||
from cognee.tasks.storage.exceptions import (
|
from cognee.tasks.storage.exceptions import (
|
||||||
|
|
@ -10,7 +12,9 @@ from cognee.tasks.storage.exceptions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
async def add_data_points(
|
||||||
|
data_points: List[DataPoint], context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[DataPoint]:
|
||||||
"""
|
"""
|
||||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||||
deduplicating them, and indexing them for retrieval.
|
deduplicating them, and indexing them for retrieval.
|
||||||
|
|
@ -33,8 +37,15 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
- Deduplicates nodes and edges across all results.
|
- Deduplicates nodes and edges across all results.
|
||||||
- Updates the node index via `index_data_points`.
|
- Updates the node index via `index_data_points`.
|
||||||
- Inserts nodes and edges into the graph engine.
|
- Inserts nodes and edges into the graph engine.
|
||||||
- Optionally updates the edge index via `index_graph_edges`.
|
|
||||||
"""
|
"""
|
||||||
|
user: Optional[User] = None
|
||||||
|
data = None
|
||||||
|
dataset = None
|
||||||
|
|
||||||
|
if context:
|
||||||
|
data = context["data"]
|
||||||
|
dataset = context["dataset"]
|
||||||
|
user = context["user"]
|
||||||
|
|
||||||
if not isinstance(data_points, list):
|
if not isinstance(data_points, list):
|
||||||
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
|
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
|
||||||
|
|
@ -71,6 +82,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
await graph_engine.add_nodes(nodes)
|
await graph_engine.add_nodes(nodes)
|
||||||
await index_data_points(nodes)
|
await index_data_points(nodes)
|
||||||
|
|
||||||
|
if user and dataset and data:
|
||||||
|
await upsert_nodes(nodes, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
|
||||||
|
await upsert_edges(edges, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
|
||||||
|
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
await index_graph_edges(edges)
|
await index_graph_edges(edges)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,26 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
|
|
||||||
|
from cognee.infrastructure.files.storage import get_file_storage
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
|
|
||||||
async def build_graph_with_temporal_awareness(text_list):
|
|
||||||
url = os.getenv("GRAPH_DATABASE_URL")
|
async def build_graph_with_temporal_awareness(data: List[Data]):
|
||||||
password = os.getenv("GRAPH_DATABASE_PASSWORD")
|
text_list: List[str] = []
|
||||||
|
|
||||||
|
for text_data in data:
|
||||||
|
file_dir = os.path.dirname(text_data.raw_data_location)
|
||||||
|
file_name = os.path.basename(text_data.raw_data_location)
|
||||||
|
file_storage = get_file_storage(file_dir)
|
||||||
|
async with file_storage.open(file_name, "r") as file:
|
||||||
|
text_list.append(file.read())
|
||||||
|
|
||||||
|
url = os.getenv("GRAPH_DATABASE_URL", "")
|
||||||
|
password = os.getenv("GRAPH_DATABASE_PASSWORD", "")
|
||||||
graphiti = Graphiti(url, "neo4j", password)
|
graphiti = Graphiti(url, "neo4j", password)
|
||||||
|
|
||||||
await graphiti.build_indices_and_constraints()
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
@ -22,4 +35,5 @@ async def build_graph_with_temporal_awareness(text_list):
|
||||||
reference_time=datetime.now(),
|
reference_time=datetime.now(),
|
||||||
)
|
)
|
||||||
print(f"Added text: {text[:35]}...")
|
print(f"Added text: {text[:35]}...")
|
||||||
|
|
||||||
return graphiti
|
return graphiti
|
||||||
|
|
|
||||||
107
cognee/tests/test_delete_custom_graph.py
Normal file
107
cognee/tests/test_delete_custom_graph.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
from typing import List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.data.models import Data, Dataset
|
||||||
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
from cognee.modules.graph.methods import delete_data_nodes_and_edges
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_custom_graph")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_custom_graph")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Organization(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class ForProfit(Organization):
|
||||||
|
name: str = "For-Profit"
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class NonProfit(Organization):
|
||||||
|
name: str = "Non-Profit"
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: List[Organization]
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
companyA = ForProfit(name="Company A")
|
||||||
|
companyB = NonProfit(name="Company B")
|
||||||
|
|
||||||
|
person1 = Person(name="John", works_for=[companyA, companyB])
|
||||||
|
person2 = Person(name="Jane", works_for=[companyB])
|
||||||
|
|
||||||
|
user: User = await get_default_user() # type: ignore
|
||||||
|
|
||||||
|
dataset = Dataset(id=uuid4())
|
||||||
|
data1 = Data(id=uuid4())
|
||||||
|
data2 = Data(id=uuid4())
|
||||||
|
|
||||||
|
await add_data_points(
|
||||||
|
[person1],
|
||||||
|
context={
|
||||||
|
"user": user,
|
||||||
|
"dataset": dataset,
|
||||||
|
"data": data1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await add_data_points(
|
||||||
|
[person2],
|
||||||
|
context={
|
||||||
|
"user": user,
|
||||||
|
"dataset": dataset,
|
||||||
|
"data": data2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 4 and len(edges) == 3, (
|
||||||
|
"Nodes and edges are not correctly added to the graph."
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_data_nodes_and_edges(dataset.id, data1.id) # type: ignore
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
||||||
|
|
||||||
|
await delete_data_nodes_and_edges(dataset.id, data2.id) # type: ignore
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
77
cognee/tests/test_delete_default_graph.py
Normal file
77
cognee/tests/test_delete_default_graph.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.modules.data.methods import delete_data, get_dataset_data
|
||||||
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
from cognee.modules.graph.methods import (
|
||||||
|
delete_data_nodes_and_edges,
|
||||||
|
)
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_default_graph")
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
assert not await vector_engine.has_collection("Entity_name")
|
||||||
|
|
||||||
|
await cognee.add(
|
||||||
|
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
|
)
|
||||||
|
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||||
|
|
||||||
|
cognify_result: dict = await cognee.cognify()
|
||||||
|
dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
|
dataset_data = await get_dataset_data(dataset_id)
|
||||||
|
added_data = dataset_data[0]
|
||||||
|
|
||||||
|
file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
|
||||||
|
)
|
||||||
|
await visualize_graph(file_path)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted."
|
||||||
|
|
||||||
|
await delete_data_nodes_and_edges(dataset_id, added_data.id) # type: ignore
|
||||||
|
|
||||||
|
await delete_data(added_data)
|
||||||
|
|
||||||
|
file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
|
||||||
|
)
|
||||||
|
await visualize_graph(file_path)
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) >= 8 and len(edges) >= 10, "Nodes and edges are not deleted."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -100,7 +100,7 @@ async def main():
|
||||||
|
|
||||||
pipeline = run_tasks(
|
pipeline = run_tasks(
|
||||||
[Task(ingest_files), Task(add_data_points)],
|
[Task(ingest_files), Task(add_data_points)],
|
||||||
dataset_id=datasets[0].id,
|
dataset=datasets[0],
|
||||||
data=data,
|
data=data,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,18 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Dict, List
|
||||||
|
from uuid import NAMESPACE_OID, UUID, uuid4, uuid5
|
||||||
from neo4j import exceptions
|
from neo4j import exceptions
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee import prune
|
from cognee import prune
|
||||||
|
|
||||||
# from cognee import visualize_graph
|
# from cognee import visualize_graph
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.data.models import Data, Dataset
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.pipelines import run_tasks, Task
|
from cognee.pipelines import run_tasks, Task
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
|
|
@ -20,7 +25,6 @@ products_aggregator_node = Products()
|
||||||
|
|
||||||
|
|
||||||
class Product(DataPoint):
|
class Product(DataPoint):
|
||||||
id: str
|
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
price: float
|
price: float
|
||||||
|
|
@ -36,7 +40,6 @@ preferences_aggregator_node = Preferences()
|
||||||
|
|
||||||
|
|
||||||
class Preference(DataPoint):
|
class Preference(DataPoint):
|
||||||
id: str
|
|
||||||
name: str
|
name: str
|
||||||
value: str
|
value: str
|
||||||
is_type: Preferences = preferences_aggregator_node
|
is_type: Preferences = preferences_aggregator_node
|
||||||
|
|
@ -50,7 +53,6 @@ customers_aggregator_node = Customers()
|
||||||
|
|
||||||
|
|
||||||
class Customer(DataPoint):
|
class Customer(DataPoint):
|
||||||
id: str
|
|
||||||
name: str
|
name: str
|
||||||
has_preference: list[Preference]
|
has_preference: list[Preference]
|
||||||
purchased: list[Product]
|
purchased: list[Product]
|
||||||
|
|
@ -58,17 +60,14 @@ class Customer(DataPoint):
|
||||||
is_type: Customers = customers_aggregator_node
|
is_type: Customers = customers_aggregator_node
|
||||||
|
|
||||||
|
|
||||||
def ingest_files():
|
def ingest_customers(data):
|
||||||
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
|
|
||||||
customers = json.loads(open(customers_file_path, "r").read())
|
|
||||||
|
|
||||||
customers_data_points = {}
|
customers_data_points = {}
|
||||||
products_data_points = {}
|
products_data_points = {}
|
||||||
preferences_data_points = {}
|
preferences_data_points = {}
|
||||||
|
|
||||||
for customer in customers:
|
for customer in data[0].customers:
|
||||||
new_customer = Customer(
|
new_customer = Customer(
|
||||||
id=customer["id"],
|
id=uuid5(NAMESPACE_OID, customer["id"]),
|
||||||
name=customer["name"],
|
name=customer["name"],
|
||||||
liked=[],
|
liked=[],
|
||||||
purchased=[],
|
purchased=[],
|
||||||
|
|
@ -79,7 +78,7 @@ def ingest_files():
|
||||||
for product in customer["products"]:
|
for product in customer["products"]:
|
||||||
if product["id"] not in products_data_points:
|
if product["id"] not in products_data_points:
|
||||||
products_data_points[product["id"]] = Product(
|
products_data_points[product["id"]] = Product(
|
||||||
id=product["id"],
|
id=uuid5(NAMESPACE_OID, product["id"]),
|
||||||
type=product["type"],
|
type=product["type"],
|
||||||
name=product["name"],
|
name=product["name"],
|
||||||
price=product["price"],
|
price=product["price"],
|
||||||
|
|
@ -96,7 +95,7 @@ def ingest_files():
|
||||||
for preference in customer["preferences"]:
|
for preference in customer["preferences"]:
|
||||||
if preference["id"] not in preferences_data_points:
|
if preference["id"] not in preferences_data_points:
|
||||||
preferences_data_points[preference["id"]] = Preference(
|
preferences_data_points[preference["id"]] = Preference(
|
||||||
id=preference["id"],
|
id=uuid5(NAMESPACE_OID, preference["id"]),
|
||||||
name=preference["name"],
|
name=preference["name"],
|
||||||
value=preference["value"],
|
value=preference["value"],
|
||||||
)
|
)
|
||||||
|
|
@ -104,7 +103,7 @@ def ingest_files():
|
||||||
new_preference = preferences_data_points[preference["id"]]
|
new_preference = preferences_data_points[preference["id"]]
|
||||||
new_customer.has_preference.append(new_preference)
|
new_customer.has_preference.append(new_preference)
|
||||||
|
|
||||||
return customers_data_points.values()
|
return list(customers_data_points.values())
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
@ -113,7 +112,28 @@ async def main():
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
|
# Get user and dataset
|
||||||
|
user: User = await get_default_user() # type: ignore
|
||||||
|
main_dataset = Dataset(id=uuid4(), name="demo_dataset")
|
||||||
|
|
||||||
|
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
|
||||||
|
customers = json.loads(open(customers_file_path, "r").read())
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
customers: List[Dict]
|
||||||
|
|
||||||
|
data = Data(
|
||||||
|
id=uuid4(),
|
||||||
|
customers=customers,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = run_tasks(
|
||||||
|
[Task(ingest_customers), Task(add_data_points)],
|
||||||
|
dataset=main_dataset,
|
||||||
|
data=[data],
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
async for status in pipeline:
|
async for status in pipeline:
|
||||||
print(status)
|
print(status)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
|
from cognee.modules.data.methods import get_dataset_data, get_datasets
|
||||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||||
from cognee.modules.pipelines import Task, run_tasks
|
from cognee.modules.pipelines import Task, run_tasks
|
||||||
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
|
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
|
||||||
|
|
@ -35,10 +36,13 @@ async def main():
|
||||||
await cognee.add(text)
|
await cognee.add(text)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(build_graph_with_temporal_awareness, text_list=text_list),
|
Task(build_graph_with_temporal_awareness),
|
||||||
]
|
]
|
||||||
|
|
||||||
pipeline = run_tasks(tasks, user=user)
|
datasets = await get_datasets(user.id)
|
||||||
|
dataset_data = await get_dataset_data(datasets[0].id) # type: ignore
|
||||||
|
|
||||||
|
pipeline = run_tasks(tasks, dataset=datasets[0], data=dataset_data, user=user)
|
||||||
|
|
||||||
async for result in pipeline:
|
async for result in pipeline:
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
||||||
8
notebooks/cognee_demo.ipynb
vendored
8
notebooks/cognee_demo.ipynb
vendored
|
|
@ -483,7 +483,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"id": "7c431fdef4921ae0",
|
"id": "7c431fdef4921ae0",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -535,7 +535,7 @@
|
||||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
|
" pipeline_run = run_tasks(tasks, dataset, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
|
||||||
" pipeline_run_status = None\n",
|
" pipeline_run_status = None\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async for run_status in pipeline_run:\n",
|
" async for run_status in pipeline_run:\n",
|
||||||
|
|
@ -1831,7 +1831,7 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": ".venv",
|
"display_name": "cognee",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
|
@ -1845,7 +1845,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.7"
|
"version": "3.10.13"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
||||||
7
poetry.lock
generated
7
poetry.lock
generated
|
|
@ -9310,6 +9310,13 @@ optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
|
||||||
|
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
|
||||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
||||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
||||||
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -856,7 +856,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue