Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf1625f72e |
1 changed files with 106 additions and 80 deletions
|
|
@ -12,7 +12,7 @@ from uuid import UUID
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError
|
||||
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError, AuthorizationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -27,6 +27,7 @@ async def delete(
|
|||
data: Union[BinaryIO, List[BinaryIO], str, List[str]],
|
||||
dataset_name: str = "main_dataset",
|
||||
mode: str = "soft",
|
||||
user=None, # New parameter: user object or user_id required for authz
|
||||
):
|
||||
"""Delete a document and all its related nodes from both relational and graph databases.
|
||||
|
||||
|
|
@ -34,15 +35,19 @@ async def delete(
|
|||
data: The data to delete (file, URL, or text)
|
||||
dataset_name: Name of the dataset to delete from
|
||||
mode: "soft" (default) or "hard" - hard mode also deletes degree-one entity nodes
|
||||
user: The current authenticated user (required for authorization)
|
||||
"""
|
||||
|
||||
if user is None or not hasattr(user, "id"):
|
||||
raise AuthorizationError("Authentication required to perform delete operation.")
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(data, str):
|
||||
if data.startswith("file://"): # It's a file path
|
||||
with open(data.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = classify(file)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
return await delete_single_document(content_hash, dataset_name, mode, user)
|
||||
elif data.startswith("http"): # It's a URL
|
||||
import requests
|
||||
|
||||
|
|
@ -51,81 +56,122 @@ async def delete(
|
|||
file_data = BytesIO(response.content)
|
||||
classified_data = classify(file_data)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
return await delete_single_document(content_hash, dataset_name, mode, user)
|
||||
else: # It's a text string
|
||||
content_hash = get_text_content_hash(data)
|
||||
classified_data = classify(data)
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
return await delete_single_document(content_hash, dataset_name, mode, user)
|
||||
elif isinstance(data, list):
|
||||
# Handle list of inputs sequentially
|
||||
results = []
|
||||
for item in data:
|
||||
result = await delete(item, dataset_name, mode)
|
||||
result = await delete(item, dataset_name, mode, user)
|
||||
results.append(result)
|
||||
return {"status": "success", "message": "Multiple documents deleted", "results": results}
|
||||
else: # It's already a BinaryIO
|
||||
data.seek(0) # Ensure we're at the start of the file
|
||||
classified_data = classify(data)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset_name, mode)
|
||||
return await delete_single_document(content_hash, dataset_name, mode, user)
|
||||
|
||||
|
||||
async def delete_single_document(content_hash: str, dataset_name: str, mode: str = "soft"):
|
||||
"""Delete a single document by its content hash."""
|
||||
async def delete_single_document(content_hash: str, dataset_name: str, mode: str = "soft", user=None):
|
||||
"""Delete a single document by its content hash, after authorization."""
|
||||
|
||||
# Delete from graph database
|
||||
deletion_result = await delete_document_subgraph(content_hash, mode)
|
||||
if user is None or not hasattr(user, "id"):
|
||||
raise AuthorizationError("Authentication required to perform delete operation.")
|
||||
|
||||
logger.info(f"Deletion result: {deletion_result}")
|
||||
|
||||
# Get the deleted node IDs and convert to UUID
|
||||
deleted_node_ids = []
|
||||
for node_id in deletion_result["deleted_node_ids"]:
|
||||
try:
|
||||
# Handle both string and UUID formats
|
||||
if isinstance(node_id, str):
|
||||
# Remove any hyphens if present
|
||||
node_id = node_id.replace("-", "")
|
||||
deleted_node_ids.append(UUID(node_id))
|
||||
else:
|
||||
deleted_node_ids.append(node_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting node ID {node_id} to UUID: {e}")
|
||||
continue
|
||||
|
||||
# Delete from vector database
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Determine vector collections dynamically
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
# If no collections found, use default collections
|
||||
if not vector_collections:
|
||||
vector_collections = [
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
"EntityType_name",
|
||||
"Entity_name",
|
||||
"TextDocument_name",
|
||||
"TextSummary_text",
|
||||
]
|
||||
|
||||
# Delete records from each vector collection that exists
|
||||
for collection in vector_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
await vector_engine.delete_data_points(
|
||||
collection, [str(node_id) for node_id in deleted_node_ids]
|
||||
)
|
||||
|
||||
# Delete from relational database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Get the data point
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.content_hash == content_hash))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if data_point is None:
|
||||
raise DocumentNotFoundError(
|
||||
f"Document not found in relational DB with content hash: {content_hash}"
|
||||
)
|
||||
|
||||
# Get the dataset
|
||||
dataset = (
|
||||
await session.execute(select(Dataset).filter(Dataset.name == dataset_name))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(f"Dataset not found: {dataset_name}")
|
||||
|
||||
# AUTHORIZATION CHECKS -- Begin
|
||||
# Dataset must belong to (or be permitted for) the user
|
||||
if hasattr(dataset, "owner_id"):
|
||||
if dataset.owner_id != user.id:
|
||||
raise AuthorizationError("User is not authorized to delete from this dataset.")
|
||||
elif hasattr(dataset, "tenant_id") and hasattr(user, "tenant_id"):
|
||||
# Optional support for multi-tenant logic
|
||||
if dataset.tenant_id != user.tenant_id:
|
||||
raise AuthorizationError("User does not have access to this dataset.")
|
||||
# Data object must belong to this dataset (DatasetData) and be accessible by user
|
||||
if hasattr(data_point, "owner_id"):
|
||||
if data_point.owner_id != user.id:
|
||||
raise AuthorizationError("User is not authorized to delete this document.")
|
||||
elif hasattr(data_point, "tenant_id") and hasattr(user, "tenant_id"):
|
||||
if data_point.tenant_id != user.tenant_id:
|
||||
raise AuthorizationError("User does not have access to this document.")
|
||||
# AUTHORIZATION CHECKS -- End
|
||||
|
||||
doc_id = data_point.id
|
||||
|
||||
# Proceed with deletion (after authorization)
|
||||
|
||||
# Delete from graph database
|
||||
deletion_result = await delete_document_subgraph(content_hash, mode)
|
||||
|
||||
logger.info(f"Deletion result: {deletion_result}")
|
||||
|
||||
# Get the deleted node IDs and convert to UUID
|
||||
deleted_node_ids = []
|
||||
for node_id in deletion_result["deleted_node_ids"]:
|
||||
try:
|
||||
# Handle both string and UUID formats
|
||||
if isinstance(node_id, str):
|
||||
node_id = node_id.replace("-", "")
|
||||
deleted_node_ids.append(UUID(node_id))
|
||||
else:
|
||||
deleted_node_ids.append(node_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting node ID {node_id} to UUID: {e}")
|
||||
continue
|
||||
|
||||
# Delete from vector database
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Determine vector collections dynamically
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
# If no collections found, use default collections
|
||||
if not vector_collections:
|
||||
vector_collections = [
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
"EntityType_name",
|
||||
"Entity_name",
|
||||
"TextDocument_name",
|
||||
"TextSummary_text",
|
||||
]
|
||||
|
||||
# Delete records from each vector collection that exists
|
||||
for collection in vector_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
await vector_engine.delete_data_points(
|
||||
collection, [str(node_id) for node_id in deleted_node_ids]
|
||||
)
|
||||
|
||||
# Update graph_relationship_ledger with deleted_at timestamps
|
||||
from sqlalchemy import update, and_, or_
|
||||
from datetime import datetime
|
||||
|
|
@ -143,26 +189,6 @@ async def delete_single_document(content_hash: str, dataset_name: str, mode: str
|
|||
)
|
||||
await session.execute(update_stmt)
|
||||
|
||||
# Get the data point
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.content_hash == content_hash))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if data_point is None:
|
||||
raise DocumentNotFoundError(
|
||||
f"Document not found in relational DB with content hash: {content_hash}"
|
||||
)
|
||||
|
||||
doc_id = data_point.id
|
||||
|
||||
# Get the dataset
|
||||
dataset = (
|
||||
await session.execute(select(Dataset).filter(Dataset.name == dataset_name))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(f"Dataset not found: {dataset_name}")
|
||||
|
||||
# Delete from dataset_data table
|
||||
dataset_delete_stmt = sql_delete(DatasetData).where(
|
||||
DatasetData.data_id == doc_id, DatasetData.dataset_id == dataset.id
|
||||
|
|
@ -189,7 +215,7 @@ async def delete_single_document(content_hash: str, dataset_name: str, mode: str
|
|||
"dataset": dataset_name,
|
||||
"deleted_node_ids": [
|
||||
str(node_id) for node_id in deleted_node_ids
|
||||
], # Convert back to strings for response
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -244,4 +270,4 @@ async def delete_document_subgraph(content_hash: str, mode: str = "soft"):
|
|||
"deleted_counts": deleted_counts,
|
||||
"content_hash": content_hash,
|
||||
"deleted_node_ids": deleted_node_ids,
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue