fix: Return search backward compatibility
This commit is contained in:
parent
89207780e9
commit
e5381e110f
7 changed files with 31 additions and 24 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional, Union, List
|
||||
from typing import Optional, Union, List, Any
|
||||
from datetime import datetime
|
||||
from pydantic import Field
|
||||
from fastapi import Depends, APIRouter
|
||||
|
|
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
|
|||
|
|
@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
|||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return [
|
||||
Edge(
|
||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
attributes={
|
||||
**connection[1],
|
||||
"relationship_type": connection[1]["relationship_name"],
|
||||
},
|
||||
)
|
||||
for connection in unique_node_connections
|
||||
]
|
||||
return unique_node_connections
|
||||
# return [
|
||||
# Edge(
|
||||
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
# attributes={
|
||||
# **connection[1],
|
||||
# "relationship_type": connection[1]["relationship_name"],
|
||||
# },
|
||||
# )
|
||||
# for connection in unique_node_connections
|
||||
# ]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -134,21 +134,27 @@ async def search(
|
|||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return_value = {}
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value[str(datasets[0].id)] = {
|
||||
"search_result": result,
|
||||
"dataset_id": str(datasets[0].id),
|
||||
}
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": result,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
else:
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
|
||||
return return_value
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
# return [
|
||||
# SearchResult(
|
||||
# search_result=result,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue