Merge 1509512f27 into 2ef347f8fa
This commit is contained in:
commit
cf40657a9c
8 changed files with 372 additions and 301 deletions
|
|
@ -216,6 +216,11 @@ TOKENIZERS_PARALLELISM="false"
|
|||
# LITELLM Logging Level. Set to quiet down logging
|
||||
LITELLM_LOG="ERROR"
|
||||
|
||||
|
||||
# Enable or disable the last accessed timestamp tracking and cleanup functionality.
|
||||
|
||||
ENABLE_LAST_ACCESSED="false"
|
||||
|
||||
# Set this environment variable to disable sending telemetry data
|
||||
# TELEMETRY_DISABLED=1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
|
@ -51,7 +50,6 @@ class ChunksRetriever(BaseRetriever):
|
|||
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
||||
await update_node_access_timestamps(found_chunks)
|
||||
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("DocumentChunk_text collection not found in vector database")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from cognee.modules.retrieval.utils.session_cache import (
|
|||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
|
@ -68,7 +67,6 @@ class CompletionRetriever(BaseRetriever):
|
|||
|
||||
if len(found_chunks) == 0:
|
||||
return ""
|
||||
await update_node_access_timestamps(found_chunks)
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
combined_context = "\n".join(chunks_payload)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from cognee.modules.retrieval.utils.session_cache import (
|
|||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -149,7 +148,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
entity_nodes = get_entity_nodes_from_triplets(triplets)
|
||||
|
||||
await update_node_access_timestamps(entity_nodes)
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("SummariesRetriever")
|
||||
|
|
@ -56,8 +55,6 @@ class SummariesRetriever(BaseRetriever):
|
|||
)
|
||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||
|
||||
await update_node_access_timestamps(summaries_results)
|
||||
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("TextSummary_text collection not found in vector database")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
|
|
|||
|
|
@ -1,88 +1,87 @@
|
|||
"""Utilities for tracking data access in retrievers."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Any
|
||||
from uuid import UUID
|
||||
import os
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy import update
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def update_node_access_timestamps(items: List[Any]):
|
||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
||||
return
|
||||
|
||||
if not items:
|
||||
return
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
timestamp_dt = datetime.now(timezone.utc)
|
||||
|
||||
# Extract node IDs
|
||||
node_ids = []
|
||||
for item in items:
|
||||
item_id = item.payload.get("id") if hasattr(item, "payload") else item.get("id")
|
||||
if item_id:
|
||||
node_ids.append(str(item_id))
|
||||
|
||||
if not node_ids:
|
||||
return
|
||||
|
||||
# Focus on document-level tracking via projection
|
||||
try:
|
||||
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
|
||||
if doc_ids:
|
||||
await _update_sql_records(doc_ids, timestamp_dt)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update SQL timestamps: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _find_origin_documents_via_projection(graph_engine, node_ids):
|
||||
"""Find origin documents using graph projection instead of DB queries"""
|
||||
# Project the entire graph with necessary properties
|
||||
memory_fragment = CogneeGraph()
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=["id", "type"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
# Find origin documents by traversing the in-memory graph
|
||||
doc_ids = set()
|
||||
for node_id in node_ids:
|
||||
node = memory_fragment.get_node(node_id)
|
||||
if node and node.get_attribute("type") == "DocumentChunk":
|
||||
# Traverse edges to find connected documents
|
||||
for edge in node.get_skeleton_edges():
|
||||
# Get the neighbor node
|
||||
neighbor = (
|
||||
edge.get_destination_node()
|
||||
if edge.get_source_node().id == node_id
|
||||
else edge.get_source_node()
|
||||
)
|
||||
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
|
||||
doc_ids.add(neighbor.id)
|
||||
|
||||
return list(doc_ids)
|
||||
|
||||
|
||||
async def _update_sql_records(doc_ids, timestamp_dt):
|
||||
"""Update SQL Data table (same for all providers)"""
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = (
|
||||
update(Data)
|
||||
.where(Data.id.in_([UUID(doc_id) for doc_id in doc_ids]))
|
||||
.values(last_accessed=timestamp_dt)
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
"""Utilities for tracking data access in retrievers."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Any
|
||||
from uuid import UUID
|
||||
import os
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy import update
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def update_node_access_timestamps(items: List[Any]):
|
||||
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
|
||||
return
|
||||
|
||||
if not items:
|
||||
return
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
timestamp_dt = datetime.now(timezone.utc)
|
||||
|
||||
# Extract node IDs - updated for graph node format
|
||||
node_ids = []
|
||||
for item in items:
|
||||
# Handle graph nodes from prepare_search_result (direct id attribute)
|
||||
if hasattr(item, 'id'):
|
||||
node_ids.append(str(item.id))
|
||||
# Fallback for original retriever format
|
||||
elif hasattr(item, 'payload') and item.payload.get("id"):
|
||||
node_ids.append(str(item.payload.get("id")))
|
||||
elif isinstance(item, dict) and item.get("id"):
|
||||
node_ids.append(str(item.get("id")))
|
||||
|
||||
if not node_ids:
|
||||
return
|
||||
|
||||
# Focus on document-level tracking via projection
|
||||
try:
|
||||
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
|
||||
if doc_ids:
|
||||
await _update_sql_records(doc_ids, timestamp_dt)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update SQL timestamps: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _find_origin_documents_via_projection(graph_engine, node_ids):
|
||||
"""Find origin documents using graph projection instead of DB queries"""
|
||||
# Project the entire graph with necessary properties
|
||||
memory_fragment = CogneeGraph()
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=["id", "type"],
|
||||
edge_properties_to_project=["relationship_name"]
|
||||
)
|
||||
|
||||
# Find origin documents by traversing the in-memory graph
|
||||
doc_ids = set()
|
||||
for node_id in node_ids:
|
||||
node = memory_fragment.get_node(node_id)
|
||||
if node and node.get_attribute("type") == "DocumentChunk":
|
||||
# Traverse edges to find connected documents
|
||||
for edge in node.get_skeleton_edges():
|
||||
# Get the neighbor node
|
||||
neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node()
|
||||
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
|
||||
doc_ids.add(neighbor.id)
|
||||
|
||||
return list(doc_ids)
|
||||
|
||||
|
||||
async def _update_sql_records(doc_ids, timestamp_dt):
|
||||
"""Update SQL Data table (same for all providers)"""
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = update(Data).where(
|
||||
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
|
||||
).values(last_accessed=timestamp_dt)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ from cognee.context_global_variables import backend_access_control_enabled
|
|||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import (
|
||||
SearchResult,
|
||||
SearchType,
|
||||
SearchResultDataset,
|
||||
SearchResult,
|
||||
SearchType,
|
||||
)
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -26,6 +27,7 @@ from cognee import __version__ as cognee_version
|
|||
from .get_search_type_tools import get_search_type_tools
|
||||
from .no_access_control_search import no_access_control_search
|
||||
from ..utils.prepare_search_result import prepare_search_result
|
||||
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -43,10 +45,11 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
verbose=False,
|
||||
verbose: bool = False,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
|
||||
|
|
@ -73,9 +76,11 @@ async def search(
|
|||
},
|
||||
)
|
||||
|
||||
actual_accessed_items = [] # Collect all accessed items here
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if backend_access_control_enabled():
|
||||
search_results = await authorized_search(
|
||||
raw_search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
|
|
@ -92,8 +97,19 @@ async def search(
|
|||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
if use_combined_context:
|
||||
# raw_search_results is (completion, context, datasets)
|
||||
_, context_data, _ = raw_search_results
|
||||
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||
actual_accessed_items.extend(context_data)
|
||||
# If context_data is a string, it's already textual and might not map to specific nodes for timestamp updates
|
||||
else:
|
||||
for result_tuple in raw_search_results:
|
||||
_, context_data, _ = result_tuple
|
||||
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||
actual_accessed_items.extend(context_data)
|
||||
else:
|
||||
search_results = [
|
||||
raw_search_results = [
|
||||
await no_access_control_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -110,6 +126,15 @@ async def search(
|
|||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
]
|
||||
# In this case, raw_search_results is a list containing a single tuple
|
||||
if raw_search_results:
|
||||
_, context_data, _ = raw_search_results[0]
|
||||
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||
actual_accessed_items.extend(context_data)
|
||||
|
||||
# Call the update_node_access_timestamps function here
|
||||
# Pass the collected actual_accessed_items
|
||||
await update_node_access_timestamps(actual_accessed_items)
|
||||
|
||||
send_telemetry(
|
||||
"cognee.search EXECUTION COMPLETED",
|
||||
|
|
@ -120,6 +145,8 @@ async def search(
|
|||
},
|
||||
)
|
||||
|
||||
search_results = raw_search_results
|
||||
|
||||
await log_result(
|
||||
query.id,
|
||||
json.dumps(
|
||||
|
|
@ -130,48 +157,65 @@ async def search(
|
|||
user.id,
|
||||
)
|
||||
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
if use_combined_context:
|
||||
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
||||
prepared_search_results = await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
search_result_dict = {
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
else:
|
||||
search_result_dict = {
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
return_value.append(search_result_dict)
|
||||
|
||||
return return_value
|
||||
return CombinedSearchResult(
|
||||
result=result,
|
||||
graphs=graphs,
|
||||
context=context,
|
||||
datasets=[
|
||||
SearchResultDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
)
|
||||
for dataset in datasets
|
||||
],
|
||||
)
|
||||
else:
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
return_value.append(prepared_search_results["context"])
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
search_result_dict = {
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
else:
|
||||
search_result_dict = {
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
|
||||
return return_value
|
||||
else:
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
|
|
|
|||
|
|
@ -1,165 +1,197 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, update
|
||||
from cognee.modules.data.models import Data, DatasetData
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def test_textdocument_cleanup_with_sql():
|
||||
"""
|
||||
End-to-end test for TextDocument cleanup based on last_accessed timestamps.
|
||||
"""
|
||||
# Enable last accessed tracking BEFORE any cognee operations
|
||||
os.environ["ENABLE_LAST_ACCESSED"] = "true"
|
||||
|
||||
# Setup test directories
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup")
|
||||
).resolve()
|
||||
)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Initialize database
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
# Clean slate
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
logger.info("🧪 Testing TextDocument cleanup based on last_accessed")
|
||||
|
||||
# Step 1: Add and cognify a test document
|
||||
dataset_name = "test_cleanup_dataset"
|
||||
test_text = """
|
||||
Machine learning is a subset of artificial intelligence that enables systems to learn
|
||||
and improve from experience without being explicitly programmed. Deep learning uses
|
||||
neural networks with multiple layers to process data.
|
||||
"""
|
||||
|
||||
await setup()
|
||||
user = await get_default_user()
|
||||
await cognee.add([test_text], dataset_name=dataset_name, user=user)
|
||||
|
||||
cognify_result = await cognee.cognify([dataset_name], user=user)
|
||||
|
||||
# Extract dataset_id from cognify result
|
||||
dataset_id = None
|
||||
for ds_id, pipeline_result in cognify_result.items():
|
||||
dataset_id = ds_id
|
||||
break
|
||||
|
||||
assert dataset_id is not None, "Failed to get dataset_id from cognify result"
|
||||
logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}")
|
||||
|
||||
# Step 2: Perform search to trigger last_accessed update
|
||||
logger.info("Triggering search to update last_accessed...")
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="machine learning",
|
||||
datasets=[dataset_name],
|
||||
user=user,
|
||||
)
|
||||
logger.info(f"✅ Search completed, found {len(search_results)} results")
|
||||
assert len(search_results) > 0, "Search should return results"
|
||||
|
||||
# Step 3: Verify last_accessed was set and get data_id
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data, DatasetData)
|
||||
.join(DatasetData, Data.id == DatasetData.data_id)
|
||||
.where(DatasetData.dataset_id == dataset_id)
|
||||
)
|
||||
data_records = result.all()
|
||||
assert len(data_records) > 0, "No Data records found for the dataset"
|
||||
data_record = data_records[0][0]
|
||||
data_id = data_record.id
|
||||
|
||||
# Verify last_accessed is set
|
||||
assert data_record.last_accessed is not None, (
|
||||
"last_accessed should be set after search operation"
|
||||
)
|
||||
|
||||
original_last_accessed = data_record.last_accessed
|
||||
logger.info(f"✅ last_accessed verified: {original_last_accessed}")
|
||||
|
||||
# Step 4: Manually age the timestamp
|
||||
minutes_threshold = 30
|
||||
aged_timestamp = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold + 10)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Verify timestamp was updated
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(select(Data).where(Data.id == data_id))
|
||||
updated_data = result.scalar_one_or_none()
|
||||
assert updated_data is not None, "Data record should exist"
|
||||
retrieved_timestamp = updated_data.last_accessed
|
||||
if retrieved_timestamp.tzinfo is None:
|
||||
retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc)
|
||||
assert retrieved_timestamp == aged_timestamp, "Timestamp should be updated to aged value"
|
||||
|
||||
# Step 5: Test cleanup (document-level is now the default)
|
||||
from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data
|
||||
|
||||
# First do a dry run
|
||||
logger.info("Testing dry run...")
|
||||
dry_run_result = await cleanup_unused_data(minutes_threshold=10, dry_run=True, user_id=user.id)
|
||||
|
||||
# Debug: Print the actual result
|
||||
logger.info(f"Dry run result: {dry_run_result}")
|
||||
|
||||
assert dry_run_result["status"] == "dry_run", (
|
||||
f"Status should be 'dry_run', got: {dry_run_result['status']}"
|
||||
)
|
||||
assert dry_run_result["unused_count"] > 0, "Should find at least one unused document"
|
||||
logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents")
|
||||
|
||||
# Now run actual cleanup
|
||||
logger.info("Executing cleanup...")
|
||||
cleanup_result = await cleanup_unused_data(minutes_threshold=30, dry_run=False, user_id=user.id)
|
||||
|
||||
assert cleanup_result["status"] == "completed", "Cleanup should complete successfully"
|
||||
assert cleanup_result["deleted_count"]["documents"] > 0, (
|
||||
"At least one document should be deleted"
|
||||
)
|
||||
logger.info(
|
||||
f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents"
|
||||
)
|
||||
|
||||
# Step 6: Verify deletion
|
||||
async with db_engine.get_async_session() as session:
|
||||
deleted_data = (
|
||||
await session.execute(select(Data).where(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
assert deleted_data is None, "Data record should be deleted"
|
||||
logger.info("✅ Confirmed: Data record was deleted")
|
||||
|
||||
logger.info("🎉 All cleanup tests passed!")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
success = asyncio.run(test_textdocument_cleanup_with_sql())
|
||||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, update
|
||||
from cognee.modules.data.models import Data, DatasetData
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def test_all_search_types_cleanup():
|
||||
"""
|
||||
End-to-end test for TextDocument cleanup based on last_accessed timestamps
|
||||
across all search types.
|
||||
"""
|
||||
# Enable last accessed tracking BEFORE any cognee operations
|
||||
os.environ["ENABLE_LAST_ACCESSED"] = "true"
|
||||
|
||||
# Setup test directories
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup")
|
||||
).resolve()
|
||||
)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Initialize database
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
# Clean slate
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
logger.info("🧪 Testing TextDocument cleanup for all search types")
|
||||
|
||||
# Step 1: Add and cognify a test document
|
||||
dataset_name = "test_cleanup_dataset"
|
||||
test_text = """
|
||||
Machine learning is a subset of artificial intelligence that enables systems to learn
|
||||
and improve from experience without being explicitly programmed. Deep learning uses
|
||||
neural networks with multiple layers to process data.
|
||||
"""
|
||||
|
||||
await setup()
|
||||
user = await get_default_user()
|
||||
await cognee.add([test_text], dataset_name=dataset_name, user=user)
|
||||
|
||||
cognify_result = await cognee.cognify([dataset_name], user=user)
|
||||
|
||||
# Extract dataset_id from cognify result
|
||||
dataset_id = None
|
||||
for ds_id, pipeline_result in cognify_result.items():
|
||||
dataset_id = ds_id
|
||||
break
|
||||
|
||||
assert dataset_id is not None, "Failed to get dataset_id from cognify result"
|
||||
logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}")
|
||||
|
||||
# All available search types to test (excluding CODE)
|
||||
search_types_to_test = [
|
||||
SearchType.CHUNKS,
|
||||
SearchType.SUMMARIES,
|
||||
SearchType.RAG_COMPLETION,
|
||||
SearchType.GRAPH_COMPLETION,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
SearchType.GRAPH_COMPLETION_COT,
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||
SearchType.FEELING_LUCKY,
|
||||
SearchType.CHUNKS_LEXICAL,
|
||||
]
|
||||
|
||||
# Skip search types that require special data or permissions
|
||||
skip_types = {
|
||||
SearchType.CYPHER, # Requires ALLOW_CYPHER_QUERY=true
|
||||
SearchType.NATURAL_LANGUAGE, # Requires ALLOW_CYPHER_QUERY=true
|
||||
SearchType.FEEDBACK, # Requires previous search interaction
|
||||
SearchType.TEMPORAL, # Requires temporal data
|
||||
SearchType.CODING_RULES, # Requires coding rules data
|
||||
}
|
||||
|
||||
tested_data_ids = []
|
||||
|
||||
# Test each search type
|
||||
for search_type in search_types_to_test:
|
||||
if search_type in skip_types:
|
||||
logger.info(f"⏭️ Skipping {search_type.value} (requires special setup)")
|
||||
continue
|
||||
|
||||
logger.info(f"🔍 Testing {search_type.value} search...")
|
||||
|
||||
try:
|
||||
# Perform search to trigger last_accessed update
|
||||
search_results = await cognee.search(
|
||||
query_type=search_type,
|
||||
query_text="machine learning",
|
||||
datasets=[dataset_name],
|
||||
user=user,
|
||||
)
|
||||
|
||||
logger.info(f"✅ {search_type.value} search completed, found {len(search_results)} results")
|
||||
|
||||
# Verify last_accessed was set
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Data, DatasetData)
|
||||
.join(DatasetData, Data.id == DatasetData.data_id)
|
||||
.where(DatasetData.dataset_id == dataset_id)
|
||||
)
|
||||
data_records = result.all()
|
||||
assert len(data_records) > 0, "No Data records found for the dataset"
|
||||
data_record = data_records[0][0]
|
||||
data_id = data_record.id
|
||||
|
||||
# Verify last_accessed is set
|
||||
assert data_record.last_accessed is not None, (
|
||||
f"last_accessed should be set after {search_type.value} search operation"
|
||||
)
|
||||
|
||||
original_last_accessed = data_record.last_accessed
|
||||
logger.info(f"✅ {search_type.value} last_accessed verified: {original_last_accessed}")
|
||||
|
||||
if data_id not in tested_data_ids:
|
||||
tested_data_ids.append(data_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ {search_type.value} search failed: {str(e)}")
|
||||
continue
|
||||
|
||||
# Step 3: Test cleanup with aged timestamps
|
||||
from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data
|
||||
|
||||
minutes_threshold = 30
|
||||
aged_timestamp = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold + 10)
|
||||
|
||||
# Age all tested data records
|
||||
db_engine = get_relational_engine()
|
||||
for data_id in tested_data_ids:
|
||||
async with db_engine.get_async_session() as session:
|
||||
stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# First do a dry run
|
||||
logger.info("Testing dry run...")
|
||||
dry_run_result = await cleanup_unused_data(minutes_threshold=10, dry_run=True, user_id=user.id)
|
||||
|
||||
logger.info(f"Dry run result: {dry_run_result}")
|
||||
|
||||
assert dry_run_result["status"] == "dry_run", (
|
||||
f"Status should be 'dry_run', got: {dry_run_result['status']}"
|
||||
)
|
||||
assert dry_run_result["unused_count"] > 0, "Should find at least one unused document"
|
||||
logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents")
|
||||
|
||||
# Now run actual cleanup
|
||||
logger.info("Executing cleanup...")
|
||||
cleanup_result = await cleanup_unused_data(minutes_threshold=30, dry_run=False, user_id=user.id)
|
||||
|
||||
assert cleanup_result["status"] == "completed", "Cleanup should complete successfully"
|
||||
assert cleanup_result["deleted_count"]["documents"] > 0, (
|
||||
"At least one document should be deleted"
|
||||
)
|
||||
logger.info(
|
||||
f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents"
|
||||
)
|
||||
|
||||
# Step 4: Verify deletion
|
||||
for data_id in tested_data_ids:
|
||||
async with db_engine.get_async_session() as session:
|
||||
deleted_data = (
|
||||
await session.execute(select(Data).where(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
assert deleted_data is None, f"Data record {data_id} should be deleted"
|
||||
|
||||
logger.info("✅ Confirmed: All tested data records were deleted")
|
||||
logger.info("🎉 All cleanup tests passed for all search types!")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
success = asyncio.run(test_all_search_types_cleanup())
|
||||
exit(0 if success else 1)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue