fix: implementing deletion in search.py
This commit is contained in:
parent
2485c3f5f0
commit
07e67e268b
2 changed files with 95 additions and 24 deletions
|
|
@ -25,12 +25,17 @@ async def update_node_access_timestamps(items: List[Any]):
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
timestamp_dt = datetime.now(timezone.utc)
|
timestamp_dt = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Extract node IDs
|
# Extract node IDs - updated for graph node format
|
||||||
node_ids = []
|
node_ids = []
|
||||||
for item in items:
|
for item in items:
|
||||||
item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id")
|
# Handle graph nodes from prepare_search_result (direct id attribute)
|
||||||
if item_id:
|
if hasattr(item, 'id'):
|
||||||
node_ids.append(str(item_id))
|
node_ids.append(str(item.id))
|
||||||
|
# Fallback for original retriever format
|
||||||
|
elif hasattr(item, 'payload') and item.payload.get("id"):
|
||||||
|
node_ids.append(str(item.payload.get("id")))
|
||||||
|
elif isinstance(item, dict) and item.get("id"):
|
||||||
|
node_ids.append(str(item.get("id")))
|
||||||
|
|
||||||
if not node_ids:
|
if not node_ids:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from cognee import __version__ as cognee_version
|
||||||
from .get_search_type_tools import get_search_type_tools
|
from .get_search_type_tools import get_search_type_tools
|
||||||
from .no_access_control_search import no_access_control_search
|
from .no_access_control_search import no_access_control_search
|
||||||
from ..utils.prepare_search_result import prepare_search_result
|
from ..utils.prepare_search_result import prepare_search_result
|
||||||
|
from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps # Import your function
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -47,6 +48,9 @@ async def search(
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
use_combined_context: bool = False,
|
use_combined_context: bool = False,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
wide_search_top_k: Optional[int] = 100,
|
||||||
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
|
verbose: bool = False,
|
||||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -73,9 +77,11 @@ async def search(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_accessed_items = [] # Collect all accessed items here
|
||||||
|
|
||||||
# Use search function filtered by permissions if access control is enabled
|
# Use search function filtered by permissions if access control is enabled
|
||||||
if backend_access_control_enabled():
|
if backend_access_control_enabled():
|
||||||
search_results = await authorized_search(
|
raw_search_results = await authorized_search(
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -90,9 +96,22 @@ async def search(
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
use_combined_context=use_combined_context,
|
use_combined_context=use_combined_context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
|
if use_combined_context:
|
||||||
|
# raw_search_results is (completion, context, datasets)
|
||||||
|
_, context_data, _ = raw_search_results
|
||||||
|
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||||
|
actual_accessed_items.extend(context_data)
|
||||||
|
# If context_data is a string, it's already textual and might not map to specific nodes for timestamp updates
|
||||||
|
else:
|
||||||
|
for result_tuple in raw_search_results:
|
||||||
|
_, context_data, _ = result_tuple
|
||||||
|
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||||
|
actual_accessed_items.extend(context_data)
|
||||||
else:
|
else:
|
||||||
search_results = [
|
raw_search_results = [
|
||||||
await no_access_control_search(
|
await no_access_control_search(
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
|
|
@ -105,8 +124,19 @@ async def search(
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
# In this case, raw_search_results is a list containing a single tuple
|
||||||
|
if raw_search_results:
|
||||||
|
_, context_data, _ = raw_search_results[0]
|
||||||
|
if isinstance(context_data, list): # Expecting a list of Edge or similar
|
||||||
|
actual_accessed_items.extend(context_data)
|
||||||
|
|
||||||
|
# Call the update_node_access_timestamps function here
|
||||||
|
# Pass the collected actual_accessed_items
|
||||||
|
await update_node_access_timestamps(actual_accessed_items)
|
||||||
|
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
"cognee.search EXECUTION COMPLETED",
|
"cognee.search EXECUTION COMPLETED",
|
||||||
|
|
@ -117,6 +147,19 @@ async def search(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The rest of the code for logging and preparing results remains largely the same
|
||||||
|
# Ensure search_results is correctly defined for the subsequent logging/preparation logic
|
||||||
|
# based on how it was processed in the if/else blocks above.
|
||||||
|
# For now, let's assume 'search_results' should refer to 'raw_search_results'
|
||||||
|
# for the purpose of this part of the code, or be re-structured to use the
|
||||||
|
# collected components for the final output.
|
||||||
|
|
||||||
|
# This part needs careful adjustment based on the exact structure expected by prepare_search_result
|
||||||
|
# and the final return type.
|
||||||
|
# For simplicity here, let's re-assign search_results to raw_search_results for the original flow
|
||||||
|
search_results = raw_search_results
|
||||||
|
# ... rest of the original function ...
|
||||||
|
|
||||||
await log_result(
|
await log_result(
|
||||||
query.id,
|
query.id,
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
|
@ -134,6 +177,7 @@ async def search(
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_combined_context:
|
if use_combined_context:
|
||||||
|
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
||||||
prepared_search_results = await prepare_search_result(
|
prepared_search_results = await prepare_search_result(
|
||||||
search_results[0] if isinstance(search_results, list) else search_results
|
search_results[0] if isinstance(search_results, list) else search_results
|
||||||
)
|
)
|
||||||
|
|
@ -167,25 +211,30 @@ async def search(
|
||||||
datasets = prepared_search_results["datasets"]
|
datasets = prepared_search_results["datasets"]
|
||||||
|
|
||||||
if only_context:
|
if only_context:
|
||||||
return_value.append(
|
search_result_dict = {
|
||||||
{
|
"search_result": [context] if context else None,
|
||||||
"search_result": [context] if context else None,
|
"dataset_id": datasets[0].id,
|
||||||
"dataset_id": datasets[0].id,
|
"dataset_name": datasets[0].name,
|
||||||
"dataset_name": datasets[0].name,
|
"dataset_tenant_id": datasets[0].tenant_id,
|
||||||
"dataset_tenant_id": datasets[0].tenant_id,
|
}
|
||||||
"graphs": graphs,
|
if verbose:
|
||||||
}
|
# Include graphs only in verbose mode
|
||||||
)
|
search_result_dict["graphs"] = graphs
|
||||||
|
|
||||||
|
return_value.append(search_result_dict)
|
||||||
else:
|
else:
|
||||||
return_value.append(
|
search_result_dict = {
|
||||||
{
|
"search_result": [result] if result else None,
|
||||||
"search_result": [result] if result else None,
|
"dataset_id": datasets[0].id,
|
||||||
"dataset_id": datasets[0].id,
|
"dataset_name": datasets[0].name,
|
||||||
"dataset_name": datasets[0].name,
|
"dataset_tenant_id": datasets[0].tenant_id,
|
||||||
"dataset_tenant_id": datasets[0].tenant_id,
|
}
|
||||||
"graphs": graphs,
|
if verbose:
|
||||||
}
|
# Include graphs only in verbose mode
|
||||||
)
|
search_result_dict["graphs"] = graphs
|
||||||
|
|
||||||
|
return_value.append(search_result_dict)
|
||||||
|
|
||||||
return return_value
|
return return_value
|
||||||
else:
|
else:
|
||||||
return_value = []
|
return_value = []
|
||||||
|
|
@ -219,6 +268,8 @@ async def authorized_search(
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
use_combined_context: bool = False,
|
use_combined_context: bool = False,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
wide_search_top_k: Optional[int] = 100,
|
||||||
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||||
|
|
@ -246,6 +297,8 @@ async def authorized_search(
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=True,
|
only_context=True,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = {}
|
context = {}
|
||||||
|
|
@ -267,6 +320,8 @@ async def authorized_search(
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
search_tools = specific_search_tools
|
search_tools = specific_search_tools
|
||||||
if len(search_tools) == 2:
|
if len(search_tools) == 2:
|
||||||
|
|
@ -306,6 +361,8 @@ async def authorized_search(
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
return search_results
|
return search_results
|
||||||
|
|
@ -325,6 +382,8 @@ async def search_in_datasets_context(
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
context: Optional[Any] = None,
|
context: Optional[Any] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
wide_search_top_k: Optional[int] = 100,
|
||||||
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||||
"""
|
"""
|
||||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||||
|
|
@ -345,6 +404,8 @@ async def search_in_datasets_context(
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
context: Optional[Any] = None,
|
context: Optional[Any] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
wide_search_top_k: Optional[int] = 100,
|
||||||
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||||
# Set database configuration in async context for each dataset user has access for
|
# Set database configuration in async context for each dataset user has access for
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
|
|
@ -378,6 +439,8 @@ async def search_in_datasets_context(
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
search_tools = specific_search_tools
|
search_tools = specific_search_tools
|
||||||
if len(search_tools) == 2:
|
if len(search_tools) == 2:
|
||||||
|
|
@ -413,7 +476,10 @@ async def search_in_datasets_context(
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
context=context,
|
context=context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue