Merge pull request #9 from chinu0609/delete-last-acessed

fix: implementing deletion in search.py
This commit is contained in:
Chinmay Bhosale 2025-12-26 09:23:17 +05:30 committed by GitHub
commit b85a7fffe5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 64 additions and 32 deletions

View file

@ -1,5 +1,4 @@
from typing import Any, Optional from typing import Any, Optional
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
@ -49,7 +48,6 @@ class ChunksRetriever(BaseRetriever):
try: try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
logger.info(f"Found {len(found_chunks)} chunks from vector search") logger.info(f"Found {len(found_chunks)} chunks from vector search")
await update_node_access_timestamps(found_chunks)
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("DocumentChunk_text collection not found in vector database") logger.error("DocumentChunk_text collection not found in vector database")

View file

@ -8,7 +8,6 @@ from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history, save_conversation_history,
get_conversation_history, get_conversation_history,
) )
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@ -66,7 +65,6 @@ class CompletionRetriever(BaseRetriever):
if len(found_chunks) == 0: if len(found_chunks) == 0:
return "" return ""
await update_node_access_timestamps(found_chunks)
# Combine all chunks text returned from vector search (number of chunks is determined by top_k # Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
combined_context = "\n".join(chunks_payload) combined_context = "\n".join(chunks_payload)

View file

@ -16,7 +16,6 @@ from cognee.modules.retrieval.utils.session_cache import (
) )
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
from cognee.modules.retrieval.utils.models import CogneeUserInteraction from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
@ -149,7 +148,6 @@ class GraphCompletionRetriever(BaseGraphRetriever):
entity_nodes = get_entity_nodes_from_triplets(triplets) entity_nodes = get_entity_nodes_from_triplets(triplets)
await update_node_access_timestamps(entity_nodes)
return triplets return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):

View file

@ -4,7 +4,6 @@ from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("SummariesRetriever") logger = get_logger("SummariesRetriever")
@ -56,7 +55,6 @@ class SummariesRetriever(BaseRetriever):
) )
logger.info(f"Found {len(summaries_results)} summaries from vector search") logger.info(f"Found {len(summaries_results)} summaries from vector search")
await update_node_access_timestamps(summaries_results)
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("TextSummary_text collection not found in vector database") logger.error("TextSummary_text collection not found in vector database")

View file

@ -25,12 +25,17 @@ async def update_node_access_timestamps(items: List[Any]):
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
timestamp_dt = datetime.now(timezone.utc) timestamp_dt = datetime.now(timezone.utc)
# Extract node IDs # Extract node IDs - updated for graph node format
node_ids = [] node_ids = []
for item in items: for item in items:
item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") # Handle graph nodes from prepare_search_result (direct id attribute)
if item_id: if hasattr(item, 'id'):
node_ids.append(str(item_id)) node_ids.append(str(item.id))
# Fallback for original retriever format
elif hasattr(item, 'payload') and item.payload.get("id"):
node_ids.append(str(item.payload.get("id")))
elif isinstance(item, dict) and item.get("id"):
node_ids.append(str(item.get("id")))
if not node_ids: if not node_ids:
return return

View file

@ -28,6 +28,7 @@ from cognee import __version__ as cognee_version
from .get_search_type_tools import get_search_type_tools from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result from ..utils.prepare_search_result import prepare_search_result
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps # Import your function
logger = get_logger() logger = get_logger()
@ -49,6 +50,7 @@ async def search(
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
verbose: bool = False,
) -> Union[CombinedSearchResult, List[SearchResult]]: ) -> Union[CombinedSearchResult, List[SearchResult]]:
""" """
@ -75,9 +77,11 @@ async def search(
}, },
) )
actual_accessed_items = [] # Collect all accessed items here
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if backend_access_control_enabled(): if backend_access_control_enabled():
search_results = await authorized_search( raw_search_results = await authorized_search(
query_type=query_type, query_type=query_type,
query_text=query_text, query_text=query_text,
user=user, user=user,
@ -95,8 +99,19 @@ async def search(
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,
) )
if use_combined_context:
# raw_search_results is (completion, context, datasets)
_, context_data, _ = raw_search_results
if isinstance(context_data, list): # Expecting a list of Edge or similar
actual_accessed_items.extend(context_data)
# If context_data is a string, it's already textual and might not map to specific nodes for timestamp updates
else:
for result_tuple in raw_search_results:
_, context_data, _ = result_tuple
if isinstance(context_data, list): # Expecting a list of Edge or similar
actual_accessed_items.extend(context_data)
else: else:
search_results = [ raw_search_results = [
await no_access_control_search( await no_access_control_search(
query_type=query_type, query_type=query_type,
query_text=query_text, query_text=query_text,
@ -113,6 +128,15 @@ async def search(
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,
) )
] ]
# In this case, raw_search_results is a list containing a single tuple
if raw_search_results:
_, context_data, _ = raw_search_results[0]
if isinstance(context_data, list): # Expecting a list of Edge or similar
actual_accessed_items.extend(context_data)
# Call the update_node_access_timestamps function here
# Pass the collected actual_accessed_items
await update_node_access_timestamps(actual_accessed_items)
send_telemetry( send_telemetry(
"cognee.search EXECUTION COMPLETED", "cognee.search EXECUTION COMPLETED",
@ -123,6 +147,8 @@ async def search(
}, },
) )
search_results = raw_search_results
await log_result( await log_result(
query.id, query.id,
json.dumps( json.dumps(
@ -140,6 +166,7 @@ async def search(
) )
if use_combined_context: if use_combined_context:
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
prepared_search_results = await prepare_search_result( prepared_search_results = await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results search_results[0] if isinstance(search_results, list) else search_results
) )
@ -173,25 +200,30 @@ async def search(
datasets = prepared_search_results["datasets"] datasets = prepared_search_results["datasets"]
if only_context: if only_context:
return_value.append( search_result_dict = {
{ "search_result": [context] if context else None,
"search_result": [context] if context else None, "dataset_id": datasets[0].id,
"dataset_id": datasets[0].id, "dataset_name": datasets[0].name,
"dataset_name": datasets[0].name, "dataset_tenant_id": datasets[0].tenant_id,
"dataset_tenant_id": datasets[0].tenant_id, }
"graphs": graphs, if verbose:
} # Include graphs only in verbose mode
) search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
else: else:
return_value.append( search_result_dict = {
{ "search_result": [result] if result else None,
"search_result": [result] if result else None, "dataset_id": datasets[0].id,
"dataset_id": datasets[0].id, "dataset_name": datasets[0].name,
"dataset_name": datasets[0].name, "dataset_tenant_id": datasets[0].tenant_id,
"dataset_tenant_id": datasets[0].tenant_id, }
"graphs": graphs, if verbose:
} # Include graphs only in verbose mode
) search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
return return_value return return_value
else: else:
return_value = [] return_value = []
@ -319,6 +351,8 @@ async def authorized_search(
only_context=only_context, only_context=only_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
) )
return search_results return search_results
@ -438,3 +472,4 @@ async def search_in_datasets_context(
) )
return await asyncio.gather(*tasks) return await asyncio.gather(*tasks)