diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index b7a90238a..4ddc841fd 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -1,5 +1,4 @@ from typing import Any, Optional -from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever @@ -49,7 +48,6 @@ class ChunksRetriever(BaseRetriever): try: found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) logger.info(f"Found {len(found_chunks)} chunks from vector search") - await update_node_access_timestamps(found_chunks) except CollectionNotFoundError as error: logger.error("DocumentChunk_text collection not found in vector database") diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 0e9a4167c..58e4b03eb 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -8,7 +8,6 @@ from cognee.modules.retrieval.utils.session_cache import ( save_conversation_history, get_conversation_history, ) -from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError @@ -66,7 +65,6 @@ class CompletionRetriever(BaseRetriever): if len(found_chunks) == 0: return "" - await update_node_access_timestamps(found_chunks) # Combine all chunks text returned from vector search (number of chunks is determined by top_k chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] combined_context = "\n".join(chunks_payload) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 317d7cd9a..ede8386d4 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -16,7 +16,6 @@ from cognee.modules.retrieval.utils.session_cache import ( ) from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node -from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.modules.retrieval.utils.models import CogneeUserInteraction from cognee.modules.engine.models.node_set import NodeSet from cognee.infrastructure.databases.graph import get_graph_engine @@ -149,7 +148,6 @@ class GraphCompletionRetriever(BaseGraphRetriever): entity_nodes = get_entity_nodes_from_triplets(triplets) - await update_node_access_timestamps(entity_nodes) return triplets async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 0df750d22..e24f22e82 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -4,7 +4,6 @@ from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError logger = get_logger("SummariesRetriever") @@ -56,7 +55,6 @@ class SummariesRetriever(BaseRetriever): ) logger.info(f"Found {len(summaries_results)} summaries from vector search") - await update_node_access_timestamps(summaries_results) except CollectionNotFoundError as error: logger.error("TextSummary_text collection not found in vector database") diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 54fd043b9..b2d98924a 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -25,12 +25,17 @@ async def update_node_access_timestamps(items: List[Any]): graph_engine = await get_graph_engine() timestamp_dt = datetime.now(timezone.utc) - # Extract node IDs + # Extract node IDs - updated for graph node format node_ids = [] for item in items: - item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") - if item_id: - node_ids.append(str(item_id)) + # Handle graph nodes from prepare_search_result (direct id attribute) + if hasattr(item, 'id'): + node_ids.append(str(item.id)) + # Fallback for original retriever format + elif hasattr(item, 'payload') and item.payload.get("id"): + node_ids.append(str(item.payload.get("id"))) + elif isinstance(item, dict) and item.get("id"): + node_ids.append(str(item.get("id"))) if not node_ids: return diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 9f180d607..5b436d8e0 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -28,6 +28,7 @@ from cognee import __version__ as cognee_version from .get_search_type_tools import get_search_type_tools from .no_access_control_search import no_access_control_search from ..utils.prepare_search_result import prepare_search_result +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps # Import your function logger = get_logger() @@ -49,6 +50,7 @@ async def search( session_id: Optional[str] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, + verbose: bool = False, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -75,9 +77,11 @@ async def search( }, ) + actual_accessed_items = [] # Collect all accessed items here + # Use search function filtered by permissions if access control is enabled if backend_access_control_enabled(): - search_results = await authorized_search( + raw_search_results = await authorized_search( query_type=query_type, query_text=query_text, user=user, @@ -95,8 +99,19 @@ async def search( wide_search_top_k=wide_search_top_k, triplet_distance_penalty=triplet_distance_penalty, ) + if use_combined_context: + # raw_search_results is (completion, context, datasets) + _, context_data, _ = raw_search_results + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) + # If context_data is a string, it's already textual and might not map to specific nodes for timestamp updates + else: + for result_tuple in raw_search_results: + _, context_data, _ = result_tuple + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) else: - search_results = [ + raw_search_results = [ await no_access_control_search( query_type=query_type, query_text=query_text, @@ -113,6 +128,15 @@ async def search( triplet_distance_penalty=triplet_distance_penalty, ) ] + # In this case, raw_search_results is a list containing a single tuple + if raw_search_results: + _, context_data, _ = raw_search_results[0] + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) + + # Call the update_node_access_timestamps function here + # Pass the collected actual_accessed_items + await update_node_access_timestamps(actual_accessed_items) send_telemetry( "cognee.search EXECUTION COMPLETED", @@ -123,6 +147,8 @@ async def search( }, ) + search_results = raw_search_results + await log_result( query.id, json.dumps( @@ -140,6 +166,7 @@ async def search( ) if use_combined_context: + # Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info prepared_search_results = await prepare_search_result( search_results[0] if isinstance(search_results, list) else search_results ) @@ -173,25 +200,30 @@ async def search( datasets = prepared_search_results["datasets"] if only_context: - return_value.append( - { - "search_result": [context] if context else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) + search_result_dict = { + "search_result": [context] if context else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + } + if verbose: + # Include graphs only in verbose mode + search_result_dict["graphs"] = graphs + + return_value.append(search_result_dict) else: - return_value.append( - { - "search_result": [result] if result else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) + search_result_dict = { + "search_result": [result] if result else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + } + if verbose: + # Include graphs only in verbose mode + search_result_dict["graphs"] = graphs + + return_value.append(search_result_dict) + return return_value else: return_value = [] @@ -319,6 +351,8 @@ async def authorized_search( only_context=only_context, session_id=session_id, wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ) return search_results @@ -438,3 +472,4 @@ async def search_in_datasets_context( ) return await asyncio.gather(*tasks) +