fix: Return search backward compatibility

This commit is contained in:
Igor Ilic 2025-09-11 21:08:13 +02:00
parent 89207780e9
commit e5381e110f
7 changed files with 31 additions and 24 deletions

View file

@ -1,5 +1,5 @@
from uuid import UUID from uuid import UUID
from typing import Optional, Union, List from typing import Optional, Union, List, Any
from datetime import datetime from datetime import datetime
from pydantic import Field from pydantic import Field
from fastapi import Depends, APIRouter from fastapi import Depends, APIRouter
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error: except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)}) return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult]) @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" """
Search for nodes in the graph database. Search for nodes in the graph database.

View file

@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
""" """

View file

@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
unique_node_connections_map[unique_id] = True unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection) unique_node_connections.append(node_connection)
return [ return unique_node_connections
Edge( # return [
node1=Node(node_id=connection[0]["id"], attributes=connection[0]), # Edge(
node2=Node(node_id=connection[2]["id"], attributes=connection[2]), # node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
attributes={ # node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
**connection[1], # attributes={
"relationship_type": connection[1]["relationship_name"], # **connection[1],
}, # "relationship_type": connection[1]["relationship_name"],
) # },
for connection in unique_node_connections # )
] # for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """

View file

@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
return completion return [completion]

View file

@ -134,21 +134,27 @@ async def search(
else: else:
# This is for maintaining backwards compatibility # This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return_value = {} return_value = []
for search_result in search_results: for search_result in search_results:
result, context, datasets = search_result result, context, datasets = search_result
return_value[str(datasets[0].id)] = { return_value.append(
"search_result": result, {
"dataset_id": str(datasets[0].id), "search_result": result,
} "dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
}
)
return return_value return return_value
else: else:
return_value = [] return_value = []
for search_result in search_results: for search_result in search_results:
result, context, datasets = search_result result, context, datasets = search_result
return_value.append(result) return_value.append(result)
# For maintaining backwards compatibility
return return_value if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
# return [ # return [
# SearchResult( # SearchResult(
# search_result=result, # search_result=result,