feat: adds multitenant tests via pytest (#1923)
<!-- .github/pull_request_template.md -->
## Description
This PR changes the permission test in e2e tests to use pytest.
Introduces:
- fixtures for the environment setup
- one eventloop for all pytest tests
- mocking for acreate_structured_output answer generation (for search)
- Asserts in permission test (before we use the example only)
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Entity model now includes description and metadata fields for richer
entity information and indexing.
* **Tests**
* Expanded and restructured permission tests covering multi-tenant and
role-based access flows; improved test scaffolding and stability.
* E2E test workflow now runs pytest with verbose output and INFO logs.
* **Bug Fixes**
* Access-tracking updates now commit transactions so access timestamps
persist.
* **Chores**
* General formatting, cleanup, and refactoring across modules and
maintenance scripts.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
16cf955497
12 changed files with 669 additions and 702 deletions
2
.github/workflows/e2e_tests.yml
vendored
2
.github/workflows/e2e_tests.yml
vendored
|
|
@ -288,7 +288,7 @@ jobs:
|
||||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_permissions.py
|
run: uv run pytest cognee/tests/test_permissions.py -v --log-level=INFO
|
||||||
|
|
||||||
test-multi-tenancy:
|
test-multi-tenancy:
|
||||||
name: Test multi tenancy with different situations in Cognee
|
name: Test multi tenancy with different situations in Cognee
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,51 @@
|
||||||
"""add_last_accessed_to_data
|
"""add_last_accessed_to_data
|
||||||
|
|
||||||
Revision ID: e1ec1dcb50b6
|
Revision ID: e1ec1dcb50b6
|
||||||
Revises: 211ab850ef3d
|
Revises: 211ab850ef3d
|
||||||
Create Date: 2025-11-04 21:45:52.642322
|
Create Date: 2025-11-04 21:45:52.642322
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
from typing import Sequence, Union
|
import os
|
||||||
|
from typing import Sequence, Union
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'e1ec1dcb50b6'
|
# revision identifiers, used by Alembic.
|
||||||
down_revision: Union[str, None] = '211ab850ef3d'
|
revision: str = "e1ec1dcb50b6"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
down_revision: Union[str, None] = "a1b2c3d4e5f6"
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
def _get_column(inspector, table, name, schema=None):
|
|
||||||
for col in inspector.get_columns(table, schema=schema):
|
def _get_column(inspector, table, name, schema=None):
|
||||||
if col["name"] == name:
|
for col in inspector.get_columns(table, schema=schema):
|
||||||
return col
|
if col["name"] == name:
|
||||||
return None
|
return col
|
||||||
|
return None
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
conn = op.get_bind()
|
def upgrade() -> None:
|
||||||
insp = sa.inspect(conn)
|
conn = op.get_bind()
|
||||||
|
insp = sa.inspect(conn)
|
||||||
last_accessed_column = _get_column(insp, "data", "last_accessed")
|
|
||||||
if not last_accessed_column:
|
last_accessed_column = _get_column(insp, "data", "last_accessed")
|
||||||
# Always create the column for schema consistency
|
if not last_accessed_column:
|
||||||
op.add_column('data',
|
# Always create the column for schema consistency
|
||||||
sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True)
|
op.add_column("data", sa.Column("last_accessed", sa.DateTime(timezone=True), nullable=True))
|
||||||
)
|
|
||||||
|
# Only initialize existing records if feature is enabled
|
||||||
# Only initialize existing records if feature is enabled
|
enable_last_accessed = os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true"
|
||||||
enable_last_accessed = os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true"
|
if enable_last_accessed:
|
||||||
if enable_last_accessed:
|
op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP")
|
||||||
op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP")
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
def downgrade() -> None:
|
conn = op.get_bind()
|
||||||
conn = op.get_bind()
|
insp = sa.inspect(conn)
|
||||||
insp = sa.inspect(conn)
|
|
||||||
|
last_accessed_column = _get_column(insp, "data", "last_accessed")
|
||||||
last_accessed_column = _get_column(insp, "data", "last_accessed")
|
if last_accessed_column:
|
||||||
if last_accessed_column:
|
op.drop_column("data", "last_accessed")
|
||||||
op.drop_column('data', 'last_accessed')
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.engine.models.EntityType import EntityType
|
from cognee.modules.engine.models.EntityType import EntityType
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class Entity(DataPoint):
|
class Entity(DataPoint):
|
||||||
name: str
|
name: str
|
||||||
is_a: Optional[EntityType] = None
|
is_a: Optional[EntityType] = None
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
|
||||||
def get_entity_nodes_from_triplets(triplets):
|
def get_entity_nodes_from_triplets(triplets):
|
||||||
entity_nodes = []
|
entity_nodes = []
|
||||||
seen_ids = set()
|
seen_ids = set()
|
||||||
for triplet in triplets:
|
for triplet in triplets:
|
||||||
if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids:
|
if hasattr(triplet, "node1") and triplet.node1 and triplet.node1.id not in seen_ids:
|
||||||
entity_nodes.append({"id": str(triplet.node1.id)})
|
entity_nodes.append({"id": str(triplet.node1.id)})
|
||||||
seen_ids.add(triplet.node1.id)
|
seen_ids.add(triplet.node1.id)
|
||||||
if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids:
|
if hasattr(triplet, "node2") and triplet.node2 and triplet.node2.id not in seen_ids:
|
||||||
entity_nodes.append({"id": str(triplet.node2.id)})
|
entity_nodes.append({"id": str(triplet.node2.id)})
|
||||||
seen_ids.add(triplet.node2.id)
|
seen_ids.add(triplet.node2.id)
|
||||||
|
|
||||||
return entity_nodes
|
return entity_nodes
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
logger = get_logger("ChunksRetriever")
|
logger = get_logger("ChunksRetriever")
|
||||||
|
|
||||||
|
|
@ -28,7 +28,7 @@ class ChunksRetriever(BaseRetriever):
|
||||||
):
|
):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieves document chunks context based on the query.
|
Retrieves document chunks context based on the query.
|
||||||
Searches for document chunks relevant to the specified query using a vector engine.
|
Searches for document chunks relevant to the specified query using a vector engine.
|
||||||
|
|
|
||||||
|
|
@ -148,8 +148,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
# context = await self.resolve_edges_to_text(triplets)
|
# context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
entity_nodes = get_entity_nodes_from_triplets(triplets)
|
entity_nodes = get_entity_nodes_from_triplets(triplets)
|
||||||
|
|
||||||
await update_node_access_timestamps(entity_nodes)
|
await update_node_access_timestamps(entity_nodes)
|
||||||
return triplets
|
return triplets
|
||||||
|
|
||||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,9 @@ class SummariesRetriever(BaseRetriever):
|
||||||
"TextSummary_text", query, limit=self.top_k
|
"TextSummary_text", query, limit=self.top_k
|
||||||
)
|
)
|
||||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||||
|
|
||||||
await update_node_access_timestamps(summaries_results)
|
await update_node_access_timestamps(summaries_results)
|
||||||
|
|
||||||
except CollectionNotFoundError as error:
|
except CollectionNotFoundError as error:
|
||||||
logger.error("TextSummary_text collection not found in vector database")
|
logger.error("TextSummary_text collection not found in vector database")
|
||||||
raise NoDataError("No data found in the system, please add data first.") from error
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
|
||||||
|
|
@ -1,82 +1,88 @@
|
||||||
"""Utilities for tracking data access in retrievers."""
|
"""Utilities for tracking data access in retrievers."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def update_node_access_timestamps(items: List[Any]):
|
async def update_node_access_timestamps(items: List[Any]):
|
||||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
||||||
return
|
return
|
||||||
|
|
||||||
if not items:
|
if not items:
|
||||||
return
|
return
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
timestamp_dt = datetime.now(timezone.utc)
|
timestamp_dt = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Extract node IDs
|
# Extract node IDs
|
||||||
node_ids = []
|
node_ids = []
|
||||||
for item in items:
|
for item in items:
|
||||||
item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id")
|
item_id = item.payload.get("id") if hasattr(item, "payload") else item.get("id")
|
||||||
if item_id:
|
if item_id:
|
||||||
node_ids.append(str(item_id))
|
node_ids.append(str(item_id))
|
||||||
|
|
||||||
if not node_ids:
|
if not node_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Focus on document-level tracking via projection
|
# Focus on document-level tracking via projection
|
||||||
try:
|
try:
|
||||||
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
|
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
|
||||||
if doc_ids:
|
if doc_ids:
|
||||||
await _update_sql_records(doc_ids, timestamp_dt)
|
await _update_sql_records(doc_ids, timestamp_dt)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update SQL timestamps: {e}")
|
logger.error(f"Failed to update SQL timestamps: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _find_origin_documents_via_projection(graph_engine, node_ids):
|
async def _find_origin_documents_via_projection(graph_engine, node_ids):
|
||||||
"""Find origin documents using graph projection instead of DB queries"""
|
"""Find origin documents using graph projection instead of DB queries"""
|
||||||
# Project the entire graph with necessary properties
|
# Project the entire graph with necessary properties
|
||||||
memory_fragment = CogneeGraph()
|
memory_fragment = CogneeGraph()
|
||||||
await memory_fragment.project_graph_from_db(
|
await memory_fragment.project_graph_from_db(
|
||||||
graph_engine,
|
graph_engine,
|
||||||
node_properties_to_project=["id", "type"],
|
node_properties_to_project=["id", "type"],
|
||||||
edge_properties_to_project=["relationship_name"]
|
edge_properties_to_project=["relationship_name"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find origin documents by traversing the in-memory graph
|
# Find origin documents by traversing the in-memory graph
|
||||||
doc_ids = set()
|
doc_ids = set()
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
node = memory_fragment.get_node(node_id)
|
node = memory_fragment.get_node(node_id)
|
||||||
if node and node.get_attribute("type") == "DocumentChunk":
|
if node and node.get_attribute("type") == "DocumentChunk":
|
||||||
# Traverse edges to find connected documents
|
# Traverse edges to find connected documents
|
||||||
for edge in node.get_skeleton_edges():
|
for edge in node.get_skeleton_edges():
|
||||||
# Get the neighbor node
|
# Get the neighbor node
|
||||||
neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node()
|
neighbor = (
|
||||||
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
|
edge.get_destination_node()
|
||||||
doc_ids.add(neighbor.id)
|
if edge.get_source_node().id == node_id
|
||||||
|
else edge.get_source_node()
|
||||||
return list(doc_ids)
|
)
|
||||||
|
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
|
||||||
|
doc_ids.add(neighbor.id)
|
||||||
async def _update_sql_records(doc_ids, timestamp_dt):
|
|
||||||
"""Update SQL Data table (same for all providers)"""
|
return list(doc_ids)
|
||||||
db_engine = get_relational_engine()
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
stmt = update(Data).where(
|
async def _update_sql_records(doc_ids, timestamp_dt):
|
||||||
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
|
"""Update SQL Data table (same for all providers)"""
|
||||||
).values(last_accessed=timestamp_dt)
|
db_engine = get_relational_engine()
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
await session.execute(stmt)
|
stmt = (
|
||||||
|
update(Data)
|
||||||
|
.where(Data.id.in_([UUID(doc_id) for doc_id in doc_ids]))
|
||||||
|
.values(last_accessed=timestamp_dt)
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
||||||
|
|
@ -1,187 +1,172 @@
|
||||||
"""
|
"""
|
||||||
Task for automatically deleting unused data from the memify pipeline.
|
Task for automatically deleting unused data from the memify pipeline.
|
||||||
|
|
||||||
This task identifies and removes entire documents that haven't
|
This task identifies and removes entire documents that haven't
|
||||||
been accessed by retrievers for a specified period, helping maintain system
|
been accessed by retrievers for a specified period, helping maintain system
|
||||||
efficiency and storage optimization through whole-document removal.
|
efficiency and storage optimization through whole-document removal.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.models import Data, DatasetData
|
from cognee.modules.data.models import Data, DatasetData
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from sqlalchemy import select, or_
|
from sqlalchemy import select, or_
|
||||||
import cognee
|
import cognee
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_unused_data(
|
async def cleanup_unused_data(
|
||||||
minutes_threshold: Optional[int],
|
minutes_threshold: Optional[int], dry_run: bool = True, user_id: Optional[UUID] = None
|
||||||
dry_run: bool = True,
|
) -> Dict[str, Any]:
|
||||||
user_id: Optional[UUID] = None
|
"""
|
||||||
) -> Dict[str, Any]:
|
Identify and remove unused data from the memify pipeline.
|
||||||
"""
|
|
||||||
Identify and remove unused data from the memify pipeline.
|
Parameters
|
||||||
|
----------
|
||||||
Parameters
|
minutes_threshold : int
|
||||||
----------
|
Minutes since last access to consider data unused
|
||||||
minutes_threshold : int
|
dry_run : bool
|
||||||
Minutes since last access to consider data unused
|
If True, only report what would be deleted without actually deleting (default: True)
|
||||||
dry_run : bool
|
user_id : UUID, optional
|
||||||
If True, only report what would be deleted without actually deleting (default: True)
|
Limit cleanup to specific user's data (default: None)
|
||||||
user_id : UUID, optional
|
|
||||||
Limit cleanup to specific user's data (default: None)
|
Returns
|
||||||
|
-------
|
||||||
Returns
|
Dict[str, Any]
|
||||||
-------
|
Cleanup results with status, counts, and timestamp
|
||||||
Dict[str, Any]
|
"""
|
||||||
Cleanup results with status, counts, and timestamp
|
# Check 1: Environment variable must be enabled
|
||||||
"""
|
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
||||||
# Check 1: Environment variable must be enabled
|
logger.warning("Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled.")
|
||||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
return {
|
||||||
logger.warning(
|
"status": "skipped",
|
||||||
"Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled."
|
"reason": "ENABLE_LAST_ACCESSED not enabled",
|
||||||
)
|
"unused_count": 0,
|
||||||
return {
|
"deleted_count": {},
|
||||||
"status": "skipped",
|
"cleanup_date": datetime.now(timezone.utc).isoformat(),
|
||||||
"reason": "ENABLE_LAST_ACCESSED not enabled",
|
}
|
||||||
"unused_count": 0,
|
|
||||||
"deleted_count": {},
|
# Check 2: Verify tracking has actually been running
|
||||||
"cleanup_date": datetime.now(timezone.utc).isoformat()
|
db_engine = get_relational_engine()
|
||||||
}
|
async with db_engine.get_async_session() as session:
|
||||||
|
# Count records with non-NULL last_accessed
|
||||||
# Check 2: Verify tracking has actually been running
|
tracked_count = await session.execute(
|
||||||
db_engine = get_relational_engine()
|
select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None))
|
||||||
async with db_engine.get_async_session() as session:
|
)
|
||||||
# Count records with non-NULL last_accessed
|
tracked_records = tracked_count.scalar()
|
||||||
tracked_count = await session.execute(
|
|
||||||
select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None))
|
if tracked_records == 0:
|
||||||
)
|
logger.warning(
|
||||||
tracked_records = tracked_count.scalar()
|
"Cleanup skipped: No records have been tracked yet. "
|
||||||
|
"ENABLE_LAST_ACCESSED may have been recently enabled. "
|
||||||
if tracked_records == 0:
|
"Wait for retrievers to update timestamps before running cleanup."
|
||||||
logger.warning(
|
)
|
||||||
"Cleanup skipped: No records have been tracked yet. "
|
return {
|
||||||
"ENABLE_LAST_ACCESSED may have been recently enabled. "
|
"status": "skipped",
|
||||||
"Wait for retrievers to update timestamps before running cleanup."
|
"reason": "No tracked records found - tracking may be newly enabled",
|
||||||
)
|
"unused_count": 0,
|
||||||
return {
|
"deleted_count": {},
|
||||||
"status": "skipped",
|
"cleanup_date": datetime.now(timezone.utc).isoformat(),
|
||||||
"reason": "No tracked records found - tracking may be newly enabled",
|
}
|
||||||
"unused_count": 0,
|
|
||||||
"deleted_count": {},
|
logger.info(
|
||||||
"cleanup_date": datetime.now(timezone.utc).isoformat()
|
"Starting cleanup task",
|
||||||
}
|
minutes_threshold=minutes_threshold,
|
||||||
|
dry_run=dry_run,
|
||||||
logger.info(
|
user_id=str(user_id) if user_id else None,
|
||||||
"Starting cleanup task",
|
)
|
||||||
minutes_threshold=minutes_threshold,
|
|
||||||
dry_run=dry_run,
|
# Calculate cutoff timestamp
|
||||||
user_id=str(user_id) if user_id else None
|
cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold)
|
||||||
)
|
|
||||||
|
# Document-level approach (recommended)
|
||||||
# Calculate cutoff timestamp
|
return await _cleanup_via_sql(cutoff_date, dry_run, user_id)
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold)
|
|
||||||
|
|
||||||
# Document-level approach (recommended)
|
async def _cleanup_via_sql(
|
||||||
return await _cleanup_via_sql(cutoff_date, dry_run, user_id)
|
cutoff_date: datetime, dry_run: bool, user_id: Optional[UUID] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
async def _cleanup_via_sql(
|
SQL-based cleanup: Query Data table for unused documents and use cognee.delete().
|
||||||
cutoff_date: datetime,
|
|
||||||
dry_run: bool,
|
Parameters
|
||||||
user_id: Optional[UUID] = None
|
----------
|
||||||
) -> Dict[str, Any]:
|
cutoff_date : datetime
|
||||||
"""
|
Cutoff date for last_accessed filtering
|
||||||
SQL-based cleanup: Query Data table for unused documents and use cognee.delete().
|
dry_run : bool
|
||||||
|
If True, only report what would be deleted
|
||||||
Parameters
|
user_id : UUID, optional
|
||||||
----------
|
Filter by user ID if provided
|
||||||
cutoff_date : datetime
|
|
||||||
Cutoff date for last_accessed filtering
|
Returns
|
||||||
dry_run : bool
|
-------
|
||||||
If True, only report what would be deleted
|
Dict[str, Any]
|
||||||
user_id : UUID, optional
|
Cleanup results
|
||||||
Filter by user ID if provided
|
"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
Returns
|
|
||||||
-------
|
async with db_engine.get_async_session() as session:
|
||||||
Dict[str, Any]
|
# Query for Data records with old last_accessed timestamps
|
||||||
Cleanup results
|
query = (
|
||||||
"""
|
select(Data, DatasetData)
|
||||||
db_engine = get_relational_engine()
|
.join(DatasetData, Data.id == DatasetData.data_id)
|
||||||
|
.where(or_(Data.last_accessed < cutoff_date, Data.last_accessed.is_(None)))
|
||||||
async with db_engine.get_async_session() as session:
|
)
|
||||||
# Query for Data records with old last_accessed timestamps
|
|
||||||
query = select(Data, DatasetData).join(
|
if user_id:
|
||||||
DatasetData, Data.id == DatasetData.data_id
|
from cognee.modules.data.models import Dataset
|
||||||
).where(
|
|
||||||
or_(
|
query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where(
|
||||||
Data.last_accessed < cutoff_date,
|
Dataset.owner_id == user_id
|
||||||
Data.last_accessed.is_(None)
|
)
|
||||||
)
|
|
||||||
)
|
result = await session.execute(query)
|
||||||
|
unused_data = result.all()
|
||||||
if user_id:
|
|
||||||
from cognee.modules.data.models import Dataset
|
logger.info(f"Found {len(unused_data)} unused documents in SQL")
|
||||||
query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where(
|
|
||||||
Dataset.owner_id == user_id
|
if dry_run:
|
||||||
)
|
return {
|
||||||
|
"status": "dry_run",
|
||||||
result = await session.execute(query)
|
"unused_count": len(unused_data),
|
||||||
unused_data = result.all()
|
"deleted_count": {"data_items": 0, "documents": 0},
|
||||||
|
"cleanup_date": datetime.now(timezone.utc).isoformat(),
|
||||||
logger.info(f"Found {len(unused_data)} unused documents in SQL")
|
"preview": {"documents": len(unused_data)},
|
||||||
|
}
|
||||||
if dry_run:
|
|
||||||
return {
|
# Delete each document using cognee.delete()
|
||||||
"status": "dry_run",
|
deleted_count = 0
|
||||||
"unused_count": len(unused_data),
|
from cognee.modules.users.methods import get_default_user
|
||||||
"deleted_count": {
|
|
||||||
"data_items": 0,
|
user = await get_default_user() if user_id is None else None
|
||||||
"documents": 0
|
|
||||||
},
|
for data, dataset_data in unused_data:
|
||||||
"cleanup_date": datetime.now(timezone.utc).isoformat(),
|
try:
|
||||||
"preview": {
|
await cognee.delete(
|
||||||
"documents": len(unused_data)
|
data_id=data.id,
|
||||||
}
|
dataset_id=dataset_data.dataset_id,
|
||||||
}
|
mode="hard", # Use hard mode to also remove orphaned entities
|
||||||
|
user=user,
|
||||||
# Delete each document using cognee.delete()
|
)
|
||||||
deleted_count = 0
|
deleted_count += 1
|
||||||
from cognee.modules.users.methods import get_default_user
|
logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}")
|
||||||
user = await get_default_user() if user_id is None else None
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document {data.id}: {e}")
|
||||||
for data, dataset_data in unused_data:
|
|
||||||
try:
|
logger.info("Cleanup completed", deleted_count=deleted_count)
|
||||||
await cognee.delete(
|
|
||||||
data_id=data.id,
|
return {
|
||||||
dataset_id=dataset_data.dataset_id,
|
"status": "completed",
|
||||||
mode="hard", # Use hard mode to also remove orphaned entities
|
"unused_count": len(unused_data),
|
||||||
user=user
|
"deleted_count": {"data_items": deleted_count, "documents": deleted_count},
|
||||||
)
|
"cleanup_date": datetime.now(timezone.utc).isoformat(),
|
||||||
deleted_count += 1
|
|
||||||
logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document {data.id}: {e}")
|
|
||||||
|
|
||||||
logger.info("Cleanup completed", deleted_count=deleted_count)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed",
|
|
||||||
"unused_count": len(unused_data),
|
|
||||||
"deleted_count": {
|
|
||||||
"data_items": deleted_count,
|
|
||||||
"documents": deleted_count
|
|
||||||
},
|
|
||||||
"cleanup_date": datetime.now(timezone.utc).isoformat()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.chunking.models import DocumentChunk
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
|
|
||||||
|
|
@ -1,172 +1,165 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
from cognee.modules.data.models import Data, DatasetData
|
from cognee.modules.data.models import Data, DatasetData
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def test_textdocument_cleanup_with_sql():
|
async def test_textdocument_cleanup_with_sql():
|
||||||
"""
|
"""
|
||||||
End-to-end test for TextDocument cleanup based on last_accessed timestamps.
|
End-to-end test for TextDocument cleanup based on last_accessed timestamps.
|
||||||
"""
|
"""
|
||||||
# Enable last accessed tracking BEFORE any cognee operations
|
# Enable last accessed tracking BEFORE any cognee operations
|
||||||
os.environ["ENABLE_LAST_ACCESSED"] = "true"
|
os.environ["ENABLE_LAST_ACCESSED"] = "true"
|
||||||
|
|
||||||
# Setup test directories
|
# Setup test directories
|
||||||
data_directory_path = str(
|
data_directory_path = str(
|
||||||
pathlib.Path(
|
pathlib.Path(
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup")
|
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup")
|
||||||
).resolve()
|
).resolve()
|
||||||
)
|
)
|
||||||
cognee_directory_path = str(
|
cognee_directory_path = str(
|
||||||
pathlib.Path(
|
pathlib.Path(
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup")
|
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup")
|
||||||
).resolve()
|
).resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
cognee.config.system_root_directory(cognee_directory_path)
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
# Initialize database
|
# Initialize database
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
|
||||||
# Clean slate
|
# Clean slate
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
logger.info("🧪 Testing TextDocument cleanup based on last_accessed")
|
logger.info("🧪 Testing TextDocument cleanup based on last_accessed")
|
||||||
|
|
||||||
# Step 1: Add and cognify a test document
|
# Step 1: Add and cognify a test document
|
||||||
dataset_name = "test_cleanup_dataset"
|
dataset_name = "test_cleanup_dataset"
|
||||||
test_text = """
|
test_text = """
|
||||||
Machine learning is a subset of artificial intelligence that enables systems to learn
|
Machine learning is a subset of artificial intelligence that enables systems to learn
|
||||||
and improve from experience without being explicitly programmed. Deep learning uses
|
and improve from experience without being explicitly programmed. Deep learning uses
|
||||||
neural networks with multiple layers to process data.
|
neural networks with multiple layers to process data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
await cognee.add([test_text], dataset_name=dataset_name, user=user)
|
await cognee.add([test_text], dataset_name=dataset_name, user=user)
|
||||||
|
|
||||||
cognify_result = await cognee.cognify([dataset_name], user=user)
|
cognify_result = await cognee.cognify([dataset_name], user=user)
|
||||||
|
|
||||||
# Extract dataset_id from cognify result
|
# Extract dataset_id from cognify result
|
||||||
dataset_id = None
|
dataset_id = None
|
||||||
for ds_id, pipeline_result in cognify_result.items():
|
for ds_id, pipeline_result in cognify_result.items():
|
||||||
dataset_id = ds_id
|
dataset_id = ds_id
|
||||||
break
|
break
|
||||||
|
|
||||||
assert dataset_id is not None, "Failed to get dataset_id from cognify result"
|
assert dataset_id is not None, "Failed to get dataset_id from cognify result"
|
||||||
logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}")
|
logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}")
|
||||||
|
|
||||||
# Step 2: Perform search to trigger last_accessed update
|
# Step 2: Perform search to trigger last_accessed update
|
||||||
logger.info("Triggering search to update last_accessed...")
|
logger.info("Triggering search to update last_accessed...")
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
query_type=SearchType.CHUNKS,
|
query_type=SearchType.CHUNKS,
|
||||||
query_text="machine learning",
|
query_text="machine learning",
|
||||||
datasets=[dataset_name],
|
datasets=[dataset_name],
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
logger.info(f"✅ Search completed, found {len(search_results)} results")
|
logger.info(f"✅ Search completed, found {len(search_results)} results")
|
||||||
assert len(search_results) > 0, "Search should return results"
|
assert len(search_results) > 0, "Search should return results"
|
||||||
|
|
||||||
# Step 3: Verify last_accessed was set and get data_id
|
# Step 3: Verify last_accessed was set and get data_id
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Data, DatasetData)
|
select(Data, DatasetData)
|
||||||
.join(DatasetData, Data.id == DatasetData.data_id)
|
.join(DatasetData, Data.id == DatasetData.data_id)
|
||||||
.where(DatasetData.dataset_id == dataset_id)
|
.where(DatasetData.dataset_id == dataset_id)
|
||||||
)
|
)
|
||||||
data_records = result.all()
|
data_records = result.all()
|
||||||
assert len(data_records) > 0, "No Data records found for the dataset"
|
assert len(data_records) > 0, "No Data records found for the dataset"
|
||||||
data_record = data_records[0][0]
|
data_record = data_records[0][0]
|
||||||
data_id = data_record.id
|
data_id = data_record.id
|
||||||
|
|
||||||
# Verify last_accessed is set
|
# Verify last_accessed is set
|
||||||
assert data_record.last_accessed is not None, (
|
assert data_record.last_accessed is not None, (
|
||||||
"last_accessed should be set after search operation"
|
"last_accessed should be set after search operation"
|
||||||
)
|
)
|
||||||
|
|
||||||
original_last_accessed = data_record.last_accessed
|
original_last_accessed = data_record.last_accessed
|
||||||
logger.info(f"✅ last_accessed verified: {original_last_accessed}")
|
logger.info(f"✅ last_accessed verified: {original_last_accessed}")
|
||||||
|
|
||||||
# Step 4: Manually age the timestamp
|
# Step 4: Manually age the timestamp
|
||||||
minutes_threshold = 30
|
minutes_threshold = 30
|
||||||
aged_timestamp = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold + 10)
|
aged_timestamp = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold + 10)
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp)
|
stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Verify timestamp was updated
|
# Verify timestamp was updated
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
result = await session.execute(select(Data).where(Data.id == data_id))
|
result = await session.execute(select(Data).where(Data.id == data_id))
|
||||||
updated_data = result.scalar_one_or_none()
|
updated_data = result.scalar_one_or_none()
|
||||||
assert updated_data is not None, "Data record should exist"
|
assert updated_data is not None, "Data record should exist"
|
||||||
retrieved_timestamp = updated_data.last_accessed
|
retrieved_timestamp = updated_data.last_accessed
|
||||||
if retrieved_timestamp.tzinfo is None:
|
if retrieved_timestamp.tzinfo is None:
|
||||||
retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc)
|
retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc)
|
||||||
assert retrieved_timestamp == aged_timestamp, (
|
assert retrieved_timestamp == aged_timestamp, "Timestamp should be updated to aged value"
|
||||||
f"Timestamp should be updated to aged value"
|
|
||||||
)
|
# Step 5: Test cleanup (document-level is now the default)
|
||||||
|
from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data
|
||||||
# Step 5: Test cleanup (document-level is now the default)
|
|
||||||
from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data
|
# First do a dry run
|
||||||
|
logger.info("Testing dry run...")
|
||||||
# First do a dry run
|
dry_run_result = await cleanup_unused_data(minutes_threshold=10, dry_run=True, user_id=user.id)
|
||||||
logger.info("Testing dry run...")
|
|
||||||
dry_run_result = await cleanup_unused_data(
|
# Debug: Print the actual result
|
||||||
minutes_threshold=10,
|
logger.info(f"Dry run result: {dry_run_result}")
|
||||||
dry_run=True,
|
|
||||||
user_id=user.id
|
assert dry_run_result["status"] == "dry_run", (
|
||||||
)
|
f"Status should be 'dry_run', got: {dry_run_result['status']}"
|
||||||
|
)
|
||||||
# Debug: Print the actual result
|
assert dry_run_result["unused_count"] > 0, "Should find at least one unused document"
|
||||||
logger.info(f"Dry run result: {dry_run_result}")
|
logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents")
|
||||||
|
|
||||||
assert dry_run_result['status'] == 'dry_run', f"Status should be 'dry_run', got: {dry_run_result['status']}"
|
# Now run actual cleanup
|
||||||
assert dry_run_result['unused_count'] > 0, (
|
logger.info("Executing cleanup...")
|
||||||
"Should find at least one unused document"
|
cleanup_result = await cleanup_unused_data(minutes_threshold=30, dry_run=False, user_id=user.id)
|
||||||
)
|
|
||||||
logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents")
|
assert cleanup_result["status"] == "completed", "Cleanup should complete successfully"
|
||||||
|
assert cleanup_result["deleted_count"]["documents"] > 0, (
|
||||||
# Now run actual cleanup
|
"At least one document should be deleted"
|
||||||
logger.info("Executing cleanup...")
|
)
|
||||||
cleanup_result = await cleanup_unused_data(
|
logger.info(
|
||||||
minutes_threshold=30,
|
f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents"
|
||||||
dry_run=False,
|
)
|
||||||
user_id=user.id
|
|
||||||
)
|
# Step 6: Verify deletion
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
assert cleanup_result["status"] == "completed", "Cleanup should complete successfully"
|
deleted_data = (
|
||||||
assert cleanup_result["deleted_count"]["documents"] > 0, (
|
await session.execute(select(Data).where(Data.id == data_id))
|
||||||
"At least one document should be deleted"
|
).scalar_one_or_none()
|
||||||
)
|
assert deleted_data is None, "Data record should be deleted"
|
||||||
logger.info(f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents")
|
logger.info("✅ Confirmed: Data record was deleted")
|
||||||
|
|
||||||
# Step 6: Verify deletion
|
logger.info("🎉 All cleanup tests passed!")
|
||||||
async with db_engine.get_async_session() as session:
|
return True
|
||||||
deleted_data = (
|
|
||||||
await session.execute(select(Data).where(Data.id == data_id))
|
|
||||||
).scalar_one_or_none()
|
if __name__ == "__main__":
|
||||||
assert deleted_data is None, "Data record should be deleted"
|
import asyncio
|
||||||
logger.info("✅ Confirmed: Data record was deleted")
|
|
||||||
|
success = asyncio.run(test_textdocument_cleanup_with_sql())
|
||||||
logger.info("🎉 All cleanup tests passed!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
success = asyncio.run(test_textdocument_cleanup_with_sql())
|
|
||||||
exit(0 if success else 1)
|
exit(0 if success else 1)
|
||||||
|
|
|
||||||
|
|
@ -1,227 +1,212 @@
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import cognee
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
import pytest
|
||||||
from cognee.shared.logging_utils import get_logger
|
import pytest_asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
from cognee.modules.engine.operations.setup import setup as engine_setup
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user, create_user
|
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||||
|
from cognee.modules.users.methods import create_user, get_user
|
||||||
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
|
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
|
||||||
from cognee.modules.data.methods import get_dataset_data
|
from cognee.modules.users.roles.methods import add_user_to_role, create_role
|
||||||
|
from cognee.modules.users.tenants.methods import (
|
||||||
|
add_user_to_tenant,
|
||||||
|
create_tenant,
|
||||||
|
select_tenant,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def _extract_dataset_id_from_cognify(cognify_result: dict):
|
||||||
# Enable permissions feature
|
"""Extract dataset_id from cognify output dictionary."""
|
||||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "True"
|
for dataset_id, _pipeline_result in cognify_result.items():
|
||||||
|
return dataset_id
|
||||||
|
return None
|
||||||
|
|
||||||
# Clean up test directories before starting
|
|
||||||
data_directory_path = str(
|
async def _reset_engines_and_prune() -> None:
|
||||||
pathlib.Path(
|
"""Reset db engine caches and prune data/system."""
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_permissions")
|
try:
|
||||||
).resolve()
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
)
|
|
||||||
cognee_directory_path = str(
|
vector_engine = get_vector_engine()
|
||||||
pathlib.Path(
|
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_permissions")
|
await vector_engine.engine.dispose(close=True)
|
||||||
).resolve()
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||||
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
|
create_relational_engine,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
create_graph_engine.cache_clear()
|
||||||
cognee.config.system_root_directory(cognee_directory_path)
|
create_vector_engine.cache_clear()
|
||||||
|
create_relational_engine.cache_clear()
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
explanation_file_path_nlp = os.path.join(
|
|
||||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add document for default user
|
@pytest.fixture(scope="module")
|
||||||
await cognee.add([explanation_file_path_nlp], dataset_name="NLP")
|
def event_loop():
|
||||||
default_user = await get_default_user()
|
"""Single event loop for this module (avoids cross-loop futures)."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
explanation_file_path_quantum = os.path.join(
|
|
||||||
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add document for test user
|
|
||||||
test_user = await create_user("user@example.com", "example")
|
|
||||||
await cognee.add([explanation_file_path_quantum], dataset_name="QUANTUM", user=test_user)
|
|
||||||
|
|
||||||
nlp_cognify_result = await cognee.cognify(["NLP"], user=default_user)
|
|
||||||
quantum_cognify_result = await cognee.cognify(["QUANTUM"], user=test_user)
|
|
||||||
|
|
||||||
# Extract dataset_ids from cognify results
|
|
||||||
def extract_dataset_id_from_cognify(cognify_result):
|
|
||||||
"""Extract dataset_id from cognify output dictionary"""
|
|
||||||
for dataset_id, pipeline_result in cognify_result.items():
|
|
||||||
return dataset_id # Return the first (and likely only) dataset_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get dataset IDs from cognify results
|
|
||||||
default_user_dataset_id = extract_dataset_id_from_cognify(nlp_cognify_result)
|
|
||||||
print("User is", default_user_dataset_id)
|
|
||||||
test_user_dataset_id = extract_dataset_id_from_cognify(quantum_cognify_result)
|
|
||||||
|
|
||||||
# Check if default_user can only see information from the NLP dataset
|
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
|
||||||
query_text="What is in the document?",
|
|
||||||
user=default_user,
|
|
||||||
)
|
|
||||||
assert len(search_results) == 1, "The search results list lenght is not one."
|
|
||||||
print("\n\nExtracted sentences are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
assert search_results[0]["dataset_name"] == "NLP", (
|
|
||||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if test_user can only see information from the QUANTUM dataset
|
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
|
||||||
query_text="What is in the document?",
|
|
||||||
user=test_user,
|
|
||||||
)
|
|
||||||
assert len(search_results) == 1, "The search results list lenght is not one."
|
|
||||||
print("\n\nExtracted sentences are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to add document with default_user to test_users dataset (test write permission enforcement)
|
|
||||||
add_error = False
|
|
||||||
try:
|
try:
|
||||||
await cognee.add(
|
yield loop
|
||||||
[explanation_file_path_nlp],
|
finally:
|
||||||
dataset_name="QUANTUM",
|
loop.close()
|
||||||
dataset_id=test_user_dataset_id,
|
|
||||||
user=default_user,
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def permissions_example_env(tmp_path_factory):
|
||||||
|
"""One-time environment setup for the permissions example test."""
|
||||||
|
# Ensure permissions feature is enabled (example requires it), but don't override if caller set it already.
|
||||||
|
os.environ.setdefault("ENABLE_BACKEND_ACCESS_CONTROL", "True")
|
||||||
|
|
||||||
|
root = tmp_path_factory.mktemp("permissions_example")
|
||||||
|
cognee.config.data_root_directory(str(root / "data"))
|
||||||
|
cognee.config.system_root_directory(str(root / "system"))
|
||||||
|
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
await engine_setup()
|
||||||
|
|
||||||
|
assert backend_access_control_enabled(), (
|
||||||
|
"Expected permissions to be enabled via ENABLE_BACKEND_ACCESS_CONTROL=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
await _reset_engines_and_prune()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_permissions_example_flow(permissions_example_env):
|
||||||
|
"""Pytest version of `examples/python/permissions_example.py` (same scenarios, asserts instead of prints)."""
|
||||||
|
# Patch LLM calls so GRAPH_COMPLETION can run without external API keys.
|
||||||
|
llm_patch = patch(
|
||||||
|
"cognee.infrastructure.llm.LLMGateway.LLMGateway.acreate_structured_output",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="MOCK_ANSWER",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve example data file path (repo-shipped PDF).
|
||||||
|
repo_root = pathlib.Path(__file__).resolve().parent
|
||||||
|
explanation_file_path = str(repo_root / "test_data" / "artificial-intelligence.pdf")
|
||||||
|
assert pathlib.Path(explanation_file_path).exists(), (
|
||||||
|
f"Expected example PDF to exist at {explanation_file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same QUANTUM text as in the example.
|
||||||
|
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||||
|
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages
|
||||||
|
this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the
|
||||||
|
preparation and manipulation of quantum states.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create user_1, add AI dataset.
|
||||||
|
user_1 = await create_user("user_1@example.com", "example")
|
||||||
|
await cognee.add([explanation_file_path], dataset_name="AI", user=user_1)
|
||||||
|
|
||||||
|
# Create user_2, add QUANTUM dataset.
|
||||||
|
user_2 = await create_user("user_2@example.com", "example")
|
||||||
|
await cognee.add([text], dataset_name="QUANTUM", user=user_2)
|
||||||
|
|
||||||
|
ai_cognify_result = await cognee.cognify(["AI"], user=user_1)
|
||||||
|
quantum_cognify_result = await cognee.cognify(["QUANTUM"], user=user_2)
|
||||||
|
|
||||||
|
ai_dataset_id = _extract_dataset_id_from_cognify(ai_cognify_result)
|
||||||
|
quantum_dataset_id = _extract_dataset_id_from_cognify(quantum_cognify_result)
|
||||||
|
assert ai_dataset_id is not None
|
||||||
|
assert quantum_dataset_id is not None
|
||||||
|
|
||||||
|
with llm_patch:
|
||||||
|
# user_1 can read own dataset.
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="What is in the document?",
|
||||||
|
user=user_1,
|
||||||
|
datasets=[ai_dataset_id],
|
||||||
)
|
)
|
||||||
except PermissionDeniedError:
|
assert isinstance(search_results, list) and len(search_results) == 1
|
||||||
add_error = True
|
assert search_results[0]["dataset_name"] == "AI"
|
||||||
assert add_error, "PermissionDeniedError was not raised during add as expected"
|
assert search_results[0]["search_result"] == ["MOCK_ANSWER"]
|
||||||
|
|
||||||
# Try to cognify with default_user the test_users dataset (test write permission enforcement)
|
# user_1 can't read dataset owned by user_2.
|
||||||
cognify_error = False
|
with pytest.raises(PermissionDeniedError):
|
||||||
try:
|
await cognee.search(
|
||||||
await cognee.cognify(datasets=[test_user_dataset_id], user=default_user)
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
except PermissionDeniedError:
|
query_text="What is in the document?",
|
||||||
cognify_error = True
|
user=user_1,
|
||||||
assert cognify_error, "PermissionDeniedError was not raised during cognify as expected"
|
datasets=[quantum_dataset_id],
|
||||||
|
)
|
||||||
|
|
||||||
# Try to add permission for a dataset default_user does not have share permission for
|
# user_1 can't add to user_2's dataset.
|
||||||
give_permission_error = False
|
with pytest.raises(PermissionDeniedError):
|
||||||
try:
|
await cognee.add([explanation_file_path], dataset_id=quantum_dataset_id, user=user_1)
|
||||||
|
|
||||||
|
# user_2 grants read permission to user_1 for QUANTUM dataset.
|
||||||
await authorized_give_permission_on_datasets(
|
await authorized_give_permission_on_datasets(
|
||||||
default_user.id,
|
user_1.id, [quantum_dataset_id], "read", user_2.id
|
||||||
[test_user_dataset_id],
|
|
||||||
"write",
|
|
||||||
default_user.id,
|
|
||||||
)
|
)
|
||||||
except PermissionDeniedError:
|
|
||||||
give_permission_error = True
|
|
||||||
assert give_permission_error, (
|
|
||||||
"PermissionDeniedError was not raised during assignment of permission as expected"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Actually give permission to default_user to write on test_users dataset
|
with llm_patch:
|
||||||
await authorized_give_permission_on_datasets(
|
# Now user_1 can read QUANTUM dataset via dataset_id.
|
||||||
default_user.id,
|
search_results = await cognee.search(
|
||||||
[test_user_dataset_id],
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
"write",
|
query_text="What is in the document?",
|
||||||
test_user.id,
|
user=user_1,
|
||||||
)
|
dataset_ids=[quantum_dataset_id],
|
||||||
|
)
|
||||||
|
assert isinstance(search_results, list) and len(search_results) == 1
|
||||||
|
assert search_results[0]["dataset_name"] == "QUANTUM"
|
||||||
|
assert search_results[0]["search_result"] == ["MOCK_ANSWER"]
|
||||||
|
|
||||||
# Add new data to test_users dataset from default_user
|
# Tenant + role scenario.
|
||||||
await cognee.add(
|
tenant_id = await create_tenant("CogneeLab", user_2.id)
|
||||||
[explanation_file_path_nlp],
|
await select_tenant(user_id=user_2.id, tenant_id=tenant_id)
|
||||||
dataset_name="QUANTUM",
|
role_id = await create_role(role_name="Researcher", owner_id=user_2.id)
|
||||||
dataset_id=test_user_dataset_id,
|
|
||||||
user=default_user,
|
|
||||||
)
|
|
||||||
await cognee.cognify(datasets=[test_user_dataset_id], user=default_user)
|
|
||||||
|
|
||||||
# Actually give permission to default_user to read on test_users dataset
|
user_3 = await create_user("user_3@example.com", "example")
|
||||||
await authorized_give_permission_on_datasets(
|
await add_user_to_tenant(user_id=user_3.id, tenant_id=tenant_id, owner_id=user_2.id)
|
||||||
default_user.id,
|
await add_user_to_role(user_id=user_3.id, role_id=role_id, owner_id=user_2.id)
|
||||||
[test_user_dataset_id],
|
await select_tenant(user_id=user_3.id, tenant_id=tenant_id)
|
||||||
"read",
|
|
||||||
test_user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if default_user can see from test_users datasets now
|
# Can't grant role permission on a dataset that isn't part of the active tenant.
|
||||||
search_results = await cognee.search(
|
with pytest.raises(PermissionDeniedError):
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
await authorized_give_permission_on_datasets(
|
||||||
query_text="What is in the document?",
|
role_id, [quantum_dataset_id], "read", user_2.id
|
||||||
user=default_user,
|
)
|
||||||
dataset_ids=[test_user_dataset_id],
|
|
||||||
)
|
|
||||||
assert len(search_results) == 1, "The search results list length is not one."
|
|
||||||
print("\n\nExtracted sentences are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
# Re-create QUANTUM dataset in CogneeLab tenant so role permissions can be assigned.
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
user_2 = await get_user(user_2.id) # refresh tenant context
|
||||||
)
|
await cognee.add([text], dataset_name="QUANTUM_COGNEE_LAB", user=user_2)
|
||||||
|
quantum_cognee_lab_cognify_result = await cognee.cognify(
|
||||||
# Check if default_user can only see information from both datasets now
|
["QUANTUM_COGNEE_LAB"], user=user_2
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
|
||||||
query_text="What is in the document?",
|
|
||||||
user=default_user,
|
|
||||||
)
|
|
||||||
assert len(search_results) == 2, "The search results list length is not two."
|
|
||||||
print("\n\nExtracted sentences are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
# Try deleting data from test_user dataset with default_user without delete permission
|
|
||||||
delete_error = False
|
|
||||||
try:
|
|
||||||
# Get the dataset data to find the ID of the first data item (text)
|
|
||||||
test_user_dataset_data = await get_dataset_data(test_user_dataset_id)
|
|
||||||
text_data_id = test_user_dataset_data[0].id
|
|
||||||
|
|
||||||
await cognee.delete(
|
|
||||||
data_id=text_data_id, dataset_id=test_user_dataset_id, user=default_user
|
|
||||||
)
|
)
|
||||||
except PermissionDeniedError:
|
quantum_cognee_lab_dataset_id = _extract_dataset_id_from_cognify(
|
||||||
delete_error = True
|
quantum_cognee_lab_cognify_result
|
||||||
|
)
|
||||||
|
assert quantum_cognee_lab_dataset_id is not None
|
||||||
|
|
||||||
assert delete_error, "PermissionDeniedError was not raised during delete operation as expected"
|
await authorized_give_permission_on_datasets(
|
||||||
|
role_id, [quantum_cognee_lab_dataset_id], "read", user_2.id
|
||||||
|
)
|
||||||
|
|
||||||
# Try deleting data from test_user dataset with test_user
|
with llm_patch:
|
||||||
# Get the dataset data to find the ID of the first data item (text)
|
# user_3 can read via role permission.
|
||||||
test_user_dataset_data = await get_dataset_data(test_user_dataset_id)
|
search_results = await cognee.search(
|
||||||
text_data_id = test_user_dataset_data[0].id
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="What is in the document?",
|
||||||
await cognee.delete(data_id=text_data_id, dataset_id=test_user_dataset_id, user=test_user)
|
user=user_3,
|
||||||
|
dataset_ids=[quantum_cognee_lab_dataset_id],
|
||||||
# Actually give permission to default_user to delete data for test_users dataset
|
)
|
||||||
await authorized_give_permission_on_datasets(
|
assert isinstance(search_results, list) and len(search_results) == 1
|
||||||
default_user.id,
|
assert search_results[0]["dataset_name"] == "QUANTUM_COGNEE_LAB"
|
||||||
[test_user_dataset_id],
|
assert search_results[0]["search_result"] == ["MOCK_ANSWER"]
|
||||||
"delete",
|
|
||||||
test_user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try deleting data from test_user dataset with default_user after getting delete permission
|
|
||||||
# Get the dataset data to find the ID of the remaining data item (explanation_file_path_nlp)
|
|
||||||
test_user_dataset_data = await get_dataset_data(test_user_dataset_id)
|
|
||||||
explanation_file_data_id = test_user_dataset_data[0].id
|
|
||||||
|
|
||||||
await cognee.delete(
|
|
||||||
data_id=explanation_file_data_id, dataset_id=test_user_dataset_id, user=default_user
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue