Backwards compatible search (#1381)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
8dd19decbe
9 changed files with 53 additions and 30 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional, Union, List
|
||||
from typing import Optional, Union, List, Any
|
||||
from datetime import datetime
|
||||
from pydantic import Field
|
||||
from fastapi import Depends, APIRouter
|
||||
|
|
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
|
|||
|
|
@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
|||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return [
|
||||
Edge(
|
||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
attributes={
|
||||
**connection[1],
|
||||
"relationship_type": connection[1]["relationship_name"],
|
||||
},
|
||||
)
|
||||
for connection in unique_node_connections
|
||||
]
|
||||
return unique_node_connections
|
||||
# return [
|
||||
# Edge(
|
||||
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
# attributes={
|
||||
# **connection[1],
|
||||
# "relationship_type": connection[1]["relationship_name"],
|
||||
# },
|
||||
# )
|
||||
# for connection in unique_node_connections
|
||||
# ]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -132,14 +132,37 @@ async def search(
|
|||
],
|
||||
)
|
||||
else:
|
||||
return [
|
||||
SearchResult(
|
||||
search_result=result,
|
||||
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
)
|
||||
for index, (result, _, datasets) in enumerate(search_results)
|
||||
]
|
||||
# This is for maintaining backwards compatibility
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": result,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
else:
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
# return [
|
||||
# SearchResult(
|
||||
# search_result=result,
|
||||
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
# )
|
||||
# for index, (result, _, datasets) in enumerate(search_results)
|
||||
# ]
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0].dataset_name == "NLP", (
|
||||
assert search_results[0]["dataset_name"] == "NLP", (
|
||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,8 +90,7 @@ async def main():
|
|||
|
||||
print("Coding rules created by memify:")
|
||||
for coding_rule in coding_rules:
|
||||
for search_result in coding_rule.search_result:
|
||||
print("- " + search_result)
|
||||
print("- " + coding_rule)
|
||||
|
||||
# Visualize new graph with added memify context
|
||||
file_path = os.path.join(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue