feat: add delete by document (#668)

<!-- .github/pull_request_template.md -->

## Description
Delete by document.

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin
This commit is contained in:
Daniel Molnar 2025-04-17 15:42:10 +02:00 committed by GitHub
parent af276b8999
commit 9ba12b25ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3435 additions and 258 deletions

View file

@ -162,6 +162,33 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: poetry run python ./cognee/tests/test_deduplication.py
run-deletion-test:
name: Deletion Test
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Deletion Tests
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: poetry run python ./cognee/tests/test_deletion.py
run-s3-bucket-test:
name: S3 Bucket Test
runs-on: ubuntu-22.04

View file

@ -1,4 +1,5 @@
from .api.v1.add import add
from .api.v1.delete import delete
from .api.v1.cognify import cognify
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets

View file

@ -43,11 +43,7 @@ def get_add_router() -> APIRouter:
return await cognee_add(file_data)
else:
await cognee_add(
data,
datasetId,
user=user,
)
await cognee_add(data, datasetId, user=user)
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})

View file

@ -58,7 +58,9 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
Task(classify_documents),
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config={"batch_size": 50},
),
Task(
summarize_text,

View file

@ -0,0 +1 @@
from .delete import delete

View file

@ -0,0 +1,247 @@
from typing import Union, BinaryIO, List
from cognee.modules.ingestion import classify
from cognee.infrastructure.databases.relational import get_relational_engine
from sqlalchemy import select
from sqlalchemy.sql import delete as sql_delete
from cognee.modules.data.models import Data, DatasetData, Dataset
from cognee.infrastructure.databases.graph import get_graph_engine
from io import StringIO, BytesIO
import hashlib
import asyncio
from uuid import UUID
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError
from cognee.shared.logging_utils import get_logger
logger = get_logger()
def get_text_content_hash(text: str) -> str:
encoded_text = text.encode("utf-8")
return hashlib.md5(encoded_text).hexdigest()
async def delete(
data: Union[BinaryIO, List[BinaryIO], str, List[str]],
dataset_name: str = "main_dataset",
mode: str = "soft",
):
"""Delete a document and all its related nodes from both relational and graph databases.
Args:
data: The data to delete (file, URL, or text)
dataset_name: Name of the dataset to delete from
mode: "soft" (default) or "hard" - hard mode also deletes degree-one entity nodes
"""
# Handle different input types
if isinstance(data, str):
if data.startswith("file://"): # It's a file path
with open(data.replace("file://", ""), mode="rb") as file:
classified_data = classify(file)
content_hash = classified_data.get_metadata()["content_hash"]
return await delete_single_document(content_hash, dataset_name, mode)
elif data.startswith("http"): # It's a URL
import requests
response = requests.get(data)
response.raise_for_status()
file_data = BytesIO(response.content)
classified_data = classify(file_data)
content_hash = classified_data.get_metadata()["content_hash"]
return await delete_single_document(content_hash, dataset_name, mode)
else: # It's a text string
content_hash = get_text_content_hash(data)
classified_data = classify(data)
return await delete_single_document(content_hash, dataset_name, mode)
elif isinstance(data, list):
# Handle list of inputs sequentially
results = []
for item in data:
result = await delete(item, dataset_name, mode)
results.append(result)
return {"status": "success", "message": "Multiple documents deleted", "results": results}
else: # It's already a BinaryIO
data.seek(0) # Ensure we're at the start of the file
classified_data = classify(data)
content_hash = classified_data.get_metadata()["content_hash"]
return await delete_single_document(content_hash, dataset_name, mode)
async def delete_single_document(content_hash: str, dataset_name: str, mode: str = "soft"):
"""Delete a single document by its content hash."""
# Delete from graph database
deletion_result = await delete_document_subgraph(content_hash, mode)
logger.info(f"Deletion result: {deletion_result}")
# Get the deleted node IDs and convert to UUID
deleted_node_ids = []
for node_id in deletion_result["deleted_node_ids"]:
try:
# Handle both string and UUID formats
if isinstance(node_id, str):
# Remove any hyphens if present
node_id = node_id.replace("-", "")
deleted_node_ids.append(UUID(node_id))
else:
deleted_node_ids.append(node_id)
except Exception as e:
logger.error(f"Error converting node ID {node_id} to UUID: {e}")
continue
# Delete from vector database
vector_engine = get_vector_engine()
# Determine vector collections dynamically
subclasses = get_all_subclasses(DataPoint)
vector_collections = []
for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_collections.append(f"{subclass.__name__}_{field_name}")
# If no collections found, use default collections
if not vector_collections:
vector_collections = [
"DocumentChunk_text",
"EdgeType_relationship_name",
"EntityType_name",
"Entity_name",
"TextDocument_name",
"TextSummary_text",
]
# Delete records from each vector collection that exists
for collection in vector_collections:
if await vector_engine.has_collection(collection):
await vector_engine.delete_data_points(
collection, [str(node_id) for node_id in deleted_node_ids]
)
# Delete from relational database
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
# Update graph_relationship_ledger with deleted_at timestamps
from sqlalchemy import update, and_, or_
from datetime import datetime
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
update_stmt = (
update(GraphRelationshipLedger)
.where(
or_(
GraphRelationshipLedger.source_node_id.in_(deleted_node_ids),
GraphRelationshipLedger.destination_node_id.in_(deleted_node_ids),
)
)
.values(deleted_at=datetime.now())
)
await session.execute(update_stmt)
# Get the data point
data_point = (
await session.execute(select(Data).filter(Data.content_hash == content_hash))
).scalar_one_or_none()
if data_point is None:
raise DocumentNotFoundError(
f"Document not found in relational DB with content hash: {content_hash}"
)
doc_id = data_point.id
# Get the dataset
dataset = (
await session.execute(select(Dataset).filter(Dataset.name == dataset_name))
).scalar_one_or_none()
if dataset is None:
raise DatasetNotFoundError(f"Dataset not found: {dataset_name}")
# Delete from dataset_data table
dataset_delete_stmt = sql_delete(DatasetData).where(
DatasetData.data_id == doc_id, DatasetData.dataset_id == dataset.id
)
await session.execute(dataset_delete_stmt)
# Check if the document is in any other datasets
remaining_datasets = (
await session.execute(select(DatasetData).filter(DatasetData.data_id == doc_id))
).scalar_one_or_none()
# If the document is not in any other datasets, delete it from the data table
if remaining_datasets is None:
data_delete_stmt = sql_delete(Data).where(Data.id == doc_id)
await session.execute(data_delete_stmt)
await session.commit()
return {
"status": "success",
"message": "Document deleted from both graph and relational databases",
"graph_deletions": deletion_result["deleted_counts"],
"content_hash": content_hash,
"dataset": dataset_name,
"deleted_node_ids": [
str(node_id) for node_id in deleted_node_ids
], # Convert back to strings for response
}
async def delete_document_subgraph(content_hash: str, mode: str = "soft"):
"""Delete a document and all its related nodes in the correct order."""
graph_db = await get_graph_engine()
subgraph = await graph_db.get_document_subgraph(content_hash)
if not subgraph:
raise DocumentSubgraphNotFoundError(f"Document not found with content hash: {content_hash}")
# Delete in the correct order to maintain graph integrity
deletion_order = [
("orphan_entities", "orphaned entities"),
("orphan_types", "orphaned entity types"),
(
"made_from_nodes",
"made_from nodes",
), # Move before chunks since summaries are connected to chunks
("chunks", "document chunks"),
("document", "document"),
]
deleted_counts = {}
deleted_node_ids = []
for key, description in deletion_order:
nodes = subgraph[key]
if nodes:
for node in nodes:
node_id = node["id"]
await graph_db.delete_node(node_id)
deleted_node_ids.append(node_id)
deleted_counts[description] = len(nodes)
# If hard mode, also delete degree-one nodes
if mode == "hard":
# Get and delete degree one entity nodes
degree_one_entity_nodes = await graph_db.get_degree_one_nodes("Entity")
for node in degree_one_entity_nodes:
await graph_db.delete_node(node["id"])
deleted_node_ids.append(node["id"])
deleted_counts["degree_one_entities"] = deleted_counts.get("degree_one_entities", 0) + 1
# Get and delete degree one entity types
degree_one_entity_types = await graph_db.get_degree_one_nodes("EntityType")
for node in degree_one_entity_types:
await graph_db.delete_node(node["id"])
deleted_node_ids.append(node["id"])
deleted_counts["degree_one_types"] = deleted_counts.get("degree_one_types", 0) + 1
return {
"status": "success",
"deleted_counts": deleted_counts,
"content_hash": content_hash,
"deleted_node_ids": deleted_node_ids,
}

View file

@ -0,0 +1,38 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class DocumentNotFoundError(CogneeApiError):
"""Raised when a document cannot be found in the database."""
def __init__(
self,
message: str = "Document not found in database.",
name: str = "DocumentNotFoundError",
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class DatasetNotFoundError(CogneeApiError):
"""Raised when a dataset cannot be found."""
def __init__(
self,
message: str = "Dataset not found.",
name: str = "DatasetNotFoundError",
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class DocumentSubgraphNotFoundError(CogneeApiError):
"""Raised when a document's subgraph cannot be found in the graph database."""
def __init__(
self,
message: str = "Document subgraph not found in graph database.",
name: str = "DocumentSubgraphNotFoundError",
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1 @@
from .get_delete_router import get_delete_router

View file

@ -0,0 +1,77 @@
from fastapi import Form, UploadFile, Depends
from fastapi.responses import JSONResponse
from fastapi import APIRouter
from typing import List, Optional
import subprocess
from cognee.shared.logging_utils import get_logger
import requests
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
logger = get_logger()
def get_delete_router() -> APIRouter:
router = APIRouter()
@router.delete("/", response_model=None)
async def delete(
data: List[UploadFile],
dataset_name: str = Form("main_dataset"),
mode: str = Form("soft"),
user: User = Depends(get_authenticated_user),
):
"""This endpoint is responsible for deleting data from the graph.
Args:
data: The data to delete (files, URLs, or text)
dataset_name: Name of the dataset to delete from (default: "main_dataset")
mode: "soft" (default) or "hard" - hard mode also deletes degree-one entity nodes
user: Authenticated user
"""
from cognee.api.v1.delete import delete as cognee_delete
try:
# Handle each file in the list
results = []
for file in data:
if file.filename.startswith("http"):
if "github" in file.filename:
# For GitHub repos, we need to get the content hash of each file
repo_name = file.filename.split("/")[-1].replace(".git", "")
subprocess.run(
["git", "clone", file.filename, f".data/{repo_name}"], check=True
)
# Note: This would need to be implemented to get content hashes of all files
# For now, we'll just return an error
return JSONResponse(
status_code=400,
content={"error": "Deleting GitHub repositories is not yet supported"},
)
else:
# Fetch and delete the data from other types of URL
response = requests.get(file.filename)
response.raise_for_status()
file_data = response.content
result = await cognee_delete(
file_data, dataset_name=dataset_name, mode=mode
)
results.append(result)
else:
# Handle uploaded file
result = await cognee_delete(file, dataset_name=dataset_name, mode=mode)
results.append(result)
if len(results) == 1:
return results[0]
else:
return {
"status": "success",
"message": "Multiple documents deleted",
"results": results,
}
except Exception as error:
logger.error(f"Error during deletion: {str(error)}")
return JSONResponse(status_code=409, content={"error": str(error)})
return router

View file

@ -1,61 +1,195 @@
from typing import Protocol, Optional, Dict, Any
from abc import abstractmethod
from typing import Protocol, Optional, Dict, Any, List, Tuple
from abc import abstractmethod, ABC
from uuid import UUID, uuid5, NAMESPACE_DNS
from cognee.modules.graph.relationship_manager import create_relationship
from functools import wraps
import inspect
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
from cognee.shared.logging_utils import get_logger
from datetime import datetime, timezone
logger = get_logger()
# Type aliases for better readability
NodeData = Dict[str, Any]
EdgeData = Tuple[
str, str, str, Dict[str, Any]
] # (source_id, target_id, relationship_name, properties)
Node = Tuple[str, NodeData] # (node_id, properties)
class GraphDBInterface(Protocol):
def record_graph_changes(func):
"""Decorator to record graph changes in the relationship database."""
db_engine = get_relational_engine()
@wraps(func)
async def wrapper(self, *args, **kwargs):
frame = inspect.currentframe()
while frame:
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
caller_frame = frame.f_back
break
frame = frame.f_back
caller_name = caller_frame.f_code.co_name
caller_class = (
caller_frame.f_locals.get("self", None).__class__.__name__
if caller_frame.f_locals.get("self", None)
else None
)
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
result = await func(self, *args, **kwargs)
async with db_engine.get_async_session() as session:
if func.__name__ == "add_nodes":
nodes = args[0]
for node in nodes:
try:
node_id = (
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
)
relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=node[1].get("type")
if isinstance(node, tuple)
else type(node).__name__,
)
session.add(relationship)
await session.flush()
except Exception as e:
logger.error(f"Error adding relationship: {e}")
await session.rollback()
continue
elif func.__name__ == "add_edges":
edges = args[0]
for edge in edges:
try:
source_id = UUID(str(edge[0]))
target_id = UUID(str(edge[1]))
rel_type = str(edge[2])
relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",
)
session.add(relationship)
await session.flush()
except Exception as e:
logger.error(f"Error adding relationship: {e}")
await session.rollback()
continue
try:
await session.commit()
except Exception as e:
logger.error(f"Error committing session: {e}")
return result
return wrapper
class GraphDBInterface(ABC):
"""Interface for graph database operations."""
@abstractmethod
async def query(self, query: str, params: dict):
async def query(self, query: str, params: dict) -> List[Any]:
"""Execute a raw query against the database."""
raise NotImplementedError
@abstractmethod
async def add_node(self, node_id: str, node_properties: dict):
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
"""Add a single node to the graph."""
raise NotImplementedError
@abstractmethod
async def add_nodes(self, nodes: list[tuple[str, dict]]):
@record_graph_changes
async def add_nodes(self, nodes: List[Node]) -> None:
"""Add multiple nodes to the graph."""
raise NotImplementedError
@abstractmethod
async def delete_node(self, node_id: str):
async def delete_node(self, node_id: str) -> None:
"""Delete a node from the graph."""
raise NotImplementedError
@abstractmethod
async def delete_nodes(self, node_ids: list[str]):
async def delete_nodes(self, node_ids: List[str]) -> None:
"""Delete multiple nodes from the graph."""
raise NotImplementedError
@abstractmethod
async def extract_node(self, node_id: str):
async def get_node(self, node_id: str) -> Optional[NodeData]:
"""Get a single node by ID."""
raise NotImplementedError
@abstractmethod
async def extract_nodes(self, node_ids: list[str]):
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
"""Get multiple nodes by their IDs."""
raise NotImplementedError
@abstractmethod
async def add_edge(
self,
from_node: str,
to_node: str,
source_id: str,
target_id: str,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = None,
):
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""Add a single edge to the graph."""
raise NotImplementedError
@abstractmethod
async def add_edges(self, edges: tuple[str, str, str, dict]):
@record_graph_changes
async def add_edges(self, edges: List[EdgeData]) -> None:
"""Add multiple edges to the graph."""
raise NotImplementedError
@abstractmethod
async def delete_graph(
self,
):
async def delete_graph(self) -> None:
"""Delete the entire graph."""
raise NotImplementedError
@abstractmethod
async def get_graph_data(self):
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
"""Get all nodes and edges in the graph."""
raise NotImplementedError
@abstractmethod
async def get_graph_metrics(self, include_optional):
""" "https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
"""Get graph metrics and statistics."""
raise NotImplementedError
@abstractmethod
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
"""Check if an edge exists."""
raise NotImplementedError
@abstractmethod
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
"""Check if multiple edges exist."""
raise NotImplementedError
@abstractmethod
async def get_edges(self, node_id: str) -> List[EdgeData]:
"""Get all edges connected to a node."""
raise NotImplementedError
@abstractmethod
async def get_neighbors(self, node_id: str) -> List[NodeData]:
"""Get all neighboring nodes."""
raise NotImplementedError
@abstractmethod
async def get_connections(
self, node_id: str
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
"""Get all nodes connected to a given node with their relationships."""
raise NotImplementedError

View file

@ -14,7 +14,10 @@ from concurrent.futures import ThreadPoolExecutor
import kuzu
from kuzu.database import Database
from kuzu import Connection
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
@ -43,7 +46,7 @@ class KuzuAdapter(GraphDBInterface):
self.connection.execute("""
CREATE NODE TABLE IF NOT EXISTS Node(
id STRING PRIMARY KEY,
text STRING,
name STRING,
type STRING,
created_at TIMESTAMP,
updated_at TIMESTAMP,
@ -138,12 +141,16 @@ class KuzuAdapter(GraphDBInterface):
query = """
MATCH (from:Node), (to:Node)
WHERE from.id = $from_id AND to.id = $to_id
CREATE (from)-[r:EDGE {
relationship_name: $relationship_name,
created_at: timestamp($created_at),
updated_at: timestamp($updated_at),
properties: $properties
MERGE (from)-[r:EDGE {
relationship_name: $relationship_name
}]->(to)
ON CREATE SET
r.created_at = timestamp($created_at),
r.updated_at = timestamp($updated_at),
r.properties = $properties
ON MATCH SET
r.updated_at = timestamp($updated_at),
r.properties = $properties
"""
params = {
"from_id": from_node,
@ -171,7 +178,7 @@ class KuzuAdapter(GraphDBInterface):
# Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"text": str(properties.get("text", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
@ -181,35 +188,33 @@ class KuzuAdapter(GraphDBInterface):
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
# Check if node exists
exists = await self.has_node(core_properties["id"])
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
fields = []
params = {}
for key, value in core_properties.items():
if value is not None:
param_name = f"param_{key}"
fields.append(f"{key}: ${param_name}")
params[param_name] = value
if not exists:
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
fields = []
params = {}
for key, value in core_properties.items():
if value is not None:
param_name = f"param_{key}"
fields.append(f"{key}: ${param_name}")
params[param_name] = value
# Add timestamp fields
fields.extend(
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
)
params.update({"created_at": now, "updated_at": now})
# Add timestamp fields
fields.extend(
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
)
params.update({"created_at": now, "updated_at": now})
create_query = f"""
CREATE (n:Node {{{", ".join(fields)}}})
"""
await self.query(create_query, params)
merge_query = f"""
MERGE (n:Node {{id: $param_id}})
ON CREATE SET n += {{{", ".join(fields)}}}
"""
await self.query(merge_query, params)
except Exception as e:
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""Add multiple nodes to the graph in a batch operation."""
if not nodes:
@ -218,15 +223,14 @@ class KuzuAdapter(GraphDBInterface):
try:
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
# Prepare all nodes data first
# Prepare all nodes data
node_params = []
for node in nodes:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
# Extract core fields
core_properties = {
"id": str(properties.get("id", "")),
"text": str(properties.get("text", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
@ -244,36 +248,24 @@ class KuzuAdapter(GraphDBInterface):
)
if node_params:
# First check which nodes don't exist yet
check_query = """
# Batch merge nodes
merge_query = """
UNWIND $nodes AS node
MATCH (n:Node)
WHERE n.id = node.id
RETURN n.id
MERGE (n:Node {id: node.id})
ON CREATE SET
n.name = node.name,
n.type = node.type,
n.properties = node.properties,
n.created_at = timestamp(node.created_at),
n.updated_at = timestamp(node.updated_at)
ON MATCH SET
n.name = node.name,
n.type = node.type,
n.properties = node.properties,
n.updated_at = timestamp(node.updated_at)
"""
existing_nodes = await self.query(check_query, {"nodes": node_params})
existing_ids = {str(row[0]) for row in existing_nodes}
# Filter out existing nodes
new_nodes = [node for node in node_params if node["id"] not in existing_ids]
if new_nodes:
# Batch create new nodes
create_query = """
UNWIND $nodes AS node
CREATE (n:Node {
id: node.id,
text: node.text,
type: node.type,
properties: node.properties,
created_at: timestamp(node.created_at),
updated_at: timestamp(node.updated_at)
})
"""
await self.query(create_query, {"nodes": new_nodes})
logger.debug(f"Added {len(new_nodes)} new nodes in batch")
else:
logger.debug("No new nodes to add - all nodes already exist")
await self.query(merge_query, {"nodes": node_params})
logger.debug(f"Processed {len(node_params)} nodes in batch")
except Exception as e:
logger.error(f"Failed to add nodes in batch: {e}")
@ -296,7 +288,7 @@ class KuzuAdapter(GraphDBInterface):
WHERE n.id = $id
RETURN {
id: n.id,
text: n.text,
name: n.name,
type: n.type,
properties: n.properties
}
@ -318,7 +310,7 @@ class KuzuAdapter(GraphDBInterface):
WHERE n.id IN $node_ids
RETURN {
id: n.id,
text: n.text,
name: n.name,
type: n.type,
properties: n.properties
}
@ -401,6 +393,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""Add multiple edges in a batch operation."""
if not edges:
@ -409,7 +402,6 @@ class KuzuAdapter(GraphDBInterface):
try:
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
# Transform edges into the format needed for batch insertion
edge_params = [
{
"from_id": from_node,
@ -422,17 +414,20 @@ class KuzuAdapter(GraphDBInterface):
for from_node, to_node, relationship_name, properties in edges
]
# Batch create query
query = """
UNWIND $edges AS edge
MATCH (from:Node), (to:Node)
WHERE from.id = edge.from_id AND to.id = edge.to_id
CREATE (from)-[r:EDGE {
relationship_name: edge.relationship_name,
created_at: timestamp(edge.created_at),
updated_at: timestamp(edge.updated_at),
properties: edge.properties
MERGE (from)-[r:EDGE {
relationship_name: edge.relationship_name
}]->(to)
ON CREATE SET
r.created_at = timestamp(edge.created_at),
r.updated_at = timestamp(edge.updated_at),
r.properties = edge.properties
ON MATCH SET
r.updated_at = timestamp(edge.updated_at),
r.properties = edge.properties
"""
await self.query(query, {"edges": edge_params})
@ -454,14 +449,14 @@ class KuzuAdapter(GraphDBInterface):
WHERE n.id = $node_id
RETURN {
id: n.id,
text: n.text,
name: n.name,
type: n.type,
properties: n.properties
},
r.relationship_name,
{
id: m.id,
text: m.text,
name: m.name,
type: m.type,
properties: m.properties
}
@ -481,6 +476,50 @@ class KuzuAdapter(GraphDBInterface):
# Neighbor Operations
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
"""Get all neighboring nodes."""
return await self.get_neighbours(node_id)
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""Get a single node by ID."""
query_str = """
MATCH (n:Node)
WHERE n.id = $id
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
result = await self.query(query_str, {"id": node_id})
if result and result[0]:
return self._parse_node(result[0][0])
return None
except Exception as e:
logger.error(f"Failed to get node {node_id}: {e}")
return None
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""Get multiple nodes by their IDs."""
query_str = """
MATCH (n:Node)
WHERE n.id IN $node_ids
RETURN {
id: n.id,
name: n.name,
type: n.type,
properties: n.properties
}
"""
try:
results = await self.query(query_str, {"node_ids": node_ids})
return [self._parse_node(row[0]) for row in results if row[0]]
except Exception as e:
logger.error(f"Failed to get nodes: {e}")
return []
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
"""Get all neighbouring nodes."""
query_str = """
@ -554,7 +593,7 @@ class KuzuAdapter(GraphDBInterface):
WHERE n.id = $node_id
RETURN {
id: n.id,
text: n.text,
name: n.name,
type: n.type,
properties: n.properties
},
@ -564,7 +603,7 @@ class KuzuAdapter(GraphDBInterface):
},
{
id: m.id,
text: m.text,
name: m.name,
type: m.type,
properties: m.properties
}
@ -625,7 +664,7 @@ class KuzuAdapter(GraphDBInterface):
nodes_query = """
MATCH (n:Node)
RETURN n.id, {
text: n.text,
name: n.name,
type: n.type,
properties: n.properties
}
@ -716,77 +755,36 @@ class KuzuAdapter(GraphDBInterface):
return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""For the definition of these metrics, please refer to
https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
try:
# Basic metrics
node_count = await self.query("MATCH (n:Node) RETURN COUNT(n)")
edge_count = await self.query("MATCH ()-[r:EDGE]->() RETURN COUNT(r)")
num_nodes = node_count[0][0] if node_count else 0
num_edges = edge_count[0][0] if edge_count else 0
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0
# Calculate mandatory metrics
mandatory_metrics = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
"edge_density": (num_edges) / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
"num_connected_components": await self._get_num_connected_components(),
"sizes_of_connected_components": await self._get_size_of_connected_components(),
}
# Calculate connected components
components_query = """
MATCH (n:Node)
WITH n.id AS node_id
MATCH path = (n)-[:EDGE*0..]-()
WITH COLLECT(DISTINCT node_id) AS component
RETURN COLLECT(component) AS components
"""
components_result = await self.query(components_query)
component_sizes = (
[len(comp) for comp in components_result[0][0]] if components_result else []
)
mandatory_metrics.update(
{
"num_connected_components": len(component_sizes),
"sizes_of_connected_components": component_sizes,
}
)
if include_optional:
# Self-loops
self_loops_query = """
MATCH (n:Node)-[r:EDGE]->(n)
RETURN COUNT(r)
"""
self_loops = await self.query(self_loops_query)
num_selfloops = self_loops[0][0] if self_loops else 0
# Shortest paths (simplified for Kuzu)
paths_query = """
MATCH (n:Node), (m:Node)
WHERE n.id < m.id
MATCH path = (n)-[:EDGE*]-(m)
RETURN MIN(LENGTH(path)) AS length
"""
paths = await self.query(paths_query)
path_lengths = [p[0] for p in paths if p[0] is not None]
# Local clustering coefficient
clustering_query = """
MATCH (n:Node)-[:EDGE]-(neighbor)
WITH n, COUNT(DISTINCT neighbor) as degree
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
WHERE n1 <> n2
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END)
"""
clustering = await self.query(clustering_query)
# Calculate optional metrics
shortest_path_lengths = await self._get_shortest_path_lengths()
optional_metrics = {
"num_selfloops": num_selfloops,
"diameter": max(path_lengths) if path_lengths else -1,
"avg_shortest_path_length": sum(path_lengths) / len(path_lengths)
if path_lengths
"num_selfloops": await self._count_self_loops(),
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
"avg_shortest_path_length": sum(shortest_path_lengths)
/ len(shortest_path_lengths)
if shortest_path_lengths
else -1,
"avg_clustering": clustering[0][0] if clustering and clustering[0][0] else -1,
"avg_clustering": await self._get_avg_clustering(),
}
else:
optional_metrics = {
@ -813,6 +811,65 @@ class KuzuAdapter(GraphDBInterface):
"avg_clustering": -1,
}
async def _get_num_connected_components(self) -> int:
"""Get the number of connected components in the graph."""
query = """
MATCH (n:Node)
WITH n.id AS node_id
MATCH path = (n)-[:EDGE*1..3]-(m)
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
RETURN SIZE(components) AS num_components
"""
result = await self.query(query)
return result[0][0] if result else 0
async def _get_size_of_connected_components(self) -> List[int]:
"""Get the sizes of all connected components in the graph."""
query = """
MATCH (n:Node)
WITH n.id AS node_id
MATCH path = (n)-[:EDGE*1..3]-(m)
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
UNWIND components AS component
RETURN SIZE(component) AS component_size
"""
result = await self.query(query)
return [row[0] for row in result] if result else []
async def _get_shortest_path_lengths(self) -> List[int]:
"""Get the lengths of shortest paths between all pairs of nodes."""
query = """
MATCH (n:Node), (m:Node)
WHERE n.id < m.id
MATCH path = (n)-[:EDGE*]-(m)
RETURN MIN(LENGTH(path)) AS length
"""
result = await self.query(query)
return [row[0] for row in result if row[0] is not None] if result else []
async def _count_self_loops(self) -> int:
"""Count the number of self-loops in the graph."""
query = """
MATCH (n:Node)-[r:EDGE]->(n)
RETURN COUNT(r) AS count
"""
result = await self.query(query)
return result[0][0] if result else 0
async def _get_avg_clustering(self) -> float:
"""Calculate the average clustering coefficient of the graph."""
query = """
MATCH (n:Node)-[:EDGE]-(neighbor)
WITH n, COUNT(DISTINCT neighbor) as degree
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
WHERE n1 <> n2
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END) AS avg_clustering
"""
result = await self.query(query)
return result[0][0] if result and result[0][0] is not None else -1
async def get_disconnected_nodes(self) -> List[str]:
"""Get nodes that are not connected to any other node."""
query_str = """
@ -847,10 +904,8 @@ class KuzuAdapter(GraphDBInterface):
async def delete_graph(self) -> None:
"""Delete all data from the graph while preserving the database structure."""
try:
# Delete relationships from the fixed table EDGE
await self.query("MATCH ()-[r:EDGE]->() DELETE r")
# Then delete nodes
await self.query("MATCH (n:Node) DELETE n")
# Use DETACH DELETE to remove both nodes and their relationships in one operation
await self.query("MATCH (n:Node) DETACH DELETE n")
logger.info("Cleared all data from graph while preserving structure")
except Exception as e:
logger.error(f"Failed to delete graph data: {e}")
@ -922,3 +977,73 @@ class KuzuAdapter(GraphDBInterface):
except Exception as e:
logger.error(f"Failed to import graph from file: {e}")
raise
async def get_document_subgraph(self, content_hash: str):
"""Get all nodes that should be deleted when removing a document."""
query = """
MATCH (doc:Node)
WHERE (doc.type = 'TextDocument' OR doc.type = 'PdfDocument') AND doc.name = $content_hash
OPTIONAL MATCH (doc)<-[e1:EDGE]-(chunk:Node)
WHERE e1.relationship_name = 'is_part_of' AND chunk.type = 'DocumentChunk'
OPTIONAL MATCH (chunk)-[e2:EDGE]->(entity:Node)
WHERE e2.relationship_name = 'contains' AND entity.type = 'Entity'
AND NOT EXISTS {
MATCH (entity)<-[e3:EDGE]-(otherChunk:Node)-[e4:EDGE]->(otherDoc:Node)
WHERE e3.relationship_name = 'contains'
AND e4.relationship_name = 'is_part_of'
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument')
AND otherDoc.id <> doc.id
}
OPTIONAL MATCH (chunk)<-[e5:EDGE]-(made_node:Node)
WHERE e5.relationship_name = 'made_from' AND made_node.type = 'TextSummary'
OPTIONAL MATCH (entity)-[e6:EDGE]->(type:Node)
WHERE e6.relationship_name = 'is_a' AND type.type = 'EntityType'
AND NOT EXISTS {
MATCH (type)<-[e7:EDGE]-(otherEntity:Node)-[e8:EDGE]-(otherChunk:Node)-[e9:EDGE]-(otherDoc:Node)
WHERE e7.relationship_name = 'is_a'
AND e8.relationship_name = 'contains'
AND e9.relationship_name = 'is_part_of'
AND otherEntity.type = 'Entity'
AND otherChunk.type = 'DocumentChunk'
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument')
AND otherDoc.id <> doc.id
}
RETURN
COLLECT(DISTINCT doc) as document,
COLLECT(DISTINCT chunk) as chunks,
COLLECT(DISTINCT entity) as orphan_entities,
COLLECT(DISTINCT made_node) as made_from_nodes,
COLLECT(DISTINCT type) as orphan_types
"""
result = await self.query(query, {"content_hash": f"text_{content_hash}"})
if not result or not result[0]:
return None
# Convert tuple to dictionary
return {
"document": result[0][0],
"chunks": result[0][1],
"orphan_entities": result[0][2],
"made_from_nodes": result[0][3],
"orphan_types": result[0][4],
}
async def get_degree_one_nodes(self, node_type: str):
"""Get all nodes that have only one connection."""
if not node_type or node_type not in ["Entity", "EntityType"]:
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
query = f"""
MATCH (n:Node)
WHERE n.type = '{node_type}'
WITH n, COUNT {{ MATCH (n)--() }} as degree
WHERE degree = 1
RETURN n
"""
result = await self.query(query)
return [record[0] for record in result] if result else []

View file

@ -1,3 +1,5 @@
#
"""Neo4j Adapter for Graph Database"""
import json
@ -11,7 +13,10 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.modules.storage.utils import JSONEncoder
from .neo4j_metrics_utils import (
get_avg_clustering,
@ -89,6 +94,7 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
@record_graph_changes
async def add_nodes(self, nodes: list[DataPoint]) -> None:
query = """
UNWIND $nodes AS node
@ -130,9 +136,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results]
async def delete_node(self, node_id: str):
node_id = id.replace(":", "_")
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
query = "MATCH (node {id: $node_id}) DETACH DELETE node"
params = {"node_id": node_id}
return await self.query(query, params)
@ -218,6 +222,7 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
@record_graph_changes
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
query = """
UNWIND $edges AS edge
@ -373,12 +378,28 @@ class Neo4jAdapter(GraphDBInterface):
return [result["successor"] for result in results]
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
predecessors, successors = await asyncio.gather(
self.get_predecessors(node_id), self.get_successors(node_id)
)
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
"""Get all neighboring nodes."""
return await self.get_neighbours(node_id)
return predecessors + successors
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""Get a single node by ID."""
query = """
MATCH (node {id: $node_id})
RETURN node
"""
results = await self.query(query, {"node_id": node_id})
return results[0]["node"] if results else None
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""Get multiple nodes by their IDs."""
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node
"""
results = await self.query(query, {"node_ids": node_ids})
return [result["node"] for result in results]
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
@ -651,3 +672,46 @@ class Neo4jAdapter(GraphDBInterface):
}
return mandatory_metrics | optional_metrics
async def get_document_subgraph(self, content_hash: str):
query = """
MATCH (doc)
WHERE (doc:TextDocument OR doc:PdfDocument)
AND doc.name = 'text_' + $content_hash
OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk:DocumentChunk)
OPTIONAL MATCH (chunk)-[:contains]->(entity:Entity)
WHERE NOT EXISTS {
MATCH (entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
AND otherDoc.id <> doc.id
}
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node:TextSummary)
OPTIONAL MATCH (entity)-[:is_a]->(type:EntityType)
WHERE NOT EXISTS {
MATCH (type)<-[:is_a]-(otherEntity:Entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
AND otherDoc.id <> doc.id
}
RETURN
collect(DISTINCT doc) as document,
collect(DISTINCT chunk) as chunks,
collect(DISTINCT entity) as orphan_entities,
collect(DISTINCT made_node) as made_from_nodes,
collect(DISTINCT type) as orphan_types
"""
result = await self.query(query, {"content_hash": content_hash})
return result[0] if result else None
async def get_degree_one_nodes(self, node_type: str):
if not node_type or node_type not in ["Entity", "EntityType"]:
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
query = f"""
MATCH (n:{node_type})
WHERE COUNT {{ MATCH (n)--() }} = 1
RETURN n
"""
result = await self.query(query)
return [record["n"] for record in result] if result else []

View file

@ -10,7 +10,10 @@ from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
import networkx as nx
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
)
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.modules.storage.utils import JSONEncoder
@ -42,20 +45,14 @@ class NetworkXAdapter(GraphDBInterface):
async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id)
async def add_node(
self,
node: DataPoint,
) -> None:
async def add_node(self, node: DataPoint) -> None:
self.graph.add_node(node.id, **node.model_dump())
await self.save_graph_to_file(self.filename)
async def add_nodes(
self,
nodes: list[DataPoint],
) -> None:
@record_graph_changes
async def add_nodes(self, nodes: list[DataPoint]) -> None:
nodes = [(node.id, node.model_dump()) for node in nodes]
self.graph.add_nodes_from(nodes)
await self.save_graph_to_file(self.filename)
@ -74,6 +71,7 @@ class NetworkXAdapter(GraphDBInterface):
return result
@record_graph_changes
async def add_edge(
self,
from_node: str,
@ -91,38 +89,76 @@ class NetworkXAdapter(GraphDBInterface):
await self.save_graph_to_file(self.filename)
async def add_edges(
self,
edges: tuple[str, str, str, dict],
) -> None:
edges = [
(
edge[0],
edge[1],
edge[2],
{
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
},
)
for edge in edges
]
@record_graph_changes
async def add_edges(self, edges: list[tuple[str, str, str, dict]]) -> None:
if not edges:
logger.debug("No edges to add")
return
self.graph.add_edges_from(edges)
await self.save_graph_to_file(self.filename)
try:
# Validate edge format and convert UUIDs to strings
processed_edges = []
for edge in edges:
if len(edge) < 3 or len(edge) > 4:
raise ValueError(
f"Invalid edge format: {edge}. Expected (from_node, to_node, relationship_name[, properties])"
)
# Convert UUIDs to strings if needed
from_node = str(edge[0]) if isinstance(edge[0], UUID) else edge[0]
to_node = str(edge[1]) if isinstance(edge[1], UUID) else edge[1]
relationship_name = edge[2]
if not all(isinstance(x, str) for x in [from_node, to_node, relationship_name]):
raise ValueError(
f"First three elements of edge must be strings or UUIDs: {edge}"
)
# Process edge with updated_at timestamp
processed_edge = (
from_node,
to_node,
relationship_name,
{
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
},
)
processed_edges.append(processed_edge)
# Add edges to graph
self.graph.add_edges_from(processed_edges)
logger.debug(f"Added {len(processed_edges)} edges to graph")
# Save changes
await self.save_graph_to_file(self.filename)
except Exception as e:
logger.error(f"Failed to add edges: {e}")
raise
async def get_edges(self, node_id: str):
return list(self.graph.in_edges(node_id, data=True)) + list(
self.graph.out_edges(node_id, data=True)
)
async def delete_node(self, node_id: str) -> None:
"""Asynchronously delete a node from the graph if it exists."""
if self.graph.has_node(node_id):
self.graph.remove_node(node_id)
await self.save_graph_to_file(self.filename)
async def delete_node(self, node_id: UUID) -> None:
"""Asynchronously delete a node and all its relationships from the graph if it exists."""
async def delete_nodes(self, node_ids: List[str]) -> None:
if self.graph.has_node(node_id):
# First remove all edges connected to the node
for edge in list(self.graph.edges(node_id, data=True)):
source, target, data = edge
self.graph.remove_edge(source, target, key=data.get("relationship_name"))
# Then remove the node itself
self.graph.remove_node(node_id)
# Save the updated graph state
await self.save_graph_to_file(self.filename)
else:
logger.error(f"Node {node_id} not found in graph")
async def delete_nodes(self, node_ids: List[UUID]) -> None:
self.graph.remove_nodes_from(node_ids)
await self.save_graph_to_file(self.filename)
@ -179,7 +215,7 @@ class NetworkXAdapter(GraphDBInterface):
return nodes
async def get_neighbours(self, node_id: str) -> list:
async def get_neighbors(self, node_id: str) -> list:
if not self.graph.has_node(node_id):
return []
@ -188,9 +224,9 @@ class NetworkXAdapter(GraphDBInterface):
self.get_successors(node_id),
)
neighbours = predecessors + successors
neighbors = predecessors + successors
return neighbours
return neighbors
async def get_connections(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id):
@ -208,17 +244,22 @@ class NetworkXAdapter(GraphDBInterface):
connections = []
for neighbor in predecessors:
if "id" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
for edge_properties in edge_data.values():
connections.append((neighbor, edge_properties, node))
# Handle None values for predecessors and successors
if predecessors is not None:
for neighbor in predecessors:
if "id" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
if edge_data is not None:
for edge_properties in edge_data.values():
connections.append((neighbor, edge_properties, node))
for neighbor in successors:
if "id" in neighbor:
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
for edge_properties in edge_data.values():
connections.append((node, edge_properties, neighbor))
if successors is not None:
for neighbor in successors:
if "id" in neighbor:
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
if edge_data is not None:
for edge_properties in edge_data.values():
connections.append((node, edge_properties, neighbor))
return connections
@ -247,8 +288,9 @@ class NetworkXAdapter(GraphDBInterface):
async def create_empty_graph(self, file_path: str) -> None:
self.graph = nx.MultiDiGraph()
# Only create directory if file_path contains a directory
file_dir = os.path.dirname(file_path)
if not os.path.exists(file_dir):
if file_dir and not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok=True)
await self.save_graph_to_file(file_path)
@ -266,9 +308,6 @@ class NetworkXAdapter(GraphDBInterface):
async def load_graph_from_file(self, file_path: str = None):
"""Asynchronously load the graph from a file in JSON format."""
if file_path == self.filename:
return
if not file_path:
file_path = self.filename
try:
@ -460,3 +499,138 @@ class NetworkXAdapter(GraphDBInterface):
}
return mandatory_metrics | optional_metrics
async def get_document_subgraph(self, content_hash: str):
"""Get all nodes that should be deleted when removing a document."""
# Ensure graph is loaded
if self.graph is None:
await self.load_graph_from_file()
# Find the document node by looking for content_hash in the name field
document = None
document_node_id = None
for node_id, attrs in self.graph.nodes(data=True):
if (
attrs.get("type") in ["TextDocument", "PdfDocument"]
and attrs.get("name") == f"text_{content_hash}"
):
document = {"id": str(node_id), **attrs} # Convert UUID to string for consistency
document_node_id = node_id # Keep the original UUID
break
if not document:
return None
# Find chunks connected via is_part_of (chunks point TO document)
chunks = []
for source, target, edge_data in self.graph.in_edges(document_node_id, data=True):
if edge_data.get("relationship_name") == "is_part_of":
chunks.append({"id": source, **self.graph.nodes[source]}) # Keep as UUID object
# Find entities connected to chunks (chunks point TO entities via contains)
entities = []
for chunk in chunks:
chunk_id = chunk["id"] # Already a UUID object
for source, target, edge_data in self.graph.out_edges(chunk_id, data=True):
if edge_data.get("relationship_name") == "contains":
entities.append(
{"id": target, **self.graph.nodes[target]}
) # Keep as UUID object
# Find orphaned entities (entities only connected to chunks we're deleting)
orphan_entities = []
for entity in entities:
entity_id = entity["id"] # Already a UUID object
# Get all chunks that contain this entity
containing_chunks = []
for source, target, edge_data in self.graph.in_edges(entity_id, data=True):
if edge_data.get("relationship_name") == "contains":
containing_chunks.append(source) # Keep as UUID object
# Check if all containing chunks are in our chunks list
chunk_ids = [chunk["id"] for chunk in chunks]
if containing_chunks and all(c in chunk_ids for c in containing_chunks):
orphan_entities.append(entity)
# Find orphaned entity types
orphan_types = []
seen_types = set() # Track seen types to avoid duplicates
for entity in orphan_entities:
entity_id = entity["id"] # Already a UUID object
for _, target, edge_data in self.graph.out_edges(entity_id, data=True):
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
# Check if this type is only connected to entities we're deleting
type_node = self.graph.nodes[target]
if type_node.get("type") == "EntityType" and target not in seen_types:
is_orphaned = True
# Get all incoming edges to this type node
for source, _, edge_data in self.graph.in_edges(target, data=True):
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
# Check if the source entity is not in our orphan_entities list
if source not in [e["id"] for e in orphan_entities]:
is_orphaned = False
break
if is_orphaned:
orphan_types.append({"id": target, **type_node}) # Keep as UUID object
seen_types.add(target) # Mark as seen
# Find nodes connected via made_from (chunks point TO summaries)
made_from_nodes = []
for chunk in chunks:
chunk_id = chunk["id"] # Already a UUID object
for source, target, edge_data in self.graph.in_edges(chunk_id, data=True):
if edge_data.get("relationship_name") == "made_from":
made_from_nodes.append(
{"id": source, **self.graph.nodes[source]}
) # Keep as UUID object
# Return UUIDs directly without string conversion
return {
"document": [{"id": document["id"], **{k: v for k, v in document.items() if k != "id"}}]
if document
else [],
"chunks": [
{"id": chunk["id"], **{k: v for k, v in chunk.items() if k != "id"}}
for chunk in chunks
],
"orphan_entities": [
{"id": entity["id"], **{k: v for k, v in entity.items() if k != "id"}}
for entity in orphan_entities
],
"made_from_nodes": [
{"id": node["id"], **{k: v for k, v in node.items() if k != "id"}}
for node in made_from_nodes
],
"orphan_types": [
{"id": type_node["id"], **{k: v for k, v in type_node.items() if k != "id"}}
for type_node in orphan_types
],
}
async def get_degree_one_nodes(self, node_type: str):
"""Get all nodes that have only one connection."""
if not node_type or node_type not in ["Entity", "EntityType"]:
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
nodes = []
for node_id, node_data in self.graph.nodes(data=True):
if node_data.get("type") == node_type:
# Count both incoming and outgoing edges
degree = self.graph.degree(node_id)
if degree == 1:
nodes.append(node_data)
return nodes
async def get_node(self, node_id: str) -> dict:
if self.graph.has_node(node_id):
return self.graph.nodes[node_id]
return None
async def get_nodes(self, node_ids: List[str] = None) -> List[dict]:
if node_ids is None:
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
return [
{"id": node_id, **self.graph.nodes[node_id]}
for node_id in node_ids
if self.graph.has_node(node_id)
]

View file

@ -16,6 +16,8 @@ from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
from tenacity import retry, stop_after_attempt, wait_exponential
class IndexSchema(DataPoint):
id: str
@ -230,14 +232,30 @@ class LanceDBAdapter(VectorDBInterface):
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
if len(data_point_ids) == 1:
results = await collection.delete(f"id = '{data_point_ids[0]}'")
def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def _delete_data_points():
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'")
return True
# Check if we're in an event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# If we're in a running event loop, create a new task
return loop.create_task(_delete_data_points())
else:
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results
# If we're not in an event loop, run it synchronously
return asyncio.run(_delete_data_points())
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(

View file

@ -1,3 +1,5 @@
# PROPOSED TO BE DEPRECATED
from typing import Type, Optional, get_args, get_origin
from pydantic import BaseModel
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface

View file

@ -0,0 +1,41 @@
from datetime import datetime, timezone
from uuid import uuid5, NAMESPACE_DNS
from sqlalchemy import UUID, Column, DateTime, String, Index
from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base
class GraphRelationshipLedger(Base):
__tablename__ = "graph_relationship_ledger"
id = Column(
UUID,
primary_key=True,
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
)
source_node_id = Column(UUID, nullable=False)
destination_node_id = Column(UUID, nullable=False)
creator_function = Column(String, nullable=False)
node_label = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
deleted_at = Column(DateTime(timezone=True), nullable=True)
user_id = Column(UUID, nullable=True)
# Create indexes
__table_args__ = (
Index("idx_graph_relationship_id", "id"),
Index("idx_graph_relationship_ledger_source_node_id", "source_node_id"),
Index("idx_graph_relationship_ledger_destination_node_id", "destination_node_id"),
)
def to_json(self) -> dict:
return {
"id": str(self.id),
"source_node_id": str(self.parent_id),
"destination_node_id": str(self.child_id),
"creator_function": self.creator_function,
"created_at": self.created_at.isoformat(),
"deleted_at": self.deleted_at.isoformat() if self.deleted_at else None,
"user_id": str(self.user_id),
}

View file

@ -0,0 +1,54 @@
from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.modules.data.models import graph_relationship_ledger
from cognee.modules.users.models import User
async def create_relationship(
session: AsyncSession,
source_node_id: UUID,
destination_node_id: UUID,
creator_function: str,
user: User,
) -> None:
"""Create a relationship between two nodes in the graph.
Args:
session: Database session
source_node_id: ID of the source node
destination_node_id: ID of the destination node
creator_function: Name of the function creating the relationship
user: User creating the relationship
"""
relationship = graph_relationship_ledger(
source_node_id=source_node_id,
destination_node_id=destination_node_id,
creator_function=creator_function,
user_id=user.id,
)
session.add(relationship)
await session.flush()
async def delete_relationship(
session: AsyncSession,
source_node_id: UUID,
destination_node_id: UUID,
user: User,
) -> None:
"""Mark a relationship as deleted.
Args:
session: Database session
source_node_id: ID of the source node
destination_node_id: ID of the destination node
user: User deleting the relationship
"""
relationship = await session.get(
graph_relationship_ledger, (source_node_id, destination_node_id)
)
if relationship:
relationship.deleted_at = datetime.now(timezone.utc)
session.add(relationship)
await session.flush()

View file

@ -51,7 +51,6 @@ class TripletSearchContextProvider(BaseContextProvider):
tasks = [
brute_force_triplet_search(
query=f"{entity_text} {query}",
user=user,
top_k=self.top_k,
collections=self.collections,
properties_to_project=self.properties_to_project,

View file

@ -65,7 +65,7 @@ async def code_description_to_code_part(
try:
if include_docs:
search_results = await search(query_text=query, query_type="INSIGHTS", user=user)
search_results = await search(query_text=query, query_type="INSIGHTS")
concatenated_descriptions = " ".join(
obj["description"]

View file

@ -1,3 +1,5 @@
# PROPOSED TO BE DEPRECATED
import asyncio
from uuid import uuid5, NAMESPACE_OID
from typing import Type

View file

@ -51,12 +51,17 @@ async def integrate_chunk_graphs(
async def extract_graph_from_data(
data_chunks: list[DocumentChunk],
data_chunks: List[DocumentChunk],
graph_model: Type[BaseModel],
ontology_adapter: OntologyResolver = OntologyResolver(),
ontology_adapter: OntologyResolver = None,
) -> List[DocumentChunk]:
"""Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model."""
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
"""
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_adapter)
return await integrate_chunk_graphs(
data_chunks, chunk_graphs, graph_model, ontology_adapter or OntologyResolver()
)

View file

@ -3,6 +3,7 @@ from typing import List
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
from cognee.tasks.graph.cascade_extract.utils.extract_nodes import extract_nodes
from cognee.tasks.graph.cascade_extract.utils.extract_content_nodes_and_relationship_names import (
extract_content_nodes_and_relationship_names,
@ -14,7 +15,9 @@ from cognee.tasks.graph.extract_graph_from_data import integrate_chunk_graphs
async def extract_graph_from_data(
data_chunks: List[DocumentChunk], n_rounds: int = 2
data_chunks: List[DocumentChunk],
n_rounds: int = 2,
ontology_adapter: OntologyResolver = None,
) -> List[DocumentChunk]:
"""Extract and update graph data from document chunks in multiple steps."""
chunk_nodes = await asyncio.gather(
@ -37,4 +40,9 @@ async def extract_graph_from_data(
]
)
return await integrate_chunk_graphs(data_chunks, chunk_graphs, KnowledgeGraph)
return await integrate_chunk_graphs(
data_chunks=data_chunks,
chunk_graphs=chunk_graphs,
graph_model=KnowledgeGraph,
ontology_adapter=ontology_adapter or OntologyResolver(),
)

View file

@ -1,3 +1,5 @@
# PROPOSED TO BE DEPRECATED
"""This module contains the OntologyEngine class which is responsible for adding graph ontology from a JSON or CSV file."""
import csv

View file

@ -1,4 +1,5 @@
import asyncio
from typing import List
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
@ -6,7 +7,7 @@ from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
async def add_data_points(data_points: list[DataPoint]):
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
nodes = []
edges = []

View file

@ -0,0 +1,71 @@
import os
import shutil
import cognee
import pathlib
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
text_1 = """
1. Audi
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
2. BMW
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
3. Mercedes-Benz
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
4. Porsche
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
5. Volkswagen
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
"""
text_2 = """
1. Apple
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
2. Google
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
3. Microsoft
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
4. Amazon
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
5. Meta
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
"""
await cognee.add([text_1, text_2])
await cognee.cognify()
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) > 15 and len(edges) > 15, "Graph database is not loaded."
await cognee.delete([text_1, text_2], mode="hard")
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Document is not deleted."
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -92,18 +92,12 @@ async def relational_db_migration():
n_data = row[1]
m_data = row[2]
source_props = {}
if "properties" in n_data and n_data["properties"]:
source_props = json.loads(n_data["properties"])
target_props = {}
if "properties" in m_data and m_data["properties"]:
target_props = json.loads(m_data["properties"])
source_name = normalize_node_name(n_data.get("name", ""))
target_name = normalize_node_name(m_data.get("name", ""))
source_name = normalize_node_name(source_props.get("name", f"id:{n_data['id']}"))
target_name = normalize_node_name(target_props.get("name", f"id:{m_data['id']}"))
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
if source_name and target_name:
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()

View file

@ -7,12 +7,13 @@ import os
from cognee.api.v1.search import SearchType
from cognee.api.v1.visualize.visualize import visualize_graph
text_1 = """
1. Audi
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
2. BMW
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The companys vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
3. Mercedes-Benz
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
@ -21,7 +22,7 @@ Mercedes-Benz is synonymous with luxury and quality. With a history dating back
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
5. Volkswagen
Volkswagen, which means peoples car in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
"""
@ -31,16 +32,16 @@ text_2 = """
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
2. Google
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Googles innovations have played a major role in shaping the internet landscape.
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
3. Microsoft
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
4. Amazon
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazons constant drive for innovation continues to reshape both retail and technology sectors.
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
5. Meta
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The companys efforts signal a commitment to evolving digital interaction and building the metaversea shared virtual space where users can connect and collaborate.
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
"""

871
notebooks/cognee_demo.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

978
notebooks/hr_demo.ipynb Normal file
View file

@ -0,0 +1,978 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d35ac8ce-0f92-46f5-9ba4-a46970f0ce19",
"metadata": {},
"source": [
"# Cognee - Get Started"
]
},
{
"cell_type": "markdown",
"id": "074f0ea8-c659-4736-be26-be4b0e5ac665",
"metadata": {},
"source": [
"# Demo time"
]
},
{
"cell_type": "markdown",
"id": "0587d91d",
"metadata": {},
"source": [
"#### First let's define some data that we will cognify and perform a search on"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df16431d0f48b006",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:48.519686Z",
"start_time": "2024-09-20T14:02:48.515589Z"
}
},
"outputs": [],
"source": [
"job_position = \"\"\"Senior Data Scientist (Machine Learning)\n",
"\n",
"Company: TechNova Solutions\n",
"Location: San Francisco, CA\n",
"\n",
"Job Description:\n",
"\n",
"TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.\n",
"\n",
"Responsibilities:\n",
"\n",
"Develop and implement advanced machine learning algorithms and models.\n",
"Analyze large, complex datasets to extract meaningful patterns and insights.\n",
"Collaborate with cross-functional teams to integrate predictive models into products.\n",
"Stay updated with the latest advancements in machine learning and data science.\n",
"Mentor junior data scientists and provide technical guidance.\n",
"Qualifications:\n",
"\n",
"Masters or Ph.D. in Data Science, Computer Science, Statistics, or a related field.\n",
"5+ years of experience in data science and machine learning.\n",
"Proficient in Python, R, and SQL.\n",
"Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).\n",
"Strong problem-solving skills and attention to detail.\n",
"Candidate CVs\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9086abf3af077ab4",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:49.120838Z",
"start_time": "2024-09-20T14:02:49.118294Z"
}
},
"outputs": [],
"source": [
"job_1 = \"\"\"\n",
"CV 1: Relevant\n",
"Name: Dr. Emily Carter\n",
"Contact Information:\n",
"\n",
"Email: emily.carter@example.com\n",
"Phone: (555) 123-4567\n",
"Summary:\n",
"\n",
"Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.\n",
"\n",
"Education:\n",
"\n",
"Ph.D. in Computer Science, Stanford University (2014)\n",
"B.S. in Mathematics, University of California, Berkeley (2010)\n",
"Experience:\n",
"\n",
"Senior Data Scientist, InnovateAI Labs (2016 Present)\n",
"Led a team in developing machine learning models for natural language processing applications.\n",
"Implemented deep learning algorithms that improved prediction accuracy by 25%.\n",
"Collaborated with cross-functional teams to integrate models into cloud-based platforms.\n",
"Data Scientist, DataWave Analytics (2014 2016)\n",
"Developed predictive models for customer segmentation and churn analysis.\n",
"Analyzed large datasets using Hadoop and Spark frameworks.\n",
"Skills:\n",
"\n",
"Programming Languages: Python, R, SQL\n",
"Machine Learning: TensorFlow, Keras, Scikit-Learn\n",
"Big Data Technologies: Hadoop, Spark\n",
"Data Visualization: Tableau, Matplotlib\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9de0cc07f798b7f",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:49.675003Z",
"start_time": "2024-09-20T14:02:49.671615Z"
}
},
"outputs": [],
"source": [
"job_2 = \"\"\"\n",
"CV 2: Relevant\n",
"Name: Michael Rodriguez\n",
"Contact Information:\n",
"\n",
"Email: michael.rodriguez@example.com\n",
"Phone: (555) 234-5678\n",
"Summary:\n",
"\n",
"Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.\n",
"\n",
"Education:\n",
"\n",
"M.S. in Data Science, Carnegie Mellon University (2013)\n",
"B.S. in Computer Science, University of Michigan (2011)\n",
"Experience:\n",
"\n",
"Senior Data Scientist, Alpha Analytics (2017 Present)\n",
"Developed machine learning models to optimize marketing strategies.\n",
"Reduced customer acquisition cost by 15% through predictive modeling.\n",
"Data Scientist, TechInsights (2013 2017)\n",
"Analyzed user behavior data to improve product features.\n",
"Implemented A/B testing frameworks to evaluate product changes.\n",
"Skills:\n",
"\n",
"Programming Languages: Python, Java, SQL\n",
"Machine Learning: Scikit-Learn, XGBoost\n",
"Data Visualization: Seaborn, Plotly\n",
"Databases: MySQL, MongoDB\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "185ff1c102d06111",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:50.286828Z",
"start_time": "2024-09-20T14:02:50.284369Z"
}
},
"outputs": [],
"source": [
"job_3 = \"\"\"\n",
"CV 3: Relevant\n",
"Name: Sarah Nguyen\n",
"Contact Information:\n",
"\n",
"Email: sarah.nguyen@example.com\n",
"Phone: (555) 345-6789\n",
"Summary:\n",
"\n",
"Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.\n",
"\n",
"Education:\n",
"\n",
"M.S. in Statistics, University of Washington (2014)\n",
"B.S. in Applied Mathematics, University of Texas at Austin (2012)\n",
"Experience:\n",
"\n",
"Data Scientist, QuantumTech (2016 Present)\n",
"Designed and implemented machine learning algorithms for financial forecasting.\n",
"Improved model efficiency by 20% through algorithm optimization.\n",
"Junior Data Scientist, DataCore Solutions (2014 2016)\n",
"Assisted in developing predictive models for supply chain optimization.\n",
"Conducted data cleaning and preprocessing on large datasets.\n",
"Skills:\n",
"\n",
"Programming Languages: Python, R\n",
"Machine Learning Frameworks: PyTorch, Scikit-Learn\n",
"Statistical Analysis: SAS, SPSS\n",
"Cloud Platforms: AWS, Azure\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d55ce4c58f8efb67",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:50.950343Z",
"start_time": "2024-09-20T14:02:50.946378Z"
}
},
"outputs": [],
"source": [
"job_4 = \"\"\"\n",
"CV 4: Not Relevant\n",
"Name: David Thompson\n",
"Contact Information:\n",
"\n",
"Email: david.thompson@example.com\n",
"Phone: (555) 456-7890\n",
"Summary:\n",
"\n",
"Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.\n",
"\n",
"Education:\n",
"\n",
"B.F.A. in Graphic Design, Rhode Island School of Design (2012)\n",
"Experience:\n",
"\n",
"Senior Graphic Designer, CreativeWorks Agency (2015 Present)\n",
"Led design projects for clients in various industries.\n",
"Created branding materials that increased client engagement by 30%.\n",
"Graphic Designer, Visual Innovations (2012 2015)\n",
"Designed marketing collateral, including brochures, logos, and websites.\n",
"Collaborated with the marketing team to develop cohesive brand strategies.\n",
"Skills:\n",
"\n",
"Design Software: Adobe Photoshop, Illustrator, InDesign\n",
"Web Design: HTML, CSS\n",
"Specialties: Branding and Identity, Typography\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca4ecc32721ad332",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:51.548191Z",
"start_time": "2024-09-20T14:02:51.545520Z"
}
},
"outputs": [],
"source": [
"job_5 = \"\"\"\n",
"CV 5: Not Relevant\n",
"Name: Jessica Miller\n",
"Contact Information:\n",
"\n",
"Email: jessica.miller@example.com\n",
"Phone: (555) 567-8901\n",
"Summary:\n",
"\n",
"Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.\n",
"\n",
"Education:\n",
"\n",
"B.A. in Business Administration, University of Southern California (2010)\n",
"Experience:\n",
"\n",
"Sales Manager, Global Enterprises (2015 Present)\n",
"Managed a sales team of 15 members, achieving a 20% increase in annual revenue.\n",
"Developed sales strategies that expanded customer base by 25%.\n",
"Sales Representative, Market Leaders Inc. (2010 2015)\n",
"Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.\n",
"Skills:\n",
"\n",
"Sales Strategy and Planning\n",
"Team Leadership and Development\n",
"CRM Software: Salesforce, Zoho\n",
"Negotiation and Relationship Building\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "4415446a",
"metadata": {},
"source": [
"#### Please add the necessary environment information bellow:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bce39dc6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Setting environment variables\n",
"if \"GRAPHISTRY_USERNAME\" not in os.environ:\n",
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
"\n",
"if \"GRAPHISTRY_PASSWORD\" not in os.environ:\n",
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
"\n",
"if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n",
"# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"] = \"networkx\"\n",
"# Not needed if using networkx\n",
"# os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"# os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"# os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n",
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"lancedb\"\n",
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n",
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"] = \"sqlite\"\n",
"\n",
"# Database name\n",
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
"\n",
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f1a1dbd",
"metadata": {},
"outputs": [],
"source": [
"# Reset the cognee system with the following command:\n",
"\n",
"import cognee\n",
"\n",
"await cognee.prune.prune_data()\n",
"await cognee.prune.prune_system(metadata=True)"
]
},
{
"cell_type": "markdown",
"id": "383d6971",
"metadata": {},
"source": [
"#### After we have defined and gathered our data let's add it to cognee "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "904df61ba484a8e5",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:54.243987Z",
"start_time": "2024-09-20T14:02:52.498195Z"
}
},
"outputs": [],
"source": [
"import cognee\n",
"\n",
"await cognee.add([job_1, job_2, job_3, job_4, job_5, job_position], \"example\")"
]
},
{
"cell_type": "markdown",
"id": "0f15c5b1",
"metadata": {},
"source": [
"#### All good, let's cognify it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c431fdef4921ae0",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:57.925667Z",
"start_time": "2024-09-20T14:02:57.922353Z"
}
},
"outputs": [],
"source": [
"from cognee.shared.data_models import KnowledgeGraph\n",
"from cognee.modules.data.models import Dataset, Data\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.cognify.config import get_cognify_config\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_documents,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
"from cognee.tasks.graph import extract_graph_from_data\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.tasks.summarization import summarize_text\n",
"\n",
"\n",
"async def run_cognify_pipeline(dataset: Dataset, user: User = None):\n",
" data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)\n",
"\n",
" try:\n",
" cognee_config = get_cognify_config()\n",
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
" Task(\n",
" extract_graph_from_data, graph_model=KnowledgeGraph,\n",
" task_config={\"batch_size\": 10}\n",
" ), # Generate knowledge graphs from the document chunks.\n",
" Task(\n",
" summarize_text,\n",
" summarization_model=cognee_config.summarization_model,\n",
" task_config={\"batch_size\": 10},\n",
" ),\n",
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline = run_tasks(tasks, data_documents)\n",
"\n",
" async for result in pipeline:\n",
" print(result)\n",
" except Exception as error:\n",
" raise error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0a91b99c6215e09",
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-20T14:02:58.905774Z",
"start_time": "2024-09-20T14:02:58.625915Z"
}
},
"outputs": [],
"source": [
"from cognee.modules.users.methods import get_default_user\n",
"from cognee.modules.data.methods import get_datasets_by_name\n",
"\n",
"user = await get_default_user()\n",
"\n",
"datasets = await get_datasets_by_name([\"example\"], user.id)\n",
"\n",
"await run_cognify_pipeline(datasets[0], user)"
]
},
{
"cell_type": "markdown",
"id": "219a6d41",
"metadata": {},
"source": [
"#### We get the url to the graph on graphistry in the notebook cell bellow, showing nodes and connections made by the cognify process."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "080389e5",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from cognee.shared.utils import render_graph\n",
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
"import graphistry\n",
"\n",
"graphistry.login(\n",
" username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\")\n",
")\n",
"\n",
"graph_engine = await get_graph_engine()\n",
"\n",
"graph_url = await render_graph(graph_engine.graph)\n",
"print(graph_url)"
]
},
{
"cell_type": "markdown",
"id": "59e6c3c3",
"metadata": {},
"source": [
"#### We can also do a search on the data to explore the knowledge."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e7dfc8",
"metadata": {},
"outputs": [],
"source": [
"async def search(\n",
" vector_engine,\n",
" collection_name: str,\n",
" query_text: str = None,\n",
"):\n",
" query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]\n",
"\n",
" connection = await vector_engine.get_connection()\n",
" collection = await connection.open_table(collection_name)\n",
"\n",
" results = await collection.vector_search(query_vector).limit(10).to_pandas()\n",
"\n",
" result_values = list(results.to_dict(\"index\").values())\n",
"\n",
" return [\n",
" dict(\n",
" id=str(result[\"id\"]),\n",
" payload=result[\"payload\"],\n",
" score=result[\"_distance\"],\n",
" )\n",
" for result in result_values\n",
" ]\n",
"\n",
"\n",
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n",
"vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n",
" print(result)"
]
},
{
"cell_type": "markdown",
"id": "81fa2b00",
"metadata": {},
"source": [
"#### We normalize search output scores so the lower the score of the search result is the higher the chance that it's what you're looking for. In the example above we have searched for node entities in the knowledge graph related to \"sarah.nguyen@example.com\""
]
},
{
"cell_type": "markdown",
"id": "1b94ff96",
"metadata": {},
"source": [
"#### In the example bellow we'll use cognee search to summarize information regarding the node most related to \"sarah.nguyen@example.com\" in the knowledge graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21a3e9a6",
"metadata": {},
"outputs": [],
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"text\"]\n",
"\n",
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text=node_name)\n",
"print(\"\\n\\Extracted summaries are:\\n\")\n",
"for result in search_results:\n",
" print(f\"{result}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "fd6e5fe2",
"metadata": {},
"source": [
"#### In this example we'll use cognee search to find chunks in which the node most related to \"sarah.nguyen@example.com\" is a part of"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7a8abff",
"metadata": {},
"outputs": [],
"source": [
"search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=node_name)\n",
"print(\"\\n\\nExtracted chunks are:\\n\")\n",
"for result in search_results:\n",
" print(f\"{result}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "47f0112f",
"metadata": {},
"source": [
"#### In this example we'll use cognee search to give us insights from the knowledge graph related to the node most related to \"sarah.nguyen@example.com\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "706a3954",
"metadata": {},
"outputs": [],
"source": [
"search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=node_name)\n",
"print(\"\\n\\nExtracted sentences are:\\n\")\n",
"for result in search_results:\n",
" print(f\"{result}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "e519e30c0423c2a",
"metadata": {},
"source": [
"## Let's add evals"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3845443e",
"metadata": {},
"outputs": [],
"source": [
"!pip install \"cognee[deepeval]\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a2c3c70",
"metadata": {},
"outputs": [],
"source": [
"from evals.eval_on_hotpot import deepeval_answers, answer_qa_instance\n",
"from evals.qa_dataset_utils import load_qa_dataset\n",
"from evals.qa_metrics_utils import get_metrics\n",
"from evals.qa_context_provider_utils import qa_context_providers\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"import statistics\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53a609d8",
"metadata": {},
"outputs": [],
"source": [
"num_samples = 10 # With cognee, it takes ~1m10s per sample\n",
"dataset_name_or_filename = \"hotpotqa\"\n",
"dataset = load_qa_dataset(dataset_name_or_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7351ab8f",
"metadata": {},
"outputs": [],
"source": [
"context_provider_name = \"cognee\"\n",
"context_provider = qa_context_providers[context_provider_name]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9346115b",
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
" answer = await answer_qa_instance(instance, context_provider, contexts_filename)\n",
" answers.append(answer)"
]
},
{
"cell_type": "markdown",
"id": "1e7d872d",
"metadata": {},
"source": [
"#### Define Metrics for Evaluation and Calculate Score\n",
"**Options**: \n",
"- **Correctness**: Is the actual output factually correct based on the expected output?\n",
"- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?\n",
"- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?\n",
"- **Empowerment**: How well does the answer help the reader understand and make informed judgements about the topic?\n",
"- **Directness**: How specifically and clearly does the answer address the question?\n",
"- **F1 Score**: the harmonic mean of the precision and recall, using word-level Exact Match\n",
"- **EM Score**: the rate at which the predicted strings exactly match their references, ignoring white spaces and capitalization."
]
},
{
"cell_type": "markdown",
"id": "c81e2b46",
"metadata": {},
"source": [
"##### Calculate `\"Correctness\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae728344",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "764aac6d",
"metadata": {},
"outputs": [],
"source": [
"Correctness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Correctness)"
]
},
{
"cell_type": "markdown",
"id": "6d3bbdc5",
"metadata": {},
"source": [
"##### Calculating `\"Comprehensiveness\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9793ef78",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9add448a",
"metadata": {},
"outputs": [],
"source": [
"Comprehensiveness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Comprehensiveness)"
]
},
{
"cell_type": "markdown",
"id": "bce2fa25",
"metadata": {},
"source": [
"##### Calculating `\"Diversity\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f60a179e",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ccbd0ab",
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
{
"cell_type": "markdown",
"id": "191cab63",
"metadata": {},
"source": [
"##### Calculating`\"Empowerment\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66bec0bf",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b043a8f",
"metadata": {},
"outputs": [],
"source": [
"Empowerment = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Empowerment)"
]
},
{
"cell_type": "markdown",
"id": "2cac3be9",
"metadata": {},
"source": [
"##### Calculating `\"Directness\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adaa17c0",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a8f97c9",
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
{
"cell_type": "markdown",
"id": "1ad6feb8",
"metadata": {},
"source": [
"##### Calculating `\"F1\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdc48259",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"F1\"]\n",
"eval_metrics = get_metrics(metric_name_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c43c17c8",
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bfcc46d",
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
{
"cell_type": "markdown",
"id": "2583f948",
"metadata": {},
"source": [
"##### Calculating `\"EM\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90a8f630",
"metadata": {},
"outputs": [],
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d1b1ea1",
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
},
{
"cell_type": "markdown",
"id": "288ab570",
"metadata": {},
"source": [
"# Give us a star if you like it!\n",
"https://github.com/topoteretes/cognee"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cognee-c83GrcRT-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}