diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index bb3eaba42..a93fce067 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -35,7 +35,7 @@ async def no_access_control_search( [get_completion, get_context] = search_tools if only_context: - return await get_context(query_text) + return None, await get_context(query_text), [] context = await get_context(query_text) result = await get_completion(query_text, context) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 65efafb4c..8bcef815f 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -143,20 +143,35 @@ async def search( context = prepared_search_results["context"] datasets = prepared_search_results["datasets"] - return_value.append( - { - "search_result": [result] if result else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "graphs": graphs, - } - ) + if only_context: + return_value.append( + { + "search_result": [context] if context else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "graphs": graphs, + } + ) + else: + return_value.append( + { + "search_result": [result] if result else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "graphs": graphs, + } + ) return return_value else: return_value = [] - for search_result in search_results: - result, context, datasets = search_result - return_value.append(result) + if only_context: + for search_result in search_results: + prepared_search_results = await prepare_search_result(search_result) + return_value.append(prepared_search_results["context"]) + else: + for search_result in search_results: + result, context, datasets = search_result + return_value.append(result) # For maintaining backwards compatibility if len(return_value) == 1 and isinstance(return_value[0], list): return return_value[0]