diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 459b5f427..36d1c567e 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,5 +1,5 @@ from uuid import UUID -from typing import Optional, Union, List +from typing import Optional, Union, List, Any from datetime import datetime from pydantic import Field from fastapi import Depends, APIRouter @@ -73,7 +73,7 @@ def get_search_router() -> APIRouter: except Exception as error: return JSONResponse(status_code=500, content={"error": str(error)}) - @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult]) + @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List]) async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): """ Search for nodes in the graph database. diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4f4af1f06..f4c37fc93 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 282c6147e..f51433751 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 45e7f85ff..29b1e9d19 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/modules/retrieval/insights_retriever.py b/cognee/modules/retrieval/insights_retriever.py index 43b77e951..0b1991e92 100644 --- a/cognee/modules/retrieval/insights_retriever.py +++ b/cognee/modules/retrieval/insights_retriever.py @@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever): unique_node_connections_map[unique_id] = True unique_node_connections.append(node_connection) - return [ - Edge( - node1=Node(node_id=connection[0]["id"], attributes=connection[0]), - node2=Node(node_id=connection[2]["id"], attributes=connection[2]), - attributes={ - **connection[1], - "relationship_type": connection[1]["relationship_name"], - }, - ) - for connection in unique_node_connections - ] + return unique_node_connections + # return [ + # Edge( + # node1=Node(node_id=connection[0]["id"], attributes=connection[0]), + # node2=Node(node_id=connection[2]["id"], attributes=connection[2]), + # attributes={ + # **connection[1], + # "relationship_type": connection[1]["relationship_name"], + # }, + # ) + # for connection in unique_node_connections + # ] async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 09f2980dd..36cdbd33f 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, ) - return completion + return [completion] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 0bc845a10..0c236d896 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -134,21 +134,27 @@ async def search( else: # This is for maintaining backwards compatibility if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": - return_value = {} + return_value = [] for search_result in search_results: result, context, datasets = search_result - return_value[str(datasets[0].id)] = { - "search_result": result, - "dataset_id": str(datasets[0].id), - } + return_value.append( + { + "search_result": result, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + } + ) return return_value else: return_value = [] for search_result in search_results: result, context, datasets = search_result return_value.append(result) - - return return_value + # For maintaining backwards compatibility + if len(return_value) == 1 and isinstance(return_value[0], list): + return return_value[0] + else: + return return_value # return [ # SearchResult( # search_result=result,