fix: Return search backward compatibility
This commit is contained in:
parent
89207780e9
commit
e5381e110f
7 changed files with 31 additions and 24 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from fastapi import Depends, APIRouter
|
from fastapi import Depends, APIRouter
|
||||||
|
|
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||||
|
|
||||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
|
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||||
"""
|
"""
|
||||||
Search for nodes in the graph database.
|
Search for nodes in the graph database.
|
||||||
|
|
|
||||||
|
|
@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
||||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
||||||
unique_node_connections_map[unique_id] = True
|
unique_node_connections_map[unique_id] = True
|
||||||
unique_node_connections.append(node_connection)
|
unique_node_connections.append(node_connection)
|
||||||
|
|
||||||
return [
|
return unique_node_connections
|
||||||
Edge(
|
# return [
|
||||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
# Edge(
|
||||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||||
attributes={
|
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||||
**connection[1],
|
# attributes={
|
||||||
"relationship_type": connection[1]["relationship_name"],
|
# **connection[1],
|
||||||
},
|
# "relationship_type": connection[1]["relationship_name"],
|
||||||
)
|
# },
|
||||||
for connection in unique_node_connections
|
# )
|
||||||
]
|
# for connection in unique_node_connections
|
||||||
|
# ]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -134,21 +134,27 @@ async def search(
|
||||||
else:
|
else:
|
||||||
# This is for maintaining backwards compatibility
|
# This is for maintaining backwards compatibility
|
||||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
return_value = {}
|
return_value = []
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
result, context, datasets = search_result
|
result, context, datasets = search_result
|
||||||
return_value[str(datasets[0].id)] = {
|
return_value.append(
|
||||||
"search_result": result,
|
{
|
||||||
"dataset_id": str(datasets[0].id),
|
"search_result": result,
|
||||||
}
|
"dataset_id": datasets[0].id,
|
||||||
|
"dataset_name": datasets[0].name,
|
||||||
|
}
|
||||||
|
)
|
||||||
return return_value
|
return return_value
|
||||||
else:
|
else:
|
||||||
return_value = []
|
return_value = []
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
result, context, datasets = search_result
|
result, context, datasets = search_result
|
||||||
return_value.append(result)
|
return_value.append(result)
|
||||||
|
# For maintaining backwards compatibility
|
||||||
return return_value
|
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||||
|
return return_value[0]
|
||||||
|
else:
|
||||||
|
return return_value
|
||||||
# return [
|
# return [
|
||||||
# SearchResult(
|
# SearchResult(
|
||||||
# search_result=result,
|
# search_result=result,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue