feat: add delete by document (#668)
<!-- .github/pull_request_template.md --> ## Description Delete by document. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
This commit is contained in:
parent
af276b8999
commit
9ba12b25ef
30 changed files with 3435 additions and 258 deletions
27
.github/workflows/e2e_tests.yml
vendored
27
.github/workflows/e2e_tests.yml
vendored
|
|
@ -162,6 +162,33 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_deduplication.py
|
||||
|
||||
run-deletion-test:
|
||||
name: Deletion Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Deletion Tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_deletion.py
|
||||
|
||||
run-s3-bucket-test:
|
||||
name: S3 Bucket Test
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .api.v1.add import add
|
||||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
|
|
|
|||
|
|
@ -43,11 +43,7 @@ def get_add_router() -> APIRouter:
|
|||
|
||||
return await cognee_add(file_data)
|
||||
else:
|
||||
await cognee_add(
|
||||
data,
|
||||
datasetId,
|
||||
user=user,
|
||||
)
|
||||
await cognee_add(data, datasetId, user=user)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,9 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
|||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
|
|
|
|||
1
cognee/api/v1/delete/__init__.py
Normal file
1
cognee/api/v1/delete/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .delete import delete
|
||||
247
cognee/api/v1/delete/delete.py
Normal file
247
cognee/api/v1/delete/delete.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
from typing import Union, BinaryIO, List
|
||||
from cognee.modules.ingestion import classify
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import delete as sql_delete
|
||||
from cognee.modules.data.models import Data, DatasetData, Dataset
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from io import StringIO, BytesIO
|
||||
import hashlib
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_text_content_hash(text: str) -> str:
|
||||
encoded_text = text.encode("utf-8")
|
||||
return hashlib.md5(encoded_text).hexdigest()
|
||||
|
||||
|
||||
async def delete(
|
||||
data: Union[BinaryIO, List[BinaryIO], str, List[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
mode: str = "soft",
|
||||
):
|
||||
"""Delete a document and all its related nodes from both relational and graph databases.
|
||||
|
||||
Args:
|
||||
data: The data to delete (file, URL, or text)
|
||||
dataset_name: Name of the dataset to delete from
|
||||
mode: "soft" (default) or "hard" - hard mode also deletes degree-one entity nodes
|
||||
"""
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(data, str):
|
||||
if data.startswith("file://"): # It's a file path
|
||||
with open(data.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = classify(file)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
elif data.startswith("http"): # It's a URL
|
||||
import requests
|
||||
|
||||
response = requests.get(data)
|
||||
response.raise_for_status()
|
||||
file_data = BytesIO(response.content)
|
||||
classified_data = classify(file_data)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
else: # It's a text string
|
||||
content_hash = get_text_content_hash(data)
|
||||
classified_data = classify(data)
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
elif isinstance(data, list):
|
||||
# Handle list of inputs sequentially
|
||||
results = []
|
||||
for item in data:
|
||||
result = await delete(item, dataset_name, mode)
|
||||
results.append(result)
|
||||
return {"status": "success", "message": "Multiple documents deleted", "results": results}
|
||||
else: # It's already a BinaryIO
|
||||
data.seek(0) # Ensure we're at the start of the file
|
||||
classified_data = classify(data)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
|
||||
|
||||
async def delete_single_document(content_hash: str, dataset_name: str, mode: str = "soft"):
|
||||
"""Delete a single document by its content hash."""
|
||||
|
||||
# Delete from graph database
|
||||
deletion_result = await delete_document_subgraph(content_hash, mode)
|
||||
|
||||
logger.info(f"Deletion result: {deletion_result}")
|
||||
|
||||
# Get the deleted node IDs and convert to UUID
|
||||
deleted_node_ids = []
|
||||
for node_id in deletion_result["deleted_node_ids"]:
|
||||
try:
|
||||
# Handle both string and UUID formats
|
||||
if isinstance(node_id, str):
|
||||
# Remove any hyphens if present
|
||||
node_id = node_id.replace("-", "")
|
||||
deleted_node_ids.append(UUID(node_id))
|
||||
else:
|
||||
deleted_node_ids.append(node_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting node ID {node_id} to UUID: {e}")
|
||||
continue
|
||||
|
||||
# Delete from vector database
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Determine vector collections dynamically
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
# If no collections found, use default collections
|
||||
if not vector_collections:
|
||||
vector_collections = [
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
"EntityType_name",
|
||||
"Entity_name",
|
||||
"TextDocument_name",
|
||||
"TextSummary_text",
|
||||
]
|
||||
|
||||
# Delete records from each vector collection that exists
|
||||
for collection in vector_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
await vector_engine.delete_data_points(
|
||||
collection, [str(node_id) for node_id in deleted_node_ids]
|
||||
)
|
||||
|
||||
# Delete from relational database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Update graph_relationship_ledger with deleted_at timestamps
|
||||
from sqlalchemy import update, and_, or_
|
||||
from datetime import datetime
|
||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||
|
||||
update_stmt = (
|
||||
update(GraphRelationshipLedger)
|
||||
.where(
|
||||
or_(
|
||||
GraphRelationshipLedger.source_node_id.in_(deleted_node_ids),
|
||||
GraphRelationshipLedger.destination_node_id.in_(deleted_node_ids),
|
||||
)
|
||||
)
|
||||
.values(deleted_at=datetime.now())
|
||||
)
|
||||
await session.execute(update_stmt)
|
||||
|
||||
# Get the data point
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.content_hash == content_hash))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if data_point is None:
|
||||
raise DocumentNotFoundError(
|
||||
f"Document not found in relational DB with content hash: {content_hash}"
|
||||
)
|
||||
|
||||
doc_id = data_point.id
|
||||
|
||||
# Get the dataset
|
||||
dataset = (
|
||||
await session.execute(select(Dataset).filter(Dataset.name == dataset_name))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(f"Dataset not found: {dataset_name}")
|
||||
|
||||
# Delete from dataset_data table
|
||||
dataset_delete_stmt = sql_delete(DatasetData).where(
|
||||
DatasetData.data_id == doc_id, DatasetData.dataset_id == dataset.id
|
||||
)
|
||||
await session.execute(dataset_delete_stmt)
|
||||
|
||||
# Check if the document is in any other datasets
|
||||
remaining_datasets = (
|
||||
await session.execute(select(DatasetData).filter(DatasetData.data_id == doc_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
# If the document is not in any other datasets, delete it from the data table
|
||||
if remaining_datasets is None:
|
||||
data_delete_stmt = sql_delete(Data).where(Data.id == doc_id)
|
||||
await session.execute(data_delete_stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Document deleted from both graph and relational databases",
|
||||
"graph_deletions": deletion_result["deleted_counts"],
|
||||
"content_hash": content_hash,
|
||||
"dataset": dataset_name,
|
||||
"deleted_node_ids": [
|
||||
str(node_id) for node_id in deleted_node_ids
|
||||
], # Convert back to strings for response
|
||||
}
|
||||
|
||||
|
||||
async def delete_document_subgraph(content_hash: str, mode: str = "soft"):
|
||||
"""Delete a document and all its related nodes in the correct order."""
|
||||
graph_db = await get_graph_engine()
|
||||
subgraph = await graph_db.get_document_subgraph(content_hash)
|
||||
if not subgraph:
|
||||
raise DocumentSubgraphNotFoundError(f"Document not found with content hash: {content_hash}")
|
||||
|
||||
# Delete in the correct order to maintain graph integrity
|
||||
deletion_order = [
|
||||
("orphan_entities", "orphaned entities"),
|
||||
("orphan_types", "orphaned entity types"),
|
||||
(
|
||||
"made_from_nodes",
|
||||
"made_from nodes",
|
||||
), # Move before chunks since summaries are connected to chunks
|
||||
("chunks", "document chunks"),
|
||||
("document", "document"),
|
||||
]
|
||||
|
||||
deleted_counts = {}
|
||||
deleted_node_ids = []
|
||||
for key, description in deletion_order:
|
||||
nodes = subgraph[key]
|
||||
if nodes:
|
||||
for node in nodes:
|
||||
node_id = node["id"]
|
||||
await graph_db.delete_node(node_id)
|
||||
deleted_node_ids.append(node_id)
|
||||
deleted_counts[description] = len(nodes)
|
||||
|
||||
# If hard mode, also delete degree-one nodes
|
||||
if mode == "hard":
|
||||
# Get and delete degree one entity nodes
|
||||
degree_one_entity_nodes = await graph_db.get_degree_one_nodes("Entity")
|
||||
for node in degree_one_entity_nodes:
|
||||
await graph_db.delete_node(node["id"])
|
||||
deleted_node_ids.append(node["id"])
|
||||
deleted_counts["degree_one_entities"] = deleted_counts.get("degree_one_entities", 0) + 1
|
||||
|
||||
# Get and delete degree one entity types
|
||||
degree_one_entity_types = await graph_db.get_degree_one_nodes("EntityType")
|
||||
for node in degree_one_entity_types:
|
||||
await graph_db.delete_node(node["id"])
|
||||
deleted_node_ids.append(node["id"])
|
||||
deleted_counts["degree_one_types"] = deleted_counts.get("degree_one_types", 0) + 1
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_counts": deleted_counts,
|
||||
"content_hash": content_hash,
|
||||
"deleted_node_ids": deleted_node_ids,
|
||||
}
|
||||
38
cognee/api/v1/delete/exceptions.py
Normal file
38
cognee/api/v1/delete/exceptions.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class DocumentNotFoundError(CogneeApiError):
|
||||
"""Raised when a document cannot be found in the database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Document not found in database.",
|
||||
name: str = "DocumentNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class DatasetNotFoundError(CogneeApiError):
|
||||
"""Raised when a dataset cannot be found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Dataset not found.",
|
||||
name: str = "DatasetNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class DocumentSubgraphNotFoundError(CogneeApiError):
|
||||
"""Raised when a document's subgraph cannot be found in the graph database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Document subgraph not found in graph database.",
|
||||
name: str = "DocumentSubgraphNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
1
cognee/api/v1/delete/routers/__init__.py
Normal file
1
cognee/api/v1/delete/routers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .get_delete_router import get_delete_router
|
||||
77
cognee/api/v1/delete/routers/get_delete_router.py
Normal file
77
cognee/api/v1/delete/routers/get_delete_router.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from fastapi import Form, UploadFile, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter
|
||||
from typing import List, Optional
|
||||
import subprocess
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import requests
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_delete_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.delete("/", response_model=None)
|
||||
async def delete(
|
||||
data: List[UploadFile],
|
||||
dataset_name: str = Form("main_dataset"),
|
||||
mode: str = Form("soft"),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""This endpoint is responsible for deleting data from the graph.
|
||||
|
||||
Args:
|
||||
data: The data to delete (files, URLs, or text)
|
||||
dataset_name: Name of the dataset to delete from (default: "main_dataset")
|
||||
mode: "soft" (default) or "hard" - hard mode also deletes degree-one entity nodes
|
||||
user: Authenticated user
|
||||
"""
|
||||
from cognee.api.v1.delete import delete as cognee_delete
|
||||
|
||||
try:
|
||||
# Handle each file in the list
|
||||
results = []
|
||||
for file in data:
|
||||
if file.filename.startswith("http"):
|
||||
if "github" in file.filename:
|
||||
# For GitHub repos, we need to get the content hash of each file
|
||||
repo_name = file.filename.split("/")[-1].replace(".git", "")
|
||||
subprocess.run(
|
||||
["git", "clone", file.filename, f".data/{repo_name}"], check=True
|
||||
)
|
||||
# Note: This would need to be implemented to get content hashes of all files
|
||||
# For now, we'll just return an error
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Deleting GitHub repositories is not yet supported"},
|
||||
)
|
||||
else:
|
||||
# Fetch and delete the data from other types of URL
|
||||
response = requests.get(file.filename)
|
||||
response.raise_for_status()
|
||||
file_data = response.content
|
||||
result = await cognee_delete(
|
||||
file_data, dataset_name=dataset_name, mode=mode
|
||||
)
|
||||
results.append(result)
|
||||
else:
|
||||
# Handle uploaded file
|
||||
result = await cognee_delete(file, dataset_name=dataset_name, mode=mode)
|
||||
results.append(result)
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
else:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Multiple documents deleted",
|
||||
"results": results,
|
||||
}
|
||||
except Exception as error:
|
||||
logger.error(f"Error during deletion: {str(error)}")
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
|
|
@ -1,61 +1,195 @@
|
|||
from typing import Protocol, Optional, Dict, Any
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol, Optional, Dict, Any, List, Tuple
|
||||
from abc import abstractmethod, ABC
|
||||
from uuid import UUID, uuid5, NAMESPACE_DNS
|
||||
from cognee.modules.graph.relationship_manager import create_relationship
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Type aliases for better readability
|
||||
NodeData = Dict[str, Any]
|
||||
EdgeData = Tuple[
|
||||
str, str, str, Dict[str, Any]
|
||||
] # (source_id, target_id, relationship_name, properties)
|
||||
Node = Tuple[str, NodeData] # (node_id, properties)
|
||||
|
||||
|
||||
class GraphDBInterface(Protocol):
|
||||
def record_graph_changes(func):
|
||||
"""Decorator to record graph changes in the relationship database."""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
||||
caller_frame = frame.f_back
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
caller_name = caller_frame.f_code.co_name
|
||||
caller_class = (
|
||||
caller_frame.f_locals.get("self", None).__class__.__name__
|
||||
if caller_frame.f_locals.get("self", None)
|
||||
else None
|
||||
)
|
||||
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
||||
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
if func.__name__ == "add_nodes":
|
||||
nodes = args[0]
|
||||
for node in nodes:
|
||||
try:
|
||||
node_id = (
|
||||
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
|
||||
)
|
||||
relationship = GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=node_id,
|
||||
destination_node_id=node_id,
|
||||
creator_function=f"{creator}.node",
|
||||
node_label=node[1].get("type")
|
||||
if isinstance(node, tuple)
|
||||
else type(node).__name__,
|
||||
)
|
||||
session.add(relationship)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
continue
|
||||
|
||||
elif func.__name__ == "add_edges":
|
||||
edges = args[0]
|
||||
for edge in edges:
|
||||
try:
|
||||
source_id = UUID(str(edge[0]))
|
||||
target_id = UUID(str(edge[1]))
|
||||
rel_type = str(edge[2])
|
||||
relationship = GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=source_id,
|
||||
destination_node_id=target_id,
|
||||
creator_function=f"{creator}.{rel_type}",
|
||||
)
|
||||
session.add(relationship)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
continue
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error committing session: {e}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class GraphDBInterface(ABC):
|
||||
"""Interface for graph database operations."""
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, params: dict):
|
||||
async def query(self, query: str, params: dict) -> List[Any]:
|
||||
"""Execute a raw query against the database."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_node(self, node_id: str, node_properties: dict):
|
||||
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
|
||||
"""Add a single node to the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict]]):
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[Node]) -> None:
|
||||
"""Add multiple nodes to the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_node(self, node_id: str):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node from the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_nodes(self, node_ids: list[str]):
|
||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Delete multiple nodes from the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_node(self, node_id: str):
|
||||
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
||||
"""Get a single node by ID."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_nodes(self, node_ids: list[str]):
|
||||
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
||||
"""Get multiple nodes by their IDs."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
to_node: str,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relationship_name: str,
|
||||
edge_properties: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Add a single edge to the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_edges(self, edges: tuple[str, str, str, dict]):
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[EdgeData]) -> None:
|
||||
"""Add multiple edges to the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_graph(
|
||||
self,
|
||||
):
|
||||
async def delete_graph(self) -> None:
|
||||
"""Delete the entire graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph_data(self):
|
||||
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""Get all nodes and edges in the graph."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph_metrics(self, include_optional):
|
||||
""" "https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
|
||||
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
||||
"""Get graph metrics and statistics."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
||||
"""Check if an edge exists."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
|
||||
"""Check if multiple edges exist."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
||||
"""Get all edges connected to a node."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
||||
"""Get all neighboring nodes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_connections(
|
||||
self, node_id: str
|
||||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||
"""Get all nodes connected to a given node with their relationships."""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -14,7 +14,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
import kuzu
|
||||
from kuzu.database import Database
|
||||
from kuzu import Connection
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
|
@ -43,7 +46,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
self.connection.execute("""
|
||||
CREATE NODE TABLE IF NOT EXISTS Node(
|
||||
id STRING PRIMARY KEY,
|
||||
text STRING,
|
||||
name STRING,
|
||||
type STRING,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
|
|
@ -138,12 +141,16 @@ class KuzuAdapter(GraphDBInterface):
|
|||
query = """
|
||||
MATCH (from:Node), (to:Node)
|
||||
WHERE from.id = $from_id AND to.id = $to_id
|
||||
CREATE (from)-[r:EDGE {
|
||||
relationship_name: $relationship_name,
|
||||
created_at: timestamp($created_at),
|
||||
updated_at: timestamp($updated_at),
|
||||
properties: $properties
|
||||
MERGE (from)-[r:EDGE {
|
||||
relationship_name: $relationship_name
|
||||
}]->(to)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp($created_at),
|
||||
r.updated_at = timestamp($updated_at),
|
||||
r.properties = $properties
|
||||
ON MATCH SET
|
||||
r.updated_at = timestamp($updated_at),
|
||||
r.properties = $properties
|
||||
"""
|
||||
params = {
|
||||
"from_id": from_node,
|
||||
|
|
@ -171,7 +178,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
# Extract core fields with defaults if not present
|
||||
core_properties = {
|
||||
"id": str(properties.get("id", "")),
|
||||
"text": str(properties.get("text", "")),
|
||||
"name": str(properties.get("name", "")),
|
||||
"type": str(properties.get("type", "")),
|
||||
}
|
||||
|
||||
|
|
@ -181,35 +188,33 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
||||
|
||||
# Check if node exists
|
||||
exists = await self.has_node(core_properties["id"])
|
||||
# Add timestamps for new node
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
fields = []
|
||||
params = {}
|
||||
for key, value in core_properties.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
fields.append(f"{key}: ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not exists:
|
||||
# Add timestamps for new node
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
fields = []
|
||||
params = {}
|
||||
for key, value in core_properties.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
fields.append(f"{key}: ${param_name}")
|
||||
params[param_name] = value
|
||||
# Add timestamp fields
|
||||
fields.extend(
|
||||
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
|
||||
)
|
||||
params.update({"created_at": now, "updated_at": now})
|
||||
|
||||
# Add timestamp fields
|
||||
fields.extend(
|
||||
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
|
||||
)
|
||||
params.update({"created_at": now, "updated_at": now})
|
||||
|
||||
create_query = f"""
|
||||
CREATE (n:Node {{{", ".join(fields)}}})
|
||||
"""
|
||||
await self.query(create_query, params)
|
||||
merge_query = f"""
|
||||
MERGE (n:Node {{id: $param_id}})
|
||||
ON CREATE SET n += {{{", ".join(fields)}}}
|
||||
"""
|
||||
await self.query(merge_query, params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add node: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""Add multiple nodes to the graph in a batch operation."""
|
||||
if not nodes:
|
||||
|
|
@ -218,15 +223,14 @@ class KuzuAdapter(GraphDBInterface):
|
|||
try:
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Prepare all nodes data first
|
||||
# Prepare all nodes data
|
||||
node_params = []
|
||||
for node in nodes:
|
||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
||||
|
||||
# Extract core fields
|
||||
core_properties = {
|
||||
"id": str(properties.get("id", "")),
|
||||
"text": str(properties.get("text", "")),
|
||||
"name": str(properties.get("name", "")),
|
||||
"type": str(properties.get("type", "")),
|
||||
}
|
||||
|
||||
|
|
@ -244,36 +248,24 @@ class KuzuAdapter(GraphDBInterface):
|
|||
)
|
||||
|
||||
if node_params:
|
||||
# First check which nodes don't exist yet
|
||||
check_query = """
|
||||
# Batch merge nodes
|
||||
merge_query = """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Node)
|
||||
WHERE n.id = node.id
|
||||
RETURN n.id
|
||||
MERGE (n:Node {id: node.id})
|
||||
ON CREATE SET
|
||||
n.name = node.name,
|
||||
n.type = node.type,
|
||||
n.properties = node.properties,
|
||||
n.created_at = timestamp(node.created_at),
|
||||
n.updated_at = timestamp(node.updated_at)
|
||||
ON MATCH SET
|
||||
n.name = node.name,
|
||||
n.type = node.type,
|
||||
n.properties = node.properties,
|
||||
n.updated_at = timestamp(node.updated_at)
|
||||
"""
|
||||
existing_nodes = await self.query(check_query, {"nodes": node_params})
|
||||
existing_ids = {str(row[0]) for row in existing_nodes}
|
||||
|
||||
# Filter out existing nodes
|
||||
new_nodes = [node for node in node_params if node["id"] not in existing_ids]
|
||||
|
||||
if new_nodes:
|
||||
# Batch create new nodes
|
||||
create_query = """
|
||||
UNWIND $nodes AS node
|
||||
CREATE (n:Node {
|
||||
id: node.id,
|
||||
text: node.text,
|
||||
type: node.type,
|
||||
properties: node.properties,
|
||||
created_at: timestamp(node.created_at),
|
||||
updated_at: timestamp(node.updated_at)
|
||||
})
|
||||
"""
|
||||
await self.query(create_query, {"nodes": new_nodes})
|
||||
logger.debug(f"Added {len(new_nodes)} new nodes in batch")
|
||||
else:
|
||||
logger.debug("No new nodes to add - all nodes already exist")
|
||||
await self.query(merge_query, {"nodes": node_params})
|
||||
logger.debug(f"Processed {len(node_params)} nodes in batch")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add nodes in batch: {e}")
|
||||
|
|
@ -296,7 +288,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHERE n.id = $id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
|
|
@ -318,7 +310,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHERE n.id IN $node_ids
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
|
|
@ -401,6 +393,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add edge: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
"""Add multiple edges in a batch operation."""
|
||||
if not edges:
|
||||
|
|
@ -409,7 +402,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
try:
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Transform edges into the format needed for batch insertion
|
||||
edge_params = [
|
||||
{
|
||||
"from_id": from_node,
|
||||
|
|
@ -422,17 +414,20 @@ class KuzuAdapter(GraphDBInterface):
|
|||
for from_node, to_node, relationship_name, properties in edges
|
||||
]
|
||||
|
||||
# Batch create query
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from:Node), (to:Node)
|
||||
WHERE from.id = edge.from_id AND to.id = edge.to_id
|
||||
CREATE (from)-[r:EDGE {
|
||||
relationship_name: edge.relationship_name,
|
||||
created_at: timestamp(edge.created_at),
|
||||
updated_at: timestamp(edge.updated_at),
|
||||
properties: edge.properties
|
||||
MERGE (from)-[r:EDGE {
|
||||
relationship_name: edge.relationship_name
|
||||
}]->(to)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(edge.created_at),
|
||||
r.updated_at = timestamp(edge.updated_at),
|
||||
r.properties = edge.properties
|
||||
ON MATCH SET
|
||||
r.updated_at = timestamp(edge.updated_at),
|
||||
r.properties = edge.properties
|
||||
"""
|
||||
|
||||
await self.query(query, {"edges": edge_params})
|
||||
|
|
@ -454,14 +449,14 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHERE n.id = $node_id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
},
|
||||
r.relationship_name,
|
||||
{
|
||||
id: m.id,
|
||||
text: m.text,
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}
|
||||
|
|
@ -481,6 +476,50 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
# Neighbor Operations
|
||||
|
||||
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all neighboring nodes."""
|
||||
return await self.get_neighbours(node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single node by ID."""
|
||||
query_str = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id = $id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
result = await self.query(query_str, {"id": node_id})
|
||||
if result and result[0]:
|
||||
return self._parse_node(result[0][0])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get node {node_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Get multiple nodes by their IDs."""
|
||||
query_str = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $node_ids
|
||||
RETURN {
|
||||
id: n.id,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
results = await self.query(query_str, {"node_ids": node_ids})
|
||||
return [self._parse_node(row[0]) for row in results if row[0]]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get nodes: {e}")
|
||||
return []
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all neighbouring nodes."""
|
||||
query_str = """
|
||||
|
|
@ -554,7 +593,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
WHERE n.id = $node_id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
},
|
||||
|
|
@ -564,7 +603,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
{
|
||||
id: m.id,
|
||||
text: m.text,
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}
|
||||
|
|
@ -625,7 +664,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
RETURN n.id, {
|
||||
text: n.text,
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
|
|
@ -716,77 +755,36 @@ class KuzuAdapter(GraphDBInterface):
|
|||
return ([n[0] for n in nodes], [e[0] for e in edges])
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""For the definition of these metrics, please refer to
|
||||
https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
|
||||
|
||||
try:
|
||||
# Basic metrics
|
||||
node_count = await self.query("MATCH (n:Node) RETURN COUNT(n)")
|
||||
edge_count = await self.query("MATCH ()-[r:EDGE]->() RETURN COUNT(r)")
|
||||
num_nodes = node_count[0][0] if node_count else 0
|
||||
num_edges = edge_count[0][0] if edge_count else 0
|
||||
# Get basic graph data
|
||||
nodes, edges = await self.get_model_independent_graph_data()
|
||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
|
||||
num_edges = len(edges[0]["elements"]) if edges else 0
|
||||
|
||||
# Calculate mandatory metrics
|
||||
mandatory_metrics = {
|
||||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
|
||||
"edge_density": (num_edges) / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||||
"edge_density": num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
||||
"num_connected_components": await self._get_num_connected_components(),
|
||||
"sizes_of_connected_components": await self._get_size_of_connected_components(),
|
||||
}
|
||||
|
||||
# Calculate connected components
|
||||
components_query = """
|
||||
MATCH (n:Node)
|
||||
WITH n.id AS node_id
|
||||
MATCH path = (n)-[:EDGE*0..]-()
|
||||
WITH COLLECT(DISTINCT node_id) AS component
|
||||
RETURN COLLECT(component) AS components
|
||||
"""
|
||||
components_result = await self.query(components_query)
|
||||
component_sizes = (
|
||||
[len(comp) for comp in components_result[0][0]] if components_result else []
|
||||
)
|
||||
|
||||
mandatory_metrics.update(
|
||||
{
|
||||
"num_connected_components": len(component_sizes),
|
||||
"sizes_of_connected_components": component_sizes,
|
||||
}
|
||||
)
|
||||
|
||||
if include_optional:
|
||||
# Self-loops
|
||||
self_loops_query = """
|
||||
MATCH (n:Node)-[r:EDGE]->(n)
|
||||
RETURN COUNT(r)
|
||||
"""
|
||||
self_loops = await self.query(self_loops_query)
|
||||
num_selfloops = self_loops[0][0] if self_loops else 0
|
||||
|
||||
# Shortest paths (simplified for Kuzu)
|
||||
paths_query = """
|
||||
MATCH (n:Node), (m:Node)
|
||||
WHERE n.id < m.id
|
||||
MATCH path = (n)-[:EDGE*]-(m)
|
||||
RETURN MIN(LENGTH(path)) AS length
|
||||
"""
|
||||
paths = await self.query(paths_query)
|
||||
path_lengths = [p[0] for p in paths if p[0] is not None]
|
||||
|
||||
# Local clustering coefficient
|
||||
clustering_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(neighbor)
|
||||
WITH n, COUNT(DISTINCT neighbor) as degree
|
||||
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
|
||||
WHERE n1 <> n2
|
||||
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END)
|
||||
"""
|
||||
clustering = await self.query(clustering_query)
|
||||
|
||||
# Calculate optional metrics
|
||||
shortest_path_lengths = await self._get_shortest_path_lengths()
|
||||
optional_metrics = {
|
||||
"num_selfloops": num_selfloops,
|
||||
"diameter": max(path_lengths) if path_lengths else -1,
|
||||
"avg_shortest_path_length": sum(path_lengths) / len(path_lengths)
|
||||
if path_lengths
|
||||
"num_selfloops": await self._count_self_loops(),
|
||||
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
|
||||
"avg_shortest_path_length": sum(shortest_path_lengths)
|
||||
/ len(shortest_path_lengths)
|
||||
if shortest_path_lengths
|
||||
else -1,
|
||||
"avg_clustering": clustering[0][0] if clustering and clustering[0][0] else -1,
|
||||
"avg_clustering": await self._get_avg_clustering(),
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
|
|
@ -813,6 +811,65 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
async def _get_num_connected_components(self) -> int:
|
||||
"""Get the number of connected components in the graph."""
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WITH n.id AS node_id
|
||||
MATCH path = (n)-[:EDGE*1..3]-(m)
|
||||
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
|
||||
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
|
||||
RETURN SIZE(components) AS num_components
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0][0] if result else 0
|
||||
|
||||
async def _get_size_of_connected_components(self) -> List[int]:
|
||||
"""Get the sizes of all connected components in the graph."""
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WITH n.id AS node_id
|
||||
MATCH path = (n)-[:EDGE*1..3]-(m)
|
||||
WITH node_id, COLLECT(DISTINCT m.id) AS connected_nodes
|
||||
WITH COLLECT(DISTINCT connected_nodes + [node_id]) AS components
|
||||
UNWIND components AS component
|
||||
RETURN SIZE(component) AS component_size
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return [row[0] for row in result] if result else []
|
||||
|
||||
async def _get_shortest_path_lengths(self) -> List[int]:
|
||||
"""Get the lengths of shortest paths between all pairs of nodes."""
|
||||
query = """
|
||||
MATCH (n:Node), (m:Node)
|
||||
WHERE n.id < m.id
|
||||
MATCH path = (n)-[:EDGE*]-(m)
|
||||
RETURN MIN(LENGTH(path)) AS length
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return [row[0] for row in result if row[0] is not None] if result else []
|
||||
|
||||
async def _count_self_loops(self) -> int:
|
||||
"""Count the number of self-loops in the graph."""
|
||||
query = """
|
||||
MATCH (n:Node)-[r:EDGE]->(n)
|
||||
RETURN COUNT(r) AS count
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0][0] if result else 0
|
||||
|
||||
async def _get_avg_clustering(self) -> float:
|
||||
"""Calculate the average clustering coefficient of the graph."""
|
||||
query = """
|
||||
MATCH (n:Node)-[:EDGE]-(neighbor)
|
||||
WITH n, COUNT(DISTINCT neighbor) as degree
|
||||
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
|
||||
WHERE n1 <> n2
|
||||
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END) AS avg_clustering
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0][0] if result and result[0][0] is not None else -1
|
||||
|
||||
async def get_disconnected_nodes(self) -> List[str]:
|
||||
"""Get nodes that are not connected to any other node."""
|
||||
query_str = """
|
||||
|
|
@ -847,10 +904,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
async def delete_graph(self) -> None:
|
||||
"""Delete all data from the graph while preserving the database structure."""
|
||||
try:
|
||||
# Delete relationships from the fixed table EDGE
|
||||
await self.query("MATCH ()-[r:EDGE]->() DELETE r")
|
||||
# Then delete nodes
|
||||
await self.query("MATCH (n:Node) DELETE n")
|
||||
# Use DETACH DELETE to remove both nodes and their relationships in one operation
|
||||
await self.query("MATCH (n:Node) DETACH DELETE n")
|
||||
logger.info("Cleared all data from graph while preserving structure")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete graph data: {e}")
|
||||
|
|
@ -922,3 +977,73 @@ class KuzuAdapter(GraphDBInterface):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to import graph from file: {e}")
|
||||
raise
|
||||
|
||||
async def get_document_subgraph(self, content_hash: str):
|
||||
"""Get all nodes that should be deleted when removing a document."""
|
||||
query = """
|
||||
MATCH (doc:Node)
|
||||
WHERE (doc.type = 'TextDocument' OR doc.type = 'PdfDocument') AND doc.name = $content_hash
|
||||
|
||||
OPTIONAL MATCH (doc)<-[e1:EDGE]-(chunk:Node)
|
||||
WHERE e1.relationship_name = 'is_part_of' AND chunk.type = 'DocumentChunk'
|
||||
|
||||
OPTIONAL MATCH (chunk)-[e2:EDGE]->(entity:Node)
|
||||
WHERE e2.relationship_name = 'contains' AND entity.type = 'Entity'
|
||||
AND NOT EXISTS {
|
||||
MATCH (entity)<-[e3:EDGE]-(otherChunk:Node)-[e4:EDGE]->(otherDoc:Node)
|
||||
WHERE e3.relationship_name = 'contains'
|
||||
AND e4.relationship_name = 'is_part_of'
|
||||
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument')
|
||||
AND otherDoc.id <> doc.id
|
||||
}
|
||||
|
||||
OPTIONAL MATCH (chunk)<-[e5:EDGE]-(made_node:Node)
|
||||
WHERE e5.relationship_name = 'made_from' AND made_node.type = 'TextSummary'
|
||||
|
||||
OPTIONAL MATCH (entity)-[e6:EDGE]->(type:Node)
|
||||
WHERE e6.relationship_name = 'is_a' AND type.type = 'EntityType'
|
||||
AND NOT EXISTS {
|
||||
MATCH (type)<-[e7:EDGE]-(otherEntity:Node)-[e8:EDGE]-(otherChunk:Node)-[e9:EDGE]-(otherDoc:Node)
|
||||
WHERE e7.relationship_name = 'is_a'
|
||||
AND e8.relationship_name = 'contains'
|
||||
AND e9.relationship_name = 'is_part_of'
|
||||
AND otherEntity.type = 'Entity'
|
||||
AND otherChunk.type = 'DocumentChunk'
|
||||
AND (otherDoc.type = 'TextDocument' OR otherDoc.type = 'PdfDocument')
|
||||
AND otherDoc.id <> doc.id
|
||||
}
|
||||
|
||||
RETURN
|
||||
COLLECT(DISTINCT doc) as document,
|
||||
COLLECT(DISTINCT chunk) as chunks,
|
||||
COLLECT(DISTINCT entity) as orphan_entities,
|
||||
COLLECT(DISTINCT made_node) as made_from_nodes,
|
||||
COLLECT(DISTINCT type) as orphan_types
|
||||
"""
|
||||
result = await self.query(query, {"content_hash": f"text_{content_hash}"})
|
||||
if not result or not result[0]:
|
||||
return None
|
||||
|
||||
# Convert tuple to dictionary
|
||||
return {
|
||||
"document": result[0][0],
|
||||
"chunks": result[0][1],
|
||||
"orphan_entities": result[0][2],
|
||||
"made_from_nodes": result[0][3],
|
||||
"orphan_types": result[0][4],
|
||||
}
|
||||
|
||||
async def get_degree_one_nodes(self, node_type: str):
|
||||
"""Get all nodes that have only one connection."""
|
||||
if not node_type or node_type not in ["Entity", "EntityType"]:
|
||||
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
||||
|
||||
query = f"""
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = '{node_type}'
|
||||
WITH n, COUNT {{ MATCH (n)--() }} as degree
|
||||
WHERE degree = 1
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return [record[0] for record in result] if result else []
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#
|
||||
|
||||
"""Neo4j Adapter for Graph Database"""
|
||||
|
||||
import json
|
||||
|
|
@ -11,7 +13,10 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from .neo4j_metrics_utils import (
|
||||
get_avg_clustering,
|
||||
|
|
@ -89,6 +94,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
|
|
@ -130,9 +136,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return [result["node"] for result in results]
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
node_id = id.replace(":", "_")
|
||||
|
||||
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
|
||||
query = "MATCH (node {id: $node_id}) DETACH DELETE node"
|
||||
params = {"node_id": node_id}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
|
@ -218,6 +222,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
|
|
@ -373,12 +378,28 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return [result["successor"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.get_predecessors(node_id), self.get_successors(node_id)
|
||||
)
|
||||
async def get_neighbors(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all neighboring nodes."""
|
||||
return await self.get_neighbours(node_id)
|
||||
|
||||
return predecessors + successors
|
||||
async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single node by ID."""
|
||||
query = """
|
||||
MATCH (node {id: $node_id})
|
||||
RETURN node
|
||||
"""
|
||||
results = await self.query(query, {"node_id": node_id})
|
||||
return results[0]["node"] if results else None
|
||||
|
||||
async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Get multiple nodes by their IDs."""
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
RETURN node
|
||||
"""
|
||||
results = await self.query(query, {"node_ids": node_ids})
|
||||
return [result["node"] for result in results]
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
predecessors_query = """
|
||||
|
|
@ -651,3 +672,46 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
}
|
||||
|
||||
return mandatory_metrics | optional_metrics
|
||||
|
||||
async def get_document_subgraph(self, content_hash: str):
|
||||
query = """
|
||||
MATCH (doc)
|
||||
WHERE (doc:TextDocument OR doc:PdfDocument)
|
||||
AND doc.name = 'text_' + $content_hash
|
||||
|
||||
OPTIONAL MATCH (doc)<-[:is_part_of]-(chunk:DocumentChunk)
|
||||
OPTIONAL MATCH (chunk)-[:contains]->(entity:Entity)
|
||||
WHERE NOT EXISTS {
|
||||
MATCH (entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
|
||||
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
|
||||
AND otherDoc.id <> doc.id
|
||||
}
|
||||
OPTIONAL MATCH (chunk)<-[:made_from]-(made_node:TextSummary)
|
||||
OPTIONAL MATCH (entity)-[:is_a]->(type:EntityType)
|
||||
WHERE NOT EXISTS {
|
||||
MATCH (type)<-[:is_a]-(otherEntity:Entity)<-[:contains]-(otherChunk:DocumentChunk)-[:is_part_of]->(otherDoc)
|
||||
WHERE (otherDoc:TextDocument OR otherDoc:PdfDocument)
|
||||
AND otherDoc.id <> doc.id
|
||||
}
|
||||
|
||||
RETURN
|
||||
collect(DISTINCT doc) as document,
|
||||
collect(DISTINCT chunk) as chunks,
|
||||
collect(DISTINCT entity) as orphan_entities,
|
||||
collect(DISTINCT made_node) as made_from_nodes,
|
||||
collect(DISTINCT type) as orphan_types
|
||||
"""
|
||||
result = await self.query(query, {"content_hash": content_hash})
|
||||
return result[0] if result else None
|
||||
|
||||
async def get_degree_one_nodes(self, node_type: str):
|
||||
if not node_type or node_type not in ["Entity", "EntityType"]:
|
||||
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE COUNT {{ MATCH (n)--() }} = 1
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return [record["n"] for record in result] if result else []
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from uuid import UUID
|
|||
import aiofiles
|
||||
import aiofiles.os as aiofiles_os
|
||||
import networkx as nx
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
|
@ -42,20 +45,14 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async def has_node(self, node_id: str) -> bool:
|
||||
return self.graph.has_node(node_id)
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
node: DataPoint,
|
||||
) -> None:
|
||||
async def add_node(self, node: DataPoint) -> None:
|
||||
self.graph.add_node(node.id, **node.model_dump())
|
||||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def add_nodes(
|
||||
self,
|
||||
nodes: list[DataPoint],
|
||||
) -> None:
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
nodes = [(node.id, node.model_dump()) for node in nodes]
|
||||
|
||||
self.graph.add_nodes_from(nodes)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
|
@ -74,6 +71,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return result
|
||||
|
||||
@record_graph_changes
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
|
|
@ -91,38 +89,76 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def add_edges(
|
||||
self,
|
||||
edges: tuple[str, str, str, dict],
|
||||
) -> None:
|
||||
edges = [
|
||||
(
|
||||
edge[0],
|
||||
edge[1],
|
||||
edge[2],
|
||||
{
|
||||
**(edge[3] if len(edge) == 4 else {}),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict]]) -> None:
|
||||
if not edges:
|
||||
logger.debug("No edges to add")
|
||||
return
|
||||
|
||||
self.graph.add_edges_from(edges)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
try:
|
||||
# Validate edge format and convert UUIDs to strings
|
||||
processed_edges = []
|
||||
for edge in edges:
|
||||
if len(edge) < 3 or len(edge) > 4:
|
||||
raise ValueError(
|
||||
f"Invalid edge format: {edge}. Expected (from_node, to_node, relationship_name[, properties])"
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings if needed
|
||||
from_node = str(edge[0]) if isinstance(edge[0], UUID) else edge[0]
|
||||
to_node = str(edge[1]) if isinstance(edge[1], UUID) else edge[1]
|
||||
relationship_name = edge[2]
|
||||
|
||||
if not all(isinstance(x, str) for x in [from_node, to_node, relationship_name]):
|
||||
raise ValueError(
|
||||
f"First three elements of edge must be strings or UUIDs: {edge}"
|
||||
)
|
||||
|
||||
# Process edge with updated_at timestamp
|
||||
processed_edge = (
|
||||
from_node,
|
||||
to_node,
|
||||
relationship_name,
|
||||
{
|
||||
**(edge[3] if len(edge) == 4 else {}),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
processed_edges.append(processed_edge)
|
||||
|
||||
# Add edges to graph
|
||||
self.graph.add_edges_from(processed_edges)
|
||||
logger.debug(f"Added {len(processed_edges)} edges to graph")
|
||||
|
||||
# Save changes
|
||||
await self.save_graph_to_file(self.filename)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add edges: {e}")
|
||||
raise
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
return list(self.graph.in_edges(node_id, data=True)) + list(
|
||||
self.graph.out_edges(node_id, data=True)
|
||||
)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Asynchronously delete a node from the graph if it exists."""
|
||||
if self.graph.has_node(node_id):
|
||||
self.graph.remove_node(node_id)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
async def delete_node(self, node_id: UUID) -> None:
|
||||
"""Asynchronously delete a node and all its relationships from the graph if it exists."""
|
||||
|
||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
if self.graph.has_node(node_id):
|
||||
# First remove all edges connected to the node
|
||||
for edge in list(self.graph.edges(node_id, data=True)):
|
||||
source, target, data = edge
|
||||
self.graph.remove_edge(source, target, key=data.get("relationship_name"))
|
||||
|
||||
# Then remove the node itself
|
||||
self.graph.remove_node(node_id)
|
||||
|
||||
# Save the updated graph state
|
||||
await self.save_graph_to_file(self.filename)
|
||||
else:
|
||||
logger.error(f"Node {node_id} not found in graph")
|
||||
|
||||
async def delete_nodes(self, node_ids: List[UUID]) -> None:
|
||||
self.graph.remove_nodes_from(node_ids)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
|
@ -179,7 +215,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return nodes
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> list:
|
||||
async def get_neighbors(self, node_id: str) -> list:
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
|
|
@ -188,9 +224,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.get_successors(node_id),
|
||||
)
|
||||
|
||||
neighbours = predecessors + successors
|
||||
neighbors = predecessors + successors
|
||||
|
||||
return neighbours
|
||||
return neighbors
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
if not self.graph.has_node(node_id):
|
||||
|
|
@ -208,17 +244,22 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
connections = []
|
||||
|
||||
for neighbor in predecessors:
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((neighbor, edge_properties, node))
|
||||
# Handle None values for predecessors and successors
|
||||
if predecessors is not None:
|
||||
for neighbor in predecessors:
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
||||
if edge_data is not None:
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((neighbor, edge_properties, node))
|
||||
|
||||
for neighbor in successors:
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((node, edge_properties, neighbor))
|
||||
if successors is not None:
|
||||
for neighbor in successors:
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
||||
if edge_data is not None:
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((node, edge_properties, neighbor))
|
||||
|
||||
return connections
|
||||
|
||||
|
|
@ -247,8 +288,9 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async def create_empty_graph(self, file_path: str) -> None:
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
# Only create directory if file_path contains a directory
|
||||
file_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(file_dir):
|
||||
if file_dir and not os.path.exists(file_dir):
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
|
@ -266,9 +308,6 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
"""Asynchronously load the graph from a file in JSON format."""
|
||||
if file_path == self.filename:
|
||||
return
|
||||
|
||||
if not file_path:
|
||||
file_path = self.filename
|
||||
try:
|
||||
|
|
@ -460,3 +499,138 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
}
|
||||
|
||||
return mandatory_metrics | optional_metrics
|
||||
|
||||
async def get_document_subgraph(self, content_hash: str):
|
||||
"""Get all nodes that should be deleted when removing a document."""
|
||||
# Ensure graph is loaded
|
||||
if self.graph is None:
|
||||
await self.load_graph_from_file()
|
||||
|
||||
# Find the document node by looking for content_hash in the name field
|
||||
document = None
|
||||
document_node_id = None
|
||||
for node_id, attrs in self.graph.nodes(data=True):
|
||||
if (
|
||||
attrs.get("type") in ["TextDocument", "PdfDocument"]
|
||||
and attrs.get("name") == f"text_{content_hash}"
|
||||
):
|
||||
document = {"id": str(node_id), **attrs} # Convert UUID to string for consistency
|
||||
document_node_id = node_id # Keep the original UUID
|
||||
break
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
# Find chunks connected via is_part_of (chunks point TO document)
|
||||
chunks = []
|
||||
for source, target, edge_data in self.graph.in_edges(document_node_id, data=True):
|
||||
if edge_data.get("relationship_name") == "is_part_of":
|
||||
chunks.append({"id": source, **self.graph.nodes[source]}) # Keep as UUID object
|
||||
|
||||
# Find entities connected to chunks (chunks point TO entities via contains)
|
||||
entities = []
|
||||
for chunk in chunks:
|
||||
chunk_id = chunk["id"] # Already a UUID object
|
||||
for source, target, edge_data in self.graph.out_edges(chunk_id, data=True):
|
||||
if edge_data.get("relationship_name") == "contains":
|
||||
entities.append(
|
||||
{"id": target, **self.graph.nodes[target]}
|
||||
) # Keep as UUID object
|
||||
|
||||
# Find orphaned entities (entities only connected to chunks we're deleting)
|
||||
orphan_entities = []
|
||||
for entity in entities:
|
||||
entity_id = entity["id"] # Already a UUID object
|
||||
# Get all chunks that contain this entity
|
||||
containing_chunks = []
|
||||
for source, target, edge_data in self.graph.in_edges(entity_id, data=True):
|
||||
if edge_data.get("relationship_name") == "contains":
|
||||
containing_chunks.append(source) # Keep as UUID object
|
||||
|
||||
# Check if all containing chunks are in our chunks list
|
||||
chunk_ids = [chunk["id"] for chunk in chunks]
|
||||
if containing_chunks and all(c in chunk_ids for c in containing_chunks):
|
||||
orphan_entities.append(entity)
|
||||
|
||||
# Find orphaned entity types
|
||||
orphan_types = []
|
||||
seen_types = set() # Track seen types to avoid duplicates
|
||||
for entity in orphan_entities:
|
||||
entity_id = entity["id"] # Already a UUID object
|
||||
for _, target, edge_data in self.graph.out_edges(entity_id, data=True):
|
||||
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
|
||||
# Check if this type is only connected to entities we're deleting
|
||||
type_node = self.graph.nodes[target]
|
||||
if type_node.get("type") == "EntityType" and target not in seen_types:
|
||||
is_orphaned = True
|
||||
# Get all incoming edges to this type node
|
||||
for source, _, edge_data in self.graph.in_edges(target, data=True):
|
||||
if edge_data.get("relationship_name") in ["is_a", "instance_of"]:
|
||||
# Check if the source entity is not in our orphan_entities list
|
||||
if source not in [e["id"] for e in orphan_entities]:
|
||||
is_orphaned = False
|
||||
break
|
||||
if is_orphaned:
|
||||
orphan_types.append({"id": target, **type_node}) # Keep as UUID object
|
||||
seen_types.add(target) # Mark as seen
|
||||
|
||||
# Find nodes connected via made_from (chunks point TO summaries)
|
||||
made_from_nodes = []
|
||||
for chunk in chunks:
|
||||
chunk_id = chunk["id"] # Already a UUID object
|
||||
for source, target, edge_data in self.graph.in_edges(chunk_id, data=True):
|
||||
if edge_data.get("relationship_name") == "made_from":
|
||||
made_from_nodes.append(
|
||||
{"id": source, **self.graph.nodes[source]}
|
||||
) # Keep as UUID object
|
||||
|
||||
# Return UUIDs directly without string conversion
|
||||
return {
|
||||
"document": [{"id": document["id"], **{k: v for k, v in document.items() if k != "id"}}]
|
||||
if document
|
||||
else [],
|
||||
"chunks": [
|
||||
{"id": chunk["id"], **{k: v for k, v in chunk.items() if k != "id"}}
|
||||
for chunk in chunks
|
||||
],
|
||||
"orphan_entities": [
|
||||
{"id": entity["id"], **{k: v for k, v in entity.items() if k != "id"}}
|
||||
for entity in orphan_entities
|
||||
],
|
||||
"made_from_nodes": [
|
||||
{"id": node["id"], **{k: v for k, v in node.items() if k != "id"}}
|
||||
for node in made_from_nodes
|
||||
],
|
||||
"orphan_types": [
|
||||
{"id": type_node["id"], **{k: v for k, v in type_node.items() if k != "id"}}
|
||||
for type_node in orphan_types
|
||||
],
|
||||
}
|
||||
|
||||
async def get_degree_one_nodes(self, node_type: str):
|
||||
"""Get all nodes that have only one connection."""
|
||||
if not node_type or node_type not in ["Entity", "EntityType"]:
|
||||
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
||||
|
||||
nodes = []
|
||||
for node_id, node_data in self.graph.nodes(data=True):
|
||||
if node_data.get("type") == node_type:
|
||||
# Count both incoming and outgoing edges
|
||||
degree = self.graph.degree(node_id)
|
||||
if degree == 1:
|
||||
nodes.append(node_data)
|
||||
return nodes
|
||||
|
||||
async def get_node(self, node_id: str) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
return self.graph.nodes[node_id]
|
||||
return None
|
||||
|
||||
async def get_nodes(self, node_ids: List[str] = None) -> List[dict]:
|
||||
if node_ids is None:
|
||||
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
|
||||
return [
|
||||
{"id": node_id, **self.graph.nodes[node_id]}
|
||||
for node_id in node_ids
|
||||
if self.graph.has_node(node_id)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from ..models.ScoredResult import ScoredResult
|
|||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
id: str
|
||||
|
|
@ -230,14 +232,30 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
if len(data_point_ids) == 1:
|
||||
results = await collection.delete(f"id = '{data_point_ids[0]}'")
|
||||
def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def _delete_data_points():
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
return True
|
||||
|
||||
# Check if we're in an event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# If we're in a running event loop, create a new task
|
||||
return loop.create_task(_delete_data_points())
|
||||
else:
|
||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||
return results
|
||||
# If we're not in an event loop, run it synchronously
|
||||
return asyncio.run(_delete_data_points())
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# PROPOSED TO BE DEPRECATED
|
||||
|
||||
from typing import Type, Optional, get_args, get_origin
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
|
|
|||
41
cognee/modules/data/models/graph_relationship_ledger.py
Normal file
41
cognee/modules/data/models/graph_relationship_ledger.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import uuid5, NAMESPACE_DNS
|
||||
from sqlalchemy import UUID, Column, DateTime, String, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class GraphRelationshipLedger(Base):
|
||||
__tablename__ = "graph_relationship_ledger"
|
||||
|
||||
id = Column(
|
||||
UUID,
|
||||
primary_key=True,
|
||||
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
)
|
||||
source_node_id = Column(UUID, nullable=False)
|
||||
destination_node_id = Column(UUID, nullable=False)
|
||||
creator_function = Column(String, nullable=False)
|
||||
node_label = Column(String, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True)
|
||||
user_id = Column(UUID, nullable=True)
|
||||
|
||||
# Create indexes
|
||||
__table_args__ = (
|
||||
Index("idx_graph_relationship_id", "id"),
|
||||
Index("idx_graph_relationship_ledger_source_node_id", "source_node_id"),
|
||||
Index("idx_graph_relationship_ledger_destination_node_id", "destination_node_id"),
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"source_node_id": str(self.parent_id),
|
||||
"destination_node_id": str(self.child_id),
|
||||
"creator_function": self.creator_function,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"deleted_at": self.deleted_at.isoformat() if self.deleted_at else None,
|
||||
"user_id": str(self.user_id),
|
||||
}
|
||||
54
cognee/modules/graph/relationship_manager.py
Normal file
54
cognee/modules/graph/relationship_manager.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from cognee.modules.data.models import graph_relationship_ledger
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def create_relationship(
|
||||
session: AsyncSession,
|
||||
source_node_id: UUID,
|
||||
destination_node_id: UUID,
|
||||
creator_function: str,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Create a relationship between two nodes in the graph.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
source_node_id: ID of the source node
|
||||
destination_node_id: ID of the destination node
|
||||
creator_function: Name of the function creating the relationship
|
||||
user: User creating the relationship
|
||||
"""
|
||||
relationship = graph_relationship_ledger(
|
||||
source_node_id=source_node_id,
|
||||
destination_node_id=destination_node_id,
|
||||
creator_function=creator_function,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(relationship)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def delete_relationship(
|
||||
session: AsyncSession,
|
||||
source_node_id: UUID,
|
||||
destination_node_id: UUID,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Mark a relationship as deleted.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
source_node_id: ID of the source node
|
||||
destination_node_id: ID of the destination node
|
||||
user: User deleting the relationship
|
||||
"""
|
||||
relationship = await session.get(
|
||||
graph_relationship_ledger, (source_node_id, destination_node_id)
|
||||
)
|
||||
if relationship:
|
||||
relationship.deleted_at = datetime.now(timezone.utc)
|
||||
session.add(relationship)
|
||||
await session.flush()
|
||||
|
|
@ -51,7 +51,6 @@ class TripletSearchContextProvider(BaseContextProvider):
|
|||
tasks = [
|
||||
brute_force_triplet_search(
|
||||
query=f"{entity_text} {query}",
|
||||
user=user,
|
||||
top_k=self.top_k,
|
||||
collections=self.collections,
|
||||
properties_to_project=self.properties_to_project,
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ async def code_description_to_code_part(
|
|||
|
||||
try:
|
||||
if include_docs:
|
||||
search_results = await search(query_text=query, query_type="INSIGHTS", user=user)
|
||||
search_results = await search(query_text=query, query_type="INSIGHTS")
|
||||
|
||||
concatenated_descriptions = " ".join(
|
||||
obj["description"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# PROPOSED TO BE DEPRECATED
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Type
|
||||
|
|
|
|||
|
|
@ -51,12 +51,17 @@ async def integrate_chunk_graphs(
|
|||
|
||||
|
||||
async def extract_graph_from_data(
|
||||
data_chunks: list[DocumentChunk],
|
||||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[BaseModel],
|
||||
ontology_adapter: OntologyResolver = OntologyResolver(),
|
||||
ontology_adapter: OntologyResolver = None,
|
||||
) -> List[DocumentChunk]:
|
||||
"""Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model."""
|
||||
"""
|
||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
"""
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_adapter)
|
||||
|
||||
return await integrate_chunk_graphs(
|
||||
data_chunks, chunk_graphs, graph_model, ontology_adapter or OntologyResolver()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List
|
|||
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.tasks.graph.cascade_extract.utils.extract_nodes import extract_nodes
|
||||
from cognee.tasks.graph.cascade_extract.utils.extract_content_nodes_and_relationship_names import (
|
||||
extract_content_nodes_and_relationship_names,
|
||||
|
|
@ -14,7 +15,9 @@ from cognee.tasks.graph.extract_graph_from_data import integrate_chunk_graphs
|
|||
|
||||
|
||||
async def extract_graph_from_data(
|
||||
data_chunks: List[DocumentChunk], n_rounds: int = 2
|
||||
data_chunks: List[DocumentChunk],
|
||||
n_rounds: int = 2,
|
||||
ontology_adapter: OntologyResolver = None,
|
||||
) -> List[DocumentChunk]:
|
||||
"""Extract and update graph data from document chunks in multiple steps."""
|
||||
chunk_nodes = await asyncio.gather(
|
||||
|
|
@ -37,4 +40,9 @@ async def extract_graph_from_data(
|
|||
]
|
||||
)
|
||||
|
||||
return await integrate_chunk_graphs(data_chunks, chunk_graphs, KnowledgeGraph)
|
||||
return await integrate_chunk_graphs(
|
||||
data_chunks=data_chunks,
|
||||
chunk_graphs=chunk_graphs,
|
||||
graph_model=KnowledgeGraph,
|
||||
ontology_adapter=ontology_adapter or OntologyResolver(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# PROPOSED TO BE DEPRECATED
|
||||
|
||||
"""This module contains the OntologyEngine class which is responsible for adding graph ontology from a JSON or CSV file."""
|
||||
|
||||
import csv
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||
|
|
@ -6,7 +7,7 @@ from .index_data_points import index_data_points
|
|||
from .index_graph_edges import index_graph_edges
|
||||
|
||||
|
||||
async def add_data_points(data_points: list[DataPoint]):
|
||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
|
|
|||
71
cognee/tests/test_deletion.py
Normal file
71
cognee/tests/test_deletion.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import os
|
||||
import shutil
|
||||
import cognee
|
||||
import pathlib
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
text_1 = """
|
||||
1. Audi
|
||||
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
|
||||
|
||||
2. BMW
|
||||
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
|
||||
|
||||
3. Mercedes-Benz
|
||||
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
|
||||
|
||||
4. Porsche
|
||||
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
|
||||
|
||||
5. Volkswagen
|
||||
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
|
||||
|
||||
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
|
||||
"""
|
||||
|
||||
text_2 = """
|
||||
1. Apple
|
||||
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
|
||||
|
||||
2. Google
|
||||
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
|
||||
|
||||
3. Microsoft
|
||||
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
|
||||
|
||||
4. Amazon
|
||||
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
|
||||
|
||||
5. Meta
|
||||
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
|
||||
|
||||
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
|
||||
"""
|
||||
|
||||
await cognee.add([text_1, text_2])
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) > 15 and len(edges) > 15, "Graph database is not loaded."
|
||||
|
||||
await cognee.delete([text_1, text_2], mode="hard")
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Document is not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -92,18 +92,12 @@ async def relational_db_migration():
|
|||
n_data = row[1]
|
||||
m_data = row[2]
|
||||
|
||||
source_props = {}
|
||||
if "properties" in n_data and n_data["properties"]:
|
||||
source_props = json.loads(n_data["properties"])
|
||||
target_props = {}
|
||||
if "properties" in m_data and m_data["properties"]:
|
||||
target_props = json.loads(m_data["properties"])
|
||||
source_name = normalize_node_name(n_data.get("name", ""))
|
||||
target_name = normalize_node_name(m_data.get("name", ""))
|
||||
|
||||
source_name = normalize_node_name(source_props.get("name", f"id:{n_data['id']}"))
|
||||
target_name = normalize_node_name(target_props.get("name", f"id:{m_data['id']}"))
|
||||
|
||||
found_edges.add((source_name, target_name))
|
||||
distinct_node_names.update([source_name, target_name])
|
||||
if source_name and target_name:
|
||||
found_edges.add((source_name, target_name))
|
||||
distinct_node_names.update([source_name, target_name])
|
||||
|
||||
elif graph_db_provider == "networkx":
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ import os
|
|||
from cognee.api.v1.search import SearchType
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
|
||||
|
||||
text_1 = """
|
||||
1. Audi
|
||||
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
|
||||
|
||||
2. BMW
|
||||
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company’s vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
|
||||
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
|
||||
|
||||
3. Mercedes-Benz
|
||||
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
|
||||
|
|
@ -21,7 +22,7 @@ Mercedes-Benz is synonymous with luxury and quality. With a history dating back
|
|||
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
|
||||
|
||||
5. Volkswagen
|
||||
Volkswagen, which means “people’s car” in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
|
||||
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
|
||||
|
||||
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
|
||||
"""
|
||||
|
|
@ -31,16 +32,16 @@ text_2 = """
|
|||
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
|
||||
|
||||
2. Google
|
||||
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google’s innovations have played a major role in shaping the internet landscape.
|
||||
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
|
||||
|
||||
3. Microsoft
|
||||
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
|
||||
|
||||
4. Amazon
|
||||
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon’s constant drive for innovation continues to reshape both retail and technology sectors.
|
||||
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
|
||||
|
||||
5. Meta
|
||||
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company’s efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
|
||||
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
|
||||
|
||||
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
|
||||
"""
|
||||
|
|
|
|||
871
notebooks/cognee_demo.ipynb
Normal file
871
notebooks/cognee_demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
243
notebooks/graphrag_vs_rag.ipynb
Normal file
243
notebooks/graphrag_vs_rag.ipynb
Normal file
File diff suppressed because one or more lines are too long
978
notebooks/hr_demo.ipynb
Normal file
978
notebooks/hr_demo.ipynb
Normal file
|
|
@ -0,0 +1,978 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d35ac8ce-0f92-46f5-9ba4-a46970f0ce19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cognee - Get Started"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "074f0ea8-c659-4736-be26-be4b0e5ac665",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Demo time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0587d91d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### First let's define some data that we will cognify and perform a search on"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df16431d0f48b006",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:48.519686Z",
|
||||
"start_time": "2024-09-20T14:02:48.515589Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_position = \"\"\"Senior Data Scientist (Machine Learning)\n",
|
||||
"\n",
|
||||
"Company: TechNova Solutions\n",
|
||||
"Location: San Francisco, CA\n",
|
||||
"\n",
|
||||
"Job Description:\n",
|
||||
"\n",
|
||||
"TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.\n",
|
||||
"\n",
|
||||
"Responsibilities:\n",
|
||||
"\n",
|
||||
"Develop and implement advanced machine learning algorithms and models.\n",
|
||||
"Analyze large, complex datasets to extract meaningful patterns and insights.\n",
|
||||
"Collaborate with cross-functional teams to integrate predictive models into products.\n",
|
||||
"Stay updated with the latest advancements in machine learning and data science.\n",
|
||||
"Mentor junior data scientists and provide technical guidance.\n",
|
||||
"Qualifications:\n",
|
||||
"\n",
|
||||
"Master’s or Ph.D. in Data Science, Computer Science, Statistics, or a related field.\n",
|
||||
"5+ years of experience in data science and machine learning.\n",
|
||||
"Proficient in Python, R, and SQL.\n",
|
||||
"Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).\n",
|
||||
"Strong problem-solving skills and attention to detail.\n",
|
||||
"Candidate CVs\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9086abf3af077ab4",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:49.120838Z",
|
||||
"start_time": "2024-09-20T14:02:49.118294Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_1 = \"\"\"\n",
|
||||
"CV 1: Relevant\n",
|
||||
"Name: Dr. Emily Carter\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: emily.carter@example.com\n",
|
||||
"Phone: (555) 123-4567\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"Ph.D. in Computer Science, Stanford University (2014)\n",
|
||||
"B.S. in Mathematics, University of California, Berkeley (2010)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist, InnovateAI Labs (2016 – Present)\n",
|
||||
"Led a team in developing machine learning models for natural language processing applications.\n",
|
||||
"Implemented deep learning algorithms that improved prediction accuracy by 25%.\n",
|
||||
"Collaborated with cross-functional teams to integrate models into cloud-based platforms.\n",
|
||||
"Data Scientist, DataWave Analytics (2014 – 2016)\n",
|
||||
"Developed predictive models for customer segmentation and churn analysis.\n",
|
||||
"Analyzed large datasets using Hadoop and Spark frameworks.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, R, SQL\n",
|
||||
"Machine Learning: TensorFlow, Keras, Scikit-Learn\n",
|
||||
"Big Data Technologies: Hadoop, Spark\n",
|
||||
"Data Visualization: Tableau, Matplotlib\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a9de0cc07f798b7f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:49.675003Z",
|
||||
"start_time": "2024-09-20T14:02:49.671615Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_2 = \"\"\"\n",
|
||||
"CV 2: Relevant\n",
|
||||
"Name: Michael Rodriguez\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: michael.rodriguez@example.com\n",
|
||||
"Phone: (555) 234-5678\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"M.S. in Data Science, Carnegie Mellon University (2013)\n",
|
||||
"B.S. in Computer Science, University of Michigan (2011)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist, Alpha Analytics (2017 – Present)\n",
|
||||
"Developed machine learning models to optimize marketing strategies.\n",
|
||||
"Reduced customer acquisition cost by 15% through predictive modeling.\n",
|
||||
"Data Scientist, TechInsights (2013 – 2017)\n",
|
||||
"Analyzed user behavior data to improve product features.\n",
|
||||
"Implemented A/B testing frameworks to evaluate product changes.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, Java, SQL\n",
|
||||
"Machine Learning: Scikit-Learn, XGBoost\n",
|
||||
"Data Visualization: Seaborn, Plotly\n",
|
||||
"Databases: MySQL, MongoDB\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "185ff1c102d06111",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:50.286828Z",
|
||||
"start_time": "2024-09-20T14:02:50.284369Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_3 = \"\"\"\n",
|
||||
"CV 3: Relevant\n",
|
||||
"Name: Sarah Nguyen\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: sarah.nguyen@example.com\n",
|
||||
"Phone: (555) 345-6789\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"M.S. in Statistics, University of Washington (2014)\n",
|
||||
"B.S. in Applied Mathematics, University of Texas at Austin (2012)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Data Scientist, QuantumTech (2016 – Present)\n",
|
||||
"Designed and implemented machine learning algorithms for financial forecasting.\n",
|
||||
"Improved model efficiency by 20% through algorithm optimization.\n",
|
||||
"Junior Data Scientist, DataCore Solutions (2014 – 2016)\n",
|
||||
"Assisted in developing predictive models for supply chain optimization.\n",
|
||||
"Conducted data cleaning and preprocessing on large datasets.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, R\n",
|
||||
"Machine Learning Frameworks: PyTorch, Scikit-Learn\n",
|
||||
"Statistical Analysis: SAS, SPSS\n",
|
||||
"Cloud Platforms: AWS, Azure\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d55ce4c58f8efb67",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:50.950343Z",
|
||||
"start_time": "2024-09-20T14:02:50.946378Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_4 = \"\"\"\n",
|
||||
"CV 4: Not Relevant\n",
|
||||
"Name: David Thompson\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: david.thompson@example.com\n",
|
||||
"Phone: (555) 456-7890\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"B.F.A. in Graphic Design, Rhode Island School of Design (2012)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Graphic Designer, CreativeWorks Agency (2015 – Present)\n",
|
||||
"Led design projects for clients in various industries.\n",
|
||||
"Created branding materials that increased client engagement by 30%.\n",
|
||||
"Graphic Designer, Visual Innovations (2012 – 2015)\n",
|
||||
"Designed marketing collateral, including brochures, logos, and websites.\n",
|
||||
"Collaborated with the marketing team to develop cohesive brand strategies.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Design Software: Adobe Photoshop, Illustrator, InDesign\n",
|
||||
"Web Design: HTML, CSS\n",
|
||||
"Specialties: Branding and Identity, Typography\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca4ecc32721ad332",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:51.548191Z",
|
||||
"start_time": "2024-09-20T14:02:51.545520Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_5 = \"\"\"\n",
|
||||
"CV 5: Not Relevant\n",
|
||||
"Name: Jessica Miller\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: jessica.miller@example.com\n",
|
||||
"Phone: (555) 567-8901\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"B.A. in Business Administration, University of Southern California (2010)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Sales Manager, Global Enterprises (2015 – Present)\n",
|
||||
"Managed a sales team of 15 members, achieving a 20% increase in annual revenue.\n",
|
||||
"Developed sales strategies that expanded customer base by 25%.\n",
|
||||
"Sales Representative, Market Leaders Inc. (2010 – 2015)\n",
|
||||
"Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Sales Strategy and Planning\n",
|
||||
"Team Leadership and Development\n",
|
||||
"CRM Software: Salesforce, Zoho\n",
|
||||
"Negotiation and Relationship Building\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4415446a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Please add the necessary environment information bellow:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bce39dc6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Setting environment variables\n",
|
||||
"if \"GRAPHISTRY_USERNAME\" not in os.environ:\n",
|
||||
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
|
||||
"\n",
|
||||
"if \"GRAPHISTRY_PASSWORD\" not in os.environ:\n",
|
||||
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
|
||||
"\n",
|
||||
"if \"LLM_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
|
||||
"\n",
|
||||
"# \"neo4j\" or \"networkx\"\n",
|
||||
"os.environ[\"GRAPH_DATABASE_PROVIDER\"] = \"networkx\"\n",
|
||||
"# Not needed if using networkx\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
|
||||
"\n",
|
||||
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
||||
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"lancedb\"\n",
|
||||
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
|
||||
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
||||
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
||||
"\n",
|
||||
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
|
||||
"os.environ[\"DB_PROVIDER\"] = \"sqlite\"\n",
|
||||
"\n",
|
||||
"# Database name\n",
|
||||
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
|
||||
"\n",
|
||||
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
|
||||
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
||||
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
||||
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
||||
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f1a1dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reset the cognee system with the following command:\n",
|
||||
"\n",
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"await cognee.prune.prune_data()\n",
|
||||
"await cognee.prune.prune_system(metadata=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "383d6971",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### After we have defined and gathered our data let's add it to cognee "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "904df61ba484a8e5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:54.243987Z",
|
||||
"start_time": "2024-09-20T14:02:52.498195Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"await cognee.add([job_1, job_2, job_3, job_4, job_5, job_position], \"example\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0f15c5b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### All good, let's cognify it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7c431fdef4921ae0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:57.925667Z",
|
||||
"start_time": "2024-09-20T14:02:57.922353Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.shared.data_models import KnowledgeGraph\n",
|
||||
"from cognee.modules.data.models import Dataset, Data\n",
|
||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
" check_permissions_on_documents,\n",
|
||||
" classify_documents,\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
")\n",
|
||||
"from cognee.tasks.graph import extract_graph_from_data\n",
|
||||
"from cognee.tasks.storage import add_data_points\n",
|
||||
"from cognee.tasks.summarization import summarize_text\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_cognify_pipeline(dataset: Dataset, user: User = None):\n",
|
||||
" data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" cognee_config = get_cognify_config()\n",
|
||||
"\n",
|
||||
" tasks = [\n",
|
||||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
||||
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
|
||||
" Task(\n",
|
||||
" extract_graph_from_data, graph_model=KnowledgeGraph,\n",
|
||||
" task_config={\"batch_size\": 10}\n",
|
||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
||||
" Task(\n",
|
||||
" summarize_text,\n",
|
||||
" summarization_model=cognee_config.summarization_model,\n",
|
||||
" task_config={\"batch_size\": 10},\n",
|
||||
" ),\n",
|
||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline = run_tasks(tasks, data_documents)\n",
|
||||
"\n",
|
||||
" async for result in pipeline:\n",
|
||||
" print(result)\n",
|
||||
" except Exception as error:\n",
|
||||
" raise error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f0a91b99c6215e09",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:58.905774Z",
|
||||
"start_time": "2024-09-20T14:02:58.625915Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.modules.users.methods import get_default_user\n",
|
||||
"from cognee.modules.data.methods import get_datasets_by_name\n",
|
||||
"\n",
|
||||
"user = await get_default_user()\n",
|
||||
"\n",
|
||||
"datasets = await get_datasets_by_name([\"example\"], user.id)\n",
|
||||
"\n",
|
||||
"await run_cognify_pipeline(datasets[0], user)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "219a6d41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We get the url to the graph on graphistry in the notebook cell bellow, showing nodes and connections made by the cognify process."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "080389e5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from cognee.shared.utils import render_graph\n",
|
||||
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
|
||||
"import graphistry\n",
|
||||
"\n",
|
||||
"graphistry.login(\n",
|
||||
" username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"graph_engine = await get_graph_engine()\n",
|
||||
"\n",
|
||||
"graph_url = await render_graph(graph_engine.graph)\n",
|
||||
"print(graph_url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "59e6c3c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We can also do a search on the data to explore the knowledge."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5e7dfc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def search(\n",
|
||||
" vector_engine,\n",
|
||||
" collection_name: str,\n",
|
||||
" query_text: str = None,\n",
|
||||
"):\n",
|
||||
" query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]\n",
|
||||
"\n",
|
||||
" connection = await vector_engine.get_connection()\n",
|
||||
" collection = await connection.open_table(collection_name)\n",
|
||||
"\n",
|
||||
" results = await collection.vector_search(query_vector).limit(10).to_pandas()\n",
|
||||
"\n",
|
||||
" result_values = list(results.to_dict(\"index\").values())\n",
|
||||
"\n",
|
||||
" return [\n",
|
||||
" dict(\n",
|
||||
" id=str(result[\"id\"]),\n",
|
||||
" payload=result[\"payload\"],\n",
|
||||
" score=result[\"_distance\"],\n",
|
||||
" )\n",
|
||||
" for result in result_values\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
|
||||
"\n",
|
||||
"vector_engine = get_vector_engine()\n",
|
||||
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"for result in results:\n",
|
||||
" print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "81fa2b00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We normalize search output scores so the lower the score of the search result is the higher the chance that it's what you're looking for. In the example above we have searched for node entities in the knowledge graph related to \"sarah.nguyen@example.com\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1b94ff96",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In the example bellow we'll use cognee search to summarize information regarding the node most related to \"sarah.nguyen@example.com\" in the knowledge graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "21a3e9a6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"\n",
|
||||
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"text\"]\n",
|
||||
"\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text=node_name)\n",
|
||||
"print(\"\\n\\Extracted summaries are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fd6e5fe2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In this example we'll use cognee search to find chunks in which the node most related to \"sarah.nguyen@example.com\" is a part of"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c7a8abff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=node_name)\n",
|
||||
"print(\"\\n\\nExtracted chunks are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "47f0112f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In this example we'll use cognee search to give us insights from the knowledge graph related to the node most related to \"sarah.nguyen@example.com\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "706a3954",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=node_name)\n",
|
||||
"print(\"\\n\\nExtracted sentences are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e519e30c0423c2a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Let's add evals"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3845443e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cognee[deepeval]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a2c3c70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from evals.eval_on_hotpot import deepeval_answers, answer_qa_instance\n",
|
||||
"from evals.qa_dataset_utils import load_qa_dataset\n",
|
||||
"from evals.qa_metrics_utils import get_metrics\n",
|
||||
"from evals.qa_context_provider_utils import qa_context_providers\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import statistics\n",
|
||||
"import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53a609d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_samples = 10 # With cognee, it takes ~1m10s per sample\n",
|
||||
"dataset_name_or_filename = \"hotpotqa\"\n",
|
||||
"dataset = load_qa_dataset(dataset_name_or_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7351ab8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"context_provider_name = \"cognee\"\n",
|
||||
"context_provider = qa_context_providers[context_provider_name]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9346115b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"random.seed(42)\n",
|
||||
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
|
||||
"\n",
|
||||
"out_path = \"out\"\n",
|
||||
"if not Path(out_path).exists():\n",
|
||||
" Path(out_path).mkdir()\n",
|
||||
"contexts_filename = out_path / Path(\n",
|
||||
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"answers = []\n",
|
||||
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
|
||||
" answer = await answer_qa_instance(instance, context_provider, contexts_filename)\n",
|
||||
" answers.append(answer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e7d872d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define Metrics for Evaluation and Calculate Score\n",
|
||||
"**Options**: \n",
|
||||
"- **Correctness**: Is the actual output factually correct based on the expected output?\n",
|
||||
"- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?\n",
|
||||
"- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?\n",
|
||||
"- **Empowerment**: How well does the answer help the reader understand and make informed judgements about the topic?\n",
|
||||
"- **Directness**: How specifically and clearly does the answer address the question?\n",
|
||||
"- **F1 Score**: the harmonic mean of the precision and recall, using word-level Exact Match\n",
|
||||
"- **EM Score**: the rate at which the predicted strings exactly match their references, ignoring white spaces and capitalization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c81e2b46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculate `\"Correctness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae728344",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Correctness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "764aac6d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Correctness = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Correctness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d3bbdc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Comprehensiveness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9793ef78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Comprehensiveness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9add448a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Comprehensiveness = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Comprehensiveness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bce2fa25",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Diversity\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f60a179e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Diversity\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ccbd0ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(Diversity)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "191cab63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating`\"Empowerment\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "66bec0bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Empowerment\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b043a8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Empowerment = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Empowerment)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cac3be9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Directness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adaa17c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Directness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a8f97c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(Directness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ad6feb8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"F1\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bdc48259",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"F1\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c43c17c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8bfcc46d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(F1_score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2583f948",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"EM\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "90a8f630",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"EM\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d1b1ea1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(EM)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "288ab570",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Give us a star if you like it!\n",
|
||||
"https://github.com/topoteretes/cognee"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "cognee-c83GrcRT-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue