feat: implement full delete feature

This commit is contained in:
Boris Arzentar 2025-10-12 22:23:07 +02:00
parent a8dab3019e
commit caf4801a6b
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
43 changed files with 783 additions and 277 deletions

View file

@ -108,6 +108,13 @@ export default function Dashboard({ accessToken }: DashboardProps) {
setDatasets(datasets);
}, []);
const [searchValue, setSearchValue] = useState<string>("");
const handleSearchDatasetInputChange = useCallback((event: React.ChangeEvent<HTMLInputElement>) => {
const newSearchValue = event.currentTarget.value;
setSearchValue(newSearchValue);
}, []);
const isCloudEnv = isCloudEnvironment();
return (
@ -129,7 +136,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl h-[calc(100%-2.75rem)]">
<div className="relative mb-2">
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
<input id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
<input onChange={handleSearchDatasetInputChange} id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
</div>
<AddDataToCognee
@ -148,6 +155,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
<div className="mt-7 mb-14">
<CogneeInstancesAccordion>
<InstanceDatasetsAccordion
searchValue={searchValue}
onDatasetsChange={handleDatasetsChange}
/>
</CogneeInstancesAccordion>

View file

@ -19,11 +19,13 @@ interface DatasetsChangePayload {
export interface DatasetsAccordionProps extends Omit<AccordionProps, "isOpen" | "openAccordion" | "closeAccordion" | "children"> {
onDatasetsChange?: (payload: DatasetsChangePayload) => void;
useCloud?: boolean;
searchValue: string;
}
export default function DatasetsAccordion({
title,
tools,
searchValue,
switchCaretPosition = false,
className,
contentClassName,
@ -43,13 +45,7 @@ export default function DatasetsAccordion({
removeDataset,
getDatasetData,
removeDatasetData,
} = useDatasets(useCloud);
useEffect(() => {
if (datasets.length === 0) {
refreshDatasets();
}
}, [datasets.length, refreshDatasets]);
} = useDatasets(useCloud, searchValue);
const [openDatasets, openDataset] = useState<Set<string>>(new Set());
@ -237,11 +233,16 @@ export default function DatasetsAccordion({
contentClassName={contentClassName}
>
<div className="flex flex-col">
{datasets.length === 0 && (
{datasets.length === 0 && !searchValue && (
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
<span>No datasets here, add one by clicking +</span>
</div>
)}
{datasets.length === 0 && searchValue && (
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
<span>No datasets found, please adjust your search term</span>
</div>
)}
{datasets.map((dataset) => {
return (
<Accordion

View file

@ -1,16 +1,18 @@
import classNames from "classnames";
import { useCallback, useEffect } from "react";
import { fetch, isCloudEnvironment, useBoolean } from "@/utils";
import { fetch, isCloudApiKeySet, isCloudEnvironment, useBoolean } from "@/utils";
import { checkCloudConnection } from "@/modules/cloud";
import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons";
import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
import DatasetsAccordion, { DatasetsAccordionProps } from "./DatasetsAccordion";
type InstanceDatasetsAccordionProps = Omit<DatasetsAccordionProps, "title">;
interface InstanceDatasetsAccordionProps extends Omit<DatasetsAccordionProps, "title"> {
searchValue: string;
}
export default function InstanceDatasetsAccordion({ onDatasetsChange }: InstanceDatasetsAccordionProps) {
export default function InstanceDatasetsAccordion({ searchValue, onDatasetsChange }: InstanceDatasetsAccordionProps) {
const {
value: isLocalCogneeConnected,
setTrue: setLocalCogneeConnected,
@ -19,7 +21,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
const {
value: isCloudCogneeConnected,
setTrue: setCloudCogneeConnected,
} = useBoolean(isCloudEnvironment());
} = useBoolean(isCloudEnvironment() || isCloudApiKeySet());
const checkConnectionToCloudCognee = useCallback((apiKey?: string) => {
if (apiKey) {
@ -71,6 +73,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
</div>
</div>
)}
searchValue={searchValue}
tools={isLocalCogneeConnected ? <span className="text-xs text-indigo-600">Connected</span> : <span className="text-xs text-gray-400">Not connected</span>}
switchCaretPosition={true}
className="pt-3 pb-1.5"
@ -88,6 +91,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
</div>
</div>
)}
searchValue={searchValue}
tools={<span className="text-xs text-indigo-600">Connected</span>}
switchCaretPosition={true}
className="pt-3 pb-1.5"

View file

@ -1,4 +1,4 @@
import { useCallback, useState } from 'react';
import { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react';
import { fetch } from '@/utils';
import { DataFile } from './useData';
@ -11,7 +11,20 @@ export interface Dataset {
status: string;
}
function useDatasets(useCloud = false) {
function filterDatasets(datasets: Dataset[], searchValue: string) {
if (searchValue.trim() === "") {
return datasets;
}
const lowercaseSearchValue = searchValue.toLowerCase();
return datasets.filter((dataset) =>
dataset.name.toLowerCase().includes(lowercaseSearchValue)
);
}
function useDatasets(useCloud = false, searchValue: string = "") {
const allDatasets = useRef<Dataset[]>([]);
const [datasets, setDatasets] = useState<Dataset[]>([]);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
// const statusTimeout = useRef<any>(null);
@ -57,26 +70,32 @@ function useDatasets(useCloud = false) {
// };
// }, []);
useLayoutEffect(() => {
setDatasets(filterDatasets(allDatasets.current, searchValue));
}, [searchValue]);
const addDataset = useCallback((datasetName: string) => {
return createDataset({ name: datasetName }, useCloud)
.then((dataset) => {
setDatasets((datasets) => [
...datasets,
const newDatasets = [
...allDatasets.current,
dataset,
]);
];
allDatasets.current = newDatasets;
setDatasets(filterDatasets(newDatasets, searchValue));
});
}, [useCloud]);
}, [searchValue, useCloud]);
const removeDataset = useCallback((datasetId: string) => {
return fetch(`/v1/datasets/${datasetId}`, {
method: 'DELETE',
}, useCloud)
.then(() => {
setDatasets((datasets) =>
datasets.filter((dataset) => dataset.id !== datasetId)
);
const newDatasets = allDatasets.current.filter((dataset) => dataset.id !== datasetId)
allDatasets.current = newDatasets;
setDatasets(filterDatasets(newDatasets, searchValue));
});
}, [useCloud]);
}, [searchValue, useCloud]);
const fetchDatasets = useCallback(() => {
return fetch('/v1/datasets', {
@ -86,7 +105,8 @@ function useDatasets(useCloud = false) {
}, useCloud)
.then((response) => response.json())
.then((datasets) => {
setDatasets(datasets);
allDatasets.current = datasets;
setDatasets(filterDatasets(datasets, searchValue));
// if (datasets.length > 0) {
// checkDatasetStatuses(datasets);
@ -97,28 +117,38 @@ function useDatasets(useCloud = false) {
.catch((error) => {
console.error('Error fetching datasets:', error);
});
}, [useCloud]);
}, [searchValue, useCloud]);
useEffect(() => {
if (allDatasets.current.length === 0) {
fetchDatasets();
}
}, [fetchDatasets]);
const getDatasetData = useCallback((datasetId: string) => {
return fetch(`/v1/datasets/${datasetId}/data`, {}, useCloud)
.then((response) => response.json())
.then((data) => {
const datasetIndex = datasets.findIndex((dataset) => dataset.id === datasetId);
const datasetIndex = allDatasets.current.findIndex((dataset) => dataset.id === datasetId);
if (datasetIndex >= 0) {
setDatasets((datasets) => [
...datasets.slice(0, datasetIndex),
{
...datasets[datasetIndex],
data,
},
...datasets.slice(datasetIndex + 1),
]);
const newDatasets = [
...allDatasets.current.slice(0, datasetIndex),
{
...allDatasets.current[datasetIndex],
data,
},
...allDatasets.current.slice(datasetIndex + 1),
];
allDatasets.current = newDatasets;
setDatasets(filterDatasets(newDatasets, searchValue));
}
return data;
});
}, [datasets, useCloud]);
}, [searchValue, useCloud]);
const removeDatasetData = useCallback((datasetId: string, dataId: string) => {
return fetch(`/v1/datasets/${datasetId}/data/${dataId}`, {

View file

@ -7,15 +7,19 @@ import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, List, Mapping
from typing import Any, Dict, Iterable, List, Mapping
from uuid import UUID, uuid4
from pydantic import BaseModel
from cognee import config, prune, search, SearchType, visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.modules.data.models import Dataset
from cognee.modules.users.models import User
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import load_or_create_datasets
class Person(DataPoint):
@ -76,18 +80,6 @@ def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
return out
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect people from payloads."""
people = [person for payload in payloads for person in payload.get("people", [])]
return people
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
"""Collect companies from payloads."""
companies = [company for payload in payloads for company in payload.get("companies", [])]
return companies
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
"""Build person nodes keyed by name."""
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
@ -176,10 +168,10 @@ def attach_employees_to_departments(
target.employees = employees
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
def build_companies(data: Data) -> list[Company]:
"""Build company nodes from payloads."""
people = collect_people(payloads)
companies = collect_companies(payloads)
people = data.people
companies = data.companies
people_nodes = build_people_nodes(people)
groups = group_people_by_department(people)
dept_names = collect_declared_departments(groups, companies)
@ -192,19 +184,29 @@ def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
return result
def load_default_payload() -> list[Mapping[str, Any]]:
class Data(BaseModel):
id: UUID
companies: List[Dict[str, Any]]
people: List[Dict[str, Any]]
def load_default_payload() -> Data:
"""Load the default payload from data files."""
companies = load_json_file(COMPANIES_JSON)
people = load_json_file(PEOPLE_JSON)
payload = [{"companies": companies, "people": people}]
return payload
data = Data(
id=uuid4(),
companies=companies,
people=people,
)
return data
def ingest_payloads(data: List[Any] | None) -> list[Company]:
def ingest_payloads(data: List[Data]) -> list[Company]:
"""Ingest payloads and build company nodes."""
if not data or data == [None]:
data = load_default_payload()
companies = build_companies(data)
companies = build_companies(data[0])
return companies
@ -221,18 +223,16 @@ async def execute_pipeline() -> None:
await setup()
# Get user and dataset
user = await get_default_user()
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
dataset_id = datasets[0].id
user: User = await get_default_user() # type: ignore
dataset = Dataset(id=uuid4(), name="demo_dataset")
data = load_default_payload()
# Build and run pipeline
tasks = [Task(ingest_payloads), Task(add_data_points)]
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
pipeline = run_tasks(tasks, dataset, [data], user, "demo_pipeline")
async for status in pipeline:
logging.info("Pipeline status: %s", status)
# Post-process: index graph edges and visualize
await index_graph_edges()
await visualize_graph(str(GRAPH_HTML))
# Run query against graph

View file

@ -85,13 +85,13 @@ async def run_code_graph_pipeline(
if include_docs:
non_code_pipeline_run = run_tasks(
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
non_code_tasks, dataset, repo_path, user, "cognify_pipeline"
)
async for run_status in non_code_pipeline_run:
yield run_status
async for run_status in run_tasks(
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
tasks, dataset, repo_path, user, "cognify_code_pipeline", incremental_loading=False
):
yield run_status

View file

@ -17,6 +17,7 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
get_ontology_resolver_from_env,
)
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.tasks.documents import (
check_permissions_on_dataset,
@ -208,6 +209,9 @@ async def cognify(
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
}
if user is None:
user = await get_default_user()
if temporal_cognify:
tasks = await get_temporal_tasks(user, chunker, chunk_size)
else:

View file

@ -1,13 +1,8 @@
import inspect
from functools import wraps
from uuid import UUID
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple, Type, Union
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
logger = get_logger()
@ -19,121 +14,6 @@ EdgeData = Tuple[
Node = Tuple[str, NodeData] # (node_id, properties)
def record_graph_changes(func):
"""
Decorator to record graph changes in the relationship database.
Parameters:
-----------
- func: The asynchronous function to wrap, which likely modifies graph data.
Returns:
--------
Returns the wrapped function that manages database relationships.
"""
@wraps(func)
async def wrapper(self, *args, **kwargs):
"""
Wraps the given asynchronous function to handle database relationships.
Tracks the caller's function and class name for context. When the wrapped function is
called, it manages database relationships for nodes or edges by adding entries to a
ledger and committing the changes to the database session. Errors during relationship
addition or session commit are logged and will not disrupt the execution of the wrapped
function.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function call.
"""
db_engine = get_relational_engine()
frame = inspect.currentframe()
while frame:
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
caller_frame = frame.f_back
break
frame = frame.f_back
caller_name = caller_frame.f_code.co_name
caller_class = (
caller_frame.f_locals.get("self", None).__class__.__name__
if caller_frame.f_locals.get("self", None)
else None
)
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
result = await func(self, *args, **kwargs)
async with db_engine.get_async_session() as session:
if func.__name__ == "add_nodes":
nodes: List[DataPoint] = args[0]
relationship_ledgers = []
for node in nodes:
node_id = UUID(str(node.id))
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=getattr(node, "name", None) or str(node.id),
)
)
try:
session.add_all(relationship_ledgers)
await session.flush()
except Exception as e:
logger.debug(f"Error adding relationship: {e}")
await session.rollback()
elif func.__name__ == "add_edges":
edges = args[0]
relationship_ledgers = []
for edge in edges:
source_id = UUID(str(edge[0]))
target_id = UUID(str(edge[1]))
rel_type = str(edge[2])
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",
)
)
try:
session.add_all(relationship_ledgers)
await session.flush()
except Exception as e:
logger.debug(f"Error adding relationship: {e}")
await session.rollback()
try:
await session.commit()
except Exception as e:
logger.debug(f"Error committing session: {e}")
return result
return wrapper
class GraphDBInterface(ABC):
"""
Define an interface for graph database operations to be implemented by concrete classes.
@ -189,7 +69,6 @@ class GraphDBInterface(ABC):
raise NotImplementedError
@abstractmethod
@record_graph_changes
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
"""
Add multiple nodes to the graph in a single operation.
@ -273,7 +152,6 @@ class GraphDBInterface(ABC):
raise NotImplementedError
@abstractmethod
@record_graph_changes
async def add_edges(
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
) -> None:

View file

@ -17,7 +17,6 @@ from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
@ -378,7 +377,6 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a batch operation.
@ -675,7 +673,6 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges in a batch operation.

View file

@ -16,7 +16,6 @@ from cognee.tasks.temporal_graph.models import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.modules.storage.utils import JSONEncoder
@ -175,7 +174,6 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
@record_graph_changes
@override_distributed(queued_add_nodes)
async def add_nodes(self, nodes: list[DataPoint]) -> None:
"""
@ -446,7 +444,6 @@ class Neo4jAdapter(GraphDBInterface):
return flattened
@record_graph_changes
@override_distributed(queued_add_edges)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
"""

View file

@ -6,7 +6,6 @@ from uuid import UUID
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
NodeData,
EdgeData,
Node,
@ -229,7 +228,6 @@ class NeptuneGraphDB(GraphDBInterface):
logger.error(f"Failed to add node {node.id}: {error_msg}")
raise Exception(f"Failed to add node: {error_msg}") from e
@record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a single operation.
@ -534,7 +532,6 @@ class NeptuneGraphDB(GraphDBInterface):
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
raise Exception(f"Failed to add edge: {error_msg}") from e
@record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
"""
Add multiple edges to the graph in a single operation.

View file

@ -1,5 +1,6 @@
import asyncio
from os import path
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
@ -282,7 +283,7 @@ class LanceDBAdapter(VectorDBInterface):
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts

View file

@ -1,5 +1,6 @@
import asyncio
from typing import List, Optional, get_type_hints
from uuid import UUID
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
@ -384,7 +385,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)

View file

@ -1,7 +1,7 @@
from typing import List, Protocol, Optional, Union, Any
from typing import List, Protocol, Optional, Any
from abc import abstractmethod
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol):
@ -127,9 +127,7 @@ class VectorDBInterface(Protocol):
raise NotImplementedError
@abstractmethod
async def delete_data_points(
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
):
async def delete_data_points(self, collection_name: str, data_point_ids: List[UUID]):
"""
Delete specified data points from a collection.

View file

@ -1,4 +1,3 @@
import pickle
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime, timezone
@ -28,8 +27,6 @@ class DataPoint(BaseModel):
- update_version
- to_json
- from_json
- to_pickle
- from_pickle
- to_dict
- from_dict
"""

View file

@ -1,12 +1,16 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from cognee.modules.data.models import Dataset
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.data.models import Dataset
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.models import User
@with_async_session
async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset:
owner_id = user.id

View file

@ -1,4 +1,5 @@
from .generate_node_id import generate_node_id
from .generate_edge_id import generate_edge_id
from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name
from .generate_event_datapoint import generate_event_datapoint

View file

@ -1 +1,8 @@
from .get_formatted_graph_data import get_formatted_graph_data
from .upsert_edges import upsert_edges
from .upsert_nodes import upsert_nodes
from .get_data_related_nodes import get_data_related_nodes
from .delete_data_related_nodes import delete_data_related_nodes
from .delete_data_related_edges import delete_data_related_edges
from .get_data_related_edges import get_data_related_edges
from .delete_data_nodes_and_edges import delete_data_nodes_and_edges

View file

@ -0,0 +1,43 @@
from uuid import UUID
from typing import Dict, List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
from cognee.modules.graph.methods import (
delete_data_related_edges,
delete_data_related_nodes,
get_data_related_nodes,
get_data_related_edges,
)
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None:
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
graph_engine = await get_graph_engine()
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
affected_vector_collections: Dict[str, List] = {}
for node in affected_nodes:
for indexed_field in node.indexed_fields:
collection_name = f"{node.type}_{indexed_field}"
if collection_name not in affected_vector_collections:
affected_vector_collections[collection_name] = []
affected_vector_collections[collection_name].append(node)
vector_engine = get_vector_engine()
for affected_collection, affected_nodes in affected_vector_collections.items():
await vector_engine.delete_data_points(
affected_collection, [node.id for node in affected_nodes]
)
affected_relationships = await get_data_related_edges(dataset_id, data_id)
await vector_engine.delete_data_points(
"EdgeType_relationship_name",
[generate_edge_id(edge.relationship_name) for edge in affected_relationships],
)
await delete_data_related_nodes(data_id)
await delete_data_related_edges(data_id)

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge
@with_async_session
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()
await session.execute(delete(Edge).where(Edge.id.in_([node.id for node in nodes])))

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
@with_async_session
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))

View file

@ -0,0 +1,17 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge
@with_async_session
async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: AsyncSession):
return (
await session.scalars(
select(Edge)
.where(Edge.data_id == data_id, Edge.dataset_id == dataset_id)
.distinct(Edge.relationship_name)
)
).all()

View file

@ -0,0 +1,26 @@
from uuid import UUID
from sqlalchemy import and_, exists, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
@with_async_session
async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession):
NodeAlias = Node.__table__.alias("n2")
subq = select(NodeAlias.c.id).where(
and_(
NodeAlias.c.slug == Node.slug,
NodeAlias.c.dataset_id == dataset_id,
NodeAlias.c.data_id != data_id,
)
)
query_statement = select(Node).where(
and_(Node.data_id == data_id, Node.dataset_id == dataset_id, ~exists(subq))
)
data_related_nodes = await session.scalars(query_statement)
return data_related_nodes.all()

View file

@ -0,0 +1,8 @@
import uuid
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
async def set_current_user(session: AsyncSession, user_id: uuid.UUID, local: bool = False):
scope = "LOCAL " if local else ""
await session.execute(text(f"SET {scope}app.current_user_id = '{user_id}'"))

View file

@ -0,0 +1,57 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from typing import Any, Dict, List, Tuple
from fastapi.encoders import jsonable_encoder
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models.Edge import Edge
from .set_current_user import set_current_user
@with_async_session
async def upsert_edges(
edges: List[Tuple[UUID, UUID, str, Dict[str, Any]]],
user_id: UUID,
data_id: UUID,
dataset_id: UUID,
session: AsyncSession,
):
"""
Adds edges to the edges table.
Parameters:
-----------
- edges (list): A list of edges to be added to the graph.
"""
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
upsert_statement = (
insert(Edge)
.values(
[
{
"id": uuid5(
NAMESPACE_OID,
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
),
"user_id": user_id,
"data_id": data_id,
"dataset_id": dataset_id,
"source_node_id": edge[0],
"destination_node_id": edge[1],
"relationship_name": edge[2],
"label": edge[2],
"props": jsonable_encoder(edge[3]),
}
for edge in edges
]
)
.on_conflict_do_nothing(index_elements=["id"])
)
await session.execute(upsert_statement)
await session.commit()

View file

@ -0,0 +1,51 @@
from typing import List
from uuid import NAMESPACE_OID, UUID, uuid5
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
from .set_current_user import set_current_user
@with_async_session
async def upsert_nodes(
nodes: List[DataPoint], user_id: UUID, dataset_id: UUID, data_id: UUID, session: AsyncSession
):
"""
Adds nodes to the nodes table.
Parameters:
-----------
- nodes (list): A list of nodes to be added to the graph.
"""
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
upsert_statement = (
insert(Node)
.values(
[
{
"id": uuid5(
NAMESPACE_OID, str(user_id) + str(dataset_id) + str(data_id) + str(node.id)
),
"slug": node.id,
"user_id": user_id,
"data_id": data_id,
"dataset_id": dataset_id,
"type": node.type,
"indexed_fields": DataPoint.get_embeddable_property_names(node),
"label": getattr(node, "label", getattr(node, "name", str(node.id))),
}
for node in nodes
]
)
.on_conflict_do_nothing(index_elements=["id"])
)
await session.execute(upsert_statement)
await session.commit()

View file

@ -0,0 +1,50 @@
from sqlalchemy import (
# event,
String,
JSON,
UUID,
)
# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column
from cognee.infrastructure.databases.relational import Base
class Edge(Base):
__tablename__ = "edges"
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
relationship_name: Mapped[str | None] = mapped_column(String(255))
label: Mapped[str | None] = mapped_column(String(255))
props: Mapped[dict | None] = mapped_column(JSON)
# __table_args__ = (
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
# )
# Enable row-level security (RLS) for edges
# enable_edge_rls = DDL("""
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON edges
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)
# event.listen(Edge.__table__, "after_create", enable_edge_rls)
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)

View file

@ -0,0 +1,53 @@
from sqlalchemy import (
Index,
# event,
String,
JSON,
UUID,
)
# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column
from cognee.infrastructure.databases.relational import Base
class Node(Base):
__tablename__ = "nodes"
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
label: Mapped[str] = mapped_column(String(255))
type: Mapped[str] = mapped_column(String(255), nullable=False)
indexed_fields: Mapped[list] = mapped_column(JSON)
# props: Mapped[dict] = mapped_column(JSON)
__table_args__ = (
Index("index_node_dataset_slug", "dataset_id", "slug"),
Index("index_node_dataset_data", "dataset_id", "data_id"),
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
)
# Enable row-level security (RLS) for nodes
# enable_node_rls = DDL("""
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON nodes
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)
# event.listen(Node.__table__, "after_create", enable_node_rls)
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)

View file

@ -0,0 +1,2 @@
from .Edge import Edge
from .Node import Node

View file

@ -134,8 +134,8 @@ def _targets_generator(
async def get_graph_from_model(
data_point: DataPoint,
added_nodes: Dict[str, bool],
added_edges: Dict[str, bool],
added_nodes: Optional[Dict[str, bool]] = None,
added_edges: Optional[Dict[str, bool]] = None,
visited_properties: Optional[Dict[str, bool]] = None,
include_root: bool = True,
) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]:
@ -152,6 +152,12 @@ async def get_graph_from_model(
Returns:
Tuple of (nodes, edges) extracted from the model
"""
if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}
if str(data_point.id) in added_nodes:
return [], []

View file

@ -1,6 +1,6 @@
import asyncio
from uuid import UUID
from typing import Union
from typing import Dict, Optional, Union
from cognee.modules.pipelines.layers.setup_and_check_environment import (
setup_and_check_environment,
@ -29,12 +29,13 @@ update_status_lock = asyncio.Lock()
async def run_pipeline(
tasks: list[Task],
data=None,
datasets: Union[str, list[str], list[UUID]] = None,
user: User = None,
datasets: Optional[Union[str, list[str], list[UUID]]] = None,
user: Optional[User] = None,
pipeline_name: str = "custom_pipeline",
vector_db_config: dict = None,
graph_db_config: dict = None,
vector_db_config: Optional[dict] = None,
graph_db_config: Optional[dict] = None,
incremental_loading: bool = False,
context: Optional[Dict] = None,
):
validate_pipeline_tasks(tasks)
await setup_and_check_environment(vector_db_config, graph_db_config)
@ -48,8 +49,8 @@ async def run_pipeline(
tasks=tasks,
data=data,
pipeline_name=pipeline_name,
context={"dataset": dataset},
incremental_loading=incremental_loading,
context=context,
):
yield run_info
@ -58,16 +59,16 @@ async def run_pipeline_per_dataset(
dataset: Dataset,
user: User,
tasks: list[Task],
data=None,
data: Optional[list[Data]] = None,
pipeline_name: str = "custom_pipeline",
context: dict = None,
incremental_loading=False,
context: Optional[Dict] = None,
):
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
await set_database_global_context_variables(dataset.id, dataset.owner_id)
if not data:
data: list[Data] = await get_dataset_data(dataset_id=dataset.id)
data = await get_dataset_data(dataset_id=dataset.id)
process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name)
if process_pipeline_status:
@ -77,7 +78,7 @@ async def run_pipeline_per_dataset(
return
pipeline_run = run_tasks(
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
tasks, dataset, data, user, pipeline_name, context, incremental_loading
)
async for pipeline_run_info in pipeline_run:

View file

@ -1,12 +1,12 @@
import os
import asyncio
from uuid import UUID
from typing import Any, List
from functools import wraps
from typing import Any, Dict, List, Optional
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
log_pipeline_run_complete,
log_pipeline_run_error,
)
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task
@ -54,25 +53,18 @@ def override_run_tasks(new_gen):
@override_run_tasks(run_tasks_distributed)
async def run_tasks(
tasks: List[Task],
dataset_id: UUID,
data: List[Any] = None,
user: User = None,
dataset: Dataset,
data: Optional[List[Any]] = None,
user: Optional[User] = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
context: Optional[Dict] = None,
incremental_loading: bool = False,
):
if not user:
user = await get_default_user()
# Get Dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset.id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
@ -99,7 +91,12 @@ async def run_tasks(
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
{
**(context or {}),
"user": user,
"data": data_item,
"dataset": dataset,
},
user,
incremental_loading,
)
@ -121,7 +118,7 @@ async def run_tasks(
)
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data
)
yield PipelineRunCompleted(
@ -141,7 +138,7 @@ async def run_tasks(
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data, error
)
yield PipelineRunErrored(

View file

@ -1,10 +1,9 @@
import asyncio
from typing import Type, List, Optional
from typing import Dict, Type, List, Optional
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.tasks.storage import index_graph_edges
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.get_default_ontology_resolver import (
@ -32,6 +31,7 @@ async def integrate_chunk_graphs(
chunk_graphs: list,
graph_model: Type[BaseModel],
ontology_resolver: BaseOntologyResolver,
context: Dict,
) -> List[DocumentChunk]:
"""Integrate chunk graphs with ontology validation and store in databases.
@ -85,19 +85,19 @@ async def integrate_chunk_graphs(
)
if len(graph_nodes) > 0:
await add_data_points(graph_nodes)
await add_data_points(graph_nodes, context)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
await index_graph_edges(graph_edges)
return data_chunks
async def extract_graph_from_data(
data_chunks: List[DocumentChunk],
context: Dict,
graph_model: Type[BaseModel],
config: Config = None,
config: Optional[Config] = None,
custom_prompt: Optional[str] = None,
) -> List[DocumentChunk]:
"""
@ -136,16 +136,16 @@ async def extract_graph_from_data(
and ontology_config.ontology_resolver
and ontology_config.matching_strategy
):
config: Config = {
config = {
"ontology_config": {
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
}
}
else:
config: Config = {
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
}
config = {"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}}
ontology_resolver = config["ontology_config"]["ontology_resolver"]
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)
return await integrate_chunk_graphs(
data_chunks, chunk_graphs, graph_model, ontology_resolver, context
)

View file

@ -1,8 +1,10 @@
import asyncio
from typing import List
from typing import Any, Dict, List, Optional
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.methods import upsert_edges, upsert_nodes
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
from cognee.modules.users.models import User
from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
from cognee.tasks.storage.exceptions import (
@ -10,7 +12,9 @@ from cognee.tasks.storage.exceptions import (
)
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
async def add_data_points(
data_points: List[DataPoint], context: Optional[Dict[str, Any]] = None
) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
@ -33,8 +37,15 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
- Deduplicates nodes and edges across all results.
- Updates the node index via `index_data_points`.
- Inserts nodes and edges into the graph engine.
- Optionally updates the edge index via `index_graph_edges`.
"""
user: Optional[User] = None
data = None
dataset = None
if context:
data = context["data"]
dataset = context["dataset"]
user = context["user"]
if not isinstance(data_points, list):
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
@ -71,6 +82,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
await graph_engine.add_nodes(nodes)
await index_data_points(nodes)
if user and dataset and data:
await upsert_nodes(nodes, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
await upsert_edges(edges, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
await graph_engine.add_edges(edges)
await index_graph_edges(edges)

View file

@ -1,13 +1,26 @@
import os
from typing import List
from datetime import datetime
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.data.models import Data
async def build_graph_with_temporal_awareness(text_list):
url = os.getenv("GRAPH_DATABASE_URL")
password = os.getenv("GRAPH_DATABASE_PASSWORD")
async def build_graph_with_temporal_awareness(data: List[Data]):
text_list: List[str] = []
for text_data in data:
file_dir = os.path.dirname(text_data.raw_data_location)
file_name = os.path.basename(text_data.raw_data_location)
file_storage = get_file_storage(file_dir)
async with file_storage.open(file_name, "r") as file:
text_list.append(file.read())
url = os.getenv("GRAPH_DATABASE_URL", "")
password = os.getenv("GRAPH_DATABASE_PASSWORD", "")
graphiti = Graphiti(url, "neo4j", password)
await graphiti.build_indices_and_constraints()
@ -22,4 +35,5 @@ async def build_graph_with_temporal_awareness(text_list):
reference_time=datetime.now(),
)
print(f"Added text: {text[:35]}...")
return graphiti

View file

@ -0,0 +1,107 @@
import os
import pathlib
from typing import List
from uuid import uuid4
import cognee
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models import Data, Dataset
from cognee.modules.engine.operations.setup import setup
from cognee.modules.graph.methods import delete_data_nodes_and_edges
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.tasks.storage import add_data_points
logger = get_logger()
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_custom_graph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_custom_graph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Organization(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class ForProfit(Organization):
name: str = "For-Profit"
metadata: dict = {"index_fields": ["name"]}
class NonProfit(Organization):
name: str = "Non-Profit"
metadata: dict = {"index_fields": ["name"]}
class Person(DataPoint):
name: str
works_for: List[Organization]
metadata: dict = {"index_fields": ["name"]}
companyA = ForProfit(name="Company A")
companyB = NonProfit(name="Company B")
person1 = Person(name="John", works_for=[companyA, companyB])
person2 = Person(name="Jane", works_for=[companyB])
user: User = await get_default_user() # type: ignore
dataset = Dataset(id=uuid4())
data1 = Data(id=uuid4())
data2 = Data(id=uuid4())
await add_data_points(
[person1],
context={
"user": user,
"dataset": dataset,
"data": data1,
},
)
await add_data_points(
[person2],
context={
"user": user,
"dataset": dataset,
"data": data2,
},
)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 4 and len(edges) == 3, (
"Nodes and edges are not correctly added to the graph."
)
await delete_data_nodes_and_edges(dataset.id, data1.id) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
await delete_data_nodes_and_edges(dataset.id, data2.id) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted."
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -0,0 +1,77 @@
import os
import pathlib
import cognee
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.methods import delete_data, get_dataset_data
from cognee.modules.engine.operations.setup import setup
from cognee.modules.graph.methods import (
delete_data_nodes_and_edges,
)
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_default_graph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
vector_engine = get_vector_engine()
assert not await vector_engine.has_collection("EdgeType_relationship_name")
assert not await vector_engine.has_collection("Entity_name")
await cognee.add(
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
)
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
cognify_result: dict = await cognee.cognify()
dataset_id = list(cognify_result.keys())[0]
dataset_data = await get_dataset_data(dataset_id)
added_data = dataset_data[0]
file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
)
await visualize_graph(file_path)
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted."
await delete_data_nodes_and_edges(dataset_id, added_data.id) # type: ignore
await delete_data(added_data)
file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
)
await visualize_graph(file_path)
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) >= 8 and len(edges) >= 10, "Nodes and edges are not deleted."
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -100,7 +100,7 @@ async def main():
pipeline = run_tasks(
[Task(ingest_files), Task(add_data_points)],
dataset_id=datasets[0].id,
dataset=datasets[0],
data=data,
incremental_loading=False,
)

View file

@ -1,13 +1,18 @@
import os
import json
import asyncio
from typing import Dict, List
from uuid import NAMESPACE_OID, UUID, uuid4, uuid5
from neo4j import exceptions
from pydantic import BaseModel
from cognee import prune
# from cognee import visualize_graph
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.low_level import setup, DataPoint
from cognee.modules.data.models import Data, Dataset
from cognee.modules.users.methods import get_default_user
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
@ -20,7 +25,6 @@ products_aggregator_node = Products()
class Product(DataPoint):
id: str
name: str
type: str
price: float
@ -36,7 +40,6 @@ preferences_aggregator_node = Preferences()
class Preference(DataPoint):
id: str
name: str
value: str
is_type: Preferences = preferences_aggregator_node
@ -50,7 +53,6 @@ customers_aggregator_node = Customers()
class Customer(DataPoint):
id: str
name: str
has_preference: list[Preference]
purchased: list[Product]
@ -58,17 +60,14 @@ class Customer(DataPoint):
is_type: Customers = customers_aggregator_node
def ingest_files():
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
customers = json.loads(open(customers_file_path, "r").read())
def ingest_customers(data):
customers_data_points = {}
products_data_points = {}
preferences_data_points = {}
for customer in customers:
for customer in data[0].customers:
new_customer = Customer(
id=customer["id"],
id=uuid5(NAMESPACE_OID, customer["id"]),
name=customer["name"],
liked=[],
purchased=[],
@ -79,7 +78,7 @@ def ingest_files():
for product in customer["products"]:
if product["id"] not in products_data_points:
products_data_points[product["id"]] = Product(
id=product["id"],
id=uuid5(NAMESPACE_OID, product["id"]),
type=product["type"],
name=product["name"],
price=product["price"],
@ -96,7 +95,7 @@ def ingest_files():
for preference in customer["preferences"]:
if preference["id"] not in preferences_data_points:
preferences_data_points[preference["id"]] = Preference(
id=preference["id"],
id=uuid5(NAMESPACE_OID, preference["id"]),
name=preference["name"],
value=preference["value"],
)
@ -104,7 +103,7 @@ def ingest_files():
new_preference = preferences_data_points[preference["id"]]
new_customer.has_preference.append(new_preference)
return customers_data_points.values()
return list(customers_data_points.values())
async def main():
@ -113,7 +112,28 @@ async def main():
await setup()
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
# Get user and dataset
user: User = await get_default_user() # type: ignore
main_dataset = Dataset(id=uuid4(), name="demo_dataset")
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
customers = json.loads(open(customers_file_path, "r").read())
class Data(BaseModel):
id: UUID
customers: List[Dict]
data = Data(
id=uuid4(),
customers=customers,
)
pipeline = run_tasks(
[Task(ingest_customers), Task(add_data_points)],
dataset=main_dataset,
data=[data],
user=user,
)
async for status in pipeline:
print(status)

View file

@ -1,6 +1,7 @@
import asyncio
import cognee
from cognee.modules.data.methods import get_dataset_data, get_datasets
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.modules.pipelines import Task, run_tasks
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
@ -35,10 +36,13 @@ async def main():
await cognee.add(text)
tasks = [
Task(build_graph_with_temporal_awareness, text_list=text_list),
Task(build_graph_with_temporal_awareness),
]
pipeline = run_tasks(tasks, user=user)
datasets = await get_datasets(user.id)
dataset_data = await get_dataset_data(datasets[0].id) # type: ignore
pipeline = run_tasks(tasks, dataset=datasets[0], data=dataset_data, user=user)
async for result in pipeline:
print(result)

View file

@ -483,7 +483,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "7c431fdef4921ae0",
"metadata": {
"ExecuteTime": {
@ -535,7 +535,7 @@
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
" pipeline_run = run_tasks(tasks, dataset, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
" pipeline_run_status = None\n",
"\n",
" async for run_status in pipeline_run:\n",
@ -1831,7 +1831,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "cognee",
"language": "python",
"name": "python3"
},
@ -1845,7 +1845,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.10.13"
}
},
"nbformat": 4,

7
poetry.lock generated
View file

@ -9310,6 +9310,13 @@ optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},

2
uv.lock generated
View file

@ -856,7 +856,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.3.4"
version = "0.3.5"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },