Backwards compatible search (#1381)

<!-- .github/pull_request_template.md -->

## Description
<!-- 
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Changes Made
<!-- List the specific changes made in this PR -->
- 
- 
- 

## Testing
<!-- Describe how you tested your changes -->

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
<!-- Link any related issues using "Fixes #issue_number" or "Relates to
#issue_number" -->

## Additional Notes
<!-- Add any additional notes, concerns, or context for reviewers -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-09-11 12:40:20 -07:00 committed by GitHub
commit 8dd19decbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 53 additions and 30 deletions

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Optional, Union, List
from typing import Optional, Union, List, Any
from datetime import datetime
from pydantic import Field
from fastapi import Depends, APIRouter
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult])
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.

View file

@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]

View file

@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]

View file

@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
question=query, answer=completion, context=context_text, triplets=triplets
)
return completion
return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""

View file

@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return [
Edge(
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
attributes={
**connection[1],
"relationship_type": connection[1]["relationship_name"],
},
)
for connection in unique_node_connections
]
return unique_node_connections
# return [
# Edge(
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
# attributes={
# **connection[1],
# "relationship_type": connection[1]["relationship_name"],
# },
# )
# for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""

View file

@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path,
)
return completion
return [completion]

View file

@ -132,14 +132,37 @@ async def search(
],
)
else:
return [
SearchResult(
search_result=result,
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
)
for index, (result, _, datasets) in enumerate(search_results)
]
# This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value.append(
{
"search_result": result,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
}
)
return return_value
else:
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
# return [
# SearchResult(
# search_result=result,
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
# )
# for index, (result, _, datasets) in enumerate(search_results)
# ]
async def authorized_search(

View file

@ -79,7 +79,7 @@ async def main():
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
assert search_results[0].dataset_name == "NLP", (
assert search_results[0]["dataset_name"] == "NLP", (
f"Dict must contain dataset name 'NLP': {search_results[0]}"
)
@ -93,7 +93,7 @@ async def main():
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", (
assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
)
@ -170,7 +170,7 @@ async def main():
for result in search_results:
print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", (
assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
)

View file

@ -90,8 +90,7 @@ async def main():
print("Coding rules created by memify:")
for coding_rule in coding_rules:
for search_result in coding_rule.search_result:
print("- " + search_result)
print("- " + coding_rule)
# Visualize new graph with added memify context
file_path = os.path.join(