Backwards compatible search (#1381)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
8dd19decbe
9 changed files with 53 additions and 30 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from fastapi import Depends, APIRouter
|
from fastapi import Depends, APIRouter
|
||||||
|
|
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||||
|
|
||||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
|
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||||
"""
|
"""
|
||||||
Search for nodes in the graph database.
|
Search for nodes in the graph database.
|
||||||
|
|
|
||||||
|
|
@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
||||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
||||||
unique_node_connections_map[unique_id] = True
|
unique_node_connections_map[unique_id] = True
|
||||||
unique_node_connections.append(node_connection)
|
unique_node_connections.append(node_connection)
|
||||||
|
|
||||||
return [
|
return unique_node_connections
|
||||||
Edge(
|
# return [
|
||||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
# Edge(
|
||||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||||
attributes={
|
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||||
**connection[1],
|
# attributes={
|
||||||
"relationship_type": connection[1]["relationship_name"],
|
# **connection[1],
|
||||||
},
|
# "relationship_type": connection[1]["relationship_name"],
|
||||||
)
|
# },
|
||||||
for connection in unique_node_connections
|
# )
|
||||||
]
|
# for connection in unique_node_connections
|
||||||
|
# ]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -132,14 +132,37 @@ async def search(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return [
|
# This is for maintaining backwards compatibility
|
||||||
SearchResult(
|
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
search_result=result,
|
return_value = []
|
||||||
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
for search_result in search_results:
|
||||||
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
result, context, datasets = search_result
|
||||||
)
|
return_value.append(
|
||||||
for index, (result, _, datasets) in enumerate(search_results)
|
{
|
||||||
]
|
"search_result": result,
|
||||||
|
"dataset_id": datasets[0].id,
|
||||||
|
"dataset_name": datasets[0].name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return return_value
|
||||||
|
else:
|
||||||
|
return_value = []
|
||||||
|
for search_result in search_results:
|
||||||
|
result, context, datasets = search_result
|
||||||
|
return_value.append(result)
|
||||||
|
# For maintaining backwards compatibility
|
||||||
|
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||||
|
return return_value[0]
|
||||||
|
else:
|
||||||
|
return return_value
|
||||||
|
# return [
|
||||||
|
# SearchResult(
|
||||||
|
# search_result=result,
|
||||||
|
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||||
|
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||||
|
# )
|
||||||
|
# for index, (result, _, datasets) in enumerate(search_results)
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
async def authorized_search(
|
async def authorized_search(
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ async def main():
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
assert search_results[0].dataset_name == "NLP", (
|
assert search_results[0]["dataset_name"] == "NLP", (
|
||||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ async def main():
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
assert search_results[0].dataset_name == "QUANTUM", (
|
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -170,7 +170,7 @@ async def main():
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
assert search_results[0].dataset_name == "QUANTUM", (
|
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,7 @@ async def main():
|
||||||
|
|
||||||
print("Coding rules created by memify:")
|
print("Coding rules created by memify:")
|
||||||
for coding_rule in coding_rules:
|
for coding_rule in coding_rules:
|
||||||
for search_result in coding_rule.search_result:
|
print("- " + coding_rule)
|
||||||
print("- " + search_result)
|
|
||||||
|
|
||||||
# Visualize new graph with added memify context
|
# Visualize new graph with added memify context
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue