fix: Return search backward compatibility

This commit is contained in:
Igor Ilic 2025-09-11 21:08:13 +02:00
parent 89207780e9
commit e5381e110f
7 changed files with 31 additions and 24 deletions

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Optional, Union, List
from typing import Optional, Union, List, Any
from datetime import datetime
from pydantic import Field
from fastapi import Depends, APIRouter
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.

View file

@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]

View file

@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]

View file

@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""

View file

@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return [
Edge(
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
attributes={
**connection[1],
"relationship_type": connection[1]["relationship_name"],
},
)
for connection in unique_node_connections
]
return unique_node_connections
# return [
# Edge(
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
# attributes={
# **connection[1],
# "relationship_type": connection[1]["relationship_name"],
# },
# )
# for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""

View file

@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
)
return completion
return [completion]

View file

@ -134,21 +134,27 @@ async def search(
else:
# This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return_value = {}
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value[str(datasets[0].id)] = {
"search_result": result,
"dataset_id": str(datasets[0].id),
}
return_value.append(
{
"search_result": result,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
}
)
return return_value
else:
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
return return_value
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
# return [
# SearchResult(
# search_result=result,