From 07e67e268b6b8c383ad1184e49c49c9ad5d93e29 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 26 Dec 2025 09:04:39 +0530 Subject: [PATCH] fix: implementing deletion in search.py --- .../retrieval/utils/access_tracking.py | 13 ++- cognee/modules/search/methods/search.py | 106 ++++++++++++++---- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 54fd043b9..b2d98924a 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -25,12 +25,17 @@ async def update_node_access_timestamps(items: List[Any]): graph_engine = await get_graph_engine() timestamp_dt = datetime.now(timezone.utc) - # Extract node IDs + # Extract node IDs - updated for graph node format node_ids = [] for item in items: - item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") - if item_id: - node_ids.append(str(item_id)) + # Handle graph nodes from prepare_search_result (direct id attribute) + if hasattr(item, 'id'): + node_ids.append(str(item.id)) + # Fallback for original retriever format + elif hasattr(item, 'payload') and item.payload.get("id"): + node_ids.append(str(item.payload.get("id"))) + elif isinstance(item, dict) and item.get("id"): + node_ids.append(str(item.get("id"))) if not node_ids: return diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index b4278424b..8709a34a1 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -28,6 +28,7 @@ from cognee import __version__ as cognee_version from .get_search_type_tools import get_search_type_tools from .no_access_control_search import no_access_control_search from ..utils.prepare_search_result import prepare_search_result +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps # Import your function logger = get_logger() @@ -47,6 +48,9 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, + verbose: bool = False, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -73,9 +77,11 @@ async def search( }, ) + actual_accessed_items = [] # Collect all accessed items here + # Use search function filtered by permissions if access control is enabled if backend_access_control_enabled(): - search_results = await authorized_search( + raw_search_results = await authorized_search( query_type=query_type, query_text=query_text, user=user, @@ -90,9 +96,22 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) + if use_combined_context: + # raw_search_results is (completion, context, datasets) + _, context_data, _ = raw_search_results + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) + # If context_data is a string, it's already textual and might not map to specific nodes for timestamp updates + else: + for result_tuple in raw_search_results: + _, context_data, _ = result_tuple + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) else: - search_results = [ + raw_search_results = [ await no_access_control_search( query_type=query_type, query_text=query_text, @@ -105,8 +124,19 @@ async def search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ] + # In this case, raw_search_results is a list containing a single tuple + if raw_search_results: + _, context_data, _ = raw_search_results[0] + if isinstance(context_data, list): # Expecting a list of Edge or similar + actual_accessed_items.extend(context_data) + + # Call the update_node_access_timestamps function here + # Pass the collected actual_accessed_items + await update_node_access_timestamps(actual_accessed_items) send_telemetry( "cognee.search EXECUTION COMPLETED", @@ -117,6 +147,19 @@ async def search( }, ) + # The rest of the code for logging and preparing results remains largely the same + # Ensure search_results is correctly defined for the subsequent logging/preparation logic + # based on how it was processed in the if/else blocks above. + # For now, let's assume 'search_results' should refer to 'raw_search_results' + # for the purpose of this part of the code, or be re-structured to use the + # collected components for the final output. + + # This part needs careful adjustment based on the exact structure expected by prepare_search_result + # and the final return type. + # For simplicity here, let's re-assign search_results to raw_search_results for the original flow + search_results = raw_search_results + # ... rest of the original function ... + await log_result( query.id, json.dumps( @@ -134,6 +177,7 @@ async def search( ) if use_combined_context: + # Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info prepared_search_results = await prepare_search_result( search_results[0] if isinstance(search_results, list) else search_results ) @@ -167,25 +211,30 @@ async def search( datasets = prepared_search_results["datasets"] if only_context: - return_value.append( - { - "search_result": [context] if context else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) + search_result_dict = { + "search_result": [context] if context else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + } + if verbose: + # Include graphs only in verbose mode + search_result_dict["graphs"] = graphs + + return_value.append(search_result_dict) else: - return_value.append( - { - "search_result": [result] if result else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) + search_result_dict = { + "search_result": [result] if result else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + } + if verbose: + # Include graphs only in verbose mode + search_result_dict["graphs"] = graphs + + return_value.append(search_result_dict) + return return_value else: return_value = [] @@ -219,6 +268,8 @@ async def authorized_search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -246,6 +297,8 @@ async def authorized_search( last_k=last_k, only_context=True, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) context = {} @@ -267,6 +320,8 @@ async def authorized_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -306,6 +361,8 @@ async def authorized_search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return search_results @@ -325,6 +382,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -345,6 +404,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -378,6 +439,8 @@ async def search_in_datasets_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -413,7 +476,10 @@ async def search_in_datasets_context( only_context=only_context, context=context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ) return await asyncio.gather(*tasks) +