Backwards compatible search (#1381)

<!-- .github/pull_request_template.md -->

## Description
<!-- 
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Changes Made
<!-- List the specific changes made in this PR -->
- 
- 
- 

## Testing
<!-- Describe how you tested your changes -->

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
<!-- Link any related issues using "Fixes #issue_number" or "Relates to
#issue_number" -->

## Additional Notes
<!-- Add any additional notes, concerns, or context for reviewers -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-09-11 12:40:20 -07:00 committed by GitHub
commit 8dd19decbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 53 additions and 30 deletions

View file

@ -1,5 +1,5 @@
from uuid import UUID from uuid import UUID
from typing import Optional, Union, List from typing import Optional, Union, List, Any
from datetime import datetime from datetime import datetime
from pydantic import Field from pydantic import Field
from fastapi import Depends, APIRouter from fastapi import Depends, APIRouter
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error: except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)}) return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult]) @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" """
Search for nodes in the graph database. Search for nodes in the graph database.

View file

@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
""" """

View file

@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
unique_node_connections_map[unique_id] = True unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection) unique_node_connections.append(node_connection)
return [ return unique_node_connections
Edge( # return [
node1=Node(node_id=connection[0]["id"], attributes=connection[0]), # Edge(
node2=Node(node_id=connection[2]["id"], attributes=connection[2]), # node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
attributes={ # node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
**connection[1], # attributes={
"relationship_type": connection[1]["relationship_name"], # **connection[1],
}, # "relationship_type": connection[1]["relationship_name"],
) # },
for connection in unique_node_connections # )
] # for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """

View file

@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
return completion return [completion]

View file

@ -132,14 +132,37 @@ async def search(
], ],
) )
else: else:
return [ # This is for maintaining backwards compatibility
SearchResult( if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
search_result=result, return_value = []
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, for search_result in search_results:
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, result, context, datasets = search_result
) return_value.append(
for index, (result, _, datasets) in enumerate(search_results) {
] "search_result": result,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
}
)
return return_value
else:
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
# return [
# SearchResult(
# search_result=result,
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
# )
# for index, (result, _, datasets) in enumerate(search_results)
# ]
async def authorized_search( async def authorized_search(

View file

@ -79,7 +79,7 @@ async def main():
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "NLP", ( assert search_results[0]["dataset_name"] == "NLP", (
f"Dict must contain dataset name 'NLP': {search_results[0]}" f"Dict must contain dataset name 'NLP': {search_results[0]}"
) )
@ -93,7 +93,7 @@ async def main():
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", ( assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
) )
@ -170,7 +170,7 @@ async def main():
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", ( assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
) )

View file

@ -90,8 +90,7 @@ async def main():
print("Coding rules created by memify:") print("Coding rules created by memify:")
for coding_rule in coding_rules: for coding_rule in coding_rules:
for search_result in coding_rule.search_result: print("- " + coding_rule)
print("- " + search_result)
# Visualize new graph with added memify context # Visualize new graph with added memify context
file_path = os.path.join( file_path = os.path.join(